├── doc
├── _templates
│ └── .dummy
├── _static
│ └── img
│ │ ├── mpi4torch-logo.png
│ │ ├── mpi4torch-logo-extrawhitespace.png
│ │ ├── mpi4torch-logo-extrawhitespace.svg
│ │ └── mpi4torch-logo.svg
├── glossary.rst
├── api_reference.rst
├── Makefile
├── index.rst
├── conf.py
├── examples.rst
└── basic_usage.rst
├── version.txt
├── requirements.txt
├── .gitignore
├── MANIFEST.in
├── pyproject.toml
├── examples
├── isend-recv-wait.py
└── simple_linear_regression.py
├── .readthedocs.yaml
├── tests
├── test_mpi4pyinterop.py
├── test_joindummies.py
├── test_nonblocking.py
└── test_collectives.py
├── LICENSE
├── .github
└── workflows
│ └── test.yml
├── README.md
├── setup.py
├── src
└── __init__.py
└── csrc
└── extension.cpp
/doc/_templates/.dummy:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/version.txt:
--------------------------------------------------------------------------------
1 | 0.1.3
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | dist/
3 | *.egg-info/
4 | *.egg
5 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include version.txt
2 | include pyproject.toml
3 | include doc/_static/img/mpi4torch-logo-extrawhitespace.png
4 |
--------------------------------------------------------------------------------
/doc/_static/img/mpi4torch-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/helmholtz-analytics/mpi4torch/HEAD/doc/_static/img/mpi4torch-logo.png
--------------------------------------------------------------------------------
/doc/_static/img/mpi4torch-logo-extrawhitespace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/helmholtz-analytics/mpi4torch/HEAD/doc/_static/img/mpi4torch-logo-extrawhitespace.png
--------------------------------------------------------------------------------
/doc/glossary.rst:
--------------------------------------------------------------------------------
1 | *******************
2 | Glossary
3 | *******************
4 |
5 | .. glossary::
6 |
7 | AD
8 | automatic differentiation
9 |
10 | DAG
11 | directed acyclic graph
12 |
13 | MPI
14 | Message Passing Interface
15 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | # PEP 517 and PEP 518 build requirements
3 | # cf. https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
4 |
5 | # we need torch as a build requirement, since we explicitly import it within setup.py
6 | requires = [
7 | "torch>=1.9.0",
8 | 'importlib-metadata;python_version<"3.8"', #PEP 508-style dependency: Only install for cpython<3.8
9 | "setuptools>=40.8.0",
10 | "wheel"
11 | ]
12 | build-backend = "setuptools.build_meta"
13 |
14 |
--------------------------------------------------------------------------------
/examples/isend-recv-wait.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mpi4torch
3 |
4 | comm = mpi4torch.COMM_WORLD
5 |
6 | a = torch.tensor([1.0 + comm.rank]).requires_grad_()
7 |
8 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0)
9 | recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy])
10 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0)
11 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b]))
12 |
13 | res = mpi4torch.JoinDummies(a+b, [wait_ret])
14 | print(res)
15 |
16 | res.backward()
17 | print(a.grad)
18 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | os: ubuntu-20.04
10 | tools:
11 | python: "3.10"
12 | apt_packages:
13 | - libopenmpi-dev
14 | - openmpi-bin
15 | - graphviz
16 |
17 | # Build documentation in the doc/ directory with Sphinx
18 | sphinx:
19 | configuration: doc/conf.py
20 |
21 | # If using Sphinx, optionally build your docs in additional formats such as PDF
22 | # formats:
23 | # - pdf
24 |
25 | # Python requirements to build the docs
26 | python:
27 | install:
28 | - method: pip
29 | path: .
30 |
31 |
--------------------------------------------------------------------------------
/doc/api_reference.rst:
--------------------------------------------------------------------------------
1 | ********************
2 | API Reference
3 | ********************
4 |
5 | .. automodule:: mpi4torch
6 | :members:
7 | :undoc-members:
8 |
9 | .. autofunction:: JoinDummies(loopthrough: torch.Tensor, dummies:List[torch.Tensor]) -> torch.Tensor
10 | .. autofunction:: JoinDummiesHandle(handle: mpi4torch.WaitHandle, dummies:List[torch.Tensor]) -> mpi4torch.WaitHandle
11 | .. data:: MPI_MAX
12 | .. data:: MPI_MIN
13 | .. data:: MPI_SUM
14 | .. data:: MPI_PROD
15 | .. data:: MPI_LAND
16 | .. data:: MPI_BAND
17 | .. data:: MPI_LOR
18 | .. data:: MPI_BOR
19 | .. data:: MPI_LXOR
20 | .. data:: MPI_BXOR
21 | .. data:: MPI_MINLOC
22 | .. data:: MPI_MAXLOC
23 |
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/tests/test_mpi4pyinterop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mpi4torch
3 | import unittest
4 | import mpi4py.MPI as MPI
5 |
6 | class TestMpi4PyInteroperability(unittest.TestCase):
7 | def test_rank_and_size(self):
8 | comm1 = MPI.COMM_WORLD
9 | comm2 = mpi4torch.comm_from_mpi4py(comm1)
10 | self.assertEqual(comm1.rank, comm2.rank)
11 | self.assertEqual(comm1.size, comm2.size)
12 |
13 | def test_simple_allreduce(self):
14 | comm1 = MPI.COMM_WORLD
15 | comm2 = mpi4torch.comm_from_mpi4py(comm1)
16 | tmp = torch.rand(10, dtype=torch.double).requires_grad_()
17 | res = comm2.Allreduce(tmp,mpi4torch.MPI_SUM)
18 | res.sum().backward()
19 | self.assertTrue((tmp.grad == comm2.size * torch.ones(10, dtype=torch.double)).all())
20 |
21 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | .. image:: _static/img/mpi4torch-logo-extrawhitespace.svg
2 |
3 | mpi4torch is an automatic-differentiable wrapper of MPI functions for the pytorch tensor library.
4 |
5 | MPI stands for Message Passing Interface and is the de facto standard communication interface on
6 | high-performance computing resources. To facilitate the usage of pytorch on these resources an MPI wrapper
7 | that is transparent to pytorch's automatic differentiation (AD) engine is much in need. This library tries
8 | to bridge this gap.
9 |
10 | .. toctree::
11 | :maxdepth: 3
12 | :caption: Table of Contents
13 |
14 | basic_usage
15 | examples
16 | api_reference
17 | glossary
18 |
19 | Indices and tables
20 | ==================
21 |
22 | * :ref:`genindex`
23 | * :ref:`modindex`
24 | * :ref:`search`
25 |
--------------------------------------------------------------------------------
/tests/test_joindummies.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mpi4torch
3 | import unittest
4 |
5 | comm = mpi4torch.COMM_WORLD
6 |
7 | class TestJoinDummies(unittest.TestCase):
8 | def test_simple_allreduce(self):
9 | tmp = torch.rand(10, dtype=torch.double).requires_grad_()
10 | tmp2 = torch.rand(10, dtype=torch.double).requires_grad_()
11 | tmp3 = torch.rand(10, dtype=torch.double).requires_grad_()
12 | res = comm.Allreduce(tmp,mpi4torch.MPI_SUM)
13 | res2 = mpi4torch.JoinDummies(res,[tmp2,tmp3])
14 | res2.sum().backward()
15 | self.assertTrue((tmp2.grad == torch.zeros(10, dtype=torch.double)).all())
16 | self.assertTrue((tmp3.grad == torch.zeros(10, dtype=torch.double)).all())
17 | self.assertTrue((tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all())
18 |
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2020-2022 Philipp Knechtges and contributors
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4 | associated documentation files (the "Software"), to deal in the Software without restriction, including
5 | without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
6 | copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the
7 | following conditions:
8 |
9 | The above copyright notice and this permission notice shall be included in all copies or substantial
10 | portions of the Software.
11 |
12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
13 | LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
14 | NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
15 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
16 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
--------------------------------------------------------------------------------
/examples/simple_linear_regression.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mpi4torch
3 | import mpi4py.MPI
4 |
5 | comm = mpi4torch.COMM_WORLD
6 |
7 | torch.manual_seed(42)
8 |
9 | num_points = 10000
10 | chunk_size = num_points // comm.size
11 | rest = num_points % comm.size
12 | if comm.rank < rest:
13 | chunk_size += 1
14 | offset = chunk_size * comm.rank
15 | else:
16 | offset = chunk_size * comm.rank + rest
17 |
18 | xinput = 2.0 * torch.rand([num_points],dtype=torch.double)[offset:offset+chunk_size]
19 |
20 | def some_parametrized_function(inp, params):
21 | return (params[2] * inp + params[1]) * inp + params[0]
22 |
23 | gen_params = torch.tensor([0.1, 1.0, -2.0])
24 |
25 | youtput = some_parametrized_function(xinput, gen_params)
26 |
27 | def lossfunction(params):
28 | # average initial params to bring all ranks on the same page
29 | params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size
30 |
31 | # compute local loss
32 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
33 |
34 | # sum up the loss among all ranks
35 | return comm.Allreduce(localloss, mpi4torch.MPI_SUM)
36 |
37 | params = torch.arange(3, dtype=torch.double).requires_grad_()
38 |
39 | # LBFGS only needs one outer iteration for a linear problem
40 | # with so few parameters
41 | num_iterations = 1
42 | optimizer = torch.optim.LBFGS([params], 1)
43 |
44 | for i in range(num_iterations):
45 | def closure():
46 | loss = lossfunction(params)
47 | optimizer.zero_grad()
48 | loss.backward()
49 | if comm.rank == 0:
50 | print("Params: ", params)
51 | print("Loss : ", loss)
52 | return loss
53 | optimizer.step(closure)
54 |
55 | # only print output on rank 0
56 | if comm.rank == 0:
57 | print("Final parameters: ", params)
58 |
--------------------------------------------------------------------------------
/tests/test_nonblocking.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mpi4torch
3 | import unittest
4 |
5 | comm = mpi4torch.COMM_WORLD
6 |
7 | class TestNonBlocking(unittest.TestCase):
8 | def test_simple_isendirecv(self):
9 | tmp = torch.rand(10000000, dtype=torch.double).requires_grad_()
10 | req = comm.Isend(tmp,(comm.rank+1)%comm.size,0)
11 | req2 = comm.Irecv(mpi4torch.JoinDummies(torch.empty_like(tmp),[req.dummy]),(comm.rank+comm.size-1)%comm.size,0)
12 | res = comm.Wait(mpi4torch.JoinDummiesHandle(req,[req2.dummy]))
13 | res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req2,[res]))
14 | res3 = res2 * comm.rank
15 | res3.sum().backward()
16 | self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all())
17 |
18 | def test_simple_isendrecv(self):
19 | tmp = torch.rand(10000000, dtype=torch.double).requires_grad_()
20 | req = comm.Isend(tmp,(comm.rank+1)%comm.size,0)
21 | res = comm.Recv(mpi4torch.JoinDummies(torch.empty_like(tmp),[req.dummy]),(comm.rank+comm.size-1)%comm.size,0)
22 | res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req,[res]))
23 | res3 = mpi4torch.JoinDummies(res,[res2]) * comm.rank
24 | res3.sum().backward()
25 | self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all())
26 |
27 | def test_simple_irecvsend(self):
28 | tmp = torch.rand(10000000, dtype=torch.double).requires_grad_()
29 | req = comm.Irecv(mpi4torch.JoinDummies(torch.empty_like(tmp),[tmp]),(comm.rank+comm.size-1)%comm.size,0)
30 | res = comm.Send(tmp,(comm.rank+1)%comm.size,0)
31 | res2 = comm.Wait(mpi4torch.JoinDummiesHandle(req,[res]))
32 | res3 = res2 * comm.rank
33 | res3.sum().backward()
34 | self.assertTrue((tmp.grad == ((comm.rank + 1 )%comm.size) * torch.ones_like(tmp)).all())
35 |
36 |
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 | import sphinx_rtd_theme
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = 'mpi4torch'
22 | copyright = '2020, Philipp Knechtges'
23 | author = 'Philipp Knechtges'
24 |
25 | # The full version, including alpha/beta/rc tags
26 | with open(os.path.join(os.path.dirname(__file__), '../version.txt'), encoding='utf-8') as filehandle:
27 | release = filehandle.read()
28 |
29 |
30 | # -- General configuration ---------------------------------------------------
31 |
32 | # Add any Sphinx extension module names here, as strings. They can be
33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
34 | # ones.
35 | extensions = [
36 | 'sphinx.ext.autodoc',
37 | 'sphinx.ext.intersphinx',
38 | 'sphinx_rtd_theme',
39 | 'sphinx.ext.napoleon',
40 | 'sphinx.ext.graphviz',
41 | ]
42 |
43 | # Add any paths that contain templates here, relative to this directory.
44 | templates_path = ['_templates']
45 |
46 | # List of patterns, relative to source directory, that match files and
47 | # directories to ignore when looking for source files.
48 | # This pattern also affects html_static_path and html_extra_path.
49 | exclude_patterns = []
50 |
51 |
52 | # -- Options for HTML output -------------------------------------------------
53 |
54 | # The theme to use for HTML and HTML Help pages. See the documentation for
55 | # a list of builtin themes.
56 | #
57 | html_theme = 'sphinx_rtd_theme'
58 |
59 | # Add any paths that contain custom static files (such as style sheets) here,
60 | # relative to this directory. They are copied after the builtin static files,
61 | # so a file named "default.css" will overwrite the builtin "default.css".
62 | html_static_path = ['_static']
63 |
64 |
65 | # -- Extension configuration -------------------------------------------------
66 |
67 | # -- Options for intersphinx extension ---------------------------------------
68 |
69 | # Example configuration for intersphinx: refer to the Python standard library.
70 | intersphinx_mapping = {'https://docs.python.org/3/': None,
71 | 'https://pytorch.org/docs/stable/': None,
72 | }
73 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | pull_request:
6 |
7 | jobs:
8 | tests:
9 | name: ${{ matrix.os }} cpy${{ matrix.python }} pytorch-${{ matrix.pytorch }} ${{ matrix.mpi }}
10 | runs-on: ${{ matrix.os }}
11 | strategy:
12 | fail-fast: false
13 | matrix:
14 | #os: [ "ubuntu-20.04", "ubuntu-22.04" ]
15 | os: [ "ubuntu-20.04" ]
16 | python: [ "3.7", "3.8", "3.9", "3.10", "3.11" ]
17 | pytorch: [ "1.9.1", "1.10.2", "1.11.0", "1.12.1", "1.13.1", "2.0.0" ]
18 | mpi: [ "openmpi", "mpich" ]
19 | exclude:
20 | - python: 3.7
21 | pytorch: 2.0.0
22 | - python: 3.10
23 | pytorch: 1.8.1
24 | - python: 3.10
25 | pytorch: 1.9.1
26 | - python: 3.10
27 | pytorch: 1.10.2
28 | - python: 3.11
29 | pytorch: 1.8.1
30 | - python: 3.11
31 | pytorch: 1.9.1
32 | - python: 3.11
33 | pytorch: 1.10.2
34 | - python: 3.11
35 | pytorch: 1.11.0
36 | - python: 3.11
37 | pytorch: 1.12.1
38 | steps:
39 | - name: Checkout
40 | uses: actions/checkout@v3
41 | - name: Install CPython ${{ matrix.python }}
42 | uses: actions/setup-python@v4
43 | with:
44 | python-version: "${{ matrix.python }}"
45 | architecture: x64
46 | - name: Install MPI ${{ matrix.mpi }}
47 | run: |
48 | if [[ "${{ matrix.mpi }}" == "openmpi" ]]; then
49 | sudo apt install libopenmpi-dev openmpi-bin
50 | elif [[ "${{ matrix.mpi }}" == "mpich" ]]; then
51 | sudo apt install mpich
52 | else
53 | exit 1
54 | fi
55 | mpirun --version
56 | - name: Setup virtual environment
57 | run: |
58 | python -m venv venv
59 | - name: Install mpi4torch
60 | run: |
61 | . venv/bin/activate
62 | echo "torch==${{ matrix.pytorch }}" >> constraints.txt
63 | PIP_CONSTRAINT="constraints.txt" pip install -v . nose2 mpi4py
64 | - name: Run Tests (np=2)
65 | run: |
66 | . venv/bin/activate
67 | #mpirun -np 2 python -c 'from mpi4torch import COMM_WORLD; print("Communicator size:", COMM_WORLD.size) if COMM_WORLD.rank == 0 else exit(0)'
68 | mpirun -np 2 nose2
69 | - name: Run Tests (np=5)
70 | run: |
71 | . venv/bin/activate
72 | if [[ "${{ matrix.mpi }}" == "openmpi" ]]; then
73 | mpirun -np 5 --oversubscribe nose2
74 | else
75 | mpirun -np 5 nose2
76 | fi
77 | - name: Run Tests (np=7)
78 | run: |
79 | . venv/bin/activate
80 | if [[ "${{ matrix.mpi }}" == "openmpi" ]]; then
81 | mpirun -np 7 --oversubscribe nose2
82 | else
83 | mpirun -np 7 nose2
84 | fi
85 |
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | --------------------------------------------------------------------------------
4 |
5 | mpi4torch is an automatic-differentiable wrapper of MPI functions for the pytorch tensor library.
6 |
7 | MPI stands for Message Passing Interface and is the de facto standard communication interface on
8 | high-performance computing resources. To facilitate the usage of pytorch on these resources an MPI wrapper
9 | that is transparent to pytorch's automatic differentiation (AD) engine is much in need. This library tries
10 | to bridge this gap.
11 |
12 | # Installation
13 |
14 | mpi4torch is also hosted on PyPI. However, due to the ABI-incompatibility of the different MPI implementations it
15 | is not provided as a binary wheel and needs to be built locally. Hence, you should have an appropriate C++ compiler
16 | installed, as well as the **development files of your MPI library** be present. The latter are usually provided
17 | through the *module system* of your local cluster, and you should consult the manuals of your cluster for this,
18 | or through the package manager of your Linux distribution.
19 |
20 | Once the dependencies have been satisfied the installation can be triggered by the usual
21 | ```
22 | pip install mpi4torch
23 | ```
24 |
25 | # Usage
26 |
27 | It is **highly advised** to first read [the basic usage chapter of the documentation](https://mpi4torch.readthedocs.io/en/latest/basic_usage.html)
28 | before jumping into action, since there are some implications of the pytorch AD design on the usage of mpi4torch.
29 | In other words, there are some footguns lurking!
30 |
31 | You have been warned, but if you insist on an easy usage example, consider the following code snippet,
32 | which is an excerpt from [examples/simple_linear_regression.py](examples/simple_linear_regression.py)
33 |
34 | ```python
35 | comm = mpi4torch.COMM_WORLD
36 |
37 | def lossfunction(params):
38 | # average initial params to bring all ranks on the same page
39 | params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size
40 |
41 | # compute local loss
42 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
43 |
44 | # sum up the loss among all ranks
45 | return comm.Allreduce(localloss, mpi4torch.MPI_SUM)
46 | ```
47 |
48 | Here we have parallelized a loss function simply by adding two calls to `Allreduce`. For a more thorough
49 | discussion of the example see [here](https://mpi4torch.readthedocs.io/en/latest/examples.html#simple-data-parallel-example).
50 |
51 | # Tests
52 |
53 | Running tests is as easy as
54 | ```
55 | mpirun -np 2 nose2
56 | ```
57 |
58 | # Project Status
59 |
60 | [](https://github.com/helmholtz-analytics/mpi4torch/actions/workflows/test.yml)
61 | [](https://mpi4torch.readthedocs.io/en/latest/?badge=latest)
62 |
--------------------------------------------------------------------------------
/doc/examples.rst:
--------------------------------------------------------------------------------
1 | *******************
2 | Examples
3 | *******************
4 |
5 | Simple data parallel example
6 | ============================
7 |
8 | Let us assume we are in typical supervised learning situation. We have plenty of data (``xinput``, ``youtput``),
9 | and we search for unknown parameters minimizing some norm or general functional, simply referred to as the
10 | loss function. Furthermore, we assume that the
11 | loss function is just a summation of losses per data point. E.g. consider the following squared error:
12 |
13 | .. code-block:: python
14 |
15 | def lossfunction(params):
16 | # compute local loss
17 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
18 | return localloss
19 |
20 | This function is usually fed into an gradient-based optimizer to find the optimal parameters.
21 | We want to argue in the following that parallelizing this code in a data-parallel way is often as easy as
22 | adding two calls to :py:meth:`mpi4torch.MPI_Communicator.Allreduce`:
23 |
24 | .. code-block:: python
25 | :emphasize-lines: 3,9
26 |
27 | def lossfunction(params):
28 | # average initial params to bring all ranks on the same page
29 | params = comm.Allreduce(params, mpi4torch.MPI_SUM) / comm.size
30 |
31 | # compute local loss
32 | localloss = torch.sum(torch.square(youtput - some_parametrized_function(xinput, params)))
33 |
34 | # sum up the loss among all ranks
35 | return comm.Allreduce(localloss, mpi4torch.MPI_SUM)
36 |
37 | :py:meth:`mpi4torch.MPI_Communicator.Allreduce` is used once to compute the average of the incoming parameters
38 | and once to collect the total loss.
39 |
40 | Embedded in a whole program this may look like (the code is also available in the git repository
41 | in the examples folder):
42 |
43 | .. literalinclude:: ../examples/simple_linear_regression.py
44 | :linenos:
45 |
46 | Note that although the averaging in line 29 might seem superfluous at first --- since all ranks start
47 | off with the same initial set of parameters --- having the adjoint of :py:meth:`mpi4torch.MPI_Communicator.Allreduce`
48 | in the backward pass
49 | is essential for all instances of the LBFGS optimizer to perform the same update on all ranks.
50 |
51 | For the second call to :py:meth:`mpi4torch.MPI_Communicator.Allreduce` in line 35 it is actually the other way
52 | around: Here the forward pass is crucial, but the backward pass merely adds up the ones coming from the
53 | different ranks, which (surprise) results in a vector of length 1 that just contains the communicator size.
54 |
55 | It is easy to see that the forward pass is indpendent of the number of ranks used to compute the result. That
56 | the parallelized backward pass also gives the same result may at first seem a bit surprising,
57 | as we already saw that the gradient with respect to ``localloss`` will just store the size
58 | of the MPI communicator. However,
59 | the corresponding backward code of the averaging in line 29 divides again through ``comm.size``, such that
60 | in total all gradients from all ranks are simply added up. The final
61 | gradient as stored in ``params.grad`` is thus also independent of the number of processes.
62 |
63 | Starting off with the same
64 | parameters on all ranks, it is thus ensured that all local LBFGS instances see the same parameters, the same losses
65 | and the same gradients, and thus perform the identical operations and give the same result.
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 |
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import CppExtension, BuildExtension
4 | import torch
5 | import copy
6 | import os
7 | import sys
8 |
9 | try:
10 | # importlib only got added in cpython 3.8
11 | if sys.version_info >= (3, 8):
12 | from importlib import metadata as importlib_metadata
13 | else:
14 | # cf. pyproject.toml file, which makes pip install importlib-metadata if necessary
15 | import importlib_metadata
16 |
17 | torchversion = importlib_metadata.distribution("torch").version
18 | except:
19 | # Fallback, but this should never happen.
20 | torchversion = torch.__version__.split("+")[0]
21 |
22 | class MpiBuildExtension(BuildExtension):
23 | def __init__(self, *args,**kwargs):
24 | super(MpiBuildExtension,self).__init__(*args,use_ninja=False,**kwargs)
25 |
26 | def build_extensions(self):
27 | """
28 | This code makes a lot assumptions on distutils internal implementation of
29 | UnixCCompiler class. However, it seems to be standard to make these assumptions,
30 | as PyTorch and mpi4py also make these assumptions.
31 |
32 | TODO: Obviously this only works for unix systems
33 | """
34 |
35 | # Save original compiler and reset it later on
36 | original_compiler = self.compiler.compiler_so
37 | new_compiler = copy.deepcopy(original_compiler)
38 | new_compiler[0] = 'mpicc'
39 | # Save original CXX compiler and reset it later on
40 |
41 | # distutils' UnixCCompiler likes to use the C++ compiler for linking, so we set it manually
42 | original_cxx_compiler = self.compiler.compiler_cxx
43 | new_cxx_compiler = copy.deepcopy(original_cxx_compiler)
44 | new_cxx_compiler[0] = 'mpicxx'
45 | # Save original linker and reset it later on
46 | # should not be used, but we set it anyway
47 | original_linker = self.compiler.linker_so
48 | new_linker = copy.deepcopy(original_linker)
49 | new_linker[0] = 'mpicc'
50 | try:
51 | self.compiler.set_executable('compiler_so', new_compiler)
52 | self.compiler.set_executable('compiler_cxx', new_cxx_compiler)
53 | self.compiler.set_executable('linker_so', new_linker)
54 | BuildExtension.build_extensions(self)
55 | finally:
56 | self.compiler.set_executable('compiler_so', original_compiler)
57 | self.compiler.set_executable('compiler_cxx', original_cxx_compiler)
58 | self.compiler.set_executable('linker_so', original_linker)
59 |
60 | with open(os.path.join(os.path.dirname(__file__), 'README.md'), encoding='utf-8') as filehandle:
61 | long_description = filehandle.read()
62 |
63 | with open(os.path.join(os.path.dirname(__file__), 'version.txt'), encoding='utf-8') as filehandle:
64 | versiontext = filehandle.read().rstrip()
65 |
66 | setup(
67 | name='mpi4torch',
68 | version=versiontext,
69 | description='AD-compatible implementation of several MPI functions for pytorch tensors',
70 | author='Philipp Knechtges',
71 | author_email='philipp.knechtges@dlr.de',
72 | long_description=long_description,
73 | long_description_content_type='text/markdown',
74 | classifiers=[
75 | "Development Status :: 3 - Alpha",
76 | "Programming Language :: Python :: 3",
77 | "Programming Language :: Python :: 3.7",
78 | "Programming Language :: Python :: 3.8",
79 | "Programming Language :: Python :: 3.9",
80 | "Programming Language :: Python :: 3.10",
81 | "Programming Language :: Python :: 3.11",
82 | "Programming Language :: C++",
83 | "License :: OSI Approved :: MIT License",
84 | "Operating System :: POSIX :: Linux",
85 | "Intended Audience :: Science/Research",
86 | "Topic :: Scientific/Engineering",
87 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
88 | "Topic :: Software Development :: Libraries :: Python Modules"
89 | ],
90 | package_dir = {'mpi4torch': 'src'},
91 | packages = ['mpi4torch'],
92 | ext_modules=[
93 | CppExtension(
94 | name='mpi4torch._mpi',
95 | sources=['csrc/extension.cpp'],
96 | extra_compile_args=['-g']),
97 | ],
98 | cmdclass={
99 | 'build_ext': MpiBuildExtension
100 | },
101 | install_requires=[
102 | # Pin the required pytorch version of the final binary wheels
103 | # to the pytorch version used at build-time. This way we
104 | # avoid possible ABI-incompatibilities.
105 | 'torch==' + torchversion,
106 | ]
107 | )
108 |
--------------------------------------------------------------------------------
/doc/_static/img/mpi4torch-logo-extrawhitespace.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/doc/_static/img/mpi4torch-logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/test_collectives.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mpi4torch
3 | import unittest
4 |
5 | comm = mpi4torch.COMM_WORLD
6 |
7 | class TestAllreduce(unittest.TestCase):
8 | def test_simple(self):
9 | tmp = torch.rand(10, dtype=torch.double).requires_grad_()
10 | res = comm.Allreduce(tmp,mpi4torch.MPI_SUM)
11 | res.sum().backward()
12 | self.assertTrue((tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all())
13 |
14 | def test_torchscript(self):
15 | tmp = torch.rand(10, dtype=torch.double).requires_grad_()
16 | @torch.jit.script
17 | def myfunc(x,comm_: mpi4torch.MPI_Communicator):
18 | return comm_.Allreduce(x,mpi4torch.MPI_SUM)
19 | res = myfunc(tmp,comm)
20 | res.sum().backward()
21 | self.assertTrue((tmp.grad == comm.size * torch.ones(10, dtype=torch.double)).all())
22 |
23 | class TestReduce(unittest.TestCase):
24 | def test_simple_inplace(self):
25 | tmp = torch.rand(10, dtype=torch.double).requires_grad_()
26 | res = comm.Reduce_(tmp,mpi4torch.MPI_SUM,0)
27 | res.sum().backward()
28 | self.assertTrue((tmp.grad == torch.ones(10,dtype=torch.double)).all())
29 |
30 | def test_noinplace_exception(self):
31 | # the 0. addition is just to make the resulting tmp variable not a leaf in the DAG,
32 | # sine we currently dont see a way to add this safeguard for leaf nodes.
33 | tmp = 0. + torch.rand(10, dtype=torch.double).requires_grad_()
34 | res = tmp + comm.Reduce_(tmp,mpi4torch.MPI_SUM,0)
35 | with self.assertRaises(RuntimeError):
36 | res.sum().backward()
37 |
38 | class TestBcast(unittest.TestCase):
39 | def test_simple_inplace(self):
40 | tmp = torch.rand(10, dtype=torch.double).requires_grad_()
41 | res = comm.Bcast_(tmp,0)
42 | res.sum().backward()
43 | if comm.rank == 0:
44 | self.assertTrue((tmp.grad == comm.size * torch.ones(10,dtype=torch.double)).all())
45 | else:
46 | self.assertTrue((tmp.grad == torch.zeros(10,dtype=torch.double)).all())
47 |
48 | class TestGather(unittest.TestCase):
49 | def test_basic_functionality(self):
50 | numdim = 4
51 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double)
52 | tmp[0,0,:,0,0] = comm.rank
53 | res = comm.Gather(tmp, 2, 0)
54 | if comm.rank == 0:
55 | tmp2 = torch.squeeze(torch.sum(res[0,0,:,0,0]))
56 | self.assertTrue((tmp2 == numdim * (comm.size - 1) * comm.size // 2).all())
57 |
58 | def test_basic_ad(self):
59 | numdim = 4
60 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double).requires_grad_()
61 | res = comm.Gather(tmp, 2, 0)
62 | res.sum().backward()
63 | self.assertTrue((tmp.grad == torch.ones_like(tmp)).all())
64 |
65 | class TestAllgather(unittest.TestCase):
66 | def test_basic_functionality(self):
67 | numdim = 4
68 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double)
69 | tmp[0,0,:,0,0] = comm.rank
70 | res = comm.Allgather(tmp, 2)
71 | tmp2 = torch.squeeze(torch.sum(res[0,0,:,0,0]))
72 | self.assertTrue((tmp2 == numdim * (comm.size - 1) * comm.size // 2).all())
73 |
74 | def test_basic_ad(self):
75 | numdim = 4
76 | tmp = torch.rand([2,5,numdim,2,3],dtype=torch.double).requires_grad_()
77 | res = comm.Allgather(tmp, 2)
78 | res.sum().backward()
79 | self.assertTrue((tmp.grad == comm.size * torch.ones_like(tmp)).all())
80 |
81 | class TestScatter(unittest.TestCase):
82 | def test_basic_functionality(self):
83 | if comm.rank == 0:
84 | tmp = torch.rand([2,5,comm.size,2,3],dtype=torch.double)
85 | for i in range(comm.size):
86 | tmp[0,0,i,0,0] = i
87 | else:
88 | tmp = torch.rand([1],dtype=torch.double)
89 | res = comm.Scatter(tmp, 2, 1, 0)
90 | self.assertTrue((res[0,0,:,0,0] == comm.rank).all())
91 |
92 | def test_scattergather(self):
93 | if comm.rank == 0:
94 | tmp = torch.rand([2,5,comm.size,2,3],dtype=torch.double)
95 | else:
96 | tmp = torch.rand([1],dtype=torch.double)
97 | res = comm.Scatter(tmp, 2, 1, 0)
98 | res2 = comm.Gather(res, 2, 0)
99 | if comm.rank == 0:
100 | self.assertTrue((res2 == tmp).all())
101 |
102 | def test_basic_ad(self):
103 | if comm.rank == 0:
104 | tmp = torch.rand([2,5,comm.size,2,3],dtype=torch.double).requires_grad_()
105 | else:
106 | tmp = torch.rand([1],dtype=torch.double).requires_grad_()
107 | res = comm.Scatter(tmp, 2, 1, 0)
108 | res.sum().backward()
109 | if comm.rank == 0:
110 | self.assertTrue((tmp.grad == torch.ones_like(tmp)).all())
111 | else:
112 | self.assertTrue((tmp.grad == torch.zeros_like(tmp)).all())
113 |
114 | class TestAlltoall(unittest.TestCase):
115 | def test_gatherscatter_equivalence(self):
116 | tmp = torch.rand([3,4,1,4,comm.size,2],dtype=torch.double)
117 | res1 = comm.Scatter(comm.Gather(tmp,2,0),4,1,0)
118 | res2 = comm.Alltoall(tmp,2,4,1)
119 | self.assertTrue((res2 == res1).all())
120 |
121 | def test_gatherscatter_equivalence_varying_numelem(self):
122 | tmp = torch.rand([3,4,comm.rank+1,4,comm.size*(comm.size+1)//2,2],dtype=torch.double)
123 | res1 = comm.Scatter(comm.Gather(tmp,2,0),4,comm.rank+1,0)
124 | res2 = comm.Alltoall(tmp,2,4,comm.rank+1)
125 | self.assertTrue((res2 == res1).all())
126 |
127 | def test_gatheraxis_scatteraxis_equal(self):
128 | tmp = torch.rand([3,4,comm.rank+1,2],dtype=torch.double)
129 | tmp[0,0,:,0] = torch.arange(comm.rank*(comm.rank+1)//2, (comm.rank+1)*(comm.rank+2)//2)
130 | res = comm.Alltoall(tmp,2,2,comm.size-comm.rank)
131 | total_numelem = comm.size*(comm.size+1)//2
132 | correct_res = torch.arange(total_numelem - (comm.size-comm.rank)*(comm.size-comm.rank+1)//2,
133 | total_numelem - (comm.size-comm.rank-1)*(comm.size-comm.rank)//2,
134 | dtype=torch.double)
135 | self.assertTrue((res[0,0,:,0] == correct_res).all())
136 |
137 | def test_identity_equivalence(self):
138 | tmp = torch.rand([3,4,2,4,3*comm.size,2],dtype=torch.double)
139 | res = comm.Alltoall(tmp,2,4,3)
140 | res2 = comm.Alltoall(res,4,2,2)
141 | self.assertTrue((res2 == tmp).all())
142 |
143 | def test_basic_ad(self):
144 | tmp = torch.rand([3,4,2,4,comm.size,2],dtype=torch.double).requires_grad_()
145 | res = comm.Alltoall(tmp,2,4,1)
146 | res.sum().backward()
147 | self.assertTrue((tmp.grad == torch.ones_like(tmp)).all())
148 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ._mpi import *
3 | from typing import List
4 |
5 | __all__ = [
6 | "MPI_MAX",
7 | "MPI_MIN",
8 | "MPI_SUM",
9 | "MPI_PROD",
10 | "MPI_LAND",
11 | "MPI_BAND",
12 | "MPI_LOR",
13 | "MPI_BOR",
14 | "MPI_LXOR",
15 | "MPI_BXOR",
16 | "MPI_MINLOC",
17 | "MPI_MAXLOC",
18 | "WaitHandle",
19 | "JoinDummies",
20 | "JoinDummiesHandle",
21 | "MPI_Communicator",
22 | "COMM_WORLD",
23 | "comm_from_mpi4py",
24 | "deactivate_cuda_aware_mpi_support"
25 | ]
26 |
27 | @torch.jit.script
28 | class WaitHandle:
29 | """Class representing a wait handle, as they are returned from one of the non-blocking MPI calls."""
30 |
31 | def __init__(self, raw_handle: List[torch.Tensor]):
32 | self._handle = raw_handle
33 |
34 | @property
35 | def dummy(self):
36 | """A dummy variable that allows for the usage of the ``WaitHandle`` as one of the
37 | second arguments of :py:func:`mpi4torch.JoinDummies` and :py:func:`mpi4torch.JoinDummiesHandle`.
38 | """
39 |
40 | return self._handle[0]
41 |
42 | @torch.jit.script
43 | def JoinDummies(loopthrough: torch.Tensor, dummies:List[torch.Tensor]) -> torch.Tensor:
44 | """This function joins multiple dummy dependencies with the DAG.
45 |
46 | From the perspective of the forward pass, this function is mostly a no-op, since it simply
47 | loops through its first argument, and discards the ``dummies`` argument.
48 |
49 | However, for the backward pass, the AD engine still considers the ``dummies`` as actual
50 | dependencies. The main use of this function is thus to manually encode dependencies
51 | that the AD engine does not see on its own. See also the introductory text in
52 | the :ref:`section_implications_mpi4torch` section on how to use this function.
53 |
54 | Parameters
55 | ----------
56 | loopthrough:
57 | Variable to pass through.
58 | dummies:
59 | List of tensors that are added as dummy dependencies to the DAG.
60 |
61 | Returns
62 | -------
63 | :py:class:`torch.tensor`:
64 | Tensor that is a shallow copy of ``loopthrough``, but whose ``grad_fn``
65 | is ``JoinDummiesBackward``.
66 | """
67 | return torch.ops.mpi4torch.JoinDummies(loopthrough, dummies)
68 |
69 | @torch.jit.script
70 | def JoinDummiesHandle(handle: WaitHandle, dummies:List[torch.Tensor]) -> WaitHandle:
71 | """This function has the same purpose as :py:func:`JoinDummies`, but accepts :py:class:`mpi4torch.WaitHandle`
72 | as its first argument.
73 |
74 | Parameters
75 | ----------
76 | handle:
77 | :py:class:`mpi4torch.WaitHandle` to pass through.
78 | dummies:
79 | List of tensors that are added as dummy dependencies to the DAG.
80 |
81 | Returns
82 | -------
83 | :py:class:`mpi4torch.WaitHandle`:
84 | A wait handle with the additional dummy dependenices added.
85 | """
86 | raw_handle = handle._handle
87 | return WaitHandle([ torch.ops.mpi4torch.JoinDummies(raw_handle[0], dummies), raw_handle[1], raw_handle[2] ])
88 |
89 | @torch.jit.script
90 | class MPI_Communicator:
91 | """MPI communicator wrapper class
92 |
93 | The only supported ways to construct an ``MPI_Communicator`` are currently either through :py:const:`mpi4torch.COMM_WORLD` or
94 | :py:func:`mpi4torch.comm_from_mpi4py`.
95 |
96 | Note
97 | ----
98 | All methods with an underscore suffix are in-place operations.
99 | """
100 |
101 | def __init__(self, comm: torch.classes.mpi4torch.MPI_Comm_Wrapper):
102 | self._comm = comm
103 |
104 | @property
105 | def rank(self) -> int:
106 | """The rank or identification number of the local process with respect to this communicator.
107 |
108 | The processes participating in a communicator are consecutively given ranks
109 | in the interval [0, :py:attr:`mpi4torch.MPI_Communicator.size` - 1].
110 | """
111 | return self._comm.GetRank()
112 |
113 | @property
114 | def size(self) -> int:
115 | """The size of the MPI communicator, i.e. the number of processes involved."""
116 | return self._comm.GetSize()
117 |
118 | # This is currently not supported by torch.jit.script:
119 | #def __getattr__(self, attrName):
120 | # if attrName in self.__dict__["_comm"]._method_names():
121 | # return self.__dict__["_comm"].__getattr__(attrName)
122 | # return self.__dict__[attrName]
123 | # So we need to write out every function by hand
124 |
125 | def Allreduce(self, tensor: torch.Tensor, op: int) -> torch.Tensor:
126 | """Combines values from all processes and distributes the result back to all processes.
127 |
128 | The combination operation is performed element-wise on the tensor.
129 |
130 | This is the wrapper function of `MPI_Allreduce `_.
131 |
132 | Parameters
133 | ----------
134 | tensor:
135 | :py:class:`torch.Tensor` that shall be combined. It needs to have the same shape on all processes.
136 | op:
137 | Operation to combine the results. Only supported operations are :py:const:`mpi4torch.MPI_MAX`,
138 | :py:const:`mpi4torch.MPI_MIN`, :py:const:`mpi4torch.MPI_SUM`, :py:const:`mpi4torch.MPI_PROD`, :py:const:`mpi4torch.MPI_LAND`,
139 | :py:const:`mpi4torch.MPI_BAND`, :py:const:`mpi4torch.MPI_LOR`, :py:const:`mpi4torch.MPI_BOR`, :py:const:`mpi4torch.MPI_LXOR`,
140 | :py:const:`mpi4torch.MPI_BXOR`, :py:const:`mpi4torch.MPI_MINLOC`,
141 | :py:const:`mpi4torch.MPI_MAXLOC`
142 |
143 | Returns
144 | -------
145 | :py:class:`torch.Tensor`:
146 | Combined tensor of the same shape as the input `tensor`.
147 |
148 | Note
149 | ----
150 | Only :py:const:`mpi4torch.MPI_SUM` is supported in the backwards pass at the moment.
151 | """
152 | return self._comm.Allreduce(tensor, op)
153 |
154 | def Bcast_(self, tensor: torch.Tensor, root: int) -> torch.Tensor:
155 | """Broadcasts a tensor from the `root` process to all other processes.
156 |
157 | This is an in-place operation.
158 |
159 | This is the wrapper function of `MPI_Bcast `_.
160 |
161 | Parameters
162 | ----------
163 | tensor:
164 | :py:class:`torch.Tensor` that shall be broadcasted. The tensor needs to have the same shape on all processes,
165 | since it is an in-place operation.
166 | root:
167 | The root process, whose tensor shall be broadcasted to the others.
168 |
169 | Returns
170 | -------
171 | :py:class:`torch.Tensor`:
172 | For `rank == root` this is the same as the input `tensor`. For all other processes this is the input `tensor` filled with the content
173 | from the `root` process.
174 | """
175 | return self._comm.Bcast_(tensor, root)
176 |
177 | def Reduce_(self, tensor: torch.Tensor, op: int, root: int) -> torch.Tensor:
178 | """Reduces multiple tensors of the same shape, scattered over all processes, to a single tensor of the same shape stored on the `root` process.
179 |
180 | The combination operation is performed element-wise on the tensor.
181 |
182 | This is an in-place operation.
183 |
184 | This is the wrapper function of `MPI_Reduce `_.
185 |
186 | Parameters
187 | ----------
188 | tensor:
189 | :py:class:`torch.Tensor` that shall be reduced. The tensor needs to have the same shape on all processes,
190 | since it is an element-wise operation.
191 | op:
192 | Operation to combine the results. Only supported operations are :py:const:`mpi4torch.MPI_MAX`,
193 | :py:const:`mpi4torch.MPI_MIN`, :py:const:`mpi4torch.MPI_SUM`, :py:const:`mpi4torch.MPI_PROD`, :py:const:`mpi4torch.MPI_LAND`,
194 | :py:const:`mpi4torch.MPI_BAND`, :py:const:`mpi4torch.MPI_LOR`, :py:const:`mpi4torch.MPI_BOR`, :py:const:`mpi4torch.MPI_LXOR`,
195 | :py:const:`mpi4torch.MPI_BXOR`, :py:const:`mpi4torch.MPI_MINLOC`,
196 | :py:const:`mpi4torch.MPI_MAXLOC`
197 | root:
198 | The root process, where the resulting tensor shall be gathered.
199 |
200 | Returns
201 | -------
202 | :py:class:`torch.Tensor`:
203 | For `rank == root` the result stores the reduced tensor. For all other processes the content of the resulting tensor is undefined,
204 | with the exception that the result shall still suffice as input for the second argument of :py:func:`mpi4torch.JoinDummies`.
205 |
206 | Note
207 | ----
208 | Only :py:const:`mpi4torch.MPI_SUM` is supported in the backwards pass at the moment.
209 | """
210 | return self._comm.Reduce_(tensor, op, root)
211 |
212 | def Gather(self, tensor: torch.Tensor, gatheraxis: int, root: int) -> torch.Tensor:
213 | return self._comm.Gather(tensor, gatheraxis, root)
214 |
215 | def Allgather(self, tensor: torch.Tensor, gatheraxis: int) -> torch.Tensor:
216 | return self._comm.Allgather(tensor, gatheraxis)
217 |
218 | def Scatter(self, tensor: torch.Tensor, scatteraxis: int, numelem: int, root: int) -> torch.Tensor:
219 | return self._comm.Scatter(tensor, scatteraxis, numelem, root)
220 |
221 | def Alltoall(self, tensor: torch.Tensor, gatheraxis: int, scatteraxis: int,
222 | numelem: int) -> torch.Tensor:
223 | return self._comm.Alltoall(tensor, gatheraxis, scatteraxis, numelem)
224 |
225 | def Isend(self, tensor: torch.Tensor, dest: int, tag: int) -> WaitHandle:
226 | return WaitHandle(self._comm.Isend(tensor, dest, tag))
227 |
228 | def Irecv(self, tensor: torch.Tensor, source: int, tag: int) -> WaitHandle:
229 | return WaitHandle(self._comm.Irecv(tensor, source, tag))
230 |
231 | def Wait(self, waithandle: WaitHandle) -> torch.Tensor:
232 | return self._comm.Wait(waithandle._handle)
233 |
234 | def Send(self, tensor: torch.Tensor, dest: int, tag: int) -> torch.Tensor:
235 | handle = self._comm.Isend(tensor, dest, tag)
236 | return self._comm.Wait(handle)
237 |
238 | def Recv(self, tensor: torch.Tensor, source: int, tag: int) -> torch.Tensor:
239 | handle = self._comm.Irecv(tensor, source, tag)
240 | return self._comm.Wait(handle)
241 |
242 | COMM_WORLD = MPI_Communicator(torch.ops.mpi4torch.COMM_WORLD())
243 | """
244 | World communicator ``MPI_COMM_WORLD``.
245 | """
246 |
247 | try:
248 | from mpi4py import MPI as __mpi4py_MPI
249 |
250 | def comm_from_mpi4py(comm: __mpi4py_MPI.Comm) -> MPI_Communicator:
251 | """Converts an ``mpi4py`` communicator to an :py:class:`mpi4torch.MPI_Communicator`.
252 | """
253 |
254 | fortran_handle = comm.py2f();
255 | return MPI_Communicator(torch.ops.mpi4torch.comm_from_fortran(fortran_handle))
256 | except ModuleNotFoundError:
257 | def comm_from_mpi4py(comm) -> MPI_Communicator:
258 | """Converts an ``mpi4py`` communicator to an :py:class:`mpi4torch.MPI_Communicator`.
259 | """
260 |
261 | raise RuntimeError("mpi4py is not available!")
262 |
--------------------------------------------------------------------------------
/doc/basic_usage.rst:
--------------------------------------------------------------------------------
1 | ********************
2 | Basic Usage
3 | ********************
4 |
5 | In the following we are going to discuss the different options and caveats that come with mixing MPI
6 | and pytorch's automatic differentiation (AD) functionality, and what consequences this has for using
7 | the mpi4torch library.
8 |
9 | Note that although we will within this document mostly talk about the interplay of mpi4torch with pytorch's AD,
10 | this does not mean that mpi4torch could not in principle be used as one might expect coming from other
11 | MPI libraries. The main difference,
12 | however, is that if one plans to use mpi4torch as a building block in some automatic differentiable code,
13 | the usage of mpi4torch actually differs a lot to these "classical" programming paradigms.
14 | It is thus *highly* recommended
15 | for everybody to read this document before, e.g., literally translating MPI calls to mpi4torch.
16 |
17 | How pytorch's AD works
18 | ======================
19 |
20 | Since it is important for what follows, we start with a quick reminder on how the AD engine in
21 | pytorch is used. Consider the following code
22 |
23 | .. code-block:: python
24 |
25 | import torch
26 |
27 | a = torch.tensor([0.5]).requires_grad_()
28 | b = torch.exp(a)
29 | b.backward()
30 | assert(a.grad == b)
31 |
32 | This code simply computes the derivative of the function :math:`f(x) = e^x` at the point :math:`x=0.5`.
33 | In the code we do so by initializing a torch tensor ``a`` that has the flag ``requires_grad = True`` set,
34 | which we do here by calling the ``requires_grad_()`` method. This flag is in some sense contagious: Allmost
35 | all torch functions that are called with ``a`` as their argument, pass this flag also to their output. In
36 | the example above this is the exponential function, which returns a tensor ``b`` that has also this flag set.
37 | In addition to this flag ``b`` comes with the info that it was computed from ``a`` and
38 | a property, which is called the gradient function ``grad_fn``.
39 | This is the function that tells pytorch what to do in the backward automatic differentiation pass.
40 |
41 | To illustrate this a bit more, consider the following directed acyclic graph (DAG) that represents the
42 | computational flow in the forward phase:
43 |
44 | .. graphviz::
45 | :caption: Forward DAG
46 | :align: center
47 |
48 | digraph forwarddag {
49 | exp [shape=rectangle];
50 | "a" -> exp -> "b";
51 | }
52 |
53 | What now happens in the backward pass is that pytorch executes a reversed DAG of the foward DAG, just with
54 | the functions replaced by their gradient functions. E.g. in the example above this would look like
55 |
56 | .. graphviz::
57 | :caption: Backward DAG
58 | :align: center
59 |
60 | digraph backwarddag {
61 | rankdir=BT;
62 | ExpBackward [shape=rectangle];
63 | "1" -> ExpBackward -> "a.grad";
64 | }
65 |
66 | In particular, pytorch starts off with :math:`1`, which is obviously the derivative of ``b`` with respect to
67 | itself. It then executes the ``grad_fn`` function, which in this example is the ``ExpBackward`` function.
68 | Not shown in the illustration is that the ``grad_fn`` function internally has a reference to the result of the
69 | forward calculation, which is then muliplied with ``1`` and defines the output of ``ExpBackward``.
70 | Finally pytorch stores this result in ``a.grad``, which now contains the derviative of ``b`` with respect to ``a``.
71 |
72 | This principle of course can be generalized to more complicated DAGs. pytorch in these situations still builds
73 | up the backwards DAG by recording the gradient functions and the dependencies on the go, and then executes this
74 | graph when the ``backward`` method is called. However, there are still some important implications for the
75 | usage in the following, which we want to highlight:
76 |
77 | .. _section_pure_functions:
78 |
79 | Automatic differentiable functions should at best be pure functions
80 | -------------------------------------------------------------------
81 |
82 | This statement is --- if written out like that --- probably not any news,
83 | since the concept of a differentiable function
84 | is a mathematical one, and all mathematical functions are pure in a procedural sense. Hence, a programmatic
85 | representation of a mathematically differentiable function should at best also be pure.
86 |
87 | This has some implications. One of the more important ones is, as obvious as it may seem, that this function
88 | needs to have an input and an output. Without an input (and without explicitly modifying the autograd meta data)
89 | the output of a function is from the perspective of the AD engine a constant. The same applies for functions
90 | with no output, whose branch in the backward DAG execution is simply omitted by the AD engine.
91 |
92 | Since this is so important, we repeat it:
93 |
94 | .. warning::
95 |
96 | **All automatic differentiable functions need to depend on some input tensor,
97 | and need to return an output tensor**.
98 |
99 | DAG edges can only be pytorch tensors of floating point type
100 | ------------------------------------------------------------
101 |
102 | This goes into the same direction as the last remark. Obviously differentiability is from its mathematical
103 | definition strongly tied to of the real numbers, and the floating point numbers are the only approximation
104 | to them we have in pytorch.
105 |
106 | As such we can only exchange floating point tensors along the edges in the DAG.
107 |
108 | That some form of additivity is required for the structures that are transported along the DAG edges
109 | can also be seen from the following example
110 |
111 | .. code-block:: python
112 |
113 | a = ...
114 | tmp1 = F(a)
115 | tmp2 = G1(tmp1)
116 | tmp3 = G2(tmp2)
117 | b = H(tmp1, tmp2)
118 | b.backward()
119 |
120 | Note in particular that the output from the node ``F`` is used twice: once as the input for ``G1``
121 | and once as the input for ``G2``. The forward DAG would then look like
122 |
123 | .. graphviz::
124 | :caption: Forward DAG with bifurcation
125 | :align: center
126 |
127 | digraph foo {
128 | F [shape=rectangle];
129 | G1 [shape=rectangle];
130 | G2 [shape=rectangle];
131 | H [shape=rectangle];
132 | "a" -> F -> G1 -> H -> "b";
133 | F -> G2 -> H;
134 | }
135 |
136 | The corresponding backward DAG would by simply inverting the arrows and substituting
137 | the function calls by the respective backward function calls, have the form
138 |
139 | .. graphviz::
140 | :caption: Backward DAG with bifurcation
141 | :align: center
142 |
143 | digraph foo2 {
144 | rankdir=BT;
145 | FBackward [shape=rectangle];
146 | G1Backward [shape=rectangle];
147 | G2Backward [shape=rectangle];
148 | HBackward [shape=rectangle];
149 | "1" -> HBackward -> G1Backward -> FBackward -> "a.grad";
150 | HBackward -> G2Backward -> FBackward;
151 | }
152 |
153 | However, what this picture does not show is that the bifurctation in the forward evaluation of ``b``
154 | becomes an addition in the backward pass. A more detailed representation of the backward DAG
155 | would thus be
156 |
157 | .. graphviz::
158 | :caption: Backward DAG with explicit addition
159 | :align: center
160 |
161 | digraph foo2 {
162 | rankdir=BT;
163 | FBackward [shape=rectangle];
164 | "+" [shape=rectangle];
165 | G1Backward [shape=rectangle];
166 | G2Backward [shape=rectangle];
167 | HBackward [shape=rectangle];
168 | "1" -> HBackward -> G1Backward -> "+" -> FBackward -> "a.grad";
169 | HBackward -> G2Backward -> "+";
170 | }
171 |
172 | To sum up:
173 |
174 | .. warning::
175 |
176 | The edges in the DAG representation can only be pytorch tensors of floating point type.
177 |
178 |
179 | .. _section_implications_mpi4torch:
180 |
181 | Implications for mpi4torch
182 | ==========================
183 |
184 | mpi4torch is a MPI wrapper library for pytorch tensors that tries to be as *transparent* as possible
185 | to pytorch's AD engine. By transparent we in particular mean that we do not touch the AD engine, but
186 | rather provide the MPI functions as nodes in the DAG that pytorch composes. To be more precise, one should
187 | say the DAGs that pytorch composes, which brings us already to one of the ramifications of this design
188 | decision: When parallelizing your program with mpi4torch it is still the case that each MPI rank has its
189 | individual DAG that is run during the backward step. Most importantly, these DAGs do not know anything
190 | about each other, and thus cannot resolve any dependencies with ``requires_grad`` set from any other rank.
191 | As a consequence **it is the sole responsibility of the user to manage these dependencies**.
192 |
193 | We will come to it in a minute how the user actually can encode these dependencies, but first start
194 | with an example. Consider the following code, which shows the often used Isend-Recv-Wait idiom.
195 | It from a communication perspective
196 | simply receives a tensor from the left process and passes its own tensor to the right, if all
197 | ranks are imagined to be arranged in a circle.
198 |
199 | .. code-block:: python
200 |
201 | import torch
202 | import mpi4torch
203 |
204 | comm = mpi4torch.COMM_WORLD
205 |
206 | a = torch.tensor([1.0 + comm.rank]).requires_grad_()
207 |
208 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0)
209 | b = comm.Recv(torch.empty_like(a), (comm.rank-1+comm.size)%comm.size, 0)
210 | comm.Wait(handle)
211 |
212 | res = a+b
213 | print(res)
214 |
215 | This code follows usual MPI coding paradigms and works as expected. However, when we would start asking
216 | for the gradient of (the sum of all) ``res`` with respect to the individual ``a`` s, we would get an incorrect
217 | result.
218 |
219 | .. code-block:: python
220 |
221 | res.backward()
222 | print(a.grad) # <- this would print tensor([1.])
223 |
224 | The print function would actually display 1 as the result, whereas taking the derivative of
225 | the sum of all ``res`` variables on all ranks with respect to that specific ``a`` variable should be 2.
226 |
227 | This is just one of the things that could happen. There are many more situations, in which the program would
228 | run flawlessly in forward mode, but would e.g. deadlock in the backward pass. To exemplify
229 | how this happens we will look once more at a graphical representation of the DAG.
230 |
231 | .. graphviz::
232 | :caption: Forward DAG for the Isend-Recv-Wait idiom
233 | :align: center
234 | :name: naiveforwardisendrecvwaitgraph
235 |
236 | digraph foo2 {
237 | rank = same;
238 | subgraph clusterrankm1 {
239 | a1 [label="a"];
240 | res1 [label="res"];
241 | node [shape=rectangle];
242 | Isend1 [label="Isend"];
243 | Wait1 [label="Wait"];
244 | Recv1 [label="Recv"];
245 | p1 [label="+"];
246 | a1 -> Isend1 -> Wait1;
247 | Recv1 -> p1 -> res1;
248 | a1 -> p1;
249 | label = "rank - 1";
250 | color = black;
251 | };
252 | subgraph clusterrank {
253 | a2 [label="a"];
254 | res2 [label="res"];
255 | node [shape=rectangle];
256 | Isend2 [label="Isend"];
257 | Wait2 [label="Wait"];
258 | Recv2 [label="Recv"];
259 | p2 [label="+"];
260 | a2 -> Isend2 -> Wait2;
261 | Recv2 -> p2 -> res2;
262 | a2 -> p2;
263 | label = "rank";
264 | }
265 | subgraph clusterrankp1 {
266 | a3 [label="a"];
267 | res3 [label="res"];
268 | node [shape=rectangle];
269 | Isend3 [label="Isend"];
270 | Wait3 [label="Wait"];
271 | Recv3 [label="Recv"];
272 | p3 [label="+"];
273 | a3 -> Isend3 -> Wait3;
274 | Recv3 -> p3 -> res3;
275 | a3 -> p3;
276 | label = "rank + 1";
277 | }
278 |
279 | Isend1 -> Recv2 [style=dotted, constraint=false];
280 | Isend2 -> Recv3 [style=dotted, constraint=false];
281 | #Isend3 -> Recv1 [style=dotted, constraint=false];
282 | }
283 |
284 | The graph as shown above shows the dependencies between the different computations as seen from pytorch's
285 | perspective with the addition of some dotted arrows that show the actual communication that is happening.
286 |
287 | If we would now invert the arrows in order to get the corresponding backward DAG we would obtain
288 |
289 | .. graphviz::
290 | :caption: Backward DAG for the Isend-Recv-Wait idiom for a single rank
291 | :align: center
292 |
293 | digraph foo2 {
294 | rankdir=BT;
295 | subgraph clusterrankm1 {
296 | a1 [label="a.grad"];
297 | res1 [label="1"];
298 | node [shape=rectangle];
299 | Isend1 [label="IsendBackward", style=filled, color=gray];
300 | Wait1 [label="WaitBackward", style=filled, color=gray];
301 | Recv1 [label="RecvBackward", style=filled, color=gray];
302 | p1 [label="AddBackward0"];
303 | Wait1 -> Isend1 -> a1;
304 | res1 -> p1 -> Recv1;
305 | p1 -> a1;
306 | label = "rank";
307 | };
308 | }
309 |
310 | This graph immanently makes clear why ``a.grad`` contains 1 in the end. All grayed-out nodes are omitted
311 | --- or to be more precise, not even generated --- by pytorch's AD engine, such that only ``AddBackward0``
312 | is called, which just passes through 1 to ``a.grad``.
313 |
314 | From this discussion and the :ref:`naiveforwardisendrecvwaitgraph` it becomes apparent that there are some parts
315 | that are implicit in the program code but that are missing in the DAG representation:
316 |
317 | #. As noted earlier, the DAGs are local to each MPI rank, and they do not resolve any dependencies that
318 | are the effect of communication.
319 | #. The DAGs also lack any information that was present in the linear ordering of commands in the source code
320 | file. E.g. the ``Recv`` call has to happen after ``Isend``, and ``Wait`` has to happen after ``Recv``.
321 |
322 | **It is the users responsibility to encode these dependencies in the DAG!.**
323 | This brings us to the tools mpi4torch provides to mitigate this situation.
324 |
325 | The first one is a direct consequence of the discussion in the section on
326 | :ref:`pure functions `: all DAG nodes need an input and an output.
327 | In our example above, this would e.g. concern the :py:meth:`mpi4torch.MPI_Communicator.Wait`
328 | call. In principle, ``MPI_Wait`` does not return a floating point tensor. However, mpi4torch
329 | returns a floating-point tensor, giving the user the possibility to use it to encode
330 | any other dependencies on the ``Wait`` call. These tensors are named **dummies** in mpi4torch.
331 | They do not convey any other information than that there is some (virtual/artificial)
332 | dependency to be encoded in the DAG.
333 |
334 | The dummies themselves are not really useful without a way to join them with the DAG. This is
335 | what the :py:func:`mpi4torch.JoinDummies` function is actually for. The call signature of
336 | :py:func:`mpi4torch.JoinDummies` is given by
337 |
338 | .. code-block:: python
339 |
340 | def JoinDummies(loopthrough: torch.Tensor, dummies: List[torch.Tensor]) -> torch.Tensor
341 |
342 | The function takes two arguments: the loopthrough variable and a list of dummies. From a forward
343 | execution perspective the ``JoinDummies`` function is a no-op, it simply --- as the name suggests ---
344 | loops through the ``loopthrough`` variable. The ``dummies`` are discarded and not used.
345 |
346 | However, pytorch does not know about this behaviour of the ``JoinDummies`` function, and considers
347 | the result of the function to actually depend on the dummies. Consequently, pytorch will also
348 | respect this dependency in the backward DAG.
349 |
350 | The :py:func:`mpi4torch.JoinDummies` function also has a sister function :py:func:`mpi4torch.JoinDummiesHandle`, which
351 | is thought for situations in which the ``loopthrough`` variable is a :py:class:`mpi4torch.WaitHandle`
352 | from a non-blocking MPI call, as e.g. returned by :py:func:`mpi4torch.MPI_Communicator.Isend`. The signature
353 | of :py:func:`mpi4torch.JoinDummiesHandle` is
354 |
355 | .. code-block:: python
356 |
357 | def JoinDummiesHandle(handle: WaitHandle, dummies: List[torch.Tensor]) -> WaitHandle
358 |
359 | Returning to the Isend-Recv-Wait example, we now want to put these tools to use. Starting with
360 | the call to :py:func:`mpi4torch.MPI_Communicator.Recv`, we want this call to happen after
361 | :py:func:`mpi4torch.MPI_Communicator.Isend`. Note that ``Isend`` returns a ``WaitHandle``, which
362 | cannot directly be passed to ``JoinDummies``. For these situations we will use the
363 | :py:attr:`mpi4torch.WaitHandle.dummy` property, which gives us a means to convert a ``WaitHandle``
364 | to a dummy tensor. In the example from above this could then
365 | look like
366 |
367 | .. code-block:: python
368 |
369 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0)
370 | recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy])
371 | # ~~~~~~~~~~~~~~~~~~~
372 | # This is what we
373 | # originally wanted
374 | # to pass to Recv
375 | # ~~~~~~~~~~~~~~
376 | # This adds the handle
377 | # from the previous Isend call
378 | # as a dummy dependency to the DAG
379 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0)
380 |
381 | For the ``Wait`` we now also want this to happen after the ``Recv`` call. This time we make use of
382 | :py:func:`mpi4torch.JoinDummiesHandle`.
383 |
384 | .. code-block:: python
385 |
386 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0)
387 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b]))
388 |
389 | Note that we already added a return variable for ``Wait``, since we still want to encode
390 | that our end result, the (implicit) sum of all ``res`` on all ranks, depends on the ``Isend`` to
391 | have finished. For that we introduce another call to :py:func:`mpi4torch.JoinDummies`.
392 |
393 | .. code-block:: python
394 |
395 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b]))
396 |
397 | res = mpi4torch.JoinDummies(a+b, [wait_ret])
398 |
399 |
400 | The full code example now looks like
401 |
402 | .. code-block:: python
403 |
404 | import torch
405 | import mpi4torch
406 |
407 | comm = mpi4torch.COMM_WORLD
408 |
409 | a = torch.tensor([1.0 + comm.rank]).requires_grad_()
410 |
411 | handle = comm.Isend(a,(comm.rank+1)%comm.size, 0)
412 | recvbuffer = mpi4torch.JoinDummies(torch.empty_like(a), [handle.dummy])
413 | b = comm.Recv(recvbuffer, (comm.rank-1+comm.size)%comm.size, 0)
414 | wait_ret = comm.Wait(mpi4torch.JoinDummiesHandle(handle,[b]))
415 |
416 | res = mpi4torch.JoinDummies(a+b, [wait_ret])
417 | print(res)
418 |
419 | res.backward()
420 | print(a.grad) # <- this would now correctly print tensor([2.])
421 |
422 | This code would now print the correct result for ``a.grad``. To exemplify the differences to the
423 | first version of the code we will also look at the DAG of the new version
424 |
425 | .. graphviz::
426 | :caption: Forward DAG for the Isend-Recv-Wait idiom with dummy dependencies
427 | :align: center
428 |
429 | digraph foo2 {
430 | subgraph clusterrankm1 {
431 | a [label="a"];
432 | res [label="res"];
433 | node [shape=rectangle];
434 | JoinDummies1 [label="JoinDummies"];
435 | JoinDummiesHandle [label="JoinDummiesHandle"];
436 | JoinDummies2 [label="JoinDummies"];
437 | Isend [label="Isend"];
438 | Wait [label="Wait"];
439 | Recv [label="Recv"];
440 | p1 [label="+"];
441 | a -> Isend;
442 | Isend -> JoinDummies1 -> Recv;
443 | Recv -> JoinDummiesHandle -> Wait;
444 | Isend -> JoinDummiesHandle;
445 | Recv -> p1 -> JoinDummies2;
446 | a -> p1;
447 | Wait -> JoinDummies2;
448 | JoinDummies2 -> res;
449 | label = "rank";
450 | color = black;
451 | };
452 | }
453 |
454 | The important point to note is that all communciation is part of a path between
455 | ``a`` and ``res``, and in comparison to the first version of the code there are no "dead branches".
456 | pytorch's AD engine thus has to call the respective backward methods when it propagates the gradient
457 | back from ``res`` to ``a.grad``.
458 |
459 | .. warning::
460 |
461 | In general, if you write a function that uses mpi4torch internally and shall be automatic differentiable,
462 | make sure that all communication primitives are through one way or another part of a DAG path
463 | that connects input and output of that function.
464 |
465 |
466 |
--------------------------------------------------------------------------------
/csrc/extension.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 | //#include
6 | #include
7 | #include
8 | #include
9 |
10 | #include
11 |
12 | #if defined(OPEN_MPI) && OPEN_MPI
13 | // Needed for checking cuda-awareness
14 | #include
15 | #endif
16 |
17 | #include
18 | #include
19 | #include
20 |
21 | using torch::Tensor;
22 | using torch::ScalarType;
23 | using torch::autograd::variable_list;
24 |
25 | namespace
26 | {
27 |
28 | #if defined(MPIX_CUDA_AWARE_SUPPORT)
29 | // if it is at compiletime already clear that OpenMPI has now support, we deactivate
30 | // the cuda-aware mpi support directly
31 | #define MPI4TORCH_BUILT_WITH_CUDA_AWARENESS MPIX_CUDA_AWARE_SUPPORT
32 | #else
33 | #define MPI4TORCH_BUILT_WITH_CUDA_AWARENESS 0
34 | #endif
35 |
36 | #if MPI4TORCH_BUILT_WITH_CUDA_AWARENESS
37 |
38 | bool have_cuda_aware_mpi_support = false;
39 |
40 | void inline __setup_have_cuda_aware_mpi_support()
41 | {
42 | // OpenMPI (and presumably also Parastation MPI) provides this runtime query function
43 | #if defined(MPIX_CUDA_AWARE_SUPPORT)
44 | have_cuda_aware_mpi_support = MPIX_Query_cuda_support();
45 | #else
46 | have_cuda_aware_mpi_support = false;
47 | #endif
48 | }
49 |
50 | #else
51 | const bool have_cuda_aware_mpi_support = false;
52 | #endif
53 |
54 | void deactivate_cuda_aware_mpi_support()
55 | {
56 | #if MPI4TORCH_BUILT_WITH_CUDA_AWARENESS
57 | have_cuda_aware_mpi_support = false;
58 | #endif
59 | }
60 |
61 | struct MPIDeviceHelper
62 | {
63 | MPIDeviceHelper(const Tensor& input)
64 | : device(input.device()), devicetype(device.type()), mpidevice(c10::kCPU)
65 | {
66 | setup();
67 | }
68 |
69 | MPIDeviceHelper(const c10::Device& _device)
70 | : device(_device), devicetype(device.type()), mpidevice(c10::kCPU)
71 | {
72 | setup();
73 | }
74 |
75 | Tensor fromDeviceToMPI(const Tensor& input)
76 | {
77 | if (input.device() == mpidevice) {
78 | return input;
79 | }
80 | return input.to(mpidevice);
81 | }
82 |
83 | Tensor fromMPIToDevice(const Tensor& output)
84 | {
85 | if (output.device() == device) {
86 | return output;
87 | }
88 | return output.to(device);
89 | }
90 |
91 | c10::Device device;
92 | c10::DeviceType devicetype;
93 | c10::Device mpidevice;
94 |
95 | private:
96 | void setup()
97 | {
98 | if (devicetype == c10::kCPU) {
99 | mpidevice = device;
100 | } else if (devicetype == c10::kCUDA && have_cuda_aware_mpi_support) {
101 | mpidevice = device;
102 | }
103 | }
104 | };
105 |
106 | MPI_Datatype torch2mpitype(ScalarType in)
107 | {
108 | switch(in)
109 | {
110 | case ScalarType::Byte:
111 | return MPI_BYTE;
112 | case ScalarType::Char:
113 | return MPI_CHAR;
114 | case ScalarType::Short:
115 | return MPI_SHORT;
116 | case ScalarType::Int:
117 | return MPI_INT;
118 | case ScalarType::Long:
119 | return MPI_LONG;
120 | case ScalarType::Float:
121 | return MPI_FLOAT;
122 | case ScalarType::Double:
123 | return MPI_DOUBLE;
124 | default:
125 | break;
126 | // just to silence compiler warnings of unhandeled switch cases
127 | }
128 | throw std::invalid_argument("Failure to match torch::ScalarType to MPI_Datatype!");
129 | }
130 |
131 | void check_mpi_return_value(int ierr)
132 | {
133 | if (ierr != MPI_SUCCESS) {
134 | std::ostringstream oss;
135 | oss << ierr;
136 | throw std::runtime_error("MPI call failed with error code " + oss.str());
137 | }
138 | }
139 |
140 | struct MPI_Comm_Wrapper : torch::CustomClassHolder
141 | {
142 | MPI_Comm_Wrapper(const MPI_Comm comm_ = MPI_COMM_NULL) : comm(comm_) {}
143 |
144 | MPI_Comm comm;
145 |
146 | int64_t GetRank();
147 | int64_t GetSize();
148 |
149 | Tensor MPIAllreduce(const Tensor& input, int64_t op);
150 | Tensor MPIBcast_(const Tensor& input, int64_t root);
151 | Tensor MPIReduce_(const Tensor& input, int64_t op, int64_t root);
152 |
153 | Tensor MPIGather(const Tensor& input, int64_t gatheraxis, int64_t root);
154 | Tensor MPIAllgather(const Tensor& input, int64_t gatheraxis);
155 | Tensor MPIScatter(const Tensor& input, int64_t scatteraxis, int64_t numelem, int64_t root);
156 | Tensor MPIAlltoall(const Tensor& input, int64_t gatheraxis, int64_t scatteraxis, int64_t numelem);
157 |
158 | variable_list MPIIsend(const Tensor& input, int64_t dest, int64_t tag);
159 | variable_list MPIIrecv(const Tensor& input, int64_t source, int64_t tag);
160 | Tensor MPIWait(const variable_list& input);
161 | };
162 |
163 | c10::intrusive_ptr comm_world()
164 | {
165 | return c10::make_intrusive(MPI_COMM_WORLD);
166 | }
167 |
168 | c10::intrusive_ptr comm_from_fortran(int64_t fortran_handle)
169 | {
170 | return c10::make_intrusive(MPI_Comm_f2c(fortran_handle));
171 | }
172 |
173 | Tensor JoinDummies(const Tensor& loopthrough, const variable_list& list);
174 |
175 | int64_t MPI_Comm_Wrapper::GetRank()
176 | {
177 | int rank;
178 | MPI_Comm_rank(comm,&rank);
179 | return rank;
180 | }
181 |
182 | int64_t MPI_Comm_Wrapper::GetSize()
183 | {
184 | int size;
185 | MPI_Comm_size(comm,&size);
186 | return size;
187 | }
188 |
189 | struct MPIBackwardNode : public torch::autograd::Node
190 | {
191 | MPI_Comm_Wrapper comm;
192 | };
193 |
194 | struct MPIUnimplementedNode : MPIBackwardNode
195 | {
196 | variable_list apply(variable_list&& grads) override {
197 | throw std::runtime_error("This backward operation is currently unimplemented!");
198 | }
199 | std::string name() const override {
200 | return std::string("MPIUnimplementedNode");
201 | }
202 | };
203 |
204 | enum Mpi4torchCollectiveOps : int64_t
205 | {
206 | mpi4torch_op_max,
207 | mpi4torch_op_min,
208 | mpi4torch_op_sum,
209 | mpi4torch_op_prod,
210 | mpi4torch_op_land,
211 | mpi4torch_op_band,
212 | mpi4torch_op_lor,
213 | mpi4torch_op_bor,
214 | mpi4torch_op_lxor,
215 | mpi4torch_op_bxor,
216 | mpi4torch_op_minloc,
217 | mpi4torch_op_maxloc
218 | };
219 |
220 | MPI_Op __get_mpi_op(int64_t op)
221 | {
222 | switch(op)
223 | {
224 | case mpi4torch_op_max:
225 | return MPI_MAX;
226 | case mpi4torch_op_min:
227 | return MPI_MIN;
228 | case mpi4torch_op_sum:
229 | return MPI_SUM;
230 | case mpi4torch_op_prod:
231 | return MPI_PROD;
232 | case mpi4torch_op_land:
233 | return MPI_LAND;
234 | case mpi4torch_op_band:
235 | return MPI_BAND;
236 | case mpi4torch_op_lor:
237 | return MPI_LOR;
238 | case mpi4torch_op_bor:
239 | return MPI_BOR;
240 | case mpi4torch_op_lxor:
241 | return MPI_LXOR;
242 | case mpi4torch_op_bxor:
243 | return MPI_BXOR;
244 | case mpi4torch_op_minloc:
245 | return MPI_MINLOC;
246 | case mpi4torch_op_maxloc:
247 | return MPI_MAXLOC;
248 | default:
249 | break;
250 | }
251 | throw std::invalid_argument("mpi4torch: Collective operation not supported!");
252 | }
253 |
254 | struct MPIAllreduceSumBackward : public MPIBackwardNode {
255 | variable_list apply(variable_list&& grads) override;
256 | std::string name() const override {
257 | return std::string("MPIAllreduceSumBackward");
258 | }
259 |
260 | void release_variables() override {
261 | return;
262 | }
263 | };
264 |
265 | variable_list MPIAllreduceSumBackward::apply (variable_list&& grads)
266 | {
267 | variable_list grad_inputs(1);
268 | if (should_compute_output(0)) {
269 | grad_inputs[0] = comm.MPIAllreduce(grads[0], mpi4torch_op_sum);
270 | }
271 | return grad_inputs;
272 | }
273 |
274 | Tensor MPI_Comm_Wrapper::MPIAllreduce(const Tensor& input, int64_t op)
275 | {
276 | std::shared_ptr grad_fn;
277 | auto mpiop = __get_mpi_op(op);
278 | if (torch::autograd::compute_requires_grad(input)) {
279 | if (op == mpi4torch_op_sum) {
280 | grad_fn = std::shared_ptr (new MPIAllreduceSumBackward(), torch::autograd::deleteNode);
281 | } else {
282 | grad_fn = std::shared_ptr(new MPIUnimplementedNode(), torch::autograd::deleteNode);
283 | }
284 | grad_fn->comm = *this;
285 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
286 | }
287 | auto result = ([&]() {
288 | at::AutoDispatchBelowADInplaceOrView guard;
289 |
290 | MPIDeviceHelper devhelper(input);
291 |
292 | // make input contiguous
293 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous();
294 |
295 | auto recv = torch::empty_like(input_cont);
296 |
297 | check_mpi_return_value(
298 | MPI_Allreduce(input_cont.data_ptr(), recv.data_ptr(), input_cont.numel(),
299 | torch2mpitype(input_cont.scalar_type()), mpiop, comm)
300 | );
301 |
302 | return devhelper.fromMPIToDevice(recv);
303 | })();
304 | if (grad_fn) {
305 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
306 | }
307 | return result;
308 | }
309 |
310 | struct MPIBcastInPlaceBackward : public MPIBackwardNode {
311 | MPIBcastInPlaceBackward(int _root) : root(_root) {}
312 | variable_list apply(variable_list&& grads) override;
313 | std::string name() const override {
314 | return std::string("MPIBcastInPlaceBackward");
315 | }
316 |
317 | void release_variables() override {
318 | return;
319 | }
320 |
321 | int root;
322 | };
323 |
324 | variable_list MPIBcastInPlaceBackward::apply (variable_list&& grads)
325 | {
326 | variable_list grad_inputs(1);
327 | if (should_compute_output(0)) {
328 | grad_inputs[0] = comm.MPIReduce_(grads[0],mpi4torch_op_sum,root);
329 | }
330 | return grad_inputs;
331 | }
332 |
333 | Tensor MPI_Comm_Wrapper::MPIBcast_(const Tensor& input, int64_t root)
334 | {
335 | // TODO: check for root being in int range
336 | std::shared_ptr grad_fn;
337 | if (torch::autograd::compute_requires_grad(input)) {
338 | grad_fn = std::shared_ptr (new MPIBcastInPlaceBackward(static_cast(root)),
339 | torch::autograd::deleteNode);
340 | grad_fn->comm = *this,
341 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
342 | }
343 | auto result = ([&]() {
344 | at::AutoDispatchBelowADInplaceOrView guard;
345 |
346 | MPIDeviceHelper devhelper(input);
347 |
348 | // 1. Make input contiguous
349 | // 2. Call variable_data() to make a shallow copy of the input tensor without the autograd history,
350 | // such that it can be savely returned from this function.
351 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous().variable_data();
352 |
353 | check_mpi_return_value(
354 | MPI_Bcast(input_cont.data_ptr(), input_cont.numel(),
355 | torch2mpitype(input_cont.scalar_type()),
356 | static_cast(root), comm)
357 | );
358 |
359 | return devhelper.fromMPIToDevice(input_cont);
360 | })();
361 | if (grad_fn) {
362 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
363 | }
364 | return result;
365 | }
366 |
367 | struct MPIReduceSumInPlaceBackward : public MPIBackwardNode {
368 | MPIReduceSumInPlaceBackward(int _root) : root(_root) {}
369 | variable_list apply(variable_list&& grads) override;
370 | std::string name() const override {
371 | return std::string("MPIReduceSumInPlaceBackward");
372 | }
373 |
374 | void release_variables() override {
375 | return;
376 | }
377 |
378 | int root;
379 | };
380 |
381 | variable_list MPIReduceSumInPlaceBackward::apply (variable_list&& grads)
382 | {
383 | variable_list grad_inputs(1);
384 | // TODO: for these simple functions the should_compute_output check is superfluous
385 | if (should_compute_output(0)) {
386 | // NOTE: It is probably safe to use in-place operations in the backward mode,
387 | // since I currently cannot think of any way how a bifurcation could
388 | // enter the DAG.
389 | // TODO: Proof that it is safe!
390 | grad_inputs[0] = comm.MPIBcast_(grads[0],root);
391 | }
392 | return grad_inputs;
393 | }
394 |
395 | struct MPINoInplaceBackward : public torch::autograd::Node {
396 | variable_list apply(variable_list&& grads) override
397 | {
398 | throw std::runtime_error("Reuse of variables passed to in-place MPI kernels not supported! Try using the return value");
399 | }
400 | std::string name() const override {
401 | return std::string("MPINoInplaceBackward");
402 | }
403 | };
404 |
405 | Tensor MPI_Comm_Wrapper::MPIReduce_(const Tensor& input, int64_t op, int64_t root)
406 | {
407 | // TODO: check for root being in int range
408 | std::shared_ptr grad_fn;
409 | auto mpiop = __get_mpi_op(op);
410 | if (torch::autograd::compute_requires_grad(input)) {
411 | if (op == mpi4torch_op_sum) {
412 | grad_fn = std::shared_ptr (new MPIReduceSumInPlaceBackward(static_cast(root)),
413 | torch::autograd::deleteNode);
414 | } else {
415 | grad_fn = std::shared_ptr (new MPIUnimplementedNode(), torch::autograd::deleteNode);
416 | }
417 | grad_fn->comm = *this;
418 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
419 | }
420 | auto result = ([&]() {
421 | at::AutoDispatchBelowADInplaceOrView guard;
422 |
423 | MPIDeviceHelper devhelper(input);
424 |
425 | // 1. Make input contiguous
426 | // 2. Call variable_data() to make a shallow copy of the input tensor without the autograd history,
427 | // such that it can be savely returned from this function.
428 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous().variable_data();
429 |
430 | const int rank = GetRank();
431 |
432 | void* sendbuf = input_cont.data_ptr();
433 | if (rank == root) {
434 | // One is only allowed to pass MPI_IN_PLACE for the root process
435 | // cf. https://stackoverflow.com/a/17744793
436 | sendbuf = MPI_IN_PLACE;
437 | }
438 |
439 | check_mpi_return_value(MPI_Reduce(sendbuf, input_cont.data_ptr(), input_cont.numel(),
440 | torch2mpitype(input_cont.scalar_type()),
441 | mpiop, static_cast(root), comm));
442 |
443 | if (rank != root) {
444 | // We fill the non-root results with zeros to make the function properly behaved.
445 | // TODO: We could potentially let the return-value be undefined and save some ops?
446 | input_cont.zero_();
447 | }
448 |
449 | return devhelper.fromMPIToDevice(input_cont);
450 | })();
451 | if (grad_fn) {
452 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
453 |
454 | // We only activate this safeguard if the input variable is not a leaf in the DAG,
455 | // otherwise we would get errors from the AccumulateGrad Node.
456 | if (input.grad_fn()) {
457 | // prohibit misuse of input in autograd
458 | auto& input_non_const = const_cast(input);
459 | set_history(input_non_const, std::shared_ptr(new MPINoInplaceBackward(),
460 | torch::autograd::deleteNode));
461 | }
462 | }
463 | return result;
464 | }
465 |
466 | struct MPIGatherBackward : public MPIBackwardNode {
467 | MPIGatherBackward(int64_t _gatheraxis, int64_t _root, int64_t _numelem)
468 | : gatheraxis(_gatheraxis), root(_root), numelem(_numelem) {}
469 | variable_list apply(variable_list&& grads) override;
470 | std::string name() const override {
471 | return std::string("MPIGatherBackward");
472 | }
473 |
474 | void release_variables() override {
475 | return;
476 | }
477 |
478 | int64_t gatheraxis;
479 | int64_t root;
480 | int64_t numelem;
481 | };
482 |
483 | variable_list MPIGatherBackward::apply (variable_list&& grads)
484 | {
485 | variable_list grad_inputs(1);
486 | // TODO: for these simple functions the should_compute_output check is superfluous
487 | if (should_compute_output(0)) {
488 | // pytorch/pytorch#79446 broke this code
489 | //auto next_node = next_edge(0).function;
490 | //auto input_nr = next_edge(0).input_nr;
491 | //const int64_t numelem = next_node->input_metadata(input_nr).shape()[(size_t) gatheraxis];
492 | grad_inputs[0] = comm.MPIScatter(grads[0],gatheraxis, numelem, root);
493 | }
494 | return grad_inputs;
495 | }
496 |
497 | Tensor MPI_Comm_Wrapper::MPIGather(const Tensor& input, int64_t gatheraxis, int64_t root)
498 | {
499 | // TODO: check for root being in int range
500 | std::shared_ptr grad_fn;
501 | if (torch::autograd::compute_requires_grad(input)) {
502 | grad_fn = std::shared_ptr
503 | (new MPIGatherBackward(gatheraxis, root, input.sizes()[(size_t) gatheraxis]),
504 | torch::autograd::deleteNode);
505 | grad_fn->comm = *this;
506 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
507 | }
508 | auto result = ([&]() {
509 | at::AutoDispatchBelowADInplaceOrView guard;
510 |
511 | MPIDeviceHelper devhelper(input);
512 |
513 | // 1. Make input contiguous
514 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous();
515 |
516 | auto sizes = input_cont.sizes();
517 | const size_t ndim = sizes.size();
518 | const int npes = (int)GetSize();
519 |
520 | int64_t beforegatheraxis64 = 1;
521 | for (size_t i = 0; i < (size_t)gatheraxis; ++i) {
522 | beforegatheraxis64 *= sizes[i];
523 | }
524 |
525 | int64_t aftergatheraxis64 = 1;
526 | for (size_t i = (size_t)gatheraxis + 1; i < ndim; ++i) {
527 | aftergatheraxis64 *= sizes[i];
528 | }
529 |
530 | if (beforegatheraxis64 > INT_MAX || aftergatheraxis64 * sizes[(size_t)gatheraxis] > INT_MAX) {
531 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!");
532 | }
533 |
534 | const int beforegatheraxis = (int)beforegatheraxis64;
535 | const int gatheraxissize = (int)sizes[(size_t)gatheraxis];
536 | const int aftergatheraxis = (int)aftergatheraxis64;
537 |
538 | std::vector recvcounts(npes); // TODO: only allocate on root process
539 |
540 | check_mpi_return_value(MPI_Gather(&gatheraxissize, 1, MPI_INT,
541 | &recvcounts[0], 1,
542 | MPI_INT, root, comm));
543 |
544 | std::vector displs(npes); // TODO: only allocate on root process
545 | //displs[0] = 0; // This is a noop
546 | for (size_t i = 1; i < (size_t)npes; ++i) {
547 | int64_t tmpadd = (int64_t) displs[i-1] + (int64_t) recvcounts[i-1];
548 |
549 | if (tmpadd > INT_MAX) {
550 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!");
551 | }
552 | displs[i] = (int) tmpadd;
553 | }
554 | const int newgatheraxissize = displs[npes-1] + recvcounts[npes-1]; // TODO: add overflow check
555 |
556 | MPI_Datatype tmpdatatype1;
557 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis, aftergatheraxis * gatheraxissize,
558 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype1));
559 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype1));
560 |
561 | MPI_Datatype sendtype;
562 | MPI_Aint basic_lb, basic_extent;
563 | check_mpi_return_value(MPI_Type_get_extent(torch2mpitype(input_cont.scalar_type()), &basic_lb,
564 | &basic_extent));
565 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype1, 0, aftergatheraxis * basic_extent,
566 | &sendtype));
567 | check_mpi_return_value(MPI_Type_commit(&sendtype));
568 |
569 | MPI_Datatype tmpdatatype2;
570 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis,
571 | newgatheraxissize * aftergatheraxis,
572 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype2));
573 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype2));
574 | MPI_Datatype recvtype;
575 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype2, 0, aftergatheraxis * basic_extent,
576 | &recvtype));
577 | check_mpi_return_value(MPI_Type_commit(&recvtype));
578 |
579 | std::vector newsizes(sizes.begin(), sizes.end());
580 | newsizes[(size_t)gatheraxis] = newgatheraxissize;
581 |
582 | auto recvtensor = torch::empty(newsizes, input_cont.options(), c10::MemoryFormat::Contiguous);
583 |
584 | check_mpi_return_value(MPI_Gatherv(input_cont.data_ptr(), gatheraxissize, sendtype,
585 | recvtensor.data_ptr(), &recvcounts[0], &displs[0], recvtype,
586 | static_cast(root), comm));
587 |
588 | check_mpi_return_value(MPI_Type_free(&tmpdatatype1));
589 | check_mpi_return_value(MPI_Type_free(&tmpdatatype2));
590 | check_mpi_return_value(MPI_Type_free(&sendtype));
591 | check_mpi_return_value(MPI_Type_free(&recvtype));
592 |
593 | return devhelper.fromMPIToDevice(recvtensor);
594 | })();
595 | if (grad_fn) {
596 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
597 | }
598 | return result;
599 | }
600 |
601 | struct MPIAllgatherBackward : public MPIBackwardNode {
602 | MPIAllgatherBackward(int64_t _gatheraxis, int64_t _numelem) : gatheraxis(_gatheraxis), numelem(_numelem) {}
603 | variable_list apply(variable_list&& grads) override;
604 | std::string name() const override {
605 | return std::string("MPIAllgatherBackward");
606 | }
607 |
608 | void release_variables() override {
609 | return;
610 | }
611 |
612 | int64_t gatheraxis;
613 | int64_t numelem;
614 | };
615 |
616 | variable_list MPIAllgatherBackward::apply (variable_list&& grads)
617 | {
618 | variable_list grad_inputs(1);
619 | // TODO: for these simple functions the should_compute_output check is superfluous
620 | if (should_compute_output(0)) {
621 | // pytorch/pytorch#79446 broke this code
622 | //auto next_node = next_edge(0).function;
623 | //auto input_nr = next_edge(0).input_nr;
624 | //const int64_t numelem = next_node->input_metadata(input_nr).shape()[(size_t) gatheraxis];
625 | grad_inputs[0] = comm.MPIScatter(grads[0], gatheraxis, numelem, 0);
626 | for (int64_t root = 1; root < comm.GetSize(); ++root) {
627 | grad_inputs[0] += comm.MPIScatter(grads[0], gatheraxis, numelem, 1);
628 | }
629 | }
630 | return grad_inputs;
631 | }
632 |
633 | Tensor MPI_Comm_Wrapper::MPIAllgather(const Tensor& input, int64_t gatheraxis)
634 | {
635 | // TODO: check for root being in int range
636 | std::shared_ptr grad_fn;
637 | if (torch::autograd::compute_requires_grad(input)) {
638 | grad_fn = std::shared_ptr
639 | (new MPIAllgatherBackward(gatheraxis, input.sizes()[(size_t) gatheraxis]), torch::autograd::deleteNode);
640 | grad_fn->comm = *this;
641 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
642 | }
643 | auto result = ([&]() {
644 | at::AutoDispatchBelowADInplaceOrView guard;
645 |
646 | MPIDeviceHelper devhelper(input);
647 |
648 | // 1. Make input contiguous
649 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous();
650 |
651 | auto sizes = input_cont.sizes();
652 | const size_t ndim = sizes.size();
653 | const int npes = (int)GetSize();
654 |
655 | int64_t beforegatheraxis64 = 1;
656 | for (size_t i = 0; i < (size_t)gatheraxis; ++i) {
657 | beforegatheraxis64 *= sizes[i];
658 | }
659 |
660 | int64_t aftergatheraxis64 = 1;
661 | for (size_t i = (size_t)gatheraxis + 1; i < ndim; ++i) {
662 | aftergatheraxis64 *= sizes[i];
663 | }
664 |
665 | if (beforegatheraxis64 > INT_MAX || aftergatheraxis64 * sizes[(size_t)gatheraxis] > INT_MAX) {
666 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!");
667 | }
668 |
669 | const int beforegatheraxis = (int)beforegatheraxis64;
670 | const int gatheraxissize = (int)sizes[(size_t)gatheraxis];
671 | const int aftergatheraxis = (int)aftergatheraxis64;
672 |
673 | std::vector recvcounts(npes);
674 |
675 | check_mpi_return_value(MPI_Allgather(&gatheraxissize, 1, MPI_INT,
676 | &recvcounts[0], 1,
677 | MPI_INT, comm));
678 |
679 | std::vector displs(npes);
680 | //displs[0] = 0; // This is a noop
681 | for (size_t i = 1; i < (size_t)npes; ++i) {
682 | int64_t tmpadd = (int64_t) displs[i-1] + (int64_t) recvcounts[i-1];
683 |
684 | if (tmpadd > INT_MAX) {
685 | throw std::runtime_error("MPI_Gather: Tensor sizes exceed INT_MAX!");
686 | }
687 | displs[i] = (int) tmpadd;
688 | }
689 | const int newgatheraxissize = displs[npes-1] + recvcounts[npes-1];
690 |
691 | MPI_Datatype tmpdatatype1;
692 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis, aftergatheraxis * gatheraxissize,
693 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype1));
694 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype1));
695 |
696 | MPI_Datatype sendtype;
697 | MPI_Aint basic_lb, basic_extent;
698 | check_mpi_return_value(MPI_Type_get_extent(torch2mpitype(input_cont.scalar_type()), &basic_lb,
699 | &basic_extent));
700 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype1, 0, aftergatheraxis * basic_extent,
701 | &sendtype));
702 | check_mpi_return_value(MPI_Type_commit(&sendtype));
703 |
704 | MPI_Datatype tmpdatatype2;
705 | check_mpi_return_value(MPI_Type_vector(beforegatheraxis, aftergatheraxis,
706 | newgatheraxissize * aftergatheraxis,
707 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype2));
708 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype2));
709 | MPI_Datatype recvtype;
710 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype2, 0, aftergatheraxis * basic_extent,
711 | &recvtype));
712 | check_mpi_return_value(MPI_Type_commit(&recvtype));
713 |
714 | std::vector newsizes(sizes.begin(), sizes.end());
715 | newsizes[(size_t)gatheraxis] = newgatheraxissize;
716 |
717 | auto recvtensor = torch::empty(newsizes, input_cont.options(), c10::MemoryFormat::Contiguous);
718 |
719 | check_mpi_return_value(MPI_Allgatherv(input_cont.data_ptr(), gatheraxissize, sendtype,
720 | recvtensor.data_ptr(), &recvcounts[0], &displs[0], recvtype,
721 | comm));
722 |
723 | check_mpi_return_value(MPI_Type_free(&tmpdatatype1));
724 | check_mpi_return_value(MPI_Type_free(&tmpdatatype2));
725 | check_mpi_return_value(MPI_Type_free(&sendtype));
726 | check_mpi_return_value(MPI_Type_free(&recvtype));
727 |
728 | return devhelper.fromMPIToDevice(recvtensor);
729 | })();
730 | if (grad_fn) {
731 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
732 | }
733 | return result;
734 | }
735 |
736 | struct MPIScatterBackward : public MPIBackwardNode {
737 | MPIScatterBackward(int64_t _scatteraxis, int64_t _root)
738 | : scatteraxis(_scatteraxis), root(_root) {}
739 | variable_list apply(variable_list&& grads) override;
740 | std::string name() const override {
741 | return std::string("MPIScatterBackward");
742 | }
743 |
744 | void release_variables() override {
745 | return;
746 | }
747 |
748 | int64_t scatteraxis;
749 | int64_t root;
750 | };
751 |
752 | variable_list MPIScatterBackward::apply (variable_list&& grads)
753 | {
754 | variable_list grad_inputs(1);
755 | // TODO: for these simple functions the should_compute_output check is superfluous
756 | if (should_compute_output(0)) {
757 | auto tmp = comm.MPIGather(grads[0],scatteraxis, root);
758 | if (comm.GetRank() == root) {
759 | grad_inputs[0] = tmp;
760 | } else {
761 | auto next_node = next_edge(0).function;
762 | auto input_nr = next_edge(0).input_nr;
763 | grad_inputs[0] = JoinDummies(next_node->input_metadata(input_nr).zeros_like(), {tmp});
764 | }
765 | }
766 | return grad_inputs;
767 | }
768 |
769 | Tensor MPI_Comm_Wrapper::MPIScatter(const Tensor& input, int64_t scatteraxis, int64_t numelem, int64_t root)
770 | {
771 | // TODO: check for root being in int range
772 | std::shared_ptr grad_fn;
773 | if (torch::autograd::compute_requires_grad(input)) {
774 | grad_fn = std::shared_ptr
775 | (new MPIScatterBackward(scatteraxis, root),
776 | torch::autograd::deleteNode);
777 | grad_fn->comm = *this;
778 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
779 | }
780 | auto result = ([&]() {
781 | at::AutoDispatchBelowADInplaceOrView guard;
782 |
783 | MPIDeviceHelper devhelper(input);
784 |
785 | // 1. Make input contiguous
786 | auto input_cont = GetRank() == root ? devhelper.fromDeviceToMPI(input).contiguous() : input;
787 |
788 | size_t ndim = input_cont.sizes().size();
789 | check_mpi_return_value(MPI_Bcast(&ndim, 1, MPI_LONG, root, comm));
790 | std::vector sizes;
791 | if (GetRank() == root) {
792 | sizes = std::move(input_cont.sizes().vec());
793 | } else {
794 | sizes.resize(ndim);
795 | }
796 | check_mpi_return_value(MPI_Bcast(&sizes[0], (int)ndim, MPI_LONG, root, comm));
797 |
798 | const int npes = (int)GetSize();
799 |
800 | int64_t beforescatteraxis64 = 1;
801 | for (size_t i = 0; i < (size_t)scatteraxis; ++i) {
802 | beforescatteraxis64 *= sizes[i];
803 | }
804 |
805 | int64_t afterscatteraxis64 = 1;
806 | for (size_t i = (size_t)scatteraxis + 1; i < ndim; ++i) {
807 | afterscatteraxis64 *= sizes[i];
808 | }
809 |
810 | if (beforescatteraxis64 > INT_MAX || afterscatteraxis64 * sizes[(size_t)scatteraxis] > INT_MAX) {
811 | throw std::runtime_error("MPI_Scatter: Tensor sizes exceed INT_MAX!");
812 | }
813 |
814 | const int beforescatteraxis = (int)beforescatteraxis64;
815 | const int scatteraxissize = (int)sizes[(size_t)scatteraxis];
816 | const int afterscatteraxis = (int)afterscatteraxis64;
817 | const int newscatteraxissize = (int) numelem;
818 |
819 | std::vector sendcounts(npes); // TODO: only allocate on root process
820 |
821 | check_mpi_return_value(MPI_Gather(&newscatteraxissize, 1, MPI_INT,
822 | &sendcounts[0], 1,
823 | MPI_INT, root, comm));
824 |
825 | std::vector displs(npes); // TODO: only allocate on root process
826 | //displs[0] = 0; // This is a noop
827 | for (size_t i = 1; i < (size_t)npes; ++i) {
828 | int64_t tmpadd = (int64_t) displs[i-1] + (int64_t) sendcounts[i-1];
829 |
830 | if (tmpadd > INT_MAX) {
831 | throw std::runtime_error("MPI_Scatter: Tensor sizes exceed INT_MAX!");
832 | }
833 | displs[i] = (int) tmpadd;
834 | }
835 | if (root == GetRank() && scatteraxissize != displs[npes-1] + sendcounts[npes-1]) {
836 | throw std::runtime_error("MPI_Scatter: finaltensor.shape[scatteraxis] != sum(numelem)!");
837 | }
838 |
839 | MPI_Datatype tmpdatatype1;
840 | check_mpi_return_value(MPI_Type_vector(beforescatteraxis, afterscatteraxis,
841 | afterscatteraxis * scatteraxissize,
842 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype1));
843 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype1));
844 |
845 | MPI_Datatype sendtype;
846 | MPI_Aint basic_lb, basic_extent;
847 | check_mpi_return_value(MPI_Type_get_extent(torch2mpitype(input_cont.scalar_type()), &basic_lb,
848 | &basic_extent));
849 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype1, 0, afterscatteraxis * basic_extent,
850 | &sendtype));
851 | check_mpi_return_value(MPI_Type_commit(&sendtype));
852 |
853 | MPI_Datatype tmpdatatype2;
854 | check_mpi_return_value(MPI_Type_vector(beforescatteraxis, afterscatteraxis,
855 | newscatteraxissize * afterscatteraxis,
856 | torch2mpitype(input_cont.scalar_type()), &tmpdatatype2));
857 | check_mpi_return_value(MPI_Type_commit(&tmpdatatype2));
858 | MPI_Datatype recvtype;
859 | check_mpi_return_value(MPI_Type_create_resized(tmpdatatype2, 0, afterscatteraxis * basic_extent,
860 | &recvtype));
861 | check_mpi_return_value(MPI_Type_commit(&recvtype));
862 |
863 | std::vector newsizes(sizes.begin(), sizes.end());
864 | newsizes[(size_t)scatteraxis] = newscatteraxissize;
865 |
866 | auto recvtensor = torch::empty(newsizes, input_cont.options().device(devhelper.mpidevice),
867 | c10::MemoryFormat::Contiguous);
868 |
869 | check_mpi_return_value(MPI_Scatterv(input_cont.data_ptr(), &sendcounts[0], &displs[0], sendtype,
870 | recvtensor.data_ptr(), newscatteraxissize, recvtype,
871 | static_cast(root), comm));
872 |
873 | check_mpi_return_value(MPI_Type_free(&tmpdatatype1));
874 | check_mpi_return_value(MPI_Type_free(&tmpdatatype2));
875 | check_mpi_return_value(MPI_Type_free(&sendtype));
876 | check_mpi_return_value(MPI_Type_free(&recvtype));
877 |
878 | return devhelper.fromMPIToDevice(recvtensor);
879 | })();
880 | if (grad_fn) {
881 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
882 | }
883 | return result;
884 | }
885 |
886 | struct MPIAlltoallBackward : public MPIBackwardNode {
887 | MPIAlltoallBackward(int64_t _gatheraxis, int64_t _scatteraxis, int64_t _numelem)
888 | : gatheraxis(_gatheraxis), scatteraxis(_scatteraxis), numelem(_numelem) {}
889 | variable_list apply(variable_list&& grads) override;
890 | std::string name() const override {
891 | return std::string("MPIAlltoallBackward");
892 | }
893 |
894 | void release_variables() override {
895 | return;
896 | }
897 |
898 | int64_t gatheraxis;
899 | int64_t scatteraxis;
900 | int64_t numelem;
901 | };
902 |
903 | variable_list MPIAlltoallBackward::apply (variable_list&& grads)
904 | {
905 | variable_list grad_inputs(1);
906 | // TODO: for these simple functions the should_compute_output check is superfluous
907 | if (should_compute_output(0)) {
908 | // pytorch/pytorch#79446 broke this code
909 | //auto next_node = next_edge(0).function;
910 | //auto input_nr = next_edge(0).input_nr;
911 | //const int64_t numelem = next_node->input_metadata(input_nr).shape()[(size_t) gatheraxis];
912 | grad_inputs[0] = comm.MPIAlltoall(grads[0], scatteraxis, gatheraxis, numelem);
913 | }
914 | return grad_inputs;
915 | }
916 |
917 | Tensor MPI_Comm_Wrapper::MPIAlltoall(const Tensor& input, int64_t gatheraxis, int64_t scatteraxis, int64_t numelem)
918 | {
919 | // TODO: check for root being in int range
920 | std::shared_ptr grad_fn;
921 | if (torch::autograd::compute_requires_grad(input)) {
922 | grad_fn = std::shared_ptr
923 | (new MPIAlltoallBackward(gatheraxis, scatteraxis, input.sizes()[(size_t) gatheraxis]),
924 | torch::autograd::deleteNode);
925 | grad_fn->comm = *this;
926 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
927 | }
928 | auto result = ([&]() {
929 | at::AutoDispatchBelowADInplaceOrView guard;
930 |
931 | // 1. Make input contiguous
932 | auto input_cont = input.contiguous();
933 |
934 |
935 | // TODO: This is probably not the best solution latency-wise, but the total number
936 | // of bytes sent should match roughly a solution that uses Alltoallw.
937 | // If the latency turns out to become a problem, we should switch to Alltoallw.
938 | // Memory usage could also become an issue, since this solution requires
939 | // temporarily twice the memory that an Alltoallw solution would require.
940 | std::vector scattered_tensors;
941 |
942 | if (gatheraxis != scatteraxis) {
943 | // This is the easy case
944 | for(int64_t root = 0; root < GetSize(); ++root) {
945 | scattered_tensors.emplace_back(MPIScatter(input_cont, scatteraxis, numelem, root));
946 | }
947 | } else {
948 | // if gather- and scatteraxis coincide we first need to figure out who receives how many
949 | // elements from whom
950 | //
951 | // TODO: this could probably be solved more efficiently, with few more communication roundtrips
952 |
953 | const size_t npes = (size_t)GetSize();
954 | const size_t rank = (size_t)GetRank();
955 | std::vector numelem_cur(npes+1);
956 | std::vector numelem_new(npes+1);
957 |
958 | const int tmp1 = input_cont.sizes()[(size_t)gatheraxis];
959 | check_mpi_return_value(MPI_Allgather(&tmp1, 1, MPI_INT,
960 | &numelem_cur[1], 1,
961 | MPI_INT, comm));
962 | const int tmp2 = numelem;
963 | check_mpi_return_value(MPI_Allgather(&tmp2, 1, MPI_INT,
964 | &numelem_new[1], 1,
965 | MPI_INT, comm));
966 | for (size_t i = 1; i < npes; ++i) {
967 | numelem_cur[i+1] += numelem_cur[i];
968 | numelem_new[i+1] += numelem_new[i];
969 | }
970 |
971 | for(int64_t root = 0; static_cast(root) < npes; ++root) {
972 | int64_t localnumelem = std::min(numelem_new[rank+1],numelem_cur[root+1])
973 | - std::max(numelem_new[rank],numelem_cur[root]);
974 | if (localnumelem < 0) {
975 | localnumelem = 0;
976 | }
977 | scattered_tensors.emplace_back(MPIScatter(input_cont, scatteraxis, localnumelem, root));
978 | }
979 | }
980 |
981 | return at::cat(scattered_tensors, gatheraxis);
982 | })();
983 | if (grad_fn) {
984 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
985 | }
986 | return result;
987 | }
988 |
989 | struct JoinDummiesBackward : public torch::autograd::Node {
990 | variable_list apply(variable_list&& grads) override;
991 | std::string name() const override {
992 | return std::string("JoinDummiesBackward");
993 | }
994 |
995 | void handle_dummies_helper(const size_t first, variable_list& grad_inputs);
996 |
997 | // TODO: I am not sure whether it is wise here to circumvent the saved_variables list,
998 | // and do our own thing. It might be that we create a memory leak that way!
999 | Tensor loopthrough;
1000 | };
1001 |
1002 | void JoinDummiesBackward::handle_dummies_helper(const size_t first, variable_list& grad_inputs)
1003 | {
1004 | for (size_t i = first; i < num_outputs(); ++i) {
1005 | if (should_compute_output(i)) {
1006 | auto next_node = next_edge(i).function;
1007 | auto input_nr = next_edge(i).input_nr;
1008 | grad_inputs[i] = JoinDummies(next_node->input_metadata(input_nr).zeros_like(),{loopthrough});
1009 | }
1010 | }
1011 | }
1012 |
1013 | variable_list JoinDummiesBackward::apply (variable_list&& grads)
1014 | {
1015 | size_t numoutputs = num_outputs();
1016 | variable_list grad_inputs(numoutputs);
1017 | if (should_compute_output(0)) {
1018 | grad_inputs[0] = JoinDummies(grads[0], {loopthrough});
1019 | }
1020 | handle_dummies_helper(1,grad_inputs);
1021 | return grad_inputs;
1022 | }
1023 |
1024 | Tensor JoinDummies(const Tensor& loopthrough, const variable_list& list)
1025 | {
1026 | std::shared_ptr grad_fn;
1027 | if (torch::autograd::compute_requires_grad(list)) {
1028 | grad_fn = std::shared_ptr (new JoinDummiesBackward(), torch::autograd::deleteNode);
1029 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(loopthrough,list));
1030 | } else {
1031 | // if none of the dummy variables needs a gradient, we just return the loopthrough variable
1032 | return loopthrough;
1033 | }
1034 | auto result = ([&]() {
1035 | at::AutoDispatchBelowADInplaceOrView guard;
1036 |
1037 | auto res = loopthrough.variable_data();
1038 | return res;
1039 | })();
1040 | // Checking for grad_fn is unneccessary
1041 | //if (grad_fn) {
1042 | grad_fn->loopthrough = result;
1043 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
1044 | //}
1045 | return result;
1046 | }
1047 |
1048 | enum NonBlockingOp
1049 | {
1050 | Isend_Op,
1051 | Irecv_Op,
1052 | };
1053 |
1054 | struct MPINonBlockingBackward : public MPIBackwardNode {
1055 | variable_list apply(variable_list&& grads) override;
1056 | std::string name() const override {
1057 | return std::string("MPINonBlockingBackward");
1058 | }
1059 | };
1060 |
1061 | variable_list MPINonBlockingBackward::apply(variable_list&& grads)
1062 | {
1063 | variable_list grad_inputs(1);
1064 | // TODO: superfluous check??
1065 | if (should_compute_output(0)) {
1066 | grad_inputs[0] = comm.MPIWait(grads);
1067 | }
1068 | return grad_inputs;
1069 | }
1070 |
1071 | variable_list MPI_Comm_Wrapper::MPIIsend(const Tensor& input, int64_t dest, int64_t tag)
1072 | {
1073 | // TODO: check for dest and tag being in int's range
1074 | std::shared_ptr grad_fn;
1075 | if (torch::autograd::compute_requires_grad(input)) {
1076 | grad_fn = std::shared_ptr (new MPINonBlockingBackward(), torch::autograd::deleteNode);
1077 | grad_fn->comm = *this;
1078 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
1079 | }
1080 | auto result = ([&]() {
1081 | at::AutoDispatchBelowADInplaceOrView guard;
1082 |
1083 | MPIDeviceHelper devhelper(input);
1084 |
1085 | // make input contiguous
1086 | // we call variable_data() since we also return the input buffer to ensure it stays in scope
1087 | auto input_cont = devhelper.fromDeviceToMPI(input).contiguous().variable_data();
1088 |
1089 | MPI_Request req;
1090 | check_mpi_return_value(MPI_Isend(input_cont.data_ptr(), input_cont.numel(),
1091 | torch2mpitype(input_cont.scalar_type()), static_cast(dest),
1092 | static_cast(tag), comm, &req));
1093 |
1094 | auto ret = torch::empty({7},at::kDouble);
1095 | auto fortran_handle = MPI_Request_c2f(req);
1096 | ret[0] = static_cast(fortran_handle);
1097 | ret[1] = static_cast(Isend_Op);
1098 | ret[2] = static_cast(dest);
1099 | ret[3] = static_cast(tag);
1100 | ret[4] = static_cast(0xFFFFFFFF & std::hash()(input_cont.data_ptr()));
1101 | ret[5] = static_cast(devhelper.devicetype);
1102 | ret[6] = static_cast(devhelper.device.index());
1103 | variable_list retlist;
1104 | retlist.push_back(ret);
1105 | retlist.push_back(input_cont); // make sure the buffer stays in scope!!!
1106 | retlist.push_back(input.variable_data());
1107 | return retlist;
1108 | })();
1109 | if (grad_fn) {
1110 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
1111 | }
1112 | return result;
1113 | }
1114 |
1115 | variable_list MPI_Comm_Wrapper::MPIIrecv(const Tensor& input, int64_t source, int64_t tag)
1116 | {
1117 | // TODO: check for dest and tag being in int's range
1118 | std::shared_ptr grad_fn;
1119 | if (torch::autograd::compute_requires_grad(input)) {
1120 | grad_fn = std::shared_ptr (new MPINonBlockingBackward(), torch::autograd::deleteNode);
1121 | grad_fn->comm = *this;
1122 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
1123 | }
1124 | auto result = ([&]() {
1125 | at::AutoDispatchBelowADInplaceOrView guard;
1126 |
1127 | MPIDeviceHelper devhelper(input);
1128 |
1129 | // TODO: check whether input is contiguous
1130 | // TODO: Maybe add warning if input device is not mpi device
1131 | auto input_cont = devhelper.fromDeviceToMPI(input).variable_data();
1132 |
1133 | MPI_Request req;
1134 | check_mpi_return_value(MPI_Irecv(input_cont.data_ptr(), input_cont.numel(),
1135 | torch2mpitype(input_cont.scalar_type()), static_cast(source),
1136 | static_cast(tag), comm, &req));
1137 |
1138 | auto ret = torch::empty({7},at::kDouble);
1139 | auto fortran_handle = MPI_Request_c2f(req);
1140 | ret[0] = static_cast(fortran_handle);
1141 | ret[1] = static_cast(Irecv_Op);
1142 | ret[2] = static_cast(source);
1143 | ret[3] = static_cast(tag);
1144 | ret[4] = static_cast(0xFFFFFFFF & std::hash()(input_cont.data_ptr()));
1145 | ret[5] = static_cast(devhelper.devicetype);
1146 | ret[6] = static_cast(devhelper.device.index());
1147 | variable_list retlist;
1148 | retlist.push_back(ret);
1149 | retlist.push_back(input_cont); // We ensure that the buffer stays in scope and is not garbage collected
1150 | retlist.push_back(input.variable_data()); // We do this for symmetry reasons, but it is more ISend which needs this
1151 | return retlist;
1152 | })();
1153 | if (grad_fn) {
1154 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
1155 | }
1156 | return result;
1157 | }
1158 |
1159 | struct MPIWaitBackward : public MPIBackwardNode {
1160 | MPIWaitBackward(NonBlockingOp op, int64_t sourcedest_, int64_t tag_)
1161 | : operation(op), sourcedest(sourcedest_), tag(tag_+10)
1162 | {}
1163 | variable_list apply(variable_list&& grads) override;
1164 | std::string name() const override {
1165 | return std::string("MPIWaitBackward");
1166 | }
1167 |
1168 | NonBlockingOp operation;
1169 | int64_t sourcedest;
1170 | int64_t tag;
1171 | };
1172 |
1173 | variable_list MPIWaitBackward::apply(variable_list&& grads)
1174 | {
1175 | // TODO: maybe we should add some offset to the tag, since who knows how intertwined the forward and backward
1176 | // operations can become. Beware of the special value MPI_ANY_TAG
1177 |
1178 | // TODO: superfluous check??
1179 | if (should_compute_output(0)) {
1180 | auto next_node = next_edge(1).function;
1181 | auto input_nr = next_edge(1).input_nr;
1182 |
1183 | // Some rationale for this error checking:
1184 | // Everytime when in forward mode a variables is used multiple times, i.e. where the forward graph
1185 | // bifurcates, the backward DAG, as generated by pytorch, has to assume that in the reverse operation
1186 | // the two incoming gradients need to add up. This is however unfortunate for the second part of our
1187 | // wait handle, since this is where we store the send/receive buffer. If one would now add e.g. a zero
1188 | // to this receive buffer a completely new buffer would be created from pytorch, and the original
1189 | // receive buffer would go out of scope, would then be garbage-collected and any actual receive operation
1190 | // potentially yields a seg-fault. On top of that the buffer as returned by MPIWait quite likely
1191 | // contains the wrong result, since MPIWait returns the from pytorch created buffer that stores the
1192 | // result from the addition operation.
1193 | //
1194 | // NOTE: This check cannot capture all cases of misuse. The additional storage of parts/hashes
1195 | // of the respective data_ptr aims at catching the same type of errors.
1196 | if (next_node->name() != "MPINonBlockingBackward") {
1197 | std::ostringstream oss;
1198 | oss << "mpi4torch: Detected bifurcation in MPIWait handle usage. Next node in DAG"
1199 | " should be MPINonBlockingBackward, but is "
1200 | << next_node->name() << "!";
1201 | throw std::runtime_error(oss.str());
1202 | }
1203 | switch(operation) {
1204 | case Isend_Op:
1205 | {
1206 | auto buf = next_node->input_metadata(input_nr).zeros_like();
1207 | return comm.MPIIrecv(JoinDummies(buf,grads), sourcedest, tag);
1208 | }
1209 | case Irecv_Op:
1210 | {
1211 | return comm.MPIIsend(grads[0], sourcedest, tag);
1212 | }
1213 | default:
1214 | throw std::runtime_error("Unsupported NonBlockingOp!");
1215 | }
1216 | }
1217 | return variable_list();
1218 | }
1219 |
1220 | Tensor MPI_Comm_Wrapper::MPIWait(const variable_list& input)
1221 | {
1222 | auto fortran_handle = static_cast(input[0][0].item());
1223 | MPI_Request req = MPI_Request_f2c(fortran_handle);
1224 | NonBlockingOp operation = static_cast(input[0][1].item());
1225 | auto sourcedest = static_cast(input[0][2].item());
1226 | auto tag = static_cast(input[0][3].item());
1227 | auto hashvalue = static_cast(input[0][4].item());
1228 | auto devicetype = static_cast(input[0][5].item());
1229 | auto deviceindex = static_cast(input[0][6].item());
1230 |
1231 | if (hashvalue != (0xFFFFFFFF & std::hash()(input[1].data_ptr()))) {
1232 | std::ostringstream oss;
1233 | oss << "mpi4torch: Detected bifurcation in MPIWait handle usage. "
1234 | "Modifying or consuming the handle by other functions than functions from the "
1235 | "MPIWait class is prohibited!";
1236 | throw std::runtime_error(oss.str());
1237 | }
1238 |
1239 | std::shared_ptr grad_fn;
1240 | if(torch::autograd::compute_requires_grad(input)) {
1241 | grad_fn = std::shared_ptr(new MPIWaitBackward(operation, sourcedest, tag),torch::autograd::deleteNode);
1242 | grad_fn->comm = *this;
1243 | grad_fn->set_next_edges(torch::autograd::collect_next_edges(input));
1244 | }
1245 | auto result = ([&]() {
1246 | at::AutoDispatchBelowADInplaceOrView guard;
1247 |
1248 | MPI_Status status; // TODO: Handle use cases for MPI_Status
1249 | check_mpi_return_value(MPI_Wait(&req, & status));
1250 |
1251 | if (operation == Isend_Op) {
1252 | // We do not do any device conversion for Isend, we then simply return the initial tensor
1253 | return input[2].variable_data();
1254 | }
1255 |
1256 | MPIDeviceHelper devhelper(c10::Device((c10::DeviceType)devicetype, deviceindex));
1257 |
1258 | // return a shallow copy of the second input tensor without the autograd strings attached
1259 | return devhelper.fromMPIToDevice(input[1]).variable_data();
1260 | })();
1261 | if (grad_fn) {
1262 | set_history(torch::autograd::flatten_tensor_args(result), grad_fn);
1263 | }
1264 | return result;
1265 | }
1266 |
1267 | }
1268 |
1269 | // New-style API is more similar to pybind11, and will in the future be:
1270 | static auto mpi_comm_wrapper_registry = torch::class_<::MPI_Comm_Wrapper>("mpi4torch", "MPI_Comm_Wrapper")
1271 | .def("GetRank", &MPI_Comm_Wrapper::GetRank)
1272 | .def("GetSize", &MPI_Comm_Wrapper::GetSize)
1273 | .def("Allreduce", &MPI_Comm_Wrapper::MPIAllreduce)
1274 | .def("Bcast_", &MPI_Comm_Wrapper::MPIBcast_)
1275 | .def("Reduce_", &MPI_Comm_Wrapper::MPIReduce_)
1276 | .def("Gather", &MPI_Comm_Wrapper::MPIGather)
1277 | .def("Allgather", &MPI_Comm_Wrapper::MPIAllgather)
1278 | .def("Scatter", &MPI_Comm_Wrapper::MPIScatter)
1279 | .def("Alltoall", &MPI_Comm_Wrapper::MPIAlltoall)
1280 | .def("Isend", &MPI_Comm_Wrapper::MPIIsend)
1281 | .def("Irecv", &MPI_Comm_Wrapper::MPIIrecv)
1282 | .def("Wait", &MPI_Comm_Wrapper::MPIWait)
1283 | .def_pickle([](const c10::intrusive_ptr& self) -> std::string
1284 | {
1285 | if (self->comm != MPI_COMM_WORLD) {
1286 | throw std::runtime_error("MPI communicators other than MPI_COMM_WORLD are not serializable!");
1287 | }
1288 | return std::string("MPI_COMM_WORLD");
1289 | },
1290 | [](std::string input) -> c10::intrusive_ptr
1291 | {
1292 | if (input == std::string("MPI_COMM_WORLD")) {
1293 | throw std::runtime_error("Unknown MPI communicator");
1294 | }
1295 | return c10::make_intrusive(MPI_COMM_WORLD);
1296 | }
1297 | )
1298 | ;
1299 |
1300 | // Old-style registration API until pytorch 1.4.0 is
1301 | static auto registry = torch::RegisterOperators()
1302 | .op("mpi4torch::COMM_WORLD", &comm_world)
1303 | .op("mpi4torch::comm_from_fortran", &comm_from_fortran)
1304 | .op("mpi4torch::JoinDummies", &JoinDummies);
1305 |
1306 | #if defined(OPEN_MPI)
1307 | #if OMPI_MAJOR_VERSION < 3
1308 | #define _GNU_SOURCE
1309 | #include
1310 | #endif
1311 | #endif
1312 |
1313 | struct __MPI_Finalizer
1314 | {
1315 | ~__MPI_Finalizer()
1316 | {
1317 | check_mpi_return_value(MPI_Finalize());
1318 | }
1319 | };
1320 |
1321 | static std::unique_ptr<__MPI_Finalizer> __finalizer;
1322 |
1323 | static void __mpi4torch_mpi_init()
1324 | {
1325 | #if defined(OPEN_MPI)
1326 | #if OMPI_MAJOR_VERSION < 3
1327 | // There are some issues with older versions of OpenMPI which have difficulties with dlopen-ing the MPI
1328 | // libraries the way it is done within pythons extension system (i.e. with RTLD_LOCAL) [1].
1329 | // In detail:
1330 | // - cpython loads this extension as a shared libarry with RTLD_LOCAL
1331 | // - this implies that several openmpi functions are not globally visible
1332 | // - during initalization openmpi itself dynamically loads shared objects,
1333 | // which backreference to some of the previously loaded (but not visible) symbols
1334 | // - openmpi crashes during the MPI_Init* calls since these backreferenced symbols could
1335 | // not be resolved
1336 | //
1337 | // [1] https://github.com/open-mpi/ompi/issues/3705
1338 |
1339 | // In principle we just want to reload libmpi with RTLD_GLOBAL specified before calling MPI_Init_thread
1340 | // to make the respective symbols visible.
1341 | // But we first need to find out by which path name the library goes we try to reload.
1342 |
1343 | // To do so, we query for the address of a known symbol in libmpi, e.g. MPI_Init_thread
1344 | const void* mpi_init_thread_symbol = dlsym(RTLD_DEFAULT, "MPI_Init_thread");
1345 |
1346 | if (mpi_init_thread_symbol == nullptr) {
1347 | // in principle we should never be able to reach this point, but better safe than sorry
1348 | throw std::runtime_error(std::string("mpi4torch failed with: ")+std::string(dlerror()));
1349 | }
1350 |
1351 | // Now we query the library info in which the symbol resides
1352 | Dl_info dlinfo;
1353 | if (!dladdr(mpi_init_thread_symbol, &dlinfo)) {
1354 | throw std::runtime_error(std::string("mpi4torch failed with: ")+std::string(dlerror()));
1355 | }
1356 |
1357 | // dlinfo.dli_fname should contain the pathname of the mpi library
1358 |
1359 | // As the man page of dlopen suggests, to promote the already loaded symbols from RTLD_LOCAL to RTLD_GLOBAL,
1360 | // we just need to reopen the library with RTLD_NOLOAD | RTLD_GLOBAL flags.
1361 | const void* mpi_lib_handle = dlopen(dlinfo.dli_fname, RTLD_NOW | RTLD_NOLOAD | RTLD_GLOBAL);
1362 | if (!mpi_lib_handle) {
1363 | throw std::runtime_error(std::string("mpi4torch failed with: ")+std::string(dlerror()));
1364 | }
1365 | #endif
1366 | #endif
1367 |
1368 | int mpi_initialized = -1;
1369 | check_mpi_return_value(MPI_Initialized(&mpi_initialized));
1370 | if (!mpi_initialized) {
1371 | int provided = 0;
1372 | // We play safe and initialize for multithreaded use cases
1373 | check_mpi_return_value(MPI_Init_thread(NULL, NULL, MPI_THREAD_MULTIPLE, &provided));
1374 |
1375 | // TODO: maybe in the future we actually may depend on multithreading,
1376 | // so we leave this code here
1377 | #if 0
1378 | if (provided != MPI_THREAD_MULTIPLE) {
1379 | std::cerr << "mpi4torch WARNING: MPI version does not provide full multithreading support!\n";
1380 | std::cerr << "mpi4torch WARNING: This may crash mpi4torch.\n";
1381 | }
1382 | #endif
1383 | __finalizer = std::make_unique<__MPI_Finalizer>();
1384 | } else {
1385 | #if 0
1386 | int provided = 0;
1387 | check_mpi_return_value(MPI_Query_thread(&provided));
1388 | if (provided != MPI_THREAD_MULTIPLE) {
1389 | std::cerr << "mpi4torch WARNING: MPI is already initialized but not with MPI_THREAD_MULTIPLE!\n";
1390 | std::cerr << "mpi4torch WARNING: This may crash mpi4torch.\n";
1391 | }
1392 | #endif
1393 | }
1394 | }
1395 |
1396 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1397 | // Initialize MPI
1398 | __mpi4torch_mpi_init();
1399 |
1400 | #if MPI4TORCH_BUILT_WITH_CUDA_AWARENESS
1401 | __setup_have_cuda_aware_mpi_support();
1402 | #endif
1403 |
1404 | m.def("deactivate_cuda_aware_mpi_support",&deactivate_cuda_aware_mpi_support,
1405 | "Deactivates the CUDA-aware MPI support.\n"
1406 | "\n"
1407 | "Calling this function forces mpi4torch to first move any tensor into main memory before\n"
1408 | "calling a MPI function on it, and then to move the result back into device memory after\n"
1409 | "the MPI call has finished.\n"
1410 | "\n"
1411 | "Note\n"
1412 | "----\n"
1413 | " This function is useful in situations in which MPI advertises CUDA-awareness but the\n"
1414 | " functionality is not really supported by the backend.\n");
1415 |
1416 | // Torchscript does not like the pybind11 enum_ solution
1417 | //py::enum_(m, "MPI_Op")
1418 | // .value("MPI_MAX", Mpi4torchCollectiveOps::mpi4torch_op_max)
1419 | // .value("MPI_MIN", Mpi4torchCollectiveOps::mpi4torch_op_min)
1420 | // .value("MPI_SUM", Mpi4torchCollectiveOps::mpi4torch_op_sum)
1421 | // .value("MPI_PROD", Mpi4torchCollectiveOps::mpi4torch_op_prod)
1422 | // .export_values();
1423 |
1424 | m.attr("MPI_MAX") = py::int_((int64_t)mpi4torch_op_max);
1425 | m.attr("MPI_MIN") = py::int_((int64_t)mpi4torch_op_min);
1426 | m.attr("MPI_SUM") = py::int_((int64_t)mpi4torch_op_sum);
1427 | m.attr("MPI_PROD") = py::int_((int64_t)mpi4torch_op_prod);
1428 | m.attr("MPI_LAND") = py::int_((int64_t)mpi4torch_op_land);
1429 | m.attr("MPI_BAND") = py::int_((int64_t)mpi4torch_op_band);
1430 | m.attr("MPI_LOR") = py::int_((int64_t)mpi4torch_op_lor);
1431 | m.attr("MPI_BOR") = py::int_((int64_t)mpi4torch_op_bor);
1432 | m.attr("MPI_LXOR") = py::int_((int64_t)mpi4torch_op_lxor);
1433 | m.attr("MPI_BXOR") = py::int_((int64_t)mpi4torch_op_bxor);
1434 | m.attr("MPI_MINLOC") = py::int_((int64_t)mpi4torch_op_minloc);
1435 | m.attr("MPI_MAXLOC") = py::int_((int64_t)mpi4torch_op_maxloc);
1436 | }
1437 |
1438 |
--------------------------------------------------------------------------------