├── doc ├── _templates │ └── .dummy ├── _static │ └── img │ │ ├── mpi4torch-logo.png │ │ ├── mpi4torch-logo-extrawhitespace.png │ │ ├── mpi4torch-logo-extrawhitespace.svg │ │ └── mpi4torch-logo.svg ├── glossary.rst ├── api_reference.rst ├── Makefile ├── index.rst ├── conf.py ├── examples.rst └── basic_usage.rst ├── version.txt ├── requirements.txt ├── .gitignore ├── MANIFEST.in ├── pyproject.toml ├── examples ├── isend-recv-wait.py └── simple_linear_regression.py ├── .readthedocs.yaml ├── tests ├── test_mpi4pyinterop.py ├── test_joindummies.py ├── test_nonblocking.py └── test_collectives.py ├── LICENSE ├── .github └── workflows │ └── test.yml ├── README.md ├── setup.py ├── src └── __init__.py └── csrc └── extension.cpp /doc/_templates/.dummy: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.1.3 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | *.egg-info/ 4 | *.egg 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include version.txt 2 | include pyproject.toml 3 | include doc/_static/img/mpi4torch-logo-extrawhitespace.png 4 | -------------------------------------------------------------------------------- /doc/_static/img/mpi4torch-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/helmholtz-analytics/mpi4torch/HEAD/doc/_static/img/mpi4torch-logo.png -------------------------------------------------------------------------------- /doc/_static/img/mpi4torch-logo-extrawhitespace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/helmholtz-analytics/mpi4torch/HEAD/doc/_static/img/mpi4torch-logo-extrawhitespace.png -------------------------------------------------------------------------------- /doc/glossary.rst: -------------------------------------------------------------------------------- 1 | ******************* 2 | Glossary 3 | ******************* 4 | 5 | .. glossary:: 6 | 7 | AD 8 | automatic differentiation 9 | 10 | DAG 11 | directed acyclic graph 12 | 13 | MPI 14 | Message Passing Interface 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # PEP 517 and PEP 518 build requirements 3 | # cf. https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 4 | 5 | # we need torch as a build requirement, since we explicitly import it within setup.py 6 | requires = [ 7 | "torch>=1.9.0", 8 | 'importlib-metadata;python_version<"3.8"', #PEP 508-style dependency: Only install for cpython<3.8 9 | "setuptools>=40.8.0", 10 | "wheel" 11 | ] 12 | build-backend = "setuptools.build_meta" 13 | 14 | -------------------------------------------------------------------------------- /examples/isend-recv-wait.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mpi4torch 3 | 4 | comm = mpi4torch.COMM_WORLD 5 | 6 | a = torch.tensor([1.0 + comm.rank]).requires_grad_() 7 | 8 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0) 9 | recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy]) 10 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0) 11 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b])) 12 | 13 | res = mpi4torch.JoinDummies(a+b, [wait_ret]) 14 | print(res) 15 | 16 | res.backward() 17 | print(a.grad) 18 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: ubuntu-20.04 10 | tools: 11 | python: "3.10" 12 | apt_packages: 13 | - libopenmpi-dev 14 | - openmpi-bin 15 | - graphviz 16 | 17 | # Build documentation in the doc/ directory with Sphinx 18 | sphinx: 19 | configuration: doc/conf.py 20 | 21 | # If using Sphinx, optionally build your docs in additional formats such as PDF 22 | # formats: 23 | # - pdf 24 | 25 | # Python requirements to build the docs 26 | python: 27 | install: 28 | - method: pip 29 | path: . 30 | 31 | -------------------------------------------------------------------------------- /doc/api_reference.rst: -------------------------------------------------------------------------------- 1 | ******************** 2 | API Reference 3 | ******************** 4 | 5 | .. automodule:: mpi4torch 6 | :members: 7 | :undoc-members: 8 | 9 | .. autofunction:: JoinDummies(loopthrough: torch.Tensor, dummies:List[torch.Tensor]) -> torch.Tensor 10 | .. autofunction:: JoinDummiesHandle(handle: mpi4torch.WaitHandle, dummies:List[torch.Tensor]) -> mpi4torch.WaitHandle 11 | .. data:: MPI_MAX 12 | .. data:: MPI_MIN 13 | .. data:: MPI_SUM 14 | .. data:: MPI_PROD 15 | .. data:: MPI_LAND 16 | .. data:: MPI_BAND 17 | .. data:: MPI_LOR 18 | .. data:: MPI_BOR 19 | .. data:: MPI_LXOR 20 | .. data:: MPI_BXOR 21 | .. data:: MPI_MINLOC 22 | .. data:: MPI_MAXLOC 23 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/test_mpi4pyinterop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mpi4torch 3 | import unittest 4 | import mpi4py.MPI as MPI 5 | 6 | class TestMpi4PyInteroperability(unittest.TestCase): 7 | def test_rank_and_size(self): 8 | comm1 = MPI.COMM_WORLD 9 | comm2 = mpi4torch.comm_from_mpi4py(comm1) 10 | self.assertEqual(comm1.rank, comm2.rank) 11 | self.assertEqual(comm1.size, comm2.size) 12 | 13 | def test_simple_allreduce(self): 14 | comm1 = MPI.COMM_WORLD 15 | comm2 = mpi4torch.comm_from_mpi4py(comm1) 16 | tmp = torch.rand(10, dtype=torch.double).requires_grad_() 17 | res = comm2.Allreduce(tmp,mpi4torch.MPI_SUM) 18 | res.sum().backward() 19 | self.assertTrue((tmp.grad == comm2.size * torch.ones(10, dtype=torch.double)).all()) 20 | 21 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. image:: _static/img/mpi4torch-logo-extrawhitespace.svg 2 | 3 | mpi4torch is an automatic-differentiable wrapper of MPI functions for the pytorch tensor library. 4 | 5 | MPI stands for Message Passing Interface and is the de facto standard communication interface on 6 | high-performance computing resources. To facilitate the usage of pytorch on these resources an MPI wrapper 7 | that is transparent to pytorch's automatic differentiation (AD) engine is much in need. This library tries 8 | to bridge this gap. 9 | 10 | .. toctree:: 11 | :maxdepth: 3 12 | :caption: Table of Contents 13 | 14 | basic_usage 15 | examples 16 | api_reference 17 | glossary 18 | 19 | Indices and tables 20 | ================== 21 | 22 | * :ref:`genindex` 23 | * :ref:`modindex` 24 | * :ref:`search` 25 | -------------------------------------------------------------------------------- /tests/test_joindummies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mpi4torch 3 | import unittest 4 | 5 | comm = mpi4torch.COMM_WORLD 6 | 7 | class TestJoinDummies(unittest.TestCase): 8 | def test_simple_allreduce(self): 9 | tmp = torch.rand(10, dtype=torch.double).requires_grad_() 10 | tmp2 = torch.rand(10, dtype=torch.double).requires_grad_() 11 | tmp3 = torch.rand(10, dtype=torch.double).requires_grad_() 12 | res = comm.Allreduce(tmp,mpi4torch.MPI_SUM) 13 | res2 = mpi4torch.JoinDummies(res,[tmp2,tmp3]) 14 | res2.sum().backward() 15 | self.assertTrue((tmp2.grad == torch.zeros(10, dtype=torch.double)).all()) 16 | self.assertTrue((tmp3.grad == torch.zeros(10, dtype=torch.double)).all()) 17 | self.assertTrue((tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all()) 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020-2022 Philipp Knechtges and contributors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 4 | associated documentation files (the "Software"), to deal in the Software without restriction, including 5 | without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 6 | copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the 7 | following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial 10 | portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT 13 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN 14 | NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 15 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 16 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | -------------------------------------------------------------------------------- /examples/simple_linear_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mpi4torch 3 | import mpi4py.MPI 4 | 5 | comm = mpi4torch.COMM_WORLD 6 | 7 | torch.manual_seed(42) 8 | 9 | num_points = 10000 10 | chunk_size = num_points // comm.size 11 | rest = num_points % comm.size 12 | if comm.rank < rest: 13 | chunk_size += 1 14 | offset = chunk_size * comm.rank 15 | else: 16 | offset = chunk_size * comm.rank + rest 17 | 18 | xinput = 2.0 * torch.rand([num_points],dtype=torch.double)[offset:offset+chunk_size] 19 | 20 | def some_parametrized_function(inp, params): 21 | return (params[2] * inp + params[1]) * inp + params[0] 22 | 23 | gen_params = torch.tensor([0.1, 1.0, -2.0]) 24 | 25 | youtput = some_parametrized_function(xinput, gen_params) 26 | 27 | def lossfunction(params): 28 | # average initial params to bring all ranks on the same page 29 | params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size 30 | 31 | # compute local loss 32 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params))) 33 | 34 | # sum up the loss among all ranks 35 | return comm.Allreduce(localloss, mpi4torch.MPI_SUM) 36 | 37 | params = torch.arange(3, dtype=torch.double).requires_grad_() 38 | 39 | # LBFGS only needs one outer iteration for a linear problem 40 | # with so few parameters 41 | num_iterations = 1 42 | optimizer = torch.optim.LBFGS([params], 1) 43 | 44 | for i in range(num_iterations): 45 | def closure(): 46 | loss = lossfunction(params) 47 | optimizer.zero_grad() 48 | loss.backward() 49 | if comm.rank == 0: 50 | print("Params: ", params) 51 | print("Loss : ", loss) 52 | return loss 53 | optimizer.step(closure) 54 | 55 | # only print output on rank 0 56 | if comm.rank == 0: 57 | print("Final parameters: ", params) 58 | -------------------------------------------------------------------------------- /tests/test_nonblocking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mpi4torch 3 | import unittest 4 | 5 | comm = mpi4torch.COMM_WORLD 6 | 7 | class TestNonBlocking(unittest.TestCase): 8 | def test_simple_isendirecv(self): 9 | tmp = torch.rand(10000000, dtype=torch.double).requires_grad_() 10 | req = comm.Isend(tmp,(comm.rank+1)%comm.size,0) 11 | req2 = comm.Irecv(mpi4torch.JoinDummies(torch.empty_like(tmp),[req.dummy]),(comm.rank+comm.size-1)%comm.size,0) 12 | res = comm.Wait(mpi4torch.JoinDummiesHandle(req,[req2.dummy])) 13 | res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req2,[res])) 14 | res3 = res2 * comm.rank 15 | res3.sum().backward() 16 | self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all()) 17 | 18 | def test_simple_isendrecv(self): 19 | tmp = torch.rand(10000000, dtype=torch.double).requires_grad_() 20 | req = comm.Isend(tmp,(comm.rank+1)%comm.size,0) 21 | res = comm.Recv(mpi4torch.JoinDummies(torch.empty_like(tmp),[req.dummy]),(comm.rank+comm.size-1)%comm.size,0) 22 | res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req,[res])) 23 | res3 = mpi4torch.JoinDummies(res,[res2]) * comm.rank 24 | res3.sum().backward() 25 | self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all()) 26 | 27 | def test_simple_irecvsend(self): 28 | tmp = torch.rand(10000000, dtype=torch.double).requires_grad_() 29 | req = comm.Irecv(mpi4torch.JoinDummies(torch.empty_like(tmp),[tmp]),(comm.rank+comm.size-1)%comm.size,0) 30 | res = comm.Send(tmp,(comm.rank+1)%comm.size,0) 31 | res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req,[res])) 32 | res3 = res2 * comm.rank 33 | res3.sum().backward() 34 | self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all()) 35 | 36 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | import sphinx_rtd_theme 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'mpi4torch' 22 | copyright = '2020, Philipp Knechtges' 23 | author = 'Philipp Knechtges' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | with open(os.path.join(os.path.dirname(__file__), '../version.txt'), encoding='utf-8') as filehandle: 27 | release = filehandle.read() 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | 'sphinx.ext.autodoc', 37 | 'sphinx.ext.intersphinx', 38 | 'sphinx_rtd_theme', 39 | 'sphinx.ext.napoleon', 40 | 'sphinx.ext.graphviz', 41 | ] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ['_templates'] 45 | 46 | # List of patterns, relative to source directory, that match files and 47 | # directories to ignore when looking for source files. 48 | # This pattern also affects html_static_path and html_extra_path. 49 | exclude_patterns = [] 50 | 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | 54 | # The theme to use for HTML and HTML Help pages. See the documentation for 55 | # a list of builtin themes. 56 | # 57 | html_theme = 'sphinx_rtd_theme' 58 | 59 | # Add any paths that contain custom static files (such as style sheets) here, 60 | # relative to this directory. They are copied after the builtin static files, 61 | # so a file named "default.css" will overwrite the builtin "default.css". 62 | html_static_path = ['_static'] 63 | 64 | 65 | # -- Extension configuration ------------------------------------------------- 66 | 67 | # -- Options for intersphinx extension --------------------------------------- 68 | 69 | # Example configuration for intersphinx: refer to the Python standard library. 70 | intersphinx_mapping = {'https://docs.python.org/3/': None, 71 | 'https://pytorch.org/docs/stable/': None, 72 | } 73 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | tests: 9 | name: ${{ matrix.os }} cpy${{ matrix.python }} pytorch-${{ matrix.pytorch }} ${{ matrix.mpi }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | #os: [ "ubuntu-20.04", "ubuntu-22.04" ] 15 | os: [ "ubuntu-20.04" ] 16 | python: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] 17 | pytorch: [ "1.9.1", "1.10.2", "1.11.0", "1.12.1", "1.13.1", "2.0.0" ] 18 | mpi: [ "openmpi", "mpich" ] 19 | exclude: 20 | - python: 3.7 21 | pytorch: 2.0.0 22 | - python: 3.10 23 | pytorch: 1.8.1 24 | - python: 3.10 25 | pytorch: 1.9.1 26 | - python: 3.10 27 | pytorch: 1.10.2 28 | - python: 3.11 29 | pytorch: 1.8.1 30 | - python: 3.11 31 | pytorch: 1.9.1 32 | - python: 3.11 33 | pytorch: 1.10.2 34 | - python: 3.11 35 | pytorch: 1.11.0 36 | - python: 3.11 37 | pytorch: 1.12.1 38 | steps: 39 | - name: Checkout 40 | uses: actions/checkout@v3 41 | - name: Install CPython ${{ matrix.python }} 42 | uses: actions/setup-python@v4 43 | with: 44 | python-version: "${{ matrix.python }}" 45 | architecture: x64 46 | - name: Install MPI ${{ matrix.mpi }} 47 | run: | 48 | if [[ "${{ matrix.mpi }}" == "openmpi" ]]; then 49 | sudo apt install libopenmpi-dev openmpi-bin 50 | elif [[ "${{ matrix.mpi }}" == "mpich" ]]; then 51 | sudo apt install mpich 52 | else 53 | exit 1 54 | fi 55 | mpirun --version 56 | - name: Setup virtual environment 57 | run: | 58 | python -m venv venv 59 | - name: Install mpi4torch 60 | run: | 61 | . venv/bin/activate 62 | echo "torch==${{ matrix.pytorch }}" >> constraints.txt 63 | PIP_CONSTRAINT="constraints.txt" pip install -v . nose2 mpi4py 64 | - name: Run Tests (np=2) 65 | run: | 66 | . venv/bin/activate 67 | #mpirun -np 2 python -c 'from mpi4torch import COMM_WORLD; print("Communicator size:", COMM_WORLD.size) if COMM_WORLD.rank == 0 else exit(0)' 68 | mpirun -np 2 nose2 69 | - name: Run Tests (np=5) 70 | run: | 71 | . venv/bin/activate 72 | if [[ "${{ matrix.mpi }}" == "openmpi" ]]; then 73 | mpirun -np 5 --oversubscribe nose2 74 | else 75 | mpirun -np 5 nose2 76 | fi 77 | - name: Run Tests (np=7) 78 | run: | 79 | . venv/bin/activate 80 | if [[ "${{ matrix.mpi }}" == "openmpi" ]]; then 81 | mpirun -np 7 --oversubscribe nose2 82 | else 83 | mpirun -np 7 nose2 84 | fi 85 | 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![mpi4torch Logo](./doc/_static/img/mpi4torch-logo-extrawhitespace.png) 2 | 3 | -------------------------------------------------------------------------------- 4 | 5 | mpi4torch is an automatic-differentiable wrapper of MPI functions for the pytorch tensor library. 6 | 7 | MPI stands for Message Passing Interface and is the de facto standard communication interface on 8 | high-performance computing resources. To facilitate the usage of pytorch on these resources an MPI wrapper 9 | that is transparent to pytorch's automatic differentiation (AD) engine is much in need. This library tries 10 | to bridge this gap. 11 | 12 | # Installation 13 | 14 | mpi4torch is also hosted on PyPI. However, due to the ABI-incompatibility of the different MPI implementations it 15 | is not provided as a binary wheel and needs to be built locally. Hence, you should have an appropriate C++ compiler 16 | installed, as well as the **development files of your MPI library** be present. The latter are usually provided 17 | through the *module system* of your local cluster, and you should consult the manuals of your cluster for this, 18 | or through the package manager of your Linux distribution. 19 | 20 | Once the dependencies have been satisfied the installation can be triggered by the usual 21 | ``` 22 | pip install mpi4torch 23 | ``` 24 | 25 | # Usage 26 | 27 | It is **highly advised** to first read [the basic usage chapter of the documentation](https://mpi4torch.readthedocs.io/en/latest/basic_usage.html) 28 | before jumping into action, since there are some implications of the pytorch AD design on the usage of mpi4torch. 29 | In other words, there are some footguns lurking! 30 | 31 | You have been warned, but if you insist on an easy usage example, consider the following code snippet, 32 | which is an excerpt from [examples/simple_linear_regression.py](examples/simple_linear_regression.py) 33 | 34 | ```python 35 | comm = mpi4torch.COMM_WORLD 36 | 37 | def lossfunction(params): 38 | # average initial params to bring all ranks on the same page 39 | params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size 40 | 41 | # compute local loss 42 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params))) 43 | 44 | # sum up the loss among all ranks 45 | return comm.Allreduce(localloss, mpi4torch.MPI_SUM) 46 | ``` 47 | 48 | Here we have parallelized a loss function simply by adding two calls to `Allreduce`. For a more thorough 49 | discussion of the example see [here](https://mpi4torch.readthedocs.io/en/latest/examples.html#simple-data-parallel-example). 50 | 51 | # Tests 52 | 53 | Running tests is as easy as 54 | ``` 55 | mpirun -np 2 nose2 56 | ``` 57 | 58 | # Project Status 59 | 60 | [![Tests](https://github.com/helmholtz-analytics/mpi4torch/actions/workflows/test.yml/badge.svg?branch=master)](https://github.com/helmholtz-analytics/mpi4torch/actions/workflows/test.yml) 61 | [![Documentation Status](https://readthedocs.org/projects/mpi4torch/badge/?version=latest)](https://mpi4torch.readthedocs.io/en/latest/?badge=latest) 62 | -------------------------------------------------------------------------------- /doc/examples.rst: -------------------------------------------------------------------------------- 1 | ******************* 2 | Examples 3 | ******************* 4 | 5 | Simple data parallel example 6 | ============================ 7 | 8 | Let us assume we are in typical supervised learning situation. We have plenty of data (``xinput``, ``youtput``), 9 | and we search for unknown parameters minimizing some norm or general functional, simply referred to as the 10 | loss function. Furthermore, we assume that the 11 | loss function is just a summation of losses per data point. E.g. consider the following squared error: 12 | 13 | .. code-block:: python 14 | 15 | def lossfunction(params): 16 | # compute local loss 17 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params))) 18 | return localloss 19 | 20 | This function is usually fed into an gradient-based optimizer to find the optimal parameters. 21 | We want to argue in the following that parallelizing this code in a data-parallel way is often as easy as 22 | adding two calls to :py:meth:`mpi4torch.MPI_Communicator.Allreduce`: 23 | 24 | .. code-block:: python 25 | :emphasize-lines: 3,9 26 | 27 | def lossfunction(params): 28 | # average initial params to bring all ranks on the same page 29 | params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size 30 | 31 | # compute local loss 32 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params))) 33 | 34 | # sum up the loss among all ranks 35 | return comm.Allreduce(localloss, mpi4torch.MPI_SUM) 36 | 37 | :py:meth:`mpi4torch.MPI_Communicator.Allreduce` is used once to compute the average of the incoming parameters 38 | and once to collect the total loss. 39 | 40 | Embedded in a whole program this may look like (the code is also available in the git repository 41 | in the examples folder): 42 | 43 | .. literalinclude:: ../examples/simple_linear_regression.py 44 | :linenos: 45 | 46 | Note that although the averaging in line 29 might seem superfluous at first --- since all ranks start 47 | off with the same initial set of parameters --- having the adjoint of :py:meth:`mpi4torch.MPI_Communicator.Allreduce` 48 | in the backward pass 49 | is essential for all instances of the LBFGS optimizer to perform the same update on all ranks. 50 | 51 | For the second call to :py:meth:`mpi4torch.MPI_Communicator.Allreduce` in line 35 it is actually the other way 52 | around: Here the forward pass is crucial, but the backward pass merely adds up the ones coming from the 53 | different ranks, which (surprise) results in a vector of length 1 that just contains the communicator size. 54 | 55 | It is easy to see that the forward pass is indpendent of the number of ranks used to compute the result. That 56 | the parallelized backward pass also gives the same result may at first seem a bit surprising, 57 | as we already saw that the gradient with respect to ``localloss`` will just store the size 58 | of the MPI communicator. However, 59 | the corresponding backward code of the averaging in line 29 divides again through ``comm.size``, such that 60 | in total all gradients from all ranks are simply added up. The final 61 | gradient as stored in ``params.grad`` is thus also independent of the number of processes. 62 | 63 | Starting off with the same 64 | parameters on all ranks, it is thus ensured that all local LBFGS instances see the same parameters, the same losses 65 | and the same gradients, and thus perform the identical operations and give the same result. 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import CppExtension, BuildExtension 4 | import torch 5 | import copy 6 | import os 7 | import sys 8 | 9 | try: 10 | # importlib only got added in cpython 3.8 11 | if sys.version_info >= (3, 8): 12 | from importlib import metadata as importlib_metadata 13 | else: 14 | # cf. pyproject.toml file, which makes pip install importlib-metadata if necessary 15 | import importlib_metadata 16 | 17 | torchversion = importlib_metadata.distribution("torch").version 18 | except: 19 | # Fallback, but this should never happen. 20 | torchversion = torch.__version__.split("+")[0] 21 | 22 | class MpiBuildExtension(BuildExtension): 23 | def __init__(self, *args,**kwargs): 24 | super(MpiBuildExtension,self).__init__(*args,use_ninja=False,**kwargs) 25 | 26 | def build_extensions(self): 27 | """ 28 | This code makes a lot assumptions on distutils internal implementation of 29 | UnixCCompiler class. However, it seems to be standard to make these assumptions, 30 | as PyTorch and mpi4py also make these assumptions. 31 | 32 | TODO: Obviously this only works for unix systems 33 | """ 34 | 35 | # Save original compiler and reset it later on 36 | original_compiler = self.compiler.compiler_so 37 | new_compiler = copy.deepcopy(original_compiler) 38 | new_compiler[0] = 'mpicc' 39 | # Save original CXX compiler and reset it later on 40 | 41 | # distutils' UnixCCompiler likes to use the C++ compiler for linking, so we set it manually 42 | original_cxx_compiler = self.compiler.compiler_cxx 43 | new_cxx_compiler = copy.deepcopy(original_cxx_compiler) 44 | new_cxx_compiler[0] = 'mpicxx' 45 | # Save original linker and reset it later on 46 | # should not be used, but we set it anyway 47 | original_linker = self.compiler.linker_so 48 | new_linker = copy.deepcopy(original_linker) 49 | new_linker[0] = 'mpicc' 50 | try: 51 | self.compiler.set_executable('compiler_so', new_compiler) 52 | self.compiler.set_executable('compiler_cxx', new_cxx_compiler) 53 | self.compiler.set_executable('linker_so', new_linker) 54 | BuildExtension.build_extensions(self) 55 | finally: 56 | self.compiler.set_executable('compiler_so', original_compiler) 57 | self.compiler.set_executable('compiler_cxx', original_cxx_compiler) 58 | self.compiler.set_executable('linker_so', original_linker) 59 | 60 | with open(os.path.join(os.path.dirname(__file__), 'README.md'), encoding='utf-8') as filehandle: 61 | long_description = filehandle.read() 62 | 63 | with open(os.path.join(os.path.dirname(__file__), 'version.txt'), encoding='utf-8') as filehandle: 64 | versiontext = filehandle.read().rstrip() 65 | 66 | setup( 67 | name='mpi4torch', 68 | version=versiontext, 69 | description='AD-compatible implementation of several MPI functions for pytorch tensors', 70 | author='Philipp Knechtges', 71 | author_email='philipp.knechtges@dlr.de', 72 | long_description=long_description, 73 | long_description_content_type='text/markdown', 74 | classifiers=[ 75 | "Development Status :: 3 - Alpha", 76 | "Programming Language :: Python :: 3", 77 | "Programming Language :: Python :: 3.7", 78 | "Programming Language :: Python :: 3.8", 79 | "Programming Language :: Python :: 3.9", 80 | "Programming Language :: Python :: 3.10", 81 | "Programming Language :: Python :: 3.11", 82 | "Programming Language :: C++", 83 | "License :: OSI Approved :: MIT License", 84 | "Operating System :: POSIX :: Linux", 85 | "Intended Audience :: Science/Research", 86 | "Topic :: Scientific/Engineering", 87 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 88 | "Topic :: Software Development :: Libraries :: Python Modules" 89 | ], 90 | package_dir = {'mpi4torch': 'src'}, 91 | packages = ['mpi4torch'], 92 | ext_modules=[ 93 | CppExtension( 94 | name='mpi4torch._mpi', 95 | sources=['csrc/extension.cpp'], 96 | extra_compile_args=['-g']), 97 | ], 98 | cmdclass={ 99 | 'build_ext': MpiBuildExtension 100 | }, 101 | install_requires=[ 102 | # Pin the required pytorch version of the final binary wheels 103 | # to the pytorch version used at build-time. This way we 104 | # avoid possible ABI-incompatibilities. 105 | 'torch==' + torchversion, 106 | ] 107 | ) 108 | -------------------------------------------------------------------------------- /doc/_static/img/mpi4torch-logo-extrawhitespace.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/_static/img/mpi4torch-logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_collectives.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mpi4torch 3 | import unittest 4 | 5 | comm = mpi4torch.COMM_WORLD 6 | 7 | class TestAllreduce(unittest.TestCase): 8 | def test_simple(self): 9 | tmp = torch.rand(10, dtype=torch.double).requires_grad_() 10 | res = comm.Allreduce(tmp,mpi4torch.MPI_SUM) 11 | res.sum().backward() 12 | self.assertTrue((tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all()) 13 | 14 | def test_torchscript(self): 15 | tmp = torch.rand(10, dtype=torch.double).requires_grad_() 16 | @torch.jit.script 17 | def myfunc(x,comm_: mpi4torch.MPI_Communicator): 18 | return comm_.Allreduce(x,mpi4torch.MPI_SUM) 19 | res = myfunc(tmp,comm) 20 | res.sum().backward() 21 | self.assertTrue((tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all()) 22 | 23 | class TestReduce(unittest.TestCase): 24 | def test_simple_inplace(self): 25 | tmp = torch.rand(10, dtype=torch.double).requires_grad_() 26 | res = comm.Reduce_(tmp,mpi4torch.MPI_SUM,0) 27 | res.sum().backward() 28 | self.assertTrue((tmp.grad == torch.ones(10,dtype=torch.double)).all()) 29 | 30 | def test_noinplace_exception(self): 31 | # the 0. addition is just to make the resulting tmp variable not a leaf in the DAG, 32 | # sine we currently dont see a way to add this safeguard for leaf nodes. 33 | tmp = 0. + torch.rand(10, dtype=torch.double).requires_grad_() 34 | res = tmp + comm.Reduce_(tmp,mpi4torch.MPI_SUM,0) 35 | with self.assertRaises(RuntimeError): 36 | res.sum().backward() 37 | 38 | class TestBcast(unittest.TestCase): 39 | def test_simple_inplace(self): 40 | tmp = torch.rand(10, dtype=torch.double).requires_grad_() 41 | res = comm.Bcast_(tmp,0) 42 | res.sum().backward() 43 | if comm.rank == 0: 44 | self.assertTrue((tmp.grad == comm.size * torch.ones(10,dtype=torch.double)).all()) 45 | else: 46 | self.assertTrue((tmp.grad == torch.zeros(10,dtype=torch.double)).all()) 47 | 48 | class TestGather(unittest.TestCase): 49 | def test_basic_functionality(self): 50 | numdim = 4 51 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double) 52 | tmp[0,0,:,0,0] = comm.rank 53 | res = comm.Gather(tmp, 2, 0) 54 | if comm.rank == 0: 55 | tmp2 = torch.squeeze(torch.sum(res[0,0,:,0,0])) 56 | self.assertTrue((tmp2 == numdim * (comm.size - 1) * comm.size // 2).all()) 57 | 58 | def test_basic_ad(self): 59 | numdim = 4 60 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double).requires_grad_() 61 | res = comm.Gather(tmp, 2, 0) 62 | res.sum().backward() 63 | self.assertTrue((tmp.grad == torch.ones_like(tmp)).all()) 64 | 65 | class TestAllgather(unittest.TestCase): 66 | def test_basic_functionality(self): 67 | numdim = 4 68 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double) 69 | tmp[0,0,:,0,0] = comm.rank 70 | res = comm.Allgather(tmp, 2) 71 | tmp2 = torch.squeeze(torch.sum(res[0,0,:,0,0])) 72 | self.assertTrue((tmp2 == numdim * (comm.size - 1) * comm.size // 2).all()) 73 | 74 | def test_basic_ad(self): 75 | numdim = 4 76 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double).requires_grad_() 77 | res = comm.Allgather(tmp, 2) 78 | res.sum().backward() 79 | self.assertTrue((tmp.grad == comm.size * torch.ones_like(tmp)).all()) 80 | 81 | class TestScatter(unittest.TestCase): 82 | def test_basic_functionality(self): 83 | if comm.rank == 0: 84 | tmp = torch.rand([2,5,comm.size,2,3],dtype=torch.double) 85 | for i in range(comm.size): 86 | tmp[0,0,i,0,0] = i 87 | else: 88 | tmp = torch.rand([1],dtype=torch.double) 89 | res = comm.Scatter(tmp, 2, 1, 0) 90 | self.assertTrue((res[0,0,:,0,0] == comm.rank).all()) 91 | 92 | def test_scattergather(self): 93 | if comm.rank == 0: 94 | tmp = torch.rand([2,5,comm.size,2,3],dtype=torch.double) 95 | else: 96 | tmp = torch.rand([1],dtype=torch.double) 97 | res = comm.Scatter(tmp, 2, 1, 0) 98 | res2 = comm.Gather(res, 2, 0) 99 | if comm.rank == 0: 100 | self.assertTrue((res2 == tmp).all()) 101 | 102 | def test_basic_ad(self): 103 | if comm.rank == 0: 104 | tmp = torch.rand([2,5,comm.size,2,3],dtype=torch.double).requires_grad_() 105 | else: 106 | tmp = torch.rand([1],dtype=torch.double).requires_grad_() 107 | res = comm.Scatter(tmp, 2, 1, 0) 108 | res.sum().backward() 109 | if comm.rank == 0: 110 | self.assertTrue((tmp.grad == torch.ones_like(tmp)).all()) 111 | else: 112 | self.assertTrue((tmp.grad == torch.zeros_like(tmp)).all()) 113 | 114 | class TestAlltoall(unittest.TestCase): 115 | def test_gatherscatter_equivalence(self): 116 | tmp = torch.rand([3,4,1,4,comm.size,2],dtype=torch.double) 117 | res1 = comm.Scatter(comm.Gather(tmp,2,0),4,1,0) 118 | res2 = comm.Alltoall(tmp,2,4,1) 119 | self.assertTrue((res2 == res1).all()) 120 | 121 | def test_gatherscatter_equivalence_varying_numelem(self): 122 | tmp = torch.rand([3,4,comm.rank+1,4,comm.size*(comm.size+1)//2,2],dtype=torch.double) 123 | res1 = comm.Scatter(comm.Gather(tmp,2,0),4,comm.rank+1,0) 124 | res2 = comm.Alltoall(tmp,2,4,comm.rank+1) 125 | self.assertTrue((res2 == res1).all()) 126 | 127 | def test_gatheraxis_scatteraxis_equal(self): 128 | tmp = torch.rand([3,4,comm.rank+1,2],dtype=torch.double) 129 | tmp[0,0,:,0] = torch.arange(comm.rank*(comm.rank+1)//2, (comm.rank+1)*(comm.rank+2)//2) 130 | res = comm.Alltoall(tmp,2,2,comm.size-comm.rank) 131 | total_numelem = comm.size*(comm.size+1)//2 132 | correct_res = torch.arange(total_numelem - (comm.size-comm.rank)*(comm.size-comm.rank+1)//2, 133 | total_numelem - (comm.size-comm.rank-1)*(comm.size-comm.rank)//2, 134 | dtype=torch.double) 135 | self.assertTrue((res[0,0,:,0] == correct_res).all()) 136 | 137 | def test_identity_equivalence(self): 138 | tmp = torch.rand([3,4,2,4,3*comm.size,2],dtype=torch.double) 139 | res = comm.Alltoall(tmp,2,4,3) 140 | res2 = comm.Alltoall(res,4,2,2) 141 | self.assertTrue((res2 == tmp).all()) 142 | 143 | def test_basic_ad(self): 144 | tmp = torch.rand([3,4,2,4,comm.size,2],dtype=torch.double).requires_grad_() 145 | res = comm.Alltoall(tmp,2,4,1) 146 | res.sum().backward() 147 | self.assertTrue((tmp.grad == torch.ones_like(tmp)).all()) 148 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._mpi import * 3 | from typing import List 4 | 5 | __all__ = [ 6 | "MPI_MAX", 7 | "MPI_MIN", 8 | "MPI_SUM", 9 | "MPI_PROD", 10 | "MPI_LAND", 11 | "MPI_BAND", 12 | "MPI_LOR", 13 | "MPI_BOR", 14 | "MPI_LXOR", 15 | "MPI_BXOR", 16 | "MPI_MINLOC", 17 | "MPI_MAXLOC", 18 | "WaitHandle", 19 | "JoinDummies", 20 | "JoinDummiesHandle", 21 | "MPI_Communicator", 22 | "COMM_WORLD", 23 | "comm_from_mpi4py", 24 | "deactivate_cuda_aware_mpi_support" 25 | ] 26 | 27 | @torch.jit.script 28 | class WaitHandle: 29 | """Class representing a wait handle, as they are returned from one of the non-blocking MPI calls.""" 30 | 31 | def __init__(self, raw_handle: List[torch.Tensor]): 32 | self._handle = raw_handle 33 | 34 | @property 35 | def dummy(self): 36 | """A dummy variable that allows for the usage of the ``WaitHandle`` as one of the 37 | second arguments of :py:func:`mpi4torch.JoinDummies` and :py:func:`mpi4torch.JoinDummiesHandle`. 38 | """ 39 | 40 | return self._handle[0] 41 | 42 | @torch.jit.script 43 | def JoinDummies(loopthrough: torch.Tensor, dummies:List[torch.Tensor]) -> torch.Tensor: 44 | """This function joins multiple dummy dependencies with the DAG. 45 | 46 | From the perspective of the forward pass, this function is mostly a no-op, since it simply 47 | loops through its first argument, and discards the ``dummies`` argument. 48 | 49 | However, for the backward pass, the AD engine still considers the ``dummies`` as actual 50 | dependencies. The main use of this function is thus to manually encode dependencies 51 | that the AD engine does not see on its own. See also the introductory text in 52 | the :ref:`section_implications_mpi4torch` section on how to use this function. 53 | 54 | Parameters 55 | ---------- 56 | loopthrough: 57 | Variable to pass through. 58 | dummies: 59 | List of tensors that are added as dummy dependencies to the DAG. 60 | 61 | Returns 62 | ------- 63 | :py:class:`torch.tensor`: 64 | Tensor that is a shallow copy of ``loopthrough``, but whose ``grad_fn`` 65 | is ``JoinDummiesBackward``. 66 | """ 67 | return torch.ops.mpi4torch.JoinDummies(loopthrough, dummies) 68 | 69 | @torch.jit.script 70 | def JoinDummiesHandle(handle: WaitHandle, dummies:List[torch.Tensor]) -> WaitHandle: 71 | """This function has the same purpose as :py:func:`JoinDummies`, but accepts :py:class:`mpi4torch.WaitHandle` 72 | as its first argument. 73 | 74 | Parameters 75 | ---------- 76 | handle: 77 | :py:class:`mpi4torch.WaitHandle` to pass through. 78 | dummies: 79 | List of tensors that are added as dummy dependencies to the DAG. 80 | 81 | Returns 82 | ------- 83 | :py:class:`mpi4torch.WaitHandle`: 84 | A wait handle with the additional dummy dependenices added. 85 | """ 86 | raw_handle = handle._handle 87 | return WaitHandle([ torch.ops.mpi4torch.JoinDummies(raw_handle[0], dummies), raw_handle[1], raw_handle[2] ]) 88 | 89 | @torch.jit.script 90 | class MPI_Communicator: 91 | """MPI communicator wrapper class 92 | 93 | The only supported ways to construct an ``MPI_Communicator`` are currently either through :py:const:`mpi4torch.COMM_WORLD` or 94 | :py:func:`mpi4torch.comm_from_mpi4py`. 95 | 96 | Note 97 | ---- 98 | All methods with an underscore suffix are in-place operations. 99 | """ 100 | 101 | def __init__(self, comm: torch.classes.mpi4torch.MPI_Comm_Wrapper): 102 | self._comm = comm 103 | 104 | @property 105 | def rank(self) -> int: 106 | """The rank or identification number of the local process with respect to this communicator. 107 | 108 | The processes participating in a communicator are consecutively given ranks 109 | in the interval [0, :py:attr:`mpi4torch.MPI_Communicator.size` - 1]. 110 | """ 111 | return self._comm.GetRank() 112 | 113 | @property 114 | def size(self) -> int: 115 | """The size of the MPI communicator, i.e. the number of processes involved.""" 116 | return self._comm.GetSize() 117 | 118 | # This is currently not supported by torch.jit.script: 119 | #def __getattr__(self, attrName): 120 | # if attrName in self.__dict__["_comm"]._method_names(): 121 | # return self.__dict__["_comm"].__getattr__(attrName) 122 | # return self.__dict__[attrName] 123 | # So we need to write out every function by hand 124 | 125 | def Allreduce(self, tensor: torch.Tensor, op: int) -> torch.Tensor: 126 | """Combines values from all processes and distributes the result back to all processes. 127 | 128 | The combination operation is performed element-wise on the tensor. 129 | 130 | This is the wrapper function of `MPI_Allreduce `_. 131 | 132 | Parameters 133 | ---------- 134 | tensor: 135 | :py:class:`torch.Tensor` that shall be combined. It needs to have the same shape on all processes. 136 | op: 137 | Operation to combine the results. Only supported operations are :py:const:`mpi4torch.MPI_MAX`, 138 | :py:const:`mpi4torch.MPI_MIN`, :py:const:`mpi4torch.MPI_SUM`, :py:const:`mpi4torch.MPI_PROD`, :py:const:`mpi4torch.MPI_LAND`, 139 | :py:const:`mpi4torch.MPI_BAND`, :py:const:`mpi4torch.MPI_LOR`, :py:const:`mpi4torch.MPI_BOR`, :py:const:`mpi4torch.MPI_LXOR`, 140 | :py:const:`mpi4torch.MPI_BXOR`, :py:const:`mpi4torch.MPI_MINLOC`, 141 | :py:const:`mpi4torch.MPI_MAXLOC` 142 | 143 | Returns 144 | ------- 145 | :py:class:`torch.Tensor`: 146 | Combined tensor of the same shape as the input `tensor`. 147 | 148 | Note 149 | ---- 150 | Only :py:const:`mpi4torch.MPI_SUM` is supported in the backwards pass at the moment. 151 | """ 152 | return self._comm.Allreduce(tensor, op) 153 | 154 | def Bcast_(self, tensor: torch.Tensor, root: int) -> torch.Tensor: 155 | """Broadcasts a tensor from the `root` process to all other processes. 156 | 157 | This is an in-place operation. 158 | 159 | This is the wrapper function of `MPI_Bcast `_. 160 | 161 | Parameters 162 | ---------- 163 | tensor: 164 | :py:class:`torch.Tensor` that shall be broadcasted. The tensor needs to have the same shape on all processes, 165 | since it is an in-place operation. 166 | root: 167 | The root process, whose tensor shall be broadcasted to the others. 168 | 169 | Returns 170 | ------- 171 | :py:class:`torch.Tensor`: 172 | For `rank == root` this is the same as the input `tensor`. For all other processes this is the input `tensor` filled with the content 173 | from the `root` process. 174 | """ 175 | return self._comm.Bcast_(tensor, root) 176 | 177 | def Reduce_(self, tensor: torch.Tensor, op: int, root: int) -> torch.Tensor: 178 | """Reduces multiple tensors of the same shape, scattered over all processes, to a single tensor of the same shape stored on the `root` process. 179 | 180 | The combination operation is performed element-wise on the tensor. 181 | 182 | This is an in-place operation. 183 | 184 | This is the wrapper function of `MPI_Reduce `_. 185 | 186 | Parameters 187 | ---------- 188 | tensor: 189 | :py:class:`torch.Tensor` that shall be reduced. The tensor needs to have the same shape on all processes, 190 | since it is an element-wise operation. 191 | op: 192 | Operation to combine the results. Only supported operations are :py:const:`mpi4torch.MPI_MAX`, 193 | :py:const:`mpi4torch.MPI_MIN`, :py:const:`mpi4torch.MPI_SUM`, :py:const:`mpi4torch.MPI_PROD`, :py:const:`mpi4torch.MPI_LAND`, 194 | :py:const:`mpi4torch.MPI_BAND`, :py:const:`mpi4torch.MPI_LOR`, :py:const:`mpi4torch.MPI_BOR`, :py:const:`mpi4torch.MPI_LXOR`, 195 | :py:const:`mpi4torch.MPI_BXOR`, :py:const:`mpi4torch.MPI_MINLOC`, 196 | :py:const:`mpi4torch.MPI_MAXLOC` 197 | root: 198 | The root process, where the resulting tensor shall be gathered. 199 | 200 | Returns 201 | ------- 202 | :py:class:`torch.Tensor`: 203 | For `rank == root` the result stores the reduced tensor. For all other processes the content of the resulting tensor is undefined, 204 | with the exception that the result shall still suffice as input for the second argument of :py:func:`mpi4torch.JoinDummies`. 205 | 206 | Note 207 | ---- 208 | Only :py:const:`mpi4torch.MPI_SUM` is supported in the backwards pass at the moment. 209 | """ 210 | return self._comm.Reduce_(tensor, op, root) 211 | 212 | def Gather(self, tensor: torch.Tensor, gatheraxis: int, root: int) -> torch.Tensor: 213 | return self._comm.Gather(tensor, gatheraxis, root) 214 | 215 | def Allgather(self, tensor: torch.Tensor, gatheraxis: int) -> torch.Tensor: 216 | return self._comm.Allgather(tensor, gatheraxis) 217 | 218 | def Scatter(self, tensor: torch.Tensor, scatteraxis: int, numelem: int, root: int) -> torch.Tensor: 219 | return self._comm.Scatter(tensor, scatteraxis, numelem, root) 220 | 221 | def Alltoall(self, tensor: torch.Tensor, gatheraxis: int, scatteraxis: int, 222 | numelem: int) -> torch.Tensor: 223 | return self._comm.Alltoall(tensor, gatheraxis, scatteraxis, numelem) 224 | 225 | def Isend(self, tensor: torch.Tensor, dest: int, tag: int) -> WaitHandle: 226 | return WaitHandle(self._comm.Isend(tensor, dest, tag)) 227 | 228 | def Irecv(self, tensor: torch.Tensor, source: int, tag: int) -> WaitHandle: 229 | return WaitHandle(self._comm.Irecv(tensor, source, tag)) 230 | 231 | def Wait(self, waithandle: WaitHandle) -> torch.Tensor: 232 | return self._comm.Wait(waithandle._handle) 233 | 234 | def Send(self, tensor: torch.Tensor, dest: int, tag: int) -> torch.Tensor: 235 | handle = self._comm.Isend(tensor, dest, tag) 236 | return self._comm.Wait(handle) 237 | 238 | def Recv(self, tensor: torch.Tensor, source: int, tag: int) -> torch.Tensor: 239 | handle = self._comm.Irecv(tensor, source, tag) 240 | return self._comm.Wait(handle) 241 | 242 | COMM_WORLD = MPI_Communicator(torch.ops.mpi4torch.COMM_WORLD()) 243 | """ 244 | World communicator ``MPI_COMM_WORLD``. 245 | """ 246 | 247 | try: 248 | from mpi4py import MPI as __mpi4py_MPI 249 | 250 | def comm_from_mpi4py(comm: __mpi4py_MPI.Comm) -> MPI_Communicator: 251 | """Converts an ``mpi4py`` communicator to an :py:class:`mpi4torch.MPI_Communicator`. 252 | """ 253 | 254 | fortran_handle = comm.py2f(); 255 | return MPI_Communicator(torch.ops.mpi4torch.comm_from_fortran(fortran_handle)) 256 | except ModuleNotFoundError: 257 | def comm_from_mpi4py(comm) -> MPI_Communicator: 258 | """Converts an ``mpi4py`` communicator to an :py:class:`mpi4torch.MPI_Communicator`. 259 | """ 260 | 261 | raise RuntimeError("mpi4py is not available!") 262 | -------------------------------------------------------------------------------- /doc/basic_usage.rst: -------------------------------------------------------------------------------- 1 | ******************** 2 | Basic Usage 3 | ******************** 4 | 5 | In the following we are going to discuss the different options and caveats that come with mixing MPI 6 | and pytorch's automatic differentiation (AD) functionality, and what consequences this has for using 7 | the mpi4torch library. 8 | 9 | Note that although we will within this document mostly talk about the interplay of mpi4torch with pytorch's AD, 10 | this does not mean that mpi4torch could not in principle be used as one might expect coming from other 11 | MPI libraries. The main difference, 12 | however, is that if one plans to use mpi4torch as a building block in some automatic differentiable code, 13 | the usage of mpi4torch actually differs a lot to these "classical" programming paradigms. 14 | It is thus *highly* recommended 15 | for everybody to read this document before, e.g., literally translating MPI calls to mpi4torch. 16 | 17 | How pytorch's AD works 18 | ====================== 19 | 20 | Since it is important for what follows, we start with a quick reminder on how the AD engine in 21 | pytorch is used. Consider the following code 22 | 23 | .. code-block:: python 24 | 25 | import torch 26 | 27 | a = torch.tensor([0.5]).requires_grad_() 28 | b = torch.exp(a) 29 | b.backward() 30 | assert(a.grad == b) 31 | 32 | This code simply computes the derivative of the function :math:`f(x) = e^x` at the point :math:`x=0.5`. 33 | In the code we do so by initializing a torch tensor ``a`` that has the flag ``requires_grad = True`` set, 34 | which we do here by calling the ``requires_grad_()`` method. This flag is in some sense contagious: Allmost 35 | all torch functions that are called with ``a`` as their argument, pass this flag also to their output. In 36 | the example above this is the exponential function, which returns a tensor ``b`` that has also this flag set. 37 | In addition to this flag ``b`` comes with the info that it was computed from ``a`` and 38 | a property, which is called the gradient function ``grad_fn``. 39 | This is the function that tells pytorch what to do in the backward automatic differentiation pass. 40 | 41 | To illustrate this a bit more, consider the following directed acyclic graph (DAG) that represents the 42 | computational flow in the forward phase: 43 | 44 | .. graphviz:: 45 | :caption: Forward DAG 46 | :align: center 47 | 48 | digraph forwarddag { 49 | exp [shape=rectangle]; 50 | "a" -> exp -> "b"; 51 | } 52 | 53 | What now happens in the backward pass is that pytorch executes a reversed DAG of the foward DAG, just with 54 | the functions replaced by their gradient functions. E.g. in the example above this would look like 55 | 56 | .. graphviz:: 57 | :caption: Backward DAG 58 | :align: center 59 | 60 | digraph backwarddag { 61 | rankdir=BT; 62 | ExpBackward [shape=rectangle]; 63 | "1" -> ExpBackward -> "a.grad"; 64 | } 65 | 66 | In particular, pytorch starts off with :math:`1`, which is obviously the derivative of ``b`` with respect to 67 | itself. It then executes the ``grad_fn`` function, which in this example is the ``ExpBackward`` function. 68 | Not shown in the illustration is that the ``grad_fn`` function internally has a reference to the result of the 69 | forward calculation, which is then muliplied with ``1`` and defines the output of ``ExpBackward``. 70 | Finally pytorch stores this result in ``a.grad``, which now contains the derviative of ``b`` with respect to ``a``. 71 | 72 | This principle of course can be generalized to more complicated DAGs. pytorch in these situations still builds 73 | up the backwards DAG by recording the gradient functions and the dependencies on the go, and then executes this 74 | graph when the ``backward`` method is called. However, there are still some important implications for the 75 | usage in the following, which we want to highlight: 76 | 77 | .. _section_pure_functions: 78 | 79 | Automatic differentiable functions should at best be pure functions 80 | ------------------------------------------------------------------- 81 | 82 | This statement is --- if written out like that --- probably not any news, 83 | since the concept of a differentiable function 84 | is a mathematical one, and all mathematical functions are pure in a procedural sense. Hence, a programmatic 85 | representation of a mathematically differentiable function should at best also be pure. 86 | 87 | This has some implications. One of the more important ones is, as obvious as it may seem, that this function 88 | needs to have an input and an output. Without an input (and without explicitly modifying the autograd meta data) 89 | the output of a function is from the perspective of the AD engine a constant. The same applies for functions 90 | with no output, whose branch in the backward DAG execution is simply omitted by the AD engine. 91 | 92 | Since this is so important, we repeat it: 93 | 94 | .. warning:: 95 | 96 | **All automatic differentiable functions need to depend on some input tensor, 97 | and need to return an output tensor**. 98 | 99 | DAG edges can only be pytorch tensors of floating point type 100 | ------------------------------------------------------------ 101 | 102 | This goes into the same direction as the last remark. Obviously differentiability is from its mathematical 103 | definition strongly tied to of the real numbers, and the floating point numbers are the only approximation 104 | to them we have in pytorch. 105 | 106 | As such we can only exchange floating point tensors along the edges in the DAG. 107 | 108 | That some form of additivity is required for the structures that are transported along the DAG edges 109 | can also be seen from the following example 110 | 111 | .. code-block:: python 112 | 113 | a = ... 114 | tmp1 = F(a) 115 | tmp2 = G1(tmp1) 116 | tmp3 = G2(tmp2) 117 | b = H(tmp1, tmp2) 118 | b.backward() 119 | 120 | Note in particular that the output from the node ``F`` is used twice: once as the input for ``G1`` 121 | and once as the input for ``G2``. The forward DAG would then look like 122 | 123 | .. graphviz:: 124 | :caption: Forward DAG with bifurcation 125 | :align: center 126 | 127 | digraph foo { 128 | F [shape=rectangle]; 129 | G1 [shape=rectangle]; 130 | G2 [shape=rectangle]; 131 | H [shape=rectangle]; 132 | "a" -> F -> G1 -> H -> "b"; 133 | F -> G2 -> H; 134 | } 135 | 136 | The corresponding backward DAG would by simply inverting the arrows and substituting 137 | the function calls by the respective backward function calls, have the form 138 | 139 | .. graphviz:: 140 | :caption: Backward DAG with bifurcation 141 | :align: center 142 | 143 | digraph foo2 { 144 | rankdir=BT; 145 | FBackward [shape=rectangle]; 146 | G1Backward [shape=rectangle]; 147 | G2Backward [shape=rectangle]; 148 | HBackward [shape=rectangle]; 149 | "1" -> HBackward -> G1Backward -> FBackward -> "a.grad"; 150 | HBackward -> G2Backward -> FBackward; 151 | } 152 | 153 | However, what this picture does not show is that the bifurctation in the forward evaluation of ``b`` 154 | becomes an addition in the backward pass. A more detailed representation of the backward DAG 155 | would thus be 156 | 157 | .. graphviz:: 158 | :caption: Backward DAG with explicit addition 159 | :align: center 160 | 161 | digraph foo2 { 162 | rankdir=BT; 163 | FBackward [shape=rectangle]; 164 | "+" [shape=rectangle]; 165 | G1Backward [shape=rectangle]; 166 | G2Backward [shape=rectangle]; 167 | HBackward [shape=rectangle]; 168 | "1" -> HBackward -> G1Backward -> "+" -> FBackward -> "a.grad"; 169 | HBackward -> G2Backward -> "+"; 170 | } 171 | 172 | To sum up: 173 | 174 | .. warning:: 175 | 176 | The edges in the DAG representation can only be pytorch tensors of floating point type. 177 | 178 | 179 | .. _section_implications_mpi4torch: 180 | 181 | Implications for mpi4torch 182 | ========================== 183 | 184 | mpi4torch is a MPI wrapper library for pytorch tensors that tries to be as *transparent* as possible 185 | to pytorch's AD engine. By transparent we in particular mean that we do not touch the AD engine, but 186 | rather provide the MPI functions as nodes in the DAG that pytorch composes. To be more precise, one should 187 | say the DAGs that pytorch composes, which brings us already to one of the ramifications of this design 188 | decision: When parallelizing your program with mpi4torch it is still the case that each MPI rank has its 189 | individual DAG that is run during the backward step. Most importantly, these DAGs do not know anything 190 | about each other, and thus cannot resolve any dependencies with ``requires_grad`` set from any other rank. 191 | As a consequence **it is the sole responsibility of the user to manage these dependencies**. 192 | 193 | We will come to it in a minute how the user actually can encode these dependencies, but first start 194 | with an example. Consider the following code, which shows the often used Isend-Recv-Wait idiom. 195 | It from a communication perspective 196 | simply receives a tensor from the left process and passes its own tensor to the right, if all 197 | ranks are imagined to be arranged in a circle. 198 | 199 | .. code-block:: python 200 | 201 | import torch 202 | import mpi4torch 203 | 204 | comm = mpi4torch.COMM_WORLD 205 | 206 | a = torch.tensor([1.0 + comm.rank]).requires_grad_() 207 | 208 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0) 209 | b = comm.Recv(torch.empty_like(a), (comm.rank-1+comm.size)%comm.size, 0) 210 | comm.Wait(handle) 211 | 212 | res = a+b 213 | print(res) 214 | 215 | This code follows usual MPI coding paradigms and works as expected. However, when we would start asking 216 | for the gradient of (the sum of all) ``res`` with respect to the individual ``a`` s, we would get an incorrect 217 | result. 218 | 219 | .. code-block:: python 220 | 221 | res.backward() 222 | print(a.grad) # <- this would print tensor([1.]) 223 | 224 | The print function would actually display 1 as the result, whereas taking the derivative of 225 | the sum of all ``res`` variables on all ranks with respect to that specific ``a`` variable should be 2. 226 | 227 | This is just one of the things that could happen. There are many more situations, in which the program would 228 | run flawlessly in forward mode, but would e.g. deadlock in the backward pass. To exemplify 229 | how this happens we will look once more at a graphical representation of the DAG. 230 | 231 | .. graphviz:: 232 | :caption: Forward DAG for the Isend-Recv-Wait idiom 233 | :align: center 234 | :name: naiveforwardisendrecvwaitgraph 235 | 236 | digraph foo2 { 237 | rank = same; 238 | subgraph clusterrankm1 { 239 | a1 [label="a"]; 240 | res1 [label="res"]; 241 | node [shape=rectangle]; 242 | Isend1 [label="Isend"]; 243 | Wait1 [label="Wait"]; 244 | Recv1 [label="Recv"]; 245 | p1 [label="+"]; 246 | a1 -> Isend1 -> Wait1; 247 | Recv1 -> p1 -> res1; 248 | a1 -> p1; 249 | label = "rank - 1"; 250 | color = black; 251 | }; 252 | subgraph clusterrank { 253 | a2 [label="a"]; 254 | res2 [label="res"]; 255 | node [shape=rectangle]; 256 | Isend2 [label="Isend"]; 257 | Wait2 [label="Wait"]; 258 | Recv2 [label="Recv"]; 259 | p2 [label="+"]; 260 | a2 -> Isend2 -> Wait2; 261 | Recv2 -> p2 -> res2; 262 | a2 -> p2; 263 | label = "rank"; 264 | } 265 | subgraph clusterrankp1 { 266 | a3 [label="a"]; 267 | res3 [label="res"]; 268 | node [shape=rectangle]; 269 | Isend3 [label="Isend"]; 270 | Wait3 [label="Wait"]; 271 | Recv3 [label="Recv"]; 272 | p3 [label="+"]; 273 | a3 -> Isend3 -> Wait3; 274 | Recv3 -> p3 -> res3; 275 | a3 -> p3; 276 | label = "rank + 1"; 277 | } 278 | 279 | Isend1 -> Recv2 [style=dotted, constraint=false]; 280 | Isend2 -> Recv3 [style=dotted, constraint=false]; 281 | #Isend3 -> Recv1 [style=dotted, constraint=false]; 282 | } 283 | 284 | The graph as shown above shows the dependencies between the different computations as seen from pytorch's 285 | perspective with the addition of some dotted arrows that show the actual communication that is happening. 286 | 287 | If we would now invert the arrows in order to get the corresponding backward DAG we would obtain 288 | 289 | .. graphviz:: 290 | :caption: Backward DAG for the Isend-Recv-Wait idiom for a single rank 291 | :align: center 292 | 293 | digraph foo2 { 294 | rankdir=BT; 295 | subgraph clusterrankm1 { 296 | a1 [label="a.grad"]; 297 | res1 [label="1"]; 298 | node [shape=rectangle]; 299 | Isend1 [label="IsendBackward", style=filled, color=gray]; 300 | Wait1 [label="WaitBackward", style=filled, color=gray]; 301 | Recv1 [label="RecvBackward", style=filled, color=gray]; 302 | p1 [label="AddBackward0"]; 303 | Wait1 -> Isend1 -> a1; 304 | res1 -> p1 -> Recv1; 305 | p1 -> a1; 306 | label = "rank"; 307 | }; 308 | } 309 | 310 | This graph immanently makes clear why ``a.grad`` contains 1 in the end. All grayed-out nodes are omitted 311 | --- or to be more precise, not even generated --- by pytorch's AD engine, such that only ``AddBackward0`` 312 | is called, which just passes through 1 to ``a.grad``. 313 | 314 | From this discussion and the :ref:`naiveforwardisendrecvwaitgraph` it becomes apparent that there are some parts 315 | that are implicit in the program code but that are missing in the DAG representation: 316 | 317 | #. As noted earlier, the DAGs are local to each MPI rank, and they do not resolve any dependencies that 318 | are the effect of communication. 319 | #. The DAGs also lack any information that was present in the linear ordering of commands in the source code 320 | file. E.g. the ``Recv`` call has to happen after ``Isend``, and ``Wait`` has to happen after ``Recv``. 321 | 322 | **It is the users responsibility to encode these dependencies in the DAG!.** 323 | This brings us to the tools mpi4torch provides to mitigate this situation. 324 | 325 | The first one is a direct consequence of the discussion in the section on 326 | :ref:`pure functions `: all DAG nodes need an input and an output. 327 | In our example above, this would e.g. concern the :py:meth:`mpi4torch.MPI_Communicator.Wait` 328 | call. In principle, ``MPI_Wait`` does not return a floating point tensor. However, mpi4torch 329 | returns a floating-point tensor, giving the user the possibility to use it to encode 330 | any other dependencies on the ``Wait`` call. These tensors are named **dummies** in mpi4torch. 331 | They do not convey any other information than that there is some (virtual/artificial) 332 | dependency to be encoded in the DAG. 333 | 334 | The dummies themselves are not really useful without a way to join them with the DAG. This is 335 | what the :py:func:`mpi4torch.JoinDummies` function is actually for. The call signature of 336 | :py:func:`mpi4torch.JoinDummies` is given by 337 | 338 | .. code-block:: python 339 | 340 | def JoinDummies(loopthrough: torch.Tensor, dummies: List[torch.Tensor]) -> torch.Tensor 341 | 342 | The function takes two arguments: the loopthrough variable and a list of dummies. From a forward 343 | execution perspective the ``JoinDummies`` function is a no-op, it simply --- as the name suggests --- 344 | loops through the ``loopthrough`` variable. The ``dummies`` are discarded and not used. 345 | 346 | However, pytorch does not know about this behaviour of the ``JoinDummies`` function, and considers 347 | the result of the function to actually depend on the dummies. Consequently, pytorch will also 348 | respect this dependency in the backward DAG. 349 | 350 | The :py:func:`mpi4torch.JoinDummies` function also has a sister function :py:func:`mpi4torch.JoinDummiesHandle`, which 351 | is thought for situations in which the ``loopthrough`` variable is a :py:class:`mpi4torch.WaitHandle` 352 | from a non-blocking MPI call, as e.g. returned by :py:func:`mpi4torch.MPI_Communicator.Isend`. The signature 353 | of :py:func:`mpi4torch.JoinDummiesHandle` is 354 | 355 | .. code-block:: python 356 | 357 | def JoinDummiesHandle(handle: WaitHandle, dummies: List[torch.Tensor]) -> WaitHandle 358 | 359 | Returning to the Isend-Recv-Wait example, we now want to put these tools to use. Starting with 360 | the call to :py:func:`mpi4torch.MPI_Communicator.Recv`, we want this call to happen after 361 | :py:func:`mpi4torch.MPI_Communicator.Isend`. Note that ``Isend`` returns a ``WaitHandle``, which 362 | cannot directly be passed to ``JoinDummies``. For these situations we will use the 363 | :py:attr:`mpi4torch.WaitHandle.dummy` property, which gives us a means to convert a ``WaitHandle`` 364 | to a dummy tensor. In the example from above this could then 365 | look like 366 | 367 | .. code-block:: python 368 | 369 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0) 370 | recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy]) 371 | # ~~~~~~~~~~~~~~~~~~~ 372 | # This is what we 373 | # originally wanted 374 | # to pass to Recv 375 | # ~~~~~~~~~~~~~~ 376 | # This adds the handle 377 | # from the previous Isend call 378 | # as a dummy dependency to the DAG 379 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0) 380 | 381 | For the ``Wait`` we now also want this to happen after the ``Recv`` call. This time we make use of 382 | :py:func:`mpi4torch.JoinDummiesHandle`. 383 | 384 | .. code-block:: python 385 | 386 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0) 387 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b])) 388 | 389 | Note that we already added a return variable for ``Wait``, since we still want to encode 390 | that our end result, the (implicit) sum of all ``res`` on all ranks, depends on the ``Isend`` to 391 | have finished. For that we introduce another call to :py:func:`mpi4torch.JoinDummies`. 392 | 393 | .. code-block:: python 394 | 395 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b])) 396 | 397 | res = mpi4torch.JoinDummies(a+b, [wait_ret]) 398 | 399 | 400 | The full code example now looks like 401 | 402 | .. code-block:: python 403 | 404 | import torch 405 | import mpi4torch 406 | 407 | comm = mpi4torch.COMM_WORLD 408 | 409 | a = torch.tensor([1.0 + comm.rank]).requires_grad_() 410 | 411 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0) 412 | recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy]) 413 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0) 414 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b])) 415 | 416 | res = mpi4torch.JoinDummies(a+b, [wait_ret]) 417 | print(res) 418 | 419 | res.backward() 420 | print(a.grad) # <- this would now correctly print tensor([2.]) 421 | 422 | This code would now print the correct result for ``a.grad``. To exemplify the differences to the 423 | first version of the code we will also look at the DAG of the new version 424 | 425 | .. graphviz:: 426 | :caption: Forward DAG for the Isend-Recv-Wait idiom with dummy dependencies 427 | :align: center 428 | 429 | digraph foo2 { 430 | subgraph clusterrankm1 { 431 | a [label="a"]; 432 | res [label="res"]; 433 | node [shape=rectangle]; 434 | JoinDummies1 [label="JoinDummies"]; 435 | JoinDummiesHandle [label="JoinDummiesHandle"]; 436 | JoinDummies2 [label="JoinDummies"]; 437 | Isend [label="Isend"]; 438 | Wait [label="Wait"]; 439 | Recv [label="Recv"]; 440 | p1 [label="+"]; 441 | a -> Isend; 442 | Isend -> JoinDummies1 -> Recv; 443 | Recv -> JoinDummiesHandle -> Wait; 444 | Isend -> JoinDummiesHandle; 445 | Recv -> p1 -> JoinDummies2; 446 | a -> p1; 447 | Wait -> JoinDummies2; 448 | JoinDummies2 -> res; 449 | label = "rank"; 450 | color = black; 451 | }; 452 | } 453 | 454 | The important point to note is that all communciation is part of a path between 455 | ``a`` and ``res``, and in comparison to the first version of the code there are no "dead branches". 456 | pytorch's AD engine thus has to call the respective backward methods when it propagates the gradient 457 | back from ``res`` to ``a.grad``. 458 | 459 | .. warning:: 460 | 461 | In general, if you write a function that uses mpi4torch internally and shall be automatic differentiable, 462 | make sure that all communication primitives are through one way or another part of a DAG path 463 | that connects input and output of that function. 464 | 465 | 466 | -------------------------------------------------------------------------------- /csrc/extension.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | //#include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #if defined(OPEN_MPI) && OPEN_MPI 13 | // Needed for checking cuda-awareness 14 | #include 15 | #endif 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | using torch::Tensor; 22 | using torch::ScalarType; 23 | using torch::autograd::variable_list; 24 | 25 | namespace 26 | { 27 | 28 | #if defined(MPIX_CUDA_AWARE_SUPPORT) 29 | // if it is at compiletime already clear that OpenMPI has now support, we deactivate 30 | // the cuda-aware mpi support directly 31 | #define MPI4TORCH_BUILT_WITH_CUDA_AWARENESS MPIX_CUDA_AWARE_SUPPORT 32 | #else 33 | #define MPI4TORCH_BUILT_WITH_CUDA_AWARENESS 0 34 | #endif 35 | 36 | #if MPI4TORCH_BUILT_WITH_CUDA_AWARENESS 37 | 38 | bool have_cuda_aware_mpi_support = false; 39 | 40 | void inline __setup_have_cuda_aware_mpi_support() 41 | { 42 | // OpenMPI (and presumably also Parastation MPI) provides this runtime query function 43 | #if defined(MPIX_CUDA_AWARE_SUPPORT) 44 | have_cuda_aware_mpi_support = MPIX_Query_cuda_support(); 45 | #else 46 | have_cuda_aware_mpi_support = false; 47 | #endif 48 | } 49 | 50 | #else 51 | const bool have_cuda_aware_mpi_support = false; 52 | #endif 53 | 54 | void deactivate_cuda_aware_mpi_support() 55 | { 56 | #if MPI4TORCH_BUILT_WITH_CUDA_AWARENESS 57 | have_cuda_aware_mpi_support = false; 58 | #endif 59 | } 60 | 61 | struct MPIDeviceHelper 62 | { 63 | MPIDeviceHelper(const Tensor& input) 64 | : device(input.device()), devicetype(device.type()), mpidevice(c10::kCPU) 65 | { 66 | setup(); 67 | } 68 | 69 | MPIDeviceHelper(const c10::Device& _device) 70 | : device(_device), devicetype(device.type()), mpidevice(c10::kCPU) 71 | { 72 | setup(); 73 | } 74 | 75 | Tensor fromDeviceToMPI(const Tensor& input) 76 | { 77 | if (input.device() == mpidevice) { 78 | return input; 79 | } 80 | return input.to(mpidevice); 81 | } 82 | 83 | Tensor fromMPIToDevice(const Tensor& output) 84 | { 85 | if (output.device() == device) { 86 | return output; 87 | } 88 | return output.to(device); 89 | } 90 | 91 | c10::Device device; 92 | c10::DeviceType devicetype; 93 | c10::Device mpidevice; 94 | 95 | private: 96 | void setup() 97 | { 98 | if (devicetype == c10::kCPU) { 99 | mpidevice = device; 100 | } else if (devicetype == c10::kCUDA && have_cuda_aware_mpi_support) { 101 | mpidevice = device; 102 | } 103 | } 104 | }; 105 | 106 | MPI_Datatype torch2mpitype(ScalarType in) 107 | { 108 | switch(in) 109 | { 110 | case ScalarType::Byte: 111 | return MPI_BYTE; 112 | case ScalarType::Char: 113 | return MPI_CHAR; 114 | case ScalarType::Short: 115 | return MPI_SHORT; 116 | case ScalarType::Int: 117 | return MPI_INT; 118 | case ScalarType::Long: 119 | return MPI_LONG; 120 | case ScalarType::Float: 121 | return MPI_FLOAT; 122 | case ScalarType::Double: 123 | return MPI_DOUBLE; 124 | default: 125 | break; 126 | // just to silence compiler warnings of unhandeled switch cases 127 | } 128 | throw std::invalid_argument("Failure to match torch::ScalarType to MPI_Datatype!"); 129 | } 130 | 131 | void check_mpi_return_value(int ierr) 132 | { 133 | if (ierr != MPI_SUCCESS) { 134 | std::ostringstream oss; 135 | oss << ierr; 136 | throw std::runtime_error("MPI call failed with error code " + oss.str()); 137 | } 138 | } 139 | 140 | struct MPI_Comm_Wrapper : torch::CustomClassHolder 141 | { 142 | MPI_Comm_Wrapper(const MPI_Comm comm_ = MPI_COMM_NULL) : comm(comm_) {} 143 | 144 | MPI_Comm comm; 145 | 146 | int64_t GetRank(); 147 | int64_t GetSize(); 148 | 149 | Tensor MPIAllreduce(const Tensor& input, int64_t op); 150 | Tensor MPIBcast_(const Tensor& input, int64_t root); 151 | Tensor MPIReduce_(const Tensor& input, int64_t op, int64_t root); 152 | 153 | Tensor MPIGather(const Tensor& input, int64_t gatheraxis, int64_t root); 154 | Tensor MPIAllgather(const Tensor& input, int64_t gatheraxis); 155 | Tensor MPIScatter(const Tensor& input, int64_t scatteraxis, int64_t numelem, int64_t root); 156 | Tensor MPIAlltoall(const Tensor& input, int64_t gatheraxis, int64_t scatteraxis, int64_t numelem); 157 | 158 | variable_list MPIIsend(const Tensor& input, int64_t dest, int64_t tag); 159 | variable_list MPIIrecv(const Tensor& input, int64_t source, int64_t tag); 160 | Tensor MPIWait(const variable_list& input); 161 | }; 162 | 163 | c10::intrusive_ptr comm_world() 164 | { 165 | return c10::make_intrusive(MPI_COMM_WORLD); 166 | } 167 | 168 | c10::intrusive_ptr comm_from_fortran(int64_t fortran_handle) 169 | { 170 | return c10::make_intrusive(MPI_Comm_f2c(fortran_handle)); 171 | } 172 | 173 | Tensor JoinDummies(const Tensor& loopthrough, const variable_list& list); 174 | 175 | int64_t MPI_Comm_Wrapper::GetRank() 176 | { 177 | int rank; 178 | MPI_Comm_rank(comm,&rank); 179 | return rank; 180 | } 181 | 182 | int64_t MPI_Comm_Wrapper::GetSize() 183 | { 184 | int size; 185 | MPI_Comm_size(comm,&size); 186 | return size; 187 | } 188 | 189 | struct MPIBackwardNode : public torch::autograd::Node 190 | { 191 | MPI_Comm_Wrapper comm; 192 | }; 193 | 194 | struct MPIUnimplementedNode : MPIBackwardNode 195 | { 196 | variable_list apply(variable_list&& grads) override { 197 | throw std::runtime_error("This backward operation is currently unimplemented!"); 198 | } 199 | std::string name() const override { 200 | return std::string("MPIUnimplementedNode"); 201 | } 202 | }; 203 | 204 | enum Mpi4torchCollectiveOps : int64_t 205 | { 206 | mpi4torch_op_max, 207 | mpi4torch_op_min, 208 | mpi4torch_op_sum, 209 | mpi4torch_op_prod, 210 | mpi4torch_op_land, 211 | mpi4torch_op_band, 212 | mpi4torch_op_lor, 213 | mpi4torch_op_bor, 214 | mpi4torch_op_lxor, 215 | mpi4torch_op_bxor, 216 | mpi4torch_op_minloc, 217 | mpi4torch_op_maxloc 218 | }; 219 | 220 | MPI_Op __get_mpi_op(int64_t op) 221 | { 222 | switch(op) 223 | { 224 | case mpi4torch_op_max: 225 | return MPI_MAX; 226 | case mpi4torch_op_min: 227 | return MPI_MIN; 228 | case mpi4torch_op_sum: 229 | return MPI_SUM; 230 | case mpi4torch_op_prod: 231 | return MPI_PROD; 232 | case mpi4torch_op_land: 233 | return MPI_LAND; 234 | case mpi4torch_op_band: 235 | return MPI_BAND; 236 | case mpi4torch_op_lor: 237 | return MPI_LOR; 238 | case mpi4torch_op_bor: 239 | return MPI_BOR; 240 | case mpi4torch_op_lxor: 241 | return MPI_LXOR; 242 | case mpi4torch_op_bxor: 243 | return MPI_BXOR; 244 | case mpi4torch_op_minloc: 245 | return MPI_MINLOC; 246 | case mpi4torch_op_maxloc: 247 | return MPI_MAXLOC; 248 | default: 249 | break; 250 | } 251 | throw std::invalid_argument("mpi4torch: Collective operation not supported!"); 252 | } 253 | 254 | struct MPIAllreduceSumBackward : public MPIBackwardNode { 255 | variable_list apply(variable_list&& grads) override; 256 | std::string name() const override { 257 | return std::string("MPIAllreduceSumBackward"); 258 | } 259 | 260 | void release_variables() override { 261 | return; 262 | } 263 | }; 264 | 265 | variable_list MPIAllreduceSumBackward::apply (variable_list&& grads) 266 | { 267 | variable_list grad_inputs(1); 268 | if (should_compute_output(0)) { 269 | grad_inputs[0] = comm.MPIAllreduce(grads[0], mpi4torch_op_sum); 270 | } 271 | return grad_inputs; 272 | } 273 | 274 | Tensor MPI_Comm_Wrapper::MPIAllreduce(const Tensor& input, int64_t op) 275 | { 276 | std::shared_ptr grad_fn; 277 | auto mpiop = __get_mpi_op(op); 278 | if (torch::autograd::compute_requires_grad(input)) { 279 | if (op == mpi4torch_op_sum) { 280 | grad_fn = std::shared_ptr (new MPIAllreduceSumBackward(), torch::autograd::deleteNode); 281 | } else { 282 | grad_fn = std::shared_ptr(new MPIUnimplementedNode(), torch::autograd::deleteNode); 283 | } 284 | grad_fn->comm = *this; 285 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 286 | } 287 | auto result = ([&]() { 288 | at::AutoDispatchBelowADInplaceOrView guard; 289 | 290 | MPIDeviceHelper devhelper(input); 291 | 292 | // make input contiguous 293 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous(); 294 | 295 | auto recv = torch::empty_like(input_cont); 296 | 297 | check_mpi_return_value( 298 | MPI_Allreduce(input_cont.data_ptr(), recv.data_ptr(), input_cont.numel(), 299 | torch2mpitype(input_cont.scalar_type()), mpiop, comm) 300 | ); 301 | 302 | return devhelper.fromMPIToDevice(recv); 303 | })(); 304 | if (grad_fn) { 305 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 306 | } 307 | return result; 308 | } 309 | 310 | struct MPIBcastInPlaceBackward : public MPIBackwardNode { 311 | MPIBcastInPlaceBackward(int _root) : root(_root) {} 312 | variable_list apply(variable_list&& grads) override; 313 | std::string name() const override { 314 | return std::string("MPIBcastInPlaceBackward"); 315 | } 316 | 317 | void release_variables() override { 318 | return; 319 | } 320 | 321 | int root; 322 | }; 323 | 324 | variable_list MPIBcastInPlaceBackward::apply (variable_list&& grads) 325 | { 326 | variable_list grad_inputs(1); 327 | if (should_compute_output(0)) { 328 | grad_inputs[0] = comm.MPIReduce_(grads[0],mpi4torch_op_sum,root); 329 | } 330 | return grad_inputs; 331 | } 332 | 333 | Tensor MPI_Comm_Wrapper::MPIBcast_(const Tensor& input, int64_t root) 334 | { 335 | // TODO: check for root being in int range 336 | std::shared_ptr grad_fn; 337 | if (torch::autograd::compute_requires_grad(input)) { 338 | grad_fn = std::shared_ptr (new MPIBcastInPlaceBackward(static_cast(root)), 339 | torch::autograd::deleteNode); 340 | grad_fn->comm = *this, 341 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 342 | } 343 | auto result = ([&]() { 344 | at::AutoDispatchBelowADInplaceOrView guard; 345 | 346 | MPIDeviceHelper devhelper(input); 347 | 348 | // 1. Make input contiguous 349 | // 2. Call variable_data() to make a shallow copy of the input tensor without the autograd history, 350 | // such that it can be savely returned from this function. 351 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous().variable_data(); 352 | 353 | check_mpi_return_value( 354 | MPI_Bcast(input_cont.data_ptr(), input_cont.numel(), 355 | torch2mpitype(input_cont.scalar_type()), 356 | static_cast(root), comm) 357 | ); 358 | 359 | return devhelper.fromMPIToDevice(input_cont); 360 | })(); 361 | if (grad_fn) { 362 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 363 | } 364 | return result; 365 | } 366 | 367 | struct MPIReduceSumInPlaceBackward : public MPIBackwardNode { 368 | MPIReduceSumInPlaceBackward(int _root) : root(_root) {} 369 | variable_list apply(variable_list&& grads) override; 370 | std::string name() const override { 371 | return std::string("MPIReduceSumInPlaceBackward"); 372 | } 373 | 374 | void release_variables() override { 375 | return; 376 | } 377 | 378 | int root; 379 | }; 380 | 381 | variable_list MPIReduceSumInPlaceBackward::apply (variable_list&& grads) 382 | { 383 | variable_list grad_inputs(1); 384 | // TODO: for these simple functions the should_compute_output check is superfluous 385 | if (should_compute_output(0)) { 386 | // NOTE: It is probably safe to use in-place operations in the backward mode, 387 | // since I currently cannot think of any way how a bifurcation could 388 | // enter the DAG. 389 | // TODO: Proof that it is safe! 390 | grad_inputs[0] = comm.MPIBcast_(grads[0],root); 391 | } 392 | return grad_inputs; 393 | } 394 | 395 | struct MPINoInplaceBackward : public torch::autograd::Node { 396 | variable_list apply(variable_list&& grads) override 397 | { 398 | throw std::runtime_error("Reuse of variables passed to in-place MPI kernels not supported! Try using the return value"); 399 | } 400 | std::string name() const override { 401 | return std::string("MPINoInplaceBackward"); 402 | } 403 | }; 404 | 405 | Tensor MPI_Comm_Wrapper::MPIReduce_(const Tensor& input, int64_t op, int64_t root) 406 | { 407 | // TODO: check for root being in int range 408 | std::shared_ptr grad_fn; 409 | auto mpiop = __get_mpi_op(op); 410 | if (torch::autograd::compute_requires_grad(input)) { 411 | if (op == mpi4torch_op_sum) { 412 | grad_fn = std::shared_ptr (new MPIReduceSumInPlaceBackward(static_cast(root)), 413 | torch::autograd::deleteNode); 414 | } else { 415 | grad_fn = std::shared_ptr (new MPIUnimplementedNode(), torch::autograd::deleteNode); 416 | } 417 | grad_fn->comm = *this; 418 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 419 | } 420 | auto result = ([&]() { 421 | at::AutoDispatchBelowADInplaceOrView guard; 422 | 423 | MPIDeviceHelper devhelper(input); 424 | 425 | // 1. Make input contiguous 426 | // 2. Call variable_data() to make a shallow copy of the input tensor without the autograd history, 427 | // such that it can be savely returned from this function. 428 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous().variable_data(); 429 | 430 | const int rank = GetRank(); 431 | 432 | void* sendbuf = input_cont.data_ptr(); 433 | if (rank == root) { 434 | // One is only allowed to pass MPI_IN_PLACE for the root process 435 | // cf. https://stackoverflow.com/a/17744793 436 | sendbuf = MPI_IN_PLACE; 437 | } 438 | 439 | check_mpi_return_value(MPI_Reduce(sendbuf, input_cont.data_ptr(), input_cont.numel(), 440 | torch2mpitype(input_cont.scalar_type()), 441 | mpiop, static_cast(root), comm)); 442 | 443 | if (rank != root) { 444 | // We fill the non-root results with zeros to make the function properly behaved. 445 | // TODO: We could potentially let the return-value be undefined and save some ops? 446 | input_cont.zero_(); 447 | } 448 | 449 | return devhelper.fromMPIToDevice(input_cont); 450 | })(); 451 | if (grad_fn) { 452 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 453 | 454 | // We only activate this safeguard if the input variable is not a leaf in the DAG, 455 | // otherwise we would get errors from the AccumulateGrad Node. 456 | if (input.grad_fn()) { 457 | // prohibit misuse of input in autograd 458 | auto& input_non_const = const_cast(input); 459 | set_history(input_non_const, std::shared_ptr(new MPINoInplaceBackward(), 460 | torch::autograd::deleteNode)); 461 | } 462 | } 463 | return result; 464 | } 465 | 466 | struct MPIGatherBackward : public MPIBackwardNode { 467 | MPIGatherBackward(int64_t _gatheraxis, int64_t _root, int64_t _numelem) 468 | : gatheraxis(_gatheraxis), root(_root), numelem(_numelem) {} 469 | variable_list apply(variable_list&& grads) override; 470 | std::string name() const override { 471 | return std::string("MPIGatherBackward"); 472 | } 473 | 474 | void release_variables() override { 475 | return; 476 | } 477 | 478 | int64_t gatheraxis; 479 | int64_t root; 480 | int64_t numelem; 481 | }; 482 | 483 | variable_list MPIGatherBackward::apply (variable_list&& grads) 484 | { 485 | variable_list grad_inputs(1); 486 | // TODO: for these simple functions the should_compute_output check is superfluous 487 | if (should_compute_output(0)) { 488 | // pytorch/pytorch#79446 broke this code 489 | //auto next_node = next_edge(0).function; 490 | //auto input_nr = next_edge(0).input_nr; 491 | //const int64_t numelem = next_node->input_metadata(input_nr).shape()[(size_t) gatheraxis]; 492 | grad_inputs[0] = comm.MPIScatter(grads[0],gatheraxis, numelem, root); 493 | } 494 | return grad_inputs; 495 | } 496 | 497 | Tensor MPI_Comm_Wrapper::MPIGather(const Tensor& input, int64_t gatheraxis, int64_t root) 498 | { 499 | // TODO: check for root being in int range 500 | std::shared_ptr grad_fn; 501 | if (torch::autograd::compute_requires_grad(input)) { 502 | grad_fn = std::shared_ptr 503 | (new MPIGatherBackward(gatheraxis, root, input.sizes()[(size_t) gatheraxis]), 504 | torch::autograd::deleteNode); 505 | grad_fn->comm = *this; 506 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 507 | } 508 | auto result = ([&]() { 509 | at::AutoDispatchBelowADInplaceOrView guard; 510 | 511 | MPIDeviceHelper devhelper(input); 512 | 513 | // 1. Make input contiguous 514 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous(); 515 | 516 | auto sizes = input_cont.sizes(); 517 | const size_t ndim = sizes.size(); 518 | const int npes = (int)GetSize(); 519 | 520 | int64_t beforegatheraxis64 = 1; 521 | for (size_t i = 0; i < (size_t)gatheraxis; ++i) { 522 | beforegatheraxis64 *= sizes[i]; 523 | } 524 | 525 | int64_t aftergatheraxis64 = 1; 526 | for (size_t i = (size_t)gatheraxis + 1; i < ndim; ++i) { 527 | aftergatheraxis64 *= sizes[i]; 528 | } 529 | 530 | if (beforegatheraxis64 > INT_MAX || aftergatheraxis64 * sizes[(size_t)gatheraxis] > INT_MAX) { 531 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!"); 532 | } 533 | 534 | const int beforegatheraxis = (int)beforegatheraxis64; 535 | const int gatheraxissize = (int)sizes[(size_t)gatheraxis]; 536 | const int aftergatheraxis = (int)aftergatheraxis64; 537 | 538 | std::vector recvcounts(npes); // TODO: only allocate on root process 539 | 540 | check_mpi_return_value(MPI_Gather(&gatheraxissize, 1, MPI_INT, 541 | &recvcounts[0], 1, 542 | MPI_INT, root, comm)); 543 | 544 | std::vector displs(npes); // TODO: only allocate on root process 545 | //displs[0] = 0; // This is a noop 546 | for (size_t i = 1; i < (size_t)npes; ++i) { 547 | int64_t tmpadd = (int64_t) displs[i-1] + (int64_t) recvcounts[i-1]; 548 | 549 | if (tmpadd > INT_MAX) { 550 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!"); 551 | } 552 | displs[i] = (int) tmpadd; 553 | } 554 | const int newgatheraxissize = displs[npes-1] + recvcounts[npes-1]; // TODO: add overflow check 555 | 556 | MPI_Datatype tmpdatatype1; 557 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis, aftergatheraxis * gatheraxissize, 558 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype1)); 559 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype1)); 560 | 561 | MPI_Datatype sendtype; 562 | MPI_Aint basic_lb, basic_extent; 563 | check_mpi_return_value(MPI_Type_get_extent(torch2mpitype(input_cont.scalar_type()), &basic_lb, 564 | &basic_extent)); 565 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype1, 0, aftergatheraxis * basic_extent, 566 | &sendtype)); 567 | check_mpi_return_value(MPI_Type_commit(&sendtype)); 568 | 569 | MPI_Datatype tmpdatatype2; 570 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis, 571 | newgatheraxissize * aftergatheraxis, 572 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype2)); 573 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype2)); 574 | MPI_Datatype recvtype; 575 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype2, 0, aftergatheraxis * basic_extent, 576 | &recvtype)); 577 | check_mpi_return_value(MPI_Type_commit(&recvtype)); 578 | 579 | std::vector newsizes(sizes.begin(), sizes.end()); 580 | newsizes[(size_t)gatheraxis] = newgatheraxissize; 581 | 582 | auto recvtensor = torch::empty(newsizes, input_cont.options(), c10::MemoryFormat::Contiguous); 583 | 584 | check_mpi_return_value(MPI_Gatherv(input_cont.data_ptr(), gatheraxissize, sendtype, 585 | recvtensor.data_ptr(), &recvcounts[0], &displs[0], recvtype, 586 | static_cast(root), comm)); 587 | 588 | check_mpi_return_value(MPI_Type_free(&tmpdatatype1)); 589 | check_mpi_return_value(MPI_Type_free(&tmpdatatype2)); 590 | check_mpi_return_value(MPI_Type_free(&sendtype)); 591 | check_mpi_return_value(MPI_Type_free(&recvtype)); 592 | 593 | return devhelper.fromMPIToDevice(recvtensor); 594 | })(); 595 | if (grad_fn) { 596 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 597 | } 598 | return result; 599 | } 600 | 601 | struct MPIAllgatherBackward : public MPIBackwardNode { 602 | MPIAllgatherBackward(int64_t _gatheraxis, int64_t _numelem) : gatheraxis(_gatheraxis), numelem(_numelem) {} 603 | variable_list apply(variable_list&& grads) override; 604 | std::string name() const override { 605 | return std::string("MPIAllgatherBackward"); 606 | } 607 | 608 | void release_variables() override { 609 | return; 610 | } 611 | 612 | int64_t gatheraxis; 613 | int64_t numelem; 614 | }; 615 | 616 | variable_list MPIAllgatherBackward::apply (variable_list&& grads) 617 | { 618 | variable_list grad_inputs(1); 619 | // TODO: for these simple functions the should_compute_output check is superfluous 620 | if (should_compute_output(0)) { 621 | // pytorch/pytorch#79446 broke this code 622 | //auto next_node = next_edge(0).function; 623 | //auto input_nr = next_edge(0).input_nr; 624 | //const int64_t numelem = next_node->input_metadata(input_nr).shape()[(size_t) gatheraxis]; 625 | grad_inputs[0] = comm.MPIScatter(grads[0], gatheraxis, numelem, 0); 626 | for (int64_t root = 1; root < comm.GetSize(); ++root) { 627 | grad_inputs[0] += comm.MPIScatter(grads[0], gatheraxis, numelem, 1); 628 | } 629 | } 630 | return grad_inputs; 631 | } 632 | 633 | Tensor MPI_Comm_Wrapper::MPIAllgather(const Tensor& input, int64_t gatheraxis) 634 | { 635 | // TODO: check for root being in int range 636 | std::shared_ptr grad_fn; 637 | if (torch::autograd::compute_requires_grad(input)) { 638 | grad_fn = std::shared_ptr 639 | (new MPIAllgatherBackward(gatheraxis, input.sizes()[(size_t) gatheraxis]), torch::autograd::deleteNode); 640 | grad_fn->comm = *this; 641 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 642 | } 643 | auto result = ([&]() { 644 | at::AutoDispatchBelowADInplaceOrView guard; 645 | 646 | MPIDeviceHelper devhelper(input); 647 | 648 | // 1. Make input contiguous 649 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous(); 650 | 651 | auto sizes = input_cont.sizes(); 652 | const size_t ndim = sizes.size(); 653 | const int npes = (int)GetSize(); 654 | 655 | int64_t beforegatheraxis64 = 1; 656 | for (size_t i = 0; i < (size_t)gatheraxis; ++i) { 657 | beforegatheraxis64 *= sizes[i]; 658 | } 659 | 660 | int64_t aftergatheraxis64 = 1; 661 | for (size_t i = (size_t)gatheraxis + 1; i < ndim; ++i) { 662 | aftergatheraxis64 *= sizes[i]; 663 | } 664 | 665 | if (beforegatheraxis64 > INT_MAX || aftergatheraxis64 * sizes[(size_t)gatheraxis] > INT_MAX) { 666 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!"); 667 | } 668 | 669 | const int beforegatheraxis = (int)beforegatheraxis64; 670 | const int gatheraxissize = (int)sizes[(size_t)gatheraxis]; 671 | const int aftergatheraxis = (int)aftergatheraxis64; 672 | 673 | std::vector recvcounts(npes); 674 | 675 | check_mpi_return_value(MPI_Allgather(&gatheraxissize, 1, MPI_INT, 676 | &recvcounts[0], 1, 677 | MPI_INT, comm)); 678 | 679 | std::vector displs(npes); 680 | //displs[0] = 0; // This is a noop 681 | for (size_t i = 1; i < (size_t)npes; ++i) { 682 | int64_t tmpadd = (int64_t) displs[i-1] + (int64_t) recvcounts[i-1]; 683 | 684 | if (tmpadd > INT_MAX) { 685 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!"); 686 | } 687 | displs[i] = (int) tmpadd; 688 | } 689 | const int newgatheraxissize = displs[npes-1] + recvcounts[npes-1]; 690 | 691 | MPI_Datatype tmpdatatype1; 692 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis, aftergatheraxis * gatheraxissize, 693 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype1)); 694 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype1)); 695 | 696 | MPI_Datatype sendtype; 697 | MPI_Aint basic_lb, basic_extent; 698 | check_mpi_return_value(MPI_Type_get_extent(torch2mpitype(input_cont.scalar_type()), &basic_lb, 699 | &basic_extent)); 700 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype1, 0, aftergatheraxis * basic_extent, 701 | &sendtype)); 702 | check_mpi_return_value(MPI_Type_commit(&sendtype)); 703 | 704 | MPI_Datatype tmpdatatype2; 705 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis, 706 | newgatheraxissize * aftergatheraxis, 707 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype2)); 708 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype2)); 709 | MPI_Datatype recvtype; 710 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype2, 0, aftergatheraxis * basic_extent, 711 | &recvtype)); 712 | check_mpi_return_value(MPI_Type_commit(&recvtype)); 713 | 714 | std::vector newsizes(sizes.begin(), sizes.end()); 715 | newsizes[(size_t)gatheraxis] = newgatheraxissize; 716 | 717 | auto recvtensor = torch::empty(newsizes, input_cont.options(), c10::MemoryFormat::Contiguous); 718 | 719 | check_mpi_return_value(MPI_Allgatherv(input_cont.data_ptr(), gatheraxissize, sendtype, 720 | recvtensor.data_ptr(), &recvcounts[0], &displs[0], recvtype, 721 | comm)); 722 | 723 | check_mpi_return_value(MPI_Type_free(&tmpdatatype1)); 724 | check_mpi_return_value(MPI_Type_free(&tmpdatatype2)); 725 | check_mpi_return_value(MPI_Type_free(&sendtype)); 726 | check_mpi_return_value(MPI_Type_free(&recvtype)); 727 | 728 | return devhelper.fromMPIToDevice(recvtensor); 729 | })(); 730 | if (grad_fn) { 731 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 732 | } 733 | return result; 734 | } 735 | 736 | struct MPIScatterBackward : public MPIBackwardNode { 737 | MPIScatterBackward(int64_t _scatteraxis, int64_t _root) 738 | : scatteraxis(_scatteraxis), root(_root) {} 739 | variable_list apply(variable_list&& grads) override; 740 | std::string name() const override { 741 | return std::string("MPIScatterBackward"); 742 | } 743 | 744 | void release_variables() override { 745 | return; 746 | } 747 | 748 | int64_t scatteraxis; 749 | int64_t root; 750 | }; 751 | 752 | variable_list MPIScatterBackward::apply (variable_list&& grads) 753 | { 754 | variable_list grad_inputs(1); 755 | // TODO: for these simple functions the should_compute_output check is superfluous 756 | if (should_compute_output(0)) { 757 | auto tmp = comm.MPIGather(grads[0],scatteraxis, root); 758 | if (comm.GetRank() == root) { 759 | grad_inputs[0] = tmp; 760 | } else { 761 | auto next_node = next_edge(0).function; 762 | auto input_nr = next_edge(0).input_nr; 763 | grad_inputs[0] = JoinDummies(next_node->input_metadata(input_nr).zeros_like(), {tmp}); 764 | } 765 | } 766 | return grad_inputs; 767 | } 768 | 769 | Tensor MPI_Comm_Wrapper::MPIScatter(const Tensor& input, int64_t scatteraxis, int64_t numelem, int64_t root) 770 | { 771 | // TODO: check for root being in int range 772 | std::shared_ptr grad_fn; 773 | if (torch::autograd::compute_requires_grad(input)) { 774 | grad_fn = std::shared_ptr 775 | (new MPIScatterBackward(scatteraxis, root), 776 | torch::autograd::deleteNode); 777 | grad_fn->comm = *this; 778 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 779 | } 780 | auto result = ([&]() { 781 | at::AutoDispatchBelowADInplaceOrView guard; 782 | 783 | MPIDeviceHelper devhelper(input); 784 | 785 | // 1. Make input contiguous 786 | auto input_cont = GetRank() == root ? devhelper.fromDeviceToMPI(input).contiguous() : input; 787 | 788 | size_t ndim = input_cont.sizes().size(); 789 | check_mpi_return_value(MPI_Bcast(&ndim, 1, MPI_LONG, root, comm)); 790 | std::vector sizes; 791 | if (GetRank() == root) { 792 | sizes = std::move(input_cont.sizes().vec()); 793 | } else { 794 | sizes.resize(ndim); 795 | } 796 | check_mpi_return_value(MPI_Bcast(&sizes[0], (int)ndim, MPI_LONG, root, comm)); 797 | 798 | const int npes = (int)GetSize(); 799 | 800 | int64_t beforescatteraxis64 = 1; 801 | for (size_t i = 0; i < (size_t)scatteraxis; ++i) { 802 | beforescatteraxis64 *= sizes[i]; 803 | } 804 | 805 | int64_t afterscatteraxis64 = 1; 806 | for (size_t i = (size_t)scatteraxis + 1; i < ndim; ++i) { 807 | afterscatteraxis64 *= sizes[i]; 808 | } 809 | 810 | if (beforescatteraxis64 > INT_MAX || afterscatteraxis64 * sizes[(size_t)scatteraxis] > INT_MAX) { 811 | throw std::runtime_error("MPI_Scatter: Tensor sizes exceed INT_MAX!"); 812 | } 813 | 814 | const int beforescatteraxis = (int)beforescatteraxis64; 815 | const int scatteraxissize = (int)sizes[(size_t)scatteraxis]; 816 | const int afterscatteraxis = (int)afterscatteraxis64; 817 | const int newscatteraxissize = (int) numelem; 818 | 819 | std::vector sendcounts(npes); // TODO: only allocate on root process 820 | 821 | check_mpi_return_value(MPI_Gather(&newscatteraxissize, 1, MPI_INT, 822 | &sendcounts[0], 1, 823 | MPI_INT, root, comm)); 824 | 825 | std::vector displs(npes); // TODO: only allocate on root process 826 | //displs[0] = 0; // This is a noop 827 | for (size_t i = 1; i < (size_t)npes; ++i) { 828 | int64_t tmpadd = (int64_t) displs[i-1] + (int64_t) sendcounts[i-1]; 829 | 830 | if (tmpadd > INT_MAX) { 831 | throw std::runtime_error("MPI_Scatter: Tensor sizes exceed INT_MAX!"); 832 | } 833 | displs[i] = (int) tmpadd; 834 | } 835 | if (root == GetRank() && scatteraxissize != displs[npes-1] + sendcounts[npes-1]) { 836 | throw std::runtime_error("MPI_Scatter: finaltensor.shape[scatteraxis] != sum(numelem)!"); 837 | } 838 | 839 | MPI_Datatype tmpdatatype1; 840 | check_mpi_return_value(MPI_Type_vector(beforescatteraxis, afterscatteraxis, 841 | afterscatteraxis * scatteraxissize, 842 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype1)); 843 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype1)); 844 | 845 | MPI_Datatype sendtype; 846 | MPI_Aint basic_lb, basic_extent; 847 | check_mpi_return_value(MPI_Type_get_extent(torch2mpitype(input_cont.scalar_type()), &basic_lb, 848 | &basic_extent)); 849 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype1, 0, afterscatteraxis * basic_extent, 850 | &sendtype)); 851 | check_mpi_return_value(MPI_Type_commit(&sendtype)); 852 | 853 | MPI_Datatype tmpdatatype2; 854 | check_mpi_return_value(MPI_Type_vector(beforescatteraxis, afterscatteraxis, 855 | newscatteraxissize * afterscatteraxis, 856 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype2)); 857 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype2)); 858 | MPI_Datatype recvtype; 859 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype2, 0, afterscatteraxis * basic_extent, 860 | &recvtype)); 861 | check_mpi_return_value(MPI_Type_commit(&recvtype)); 862 | 863 | std::vector newsizes(sizes.begin(), sizes.end()); 864 | newsizes[(size_t)scatteraxis] = newscatteraxissize; 865 | 866 | auto recvtensor = torch::empty(newsizes, input_cont.options().device(devhelper.mpidevice), 867 | c10::MemoryFormat::Contiguous); 868 | 869 | check_mpi_return_value(MPI_Scatterv(input_cont.data_ptr(), &sendcounts[0], &displs[0], sendtype, 870 | recvtensor.data_ptr(), newscatteraxissize, recvtype, 871 | static_cast(root), comm)); 872 | 873 | check_mpi_return_value(MPI_Type_free(&tmpdatatype1)); 874 | check_mpi_return_value(MPI_Type_free(&tmpdatatype2)); 875 | check_mpi_return_value(MPI_Type_free(&sendtype)); 876 | check_mpi_return_value(MPI_Type_free(&recvtype)); 877 | 878 | return devhelper.fromMPIToDevice(recvtensor); 879 | })(); 880 | if (grad_fn) { 881 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 882 | } 883 | return result; 884 | } 885 | 886 | struct MPIAlltoallBackward : public MPIBackwardNode { 887 | MPIAlltoallBackward(int64_t _gatheraxis, int64_t _scatteraxis, int64_t _numelem) 888 | : gatheraxis(_gatheraxis), scatteraxis(_scatteraxis), numelem(_numelem) {} 889 | variable_list apply(variable_list&& grads) override; 890 | std::string name() const override { 891 | return std::string("MPIAlltoallBackward"); 892 | } 893 | 894 | void release_variables() override { 895 | return; 896 | } 897 | 898 | int64_t gatheraxis; 899 | int64_t scatteraxis; 900 | int64_t numelem; 901 | }; 902 | 903 | variable_list MPIAlltoallBackward::apply (variable_list&& grads) 904 | { 905 | variable_list grad_inputs(1); 906 | // TODO: for these simple functions the should_compute_output check is superfluous 907 | if (should_compute_output(0)) { 908 | // pytorch/pytorch#79446 broke this code 909 | //auto next_node = next_edge(0).function; 910 | //auto input_nr = next_edge(0).input_nr; 911 | //const int64_t numelem = next_node->input_metadata(input_nr).shape()[(size_t) gatheraxis]; 912 | grad_inputs[0] = comm.MPIAlltoall(grads[0], scatteraxis, gatheraxis, numelem); 913 | } 914 | return grad_inputs; 915 | } 916 | 917 | Tensor MPI_Comm_Wrapper::MPIAlltoall(const Tensor& input, int64_t gatheraxis, int64_t scatteraxis, int64_t numelem) 918 | { 919 | // TODO: check for root being in int range 920 | std::shared_ptr grad_fn; 921 | if (torch::autograd::compute_requires_grad(input)) { 922 | grad_fn = std::shared_ptr 923 | (new MPIAlltoallBackward(gatheraxis, scatteraxis, input.sizes()[(size_t) gatheraxis]), 924 | torch::autograd::deleteNode); 925 | grad_fn->comm = *this; 926 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 927 | } 928 | auto result = ([&]() { 929 | at::AutoDispatchBelowADInplaceOrView guard; 930 | 931 | // 1. Make input contiguous 932 | auto input_cont = input.contiguous(); 933 | 934 | 935 | // TODO: This is probably not the best solution latency-wise, but the total number 936 | // of bytes sent should match roughly a solution that uses Alltoallw. 937 | // If the latency turns out to become a problem, we should switch to Alltoallw. 938 | // Memory usage could also become an issue, since this solution requires 939 | // temporarily twice the memory that an Alltoallw solution would require. 940 | std::vector scattered_tensors; 941 | 942 | if (gatheraxis != scatteraxis) { 943 | // This is the easy case 944 | for(int64_t root = 0; root < GetSize(); ++root) { 945 | scattered_tensors.emplace_back(MPIScatter(input_cont, scatteraxis, numelem, root)); 946 | } 947 | } else { 948 | // if gather- and scatteraxis coincide we first need to figure out who receives how many 949 | // elements from whom 950 | // 951 | // TODO: this could probably be solved more efficiently, with few more communication roundtrips 952 | 953 | const size_t npes = (size_t)GetSize(); 954 | const size_t rank = (size_t)GetRank(); 955 | std::vector numelem_cur(npes+1); 956 | std::vector numelem_new(npes+1); 957 | 958 | const int tmp1 = input_cont.sizes()[(size_t)gatheraxis]; 959 | check_mpi_return_value(MPI_Allgather(&tmp1, 1, MPI_INT, 960 | &numelem_cur[1], 1, 961 | MPI_INT, comm)); 962 | const int tmp2 = numelem; 963 | check_mpi_return_value(MPI_Allgather(&tmp2, 1, MPI_INT, 964 | &numelem_new[1], 1, 965 | MPI_INT, comm)); 966 | for (size_t i = 1; i < npes; ++i) { 967 | numelem_cur[i+1] += numelem_cur[i]; 968 | numelem_new[i+1] += numelem_new[i]; 969 | } 970 | 971 | for(int64_t root = 0; static_cast(root) < npes; ++root) { 972 | int64_t localnumelem = std::min(numelem_new[rank+1],numelem_cur[root+1]) 973 | - std::max(numelem_new[rank],numelem_cur[root]); 974 | if (localnumelem < 0) { 975 | localnumelem = 0; 976 | } 977 | scattered_tensors.emplace_back(MPIScatter(input_cont, scatteraxis, localnumelem, root)); 978 | } 979 | } 980 | 981 | return at::cat(scattered_tensors, gatheraxis); 982 | })(); 983 | if (grad_fn) { 984 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 985 | } 986 | return result; 987 | } 988 | 989 | struct JoinDummiesBackward : public torch::autograd::Node { 990 | variable_list apply(variable_list&& grads) override; 991 | std::string name() const override { 992 | return std::string("JoinDummiesBackward"); 993 | } 994 | 995 | void handle_dummies_helper(const size_t first, variable_list& grad_inputs); 996 | 997 | // TODO: I am not sure whether it is wise here to circumvent the saved_variables list, 998 | // and do our own thing. It might be that we create a memory leak that way! 999 | Tensor loopthrough; 1000 | }; 1001 | 1002 | void JoinDummiesBackward::handle_dummies_helper(const size_t first, variable_list& grad_inputs) 1003 | { 1004 | for (size_t i = first; i < num_outputs(); ++i) { 1005 | if (should_compute_output(i)) { 1006 | auto next_node = next_edge(i).function; 1007 | auto input_nr = next_edge(i).input_nr; 1008 | grad_inputs[i] = JoinDummies(next_node->input_metadata(input_nr).zeros_like(),{loopthrough}); 1009 | } 1010 | } 1011 | } 1012 | 1013 | variable_list JoinDummiesBackward::apply (variable_list&& grads) 1014 | { 1015 | size_t numoutputs = num_outputs(); 1016 | variable_list grad_inputs(numoutputs); 1017 | if (should_compute_output(0)) { 1018 | grad_inputs[0] = JoinDummies(grads[0], {loopthrough}); 1019 | } 1020 | handle_dummies_helper(1,grad_inputs); 1021 | return grad_inputs; 1022 | } 1023 | 1024 | Tensor JoinDummies(const Tensor& loopthrough, const variable_list& list) 1025 | { 1026 | std::shared_ptr grad_fn; 1027 | if (torch::autograd::compute_requires_grad(list)) { 1028 | grad_fn = std::shared_ptr (new JoinDummiesBackward(), torch::autograd::deleteNode); 1029 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(loopthrough,list)); 1030 | } else { 1031 | // if none of the dummy variables needs a gradient, we just return the loopthrough variable 1032 | return loopthrough; 1033 | } 1034 | auto result = ([&]() { 1035 | at::AutoDispatchBelowADInplaceOrView guard; 1036 | 1037 | auto res = loopthrough.variable_data(); 1038 | return res; 1039 | })(); 1040 | // Checking for grad_fn is unneccessary 1041 | //if (grad_fn) { 1042 | grad_fn->loopthrough = result; 1043 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 1044 | //} 1045 | return result; 1046 | } 1047 | 1048 | enum NonBlockingOp 1049 | { 1050 | Isend_Op, 1051 | Irecv_Op, 1052 | }; 1053 | 1054 | struct MPINonBlockingBackward : public MPIBackwardNode { 1055 | variable_list apply(variable_list&& grads) override; 1056 | std::string name() const override { 1057 | return std::string("MPINonBlockingBackward"); 1058 | } 1059 | }; 1060 | 1061 | variable_list MPINonBlockingBackward::apply(variable_list&& grads) 1062 | { 1063 | variable_list grad_inputs(1); 1064 | // TODO: superfluous check?? 1065 | if (should_compute_output(0)) { 1066 | grad_inputs[0] = comm.MPIWait(grads); 1067 | } 1068 | return grad_inputs; 1069 | } 1070 | 1071 | variable_list MPI_Comm_Wrapper::MPIIsend(const Tensor& input, int64_t dest, int64_t tag) 1072 | { 1073 | // TODO: check for dest and tag being in int's range 1074 | std::shared_ptr grad_fn; 1075 | if (torch::autograd::compute_requires_grad(input)) { 1076 | grad_fn = std::shared_ptr (new MPINonBlockingBackward(), torch::autograd::deleteNode); 1077 | grad_fn->comm = *this; 1078 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 1079 | } 1080 | auto result = ([&]() { 1081 | at::AutoDispatchBelowADInplaceOrView guard; 1082 | 1083 | MPIDeviceHelper devhelper(input); 1084 | 1085 | // make input contiguous 1086 | // we call variable_data() since we also return the input buffer to ensure it stays in scope 1087 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous().variable_data(); 1088 | 1089 | MPI_Request req; 1090 | check_mpi_return_value(MPI_Isend(input_cont.data_ptr(), input_cont.numel(), 1091 | torch2mpitype(input_cont.scalar_type()), static_cast(dest), 1092 | static_cast(tag), comm, &req)); 1093 | 1094 | auto ret = torch::empty({7},at::kDouble); 1095 | auto fortran_handle = MPI_Request_c2f(req); 1096 | ret[0] = static_cast(fortran_handle); 1097 | ret[1] = static_cast(Isend_Op); 1098 | ret[2] = static_cast(dest); 1099 | ret[3] = static_cast(tag); 1100 | ret[4] = static_cast(0xFFFFFFFF & std::hash()(input_cont.data_ptr())); 1101 | ret[5] = static_cast(devhelper.devicetype); 1102 | ret[6] = static_cast(devhelper.device.index()); 1103 | variable_list retlist; 1104 | retlist.push_back(ret); 1105 | retlist.push_back(input_cont); // make sure the buffer stays in scope!!! 1106 | retlist.push_back(input.variable_data()); 1107 | return retlist; 1108 | })(); 1109 | if (grad_fn) { 1110 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 1111 | } 1112 | return result; 1113 | } 1114 | 1115 | variable_list MPI_Comm_Wrapper::MPIIrecv(const Tensor& input, int64_t source, int64_t tag) 1116 | { 1117 | // TODO: check for dest and tag being in int's range 1118 | std::shared_ptr grad_fn; 1119 | if (torch::autograd::compute_requires_grad(input)) { 1120 | grad_fn = std::shared_ptr (new MPINonBlockingBackward(), torch::autograd::deleteNode); 1121 | grad_fn->comm = *this; 1122 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 1123 | } 1124 | auto result = ([&]() { 1125 | at::AutoDispatchBelowADInplaceOrView guard; 1126 | 1127 | MPIDeviceHelper devhelper(input); 1128 | 1129 | // TODO: check whether input is contiguous 1130 | // TODO: Maybe add warning if input device is not mpi device 1131 | auto input_cont = devhelper.fromDeviceToMPI(input).variable_data(); 1132 | 1133 | MPI_Request req; 1134 | check_mpi_return_value(MPI_Irecv(input_cont.data_ptr(), input_cont.numel(), 1135 | torch2mpitype(input_cont.scalar_type()), static_cast(source), 1136 | static_cast(tag), comm, &req)); 1137 | 1138 | auto ret = torch::empty({7},at::kDouble); 1139 | auto fortran_handle = MPI_Request_c2f(req); 1140 | ret[0] = static_cast(fortran_handle); 1141 | ret[1] = static_cast(Irecv_Op); 1142 | ret[2] = static_cast(source); 1143 | ret[3] = static_cast(tag); 1144 | ret[4] = static_cast(0xFFFFFFFF & std::hash()(input_cont.data_ptr())); 1145 | ret[5] = static_cast(devhelper.devicetype); 1146 | ret[6] = static_cast(devhelper.device.index()); 1147 | variable_list retlist; 1148 | retlist.push_back(ret); 1149 | retlist.push_back(input_cont); // We ensure that the buffer stays in scope and is not garbage collected 1150 | retlist.push_back(input.variable_data()); // We do this for symmetry reasons, but it is more ISend which needs this 1151 | return retlist; 1152 | })(); 1153 | if (grad_fn) { 1154 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 1155 | } 1156 | return result; 1157 | } 1158 | 1159 | struct MPIWaitBackward : public MPIBackwardNode { 1160 | MPIWaitBackward(NonBlockingOp op, int64_t sourcedest_, int64_t tag_) 1161 | : operation(op), sourcedest(sourcedest_), tag(tag_+10) 1162 | {} 1163 | variable_list apply(variable_list&& grads) override; 1164 | std::string name() const override { 1165 | return std::string("MPIWaitBackward"); 1166 | } 1167 | 1168 | NonBlockingOp operation; 1169 | int64_t sourcedest; 1170 | int64_t tag; 1171 | }; 1172 | 1173 | variable_list MPIWaitBackward::apply(variable_list&& grads) 1174 | { 1175 | // TODO: maybe we should add some offset to the tag, since who knows how intertwined the forward and backward 1176 | // operations can become. Beware of the special value MPI_ANY_TAG 1177 | 1178 | // TODO: superfluous check?? 1179 | if (should_compute_output(0)) { 1180 | auto next_node = next_edge(1).function; 1181 | auto input_nr = next_edge(1).input_nr; 1182 | 1183 | // Some rationale for this error checking: 1184 | // Everytime when in forward mode a variables is used multiple times, i.e. where the forward graph 1185 | // bifurcates, the backward DAG, as generated by pytorch, has to assume that in the reverse operation 1186 | // the two incoming gradients need to add up. This is however unfortunate for the second part of our 1187 | // wait handle, since this is where we store the send/receive buffer. If one would now add e.g. a zero 1188 | // to this receive buffer a completely new buffer would be created from pytorch, and the original 1189 | // receive buffer would go out of scope, would then be garbage-collected and any actual receive operation 1190 | // potentially yields a seg-fault. On top of that the buffer as returned by MPIWait quite likely 1191 | // contains the wrong result, since MPIWait returns the from pytorch created buffer that stores the 1192 | // result from the addition operation. 1193 | // 1194 | // NOTE: This check cannot capture all cases of misuse. The additional storage of parts/hashes 1195 | // of the respective data_ptr aims at catching the same type of errors. 1196 | if (next_node->name() != "MPINonBlockingBackward") { 1197 | std::ostringstream oss; 1198 | oss << "mpi4torch: Detected bifurcation in MPIWait handle usage. Next node in DAG" 1199 | " should be MPINonBlockingBackward, but is " 1200 | << next_node->name() << "!"; 1201 | throw std::runtime_error(oss.str()); 1202 | } 1203 | switch(operation) { 1204 | case Isend_Op: 1205 | { 1206 | auto buf = next_node->input_metadata(input_nr).zeros_like(); 1207 | return comm.MPIIrecv(JoinDummies(buf,grads), sourcedest, tag); 1208 | } 1209 | case Irecv_Op: 1210 | { 1211 | return comm.MPIIsend(grads[0], sourcedest, tag); 1212 | } 1213 | default: 1214 | throw std::runtime_error("Unsupported NonBlockingOp!"); 1215 | } 1216 | } 1217 | return variable_list(); 1218 | } 1219 | 1220 | Tensor MPI_Comm_Wrapper::MPIWait(const variable_list& input) 1221 | { 1222 | auto fortran_handle = static_cast(input[0][0].item()); 1223 | MPI_Request req = MPI_Request_f2c(fortran_handle); 1224 | NonBlockingOp operation = static_cast(input[0][1].item()); 1225 | auto sourcedest = static_cast(input[0][2].item()); 1226 | auto tag = static_cast(input[0][3].item()); 1227 | auto hashvalue = static_cast(input[0][4].item()); 1228 | auto devicetype = static_cast(input[0][5].item()); 1229 | auto deviceindex = static_cast(input[0][6].item()); 1230 | 1231 | if (hashvalue != (0xFFFFFFFF & std::hash()(input[1].data_ptr()))) { 1232 | std::ostringstream oss; 1233 | oss << "mpi4torch: Detected bifurcation in MPIWait handle usage. " 1234 | "Modifying or consuming the handle by other functions than functions from the " 1235 | "MPIWait class is prohibited!"; 1236 | throw std::runtime_error(oss.str()); 1237 | } 1238 | 1239 | std::shared_ptr grad_fn; 1240 | if(torch::autograd::compute_requires_grad(input)) { 1241 | grad_fn = std::shared_ptr(new MPIWaitBackward(operation, sourcedest, tag),torch::autograd::deleteNode); 1242 | grad_fn->comm = *this; 1243 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input)); 1244 | } 1245 | auto result = ([&]() { 1246 | at::AutoDispatchBelowADInplaceOrView guard; 1247 | 1248 | MPI_Status status; // TODO: Handle use cases for MPI_Status 1249 | check_mpi_return_value(MPI_Wait(&req, & status)); 1250 | 1251 | if (operation == Isend_Op) { 1252 | // We do not do any device conversion for Isend, we then simply return the initial tensor 1253 | return input[2].variable_data(); 1254 | } 1255 | 1256 | MPIDeviceHelper devhelper(c10::Device((c10::DeviceType)devicetype, deviceindex)); 1257 | 1258 | // return a shallow copy of the second input tensor without the autograd strings attached 1259 | return devhelper.fromMPIToDevice(input[1]).variable_data(); 1260 | })(); 1261 | if (grad_fn) { 1262 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn); 1263 | } 1264 | return result; 1265 | } 1266 | 1267 | } 1268 | 1269 | // New-style API is more similar to pybind11, and will in the future be: 1270 | static auto mpi_comm_wrapper_registry = torch::class_<::MPI_Comm_Wrapper>("mpi4torch", "MPI_Comm_Wrapper") 1271 | .def("GetRank", &MPI_Comm_Wrapper::GetRank) 1272 | .def("GetSize", &MPI_Comm_Wrapper::GetSize) 1273 | .def("Allreduce", &MPI_Comm_Wrapper::MPIAllreduce) 1274 | .def("Bcast_", &MPI_Comm_Wrapper::MPIBcast_) 1275 | .def("Reduce_", &MPI_Comm_Wrapper::MPIReduce_) 1276 | .def("Gather", &MPI_Comm_Wrapper::MPIGather) 1277 | .def("Allgather", &MPI_Comm_Wrapper::MPIAllgather) 1278 | .def("Scatter", &MPI_Comm_Wrapper::MPIScatter) 1279 | .def("Alltoall", &MPI_Comm_Wrapper::MPIAlltoall) 1280 | .def("Isend", &MPI_Comm_Wrapper::MPIIsend) 1281 | .def("Irecv", &MPI_Comm_Wrapper::MPIIrecv) 1282 | .def("Wait", &MPI_Comm_Wrapper::MPIWait) 1283 | .def_pickle([](const c10::intrusive_ptr& self) -> std::string 1284 | { 1285 | if (self->comm != MPI_COMM_WORLD) { 1286 | throw std::runtime_error("MPI communicators other than MPI_COMM_WORLD are not serializable!"); 1287 | } 1288 | return std::string("MPI_COMM_WORLD"); 1289 | }, 1290 | [](std::string input) -> c10::intrusive_ptr 1291 | { 1292 | if (input == std::string("MPI_COMM_WORLD")) { 1293 | throw std::runtime_error("Unknown MPI communicator"); 1294 | } 1295 | return c10::make_intrusive(MPI_COMM_WORLD); 1296 | } 1297 | ) 1298 | ; 1299 | 1300 | // Old-style registration API until pytorch 1.4.0 is 1301 | static auto registry = torch::RegisterOperators() 1302 | .op("mpi4torch::COMM_WORLD", &comm_world) 1303 | .op("mpi4torch::comm_from_fortran", &comm_from_fortran) 1304 | .op("mpi4torch::JoinDummies", &JoinDummies); 1305 | 1306 | #if defined(OPEN_MPI) 1307 | #if OMPI_MAJOR_VERSION < 3 1308 | #define _GNU_SOURCE 1309 | #include 1310 | #endif 1311 | #endif 1312 | 1313 | struct __MPI_Finalizer 1314 | { 1315 | ~__MPI_Finalizer() 1316 | { 1317 | check_mpi_return_value(MPI_Finalize()); 1318 | } 1319 | }; 1320 | 1321 | static std::unique_ptr<__MPI_Finalizer> __finalizer; 1322 | 1323 | static void __mpi4torch_mpi_init() 1324 | { 1325 | #if defined(OPEN_MPI) 1326 | #if OMPI_MAJOR_VERSION < 3 1327 | // There are some issues with older versions of OpenMPI which have difficulties with dlopen-ing the MPI 1328 | // libraries the way it is done within pythons extension system (i.e. with RTLD_LOCAL) [1]. 1329 | // In detail: 1330 | // - cpython loads this extension as a shared libarry with RTLD_LOCAL 1331 | // - this implies that several openmpi functions are not globally visible 1332 | // - during initalization openmpi itself dynamically loads shared objects, 1333 | // which backreference to some of the previously loaded (but not visible) symbols 1334 | // - openmpi crashes during the MPI_Init* calls since these backreferenced symbols could 1335 | // not be resolved 1336 | // 1337 | // [1] https://github.com/open-mpi/ompi/issues/3705 1338 | 1339 | // In principle we just want to reload libmpi with RTLD_GLOBAL specified before calling MPI_Init_thread 1340 | // to make the respective symbols visible. 1341 | // But we first need to find out by which path name the library goes we try to reload. 1342 | 1343 | // To do so, we query for the address of a known symbol in libmpi, e.g. MPI_Init_thread 1344 | const void* mpi_init_thread_symbol = dlsym(RTLD_DEFAULT, "MPI_Init_thread"); 1345 | 1346 | if (mpi_init_thread_symbol == nullptr) { 1347 | // in principle we should never be able to reach this point, but better safe than sorry 1348 | throw std::runtime_error(std::string("mpi4torch failed with: ")+std::string(dlerror())); 1349 | } 1350 | 1351 | // Now we query the library info in which the symbol resides 1352 | Dl_info dlinfo; 1353 | if (!dladdr(mpi_init_thread_symbol, &dlinfo)) { 1354 | throw std::runtime_error(std::string("mpi4torch failed with: ")+std::string(dlerror())); 1355 | } 1356 | 1357 | // dlinfo.dli_fname should contain the pathname of the mpi library 1358 | 1359 | // As the man page of dlopen suggests, to promote the already loaded symbols from RTLD_LOCAL to RTLD_GLOBAL, 1360 | // we just need to reopen the library with RTLD_NOLOAD | RTLD_GLOBAL flags. 1361 | const void* mpi_lib_handle = dlopen(dlinfo.dli_fname, RTLD_NOW | RTLD_NOLOAD | RTLD_GLOBAL); 1362 | if (!mpi_lib_handle) { 1363 | throw std::runtime_error(std::string("mpi4torch failed with: ")+std::string(dlerror())); 1364 | } 1365 | #endif 1366 | #endif 1367 | 1368 | int mpi_initialized = -1; 1369 | check_mpi_return_value(MPI_Initialized(&mpi_initialized)); 1370 | if (!mpi_initialized) { 1371 | int provided = 0; 1372 | // We play safe and initialize for multithreaded use cases 1373 | check_mpi_return_value(MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided)); 1374 | 1375 | // TODO: maybe in the future we actually may depend on multithreading, 1376 | // so we leave this code here 1377 | #if 0 1378 | if (provided != MPI_THREAD_MULTIPLE) { 1379 | std::cerr << "mpi4torch WARNING: MPI version does not provide full multithreading support!\n"; 1380 | std::cerr << "mpi4torch WARNING: This may crash mpi4torch.\n"; 1381 | } 1382 | #endif 1383 | __finalizer = std::make_unique<__MPI_Finalizer>(); 1384 | } else { 1385 | #if 0 1386 | int provided = 0; 1387 | check_mpi_return_value(MPI_Query_thread(&provided)); 1388 | if (provided != MPI_THREAD_MULTIPLE) { 1389 | std::cerr << "mpi4torch WARNING: MPI is already initialized but not with MPI_THREAD_MULTIPLE!\n"; 1390 | std::cerr << "mpi4torch WARNING: This may crash mpi4torch.\n"; 1391 | } 1392 | #endif 1393 | } 1394 | } 1395 | 1396 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 1397 | // Initialize MPI 1398 | __mpi4torch_mpi_init(); 1399 | 1400 | #if MPI4TORCH_BUILT_WITH_CUDA_AWARENESS 1401 | __setup_have_cuda_aware_mpi_support(); 1402 | #endif 1403 | 1404 | m.def("deactivate_cuda_aware_mpi_support",&deactivate_cuda_aware_mpi_support, 1405 | "Deactivates the CUDA-aware MPI support.\n" 1406 | "\n" 1407 | "Calling this function forces mpi4torch to first move any tensor into main memory before\n" 1408 | "calling a MPI function on it, and then to move the result back into device memory after\n" 1409 | "the MPI call has finished.\n" 1410 | "\n" 1411 | "Note\n" 1412 | "----\n" 1413 | " This function is useful in situations in which MPI advertises CUDA-awareness but the\n" 1414 | " functionality is not really supported by the backend.\n"); 1415 | 1416 | // Torchscript does not like the pybind11 enum_ solution 1417 | //py::enum_(m, "MPI_Op") 1418 | // .value("MPI_MAX", Mpi4torchCollectiveOps::mpi4torch_op_max) 1419 | // .value("MPI_MIN", Mpi4torchCollectiveOps::mpi4torch_op_min) 1420 | // .value("MPI_SUM", Mpi4torchCollectiveOps::mpi4torch_op_sum) 1421 | // .value("MPI_PROD", Mpi4torchCollectiveOps::mpi4torch_op_prod) 1422 | // .export_values(); 1423 | 1424 | m.attr("MPI_MAX") = py::int_((int64_t)mpi4torch_op_max); 1425 | m.attr("MPI_MIN") = py::int_((int64_t)mpi4torch_op_min); 1426 | m.attr("MPI_SUM") = py::int_((int64_t)mpi4torch_op_sum); 1427 | m.attr("MPI_PROD") = py::int_((int64_t)mpi4torch_op_prod); 1428 | m.attr("MPI_LAND") = py::int_((int64_t)mpi4torch_op_land); 1429 | m.attr("MPI_BAND") = py::int_((int64_t)mpi4torch_op_band); 1430 | m.attr("MPI_LOR") = py::int_((int64_t)mpi4torch_op_lor); 1431 | m.attr("MPI_BOR") = py::int_((int64_t)mpi4torch_op_bor); 1432 | m.attr("MPI_LXOR") = py::int_((int64_t)mpi4torch_op_lxor); 1433 | m.attr("MPI_BXOR") = py::int_((int64_t)mpi4torch_op_bxor); 1434 | m.attr("MPI_MINLOC") = py::int_((int64_t)mpi4torch_op_minloc); 1435 | m.attr("MPI_MAXLOC") = py::int_((int64_t)mpi4torch_op_maxloc); 1436 | } 1437 | 1438 | --------------------------------------------------------------------------------