├── .clang-format ├── .flake8 ├── .github ├── actions │ └── build │ │ └── action.yaml └── workflows │ └── main.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── hipify.py ├── include ├── torch_ucc.hpp ├── torch_ucc_comm.hpp └── torch_ucc_tracing.hpp ├── requirements-flake8.txt ├── setup.py ├── src ├── torch_ucc.cpp ├── torch_ucc_comm.cpp ├── torch_ucc_init.cpp ├── torch_ucc_init_oss.cpp └── torch_ucc_tracing.cpp └── test ├── blocking_wait_test.py ├── start_test.sh ├── torch_allgather_test.py ├── torch_allreduce_test.py ├── torch_alltoall_bench.py ├── torch_alltoall_test.py ├── torch_alltoallv_test.py ├── torch_barrier_test.py ├── torch_bcast_test.py ├── torch_gather_test.py ├── torch_init_test.py ├── torch_multiple_comms_test.py ├── torch_pg_ucc_test.py ├── torch_pt2pt_test.py ├── torch_reduce_scatter_test.py ├── torch_reduce_test.py ├── torch_sendrecv_test.py ├── torch_tests.py ├── torch_timeout_test.py ├── torch_ucc_test_setup.py └── torch_work_test.py /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -1 3 | AlignAfterOpenBracket: AlwaysBreak 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlinesLeft: true 7 | AlignOperands: false 8 | AlignTrailingComments: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: Empty 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterReturnType: None 16 | AlwaysBreakBeforeMultilineStrings: true 17 | AlwaysBreakTemplateDeclarations: true 18 | BinPackArguments: false 19 | BinPackParameters: false 20 | BraceWrapping: 21 | AfterClass: false 22 | AfterControlStatement: false 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterObjCDeclaration: false 27 | AfterStruct: false 28 | AfterUnion: false 29 | BeforeCatch: false 30 | BeforeElse: false 31 | IndentBraces: false 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeBraces: Attach 34 | BreakBeforeTernaryOperators: true 35 | BreakConstructorInitializersBeforeComma: false 36 | BreakAfterJavaFieldAnnotations: false 37 | BreakStringLiterals: false 38 | ColumnLimit: 80 39 | CommentPragmas: '^ IWYU pragma:' 40 | CompactNamespaces: false 41 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 42 | ConstructorInitializerIndentWidth: 4 43 | ContinuationIndentWidth: 4 44 | Cpp11BracedListStyle: true 45 | DerivePointerAlignment: false 46 | DisableFormat: false 47 | ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ] 48 | IncludeCategories: 49 | - Regex: '^<.*\.h(pp)?>' 50 | Priority: 1 51 | - Regex: '^<.*' 52 | Priority: 2 53 | - Regex: '.*' 54 | Priority: 3 55 | IndentCaseLabels: true 56 | IndentWidth: 2 57 | IndentWrappedFunctionNames: false 58 | KeepEmptyLinesAtTheStartOfBlocks: false 59 | MacroBlockBegin: '' 60 | MacroBlockEnd: '' 61 | MaxEmptyLinesToKeep: 1 62 | NamespaceIndentation: None 63 | ObjCBlockIndentWidth: 2 64 | ObjCSpaceAfterProperty: false 65 | ObjCSpaceBeforeProtocolList: false 66 | PenaltyBreakBeforeFirstCallParameter: 1 67 | PenaltyBreakComment: 300 68 | PenaltyBreakFirstLessLess: 120 69 | PenaltyBreakString: 1000 70 | PenaltyExcessCharacter: 1000000 71 | PenaltyReturnTypeOnItsOwnLine: 2000000 72 | PointerAlignment: Left 73 | ReflowComments: true 74 | SortIncludes: true 75 | SpaceAfterCStyleCast: false 76 | SpaceBeforeAssignmentOperators: true 77 | SpaceBeforeParens: ControlStatements 78 | SpaceInEmptyParentheses: false 79 | SpacesBeforeTrailingComments: 1 80 | SpacesInAngles: false 81 | SpacesInContainerLiterals: true 82 | SpacesInCStyleCastParentheses: false 83 | SpacesInParentheses: false 84 | SpacesInSquareBrackets: false 85 | Standard: Cpp11 86 | TabWidth: 8 87 | UseTab: Never 88 | ... 89 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,T4,W,B9 3 | max-line-length = 120 4 | # C408 ignored because we like the dict keyword argument syntax 5 | # E501 is not flexible enough, we're using B950 instead 6 | ignore = 7 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 8 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 9 | # to line this up with executable bit 10 | EXE001, 11 | # these ignores are from flake8-bugbear; please fix! 12 | B007,B008, 13 | # these ignores are from flake8-comprehensions; please fix! 14 | C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415, 15 | # bare 'except': ignored by fbcode lints; please fix! 16 | B001,E722, 17 | # import *: ignored by fbcode lints; please fix! 18 | F403 19 | per-file-ignores = __init__.py: F401 torch/utils/cpp_extension.py: B950 20 | optional-ascii-coding = True 21 | -------------------------------------------------------------------------------- /.github/actions/build/action.yaml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | description: 'Build UCX and UCC, and then install Torch_UCC with UCX and UCC' 4 | inputs: 5 | ucx: 6 | description: 'UCX git repository' 7 | required: true 8 | default: 'https://github.com/openucx/ucx.git' 9 | ucc: 10 | description: 'UCC git repository' 11 | required: true 12 | default: 'https://github.com/openucx/ucc.git' 13 | 14 | runs: 15 | using: "composite" 16 | steps: 17 | - name: Install packages 18 | shell: bash 19 | run: | 20 | apt-get update 21 | apt-get install -y --no-install-recommends build-essential git cmake libtool-bin wget autoconf automake clang 22 | conda uninstall -y pytorch torchvision 23 | pip3 install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 24 | - name: Get UCX 25 | shell: bash 26 | run: | 27 | git clone ${{ inputs.ucx }} /tmp/ucx 28 | cd /tmp/ucx 29 | ./autogen.sh 30 | CC=clang CXX=clang++ ./contrib/configure-release-mt --without-java --disable-numa --prefix=/opt/ucx 31 | make -j install 32 | - name: Get UCC 33 | shell: bash 34 | run: | 35 | git clone ${{ inputs.ucc }} /tmp/ucc 36 | cd /tmp/ucc 37 | ./autogen.sh 38 | CC=clang CXX=clang++ ./configure --with-ucx=/opt/ucx --prefix=/opt/ucc 39 | make -j install 40 | - name: Build TorchUCC with UCX and UCC 41 | shell: bash 42 | run: | 43 | CC=clang CXX=clang++ UCX_HOME=/opt/ucx/ UCC_HOME=/opt/ucc/ WITH_CUDA=no python setup.py install 44 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | env: 6 | OPENUCX_LINK: https://github.com/openucx/ucx.git 7 | UCC_LINK: https://github.com/openucx/ucc.git 8 | PARAM_LINK: https://github.com/facebookresearch/param 9 | 10 | jobs: 11 | torch-ucc-tests: 12 | runs-on: ubuntu-latest 13 | container: 14 | image: pytorch/pytorch:latest 15 | steps: 16 | - name: Checkout Action 17 | uses: actions/checkout@v1 18 | - name: Build TorchUCC 19 | uses: ./.github/actions/build 20 | with: 21 | ucx: ${OPENUCX_LINK} 22 | ucc: ${UCC_LINK} 23 | - name: Tests 24 | run: | 25 | export LD_LIBRARY_PATH=/opt/ucx/lib:/opt/ucc/lib:$LD_LIBRARY_PATH 26 | /opt/ucx/bin/ucx_info -e -u t 27 | export UCX_LOG_LEVEL=info 28 | export TORCH_UCC_ENABLE_HEALTH_CHECK=1 29 | export TORCH_SHOW_CPP_STACKTRACES=1 30 | for np in `seq 4` 31 | do 32 | echo "Test comm size $np" 33 | export TORCH_UCC_TEST_SIZE=$np 34 | echo "UCC barrier" 35 | /bin/bash ./test/start_test.sh ./test/torch_barrier_test.py --backend=gloo 36 | echo "UCC alltoall" 37 | /bin/bash ./test/start_test.sh ./test/torch_alltoall_test.py --backend=gloo 38 | echo "UCC alltoallv" 39 | /bin/bash ./test/start_test.sh ./test/torch_alltoallv_test.py --backend=gloo 40 | echo "UCC allgather" 41 | /bin/bash ./test/start_test.sh ./test/torch_allgather_test.py --backend=gloo 42 | echo "UCC allreduce" 43 | /bin/bash ./test/start_test.sh ./test/torch_allreduce_test.py --backend=gloo 44 | echo "UCC broadcast" 45 | /bin/bash ./test/start_test.sh ./test/torch_bcast_test.py --backend=gloo 46 | echo "UCC reduce" 47 | /bin/bash ./test/start_test.sh ./test/torch_reduce_test.py --backend=gloo 48 | # FIXME: disabled as UCC does not support gather on CPU tensor yet 49 | # echo "UCC gather" 50 | # /bin/bash ./test/start_test.sh ./test/torch_gather_test.py --backend=gloo 51 | done 52 | echo "UCC basic functionality test" 53 | /bin/bash ./test/start_test.sh ./test/torch_work_test.py --backend=gloo 54 | echo "UCC pt2pt" 55 | /bin/bash ./test/start_test.sh ./test/torch_pt2pt_test.py --backend=gloo 56 | echo "UCC timeout test" 57 | /bin/bash ./test/start_test.sh ./test/torch_timeout_test.py --backend=gloo 58 | echo "UCC multiple comms test" 59 | TORCH_UCC_SHARED_COMM=0 UCX_TLS=tcp /bin/bash ./test/start_test.sh ./test/torch_multiple_comms_test.py 60 | echo "UCC multiple comms test shared comm" 61 | TORCH_UCC_SHARED_COMM=1 /bin/bash ./test/start_test.sh ./test/torch_multiple_comms_test.py 62 | 63 | pytorch-unit-tests: 64 | runs-on: ubuntu-latest 65 | container: 66 | image: pytorch/pytorch:latest 67 | steps: 68 | - name: Checkout Action 69 | uses: actions/checkout@v1 70 | - name: Build TorchUCC 71 | uses: ./.github/actions/build 72 | with: 73 | ucx: ${OPENUCX_LINK} 74 | ucc: ${UCC_LINK} 75 | - name: PyTorch Unit Tests 76 | run: | 77 | export LD_LIBRARY_PATH=/opt/ucx/lib:/opt/ucc/lib:$LD_LIBRARY_PATH 78 | /opt/ucx/bin/ucx_info -e -u t 79 | export UCX_LOG_LEVEL=info 80 | export TORCH_UCC_ENABLE_HEALTH_CHECK=1 81 | export TORCH_SHOW_CPP_STACKTRACES=1 82 | pip3 install expecttest hypothesis xmlrunner unittest-xml-reporting 83 | cd test 84 | for np in `seq 4` 85 | do 86 | export BACKEND='ucc' 87 | export WORLD_SIZE=$np 88 | python torch_tests.py --subprocess 89 | done 90 | 91 | param-comm-tests: 92 | runs-on: ubuntu-latest 93 | container: 94 | image: pytorch/pytorch:latest 95 | steps: 96 | - name: Checkout Action 97 | uses: actions/checkout@v1 98 | - name: Build TorchUCC 99 | uses: ./.github/actions/build 100 | with: 101 | ucx: ${OPENUCX_LINK} 102 | ucc: ${UCC_LINK} 103 | - name: Test PARAM 104 | run: | 105 | git clone ${PARAM_LINK} /tmp/param 106 | export LD_LIBRARY_PATH=/opt/ucx/lib:/opt/ucc/lib:$LD_LIBRARY_PATH 107 | export TORCH_UCC_TEST_SIZE=4 108 | echo "PARAM-Comms Reduce w/ UCC" 109 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective reduce 110 | echo "PARAM-Comms Allreduce w/ UCC" 111 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective all_reduce 112 | echo "PARAM-Comms Alltoall w/ UCC" 113 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective all_to_all 114 | echo "PARAM-Comms Alltoallv w/ UCC" 115 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective all_to_allv 116 | echo "PARAM-Comms Broadcast w/ UCC" 117 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective broadcast 118 | echo "PARAM-Comms Allgather w/ UCC" 119 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective all_gather 120 | echo "PARAM-Comms Allgather_base w/ UCC" 121 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective all_gather_base 122 | # FIXME: disabled as UCC does not support gather on CPU tensor yet 123 | # echo "PARAM-Comms Gather w/ UCC" 124 | # /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --collective gather 125 | echo "PARAM-Comms Quantized Allreduce w/ UCC (use of c10d future)" 126 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --bitwidth 16 --collective all_reduce 127 | echo "PARAM-Comms Non-blocking Allreduce w/ UCC" 128 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --z 0 --collective all_reduce 129 | echo "PARAM-Comms Non-blocking Alltoall w/ UCC" 130 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --z 0 --collective all_to_all 131 | echo "PARAM-Comms Non-blocking Alltoallv w/ UCC" 132 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --z 0 --collective all_to_allv 133 | echo "PARAM-Comms Pt2pt w/ UCC" 134 | export TORCH_UCC_TEST_SIZE=2 135 | /bin/bash ./test/start_test.sh /tmp/param/train/comms/pt/comms.py --backend ucc --device cpu --b 4 --e 4M --c 1 --pt2pt one2one --src-ranks 0 --dst-ranks 1 136 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | 54 | # Python binaries 55 | *.egg-info 56 | build/ 57 | dist/ 58 | *.pyc 59 | 60 | # vscode 61 | *.code-workspace 62 | .vscode 63 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torch_ucc 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to torch_ucc, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository is deprecated 2 | 3 | Torch-UCC plugin has been merged into PyTorch as one of the native ProcessGroup implementations (please refer to https://github.com/pytorch/pytorch/pull/79918), which will be the single source of Torch-UCC (i.e., ProcessGroupUCC). Please check out PyTorch repository (https://github.com/pytorch/pytorch). Some code points for reference: 4 | * ProcessGroupUCC: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp, https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp) 5 | * UCC Utiliies: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/UCCUtils.cpp, https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/UCCUtils.hpp 6 | * UCC Comms Tracing: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/UCCTracing.cpp, https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/UCCTracing.hpp 7 | 8 | # PyTorch plugin for UCC 9 | 10 | This repo implements PyTorch Process Group API for [UCC](https://www.ucfconsortium.org/projects/ucc/) as a third-party plugin. 11 | 12 | ## Requirements 13 | * PyTorch 14 | * [UCX](https://github.com/openucx/ucx) 15 | * [UCC](https://github.com/openucx/ucc) 16 | 17 | ## License 18 | 19 | This repo is released under the MIT license. Please see the [`LICENSE`](LICENSE) file for more information. 20 | 21 | ## Contributing 22 | 23 | We actively welcome your pull requests! Please see [`CONTRIBUTING.md`](CONTRIBUTING.md) and [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md) for more info. 24 | ```shell 25 | # Build 26 | UCX_HOME= UCC_HOME= WITH_CUDA= python setup.py install 27 | ``` 28 | UCX_HOME required, specifies path to UCX installation directory 29 | 30 | UCC_HOME required, specifies path to UCC installation directory 31 | 32 | WITH_CUDA optional, if WITH_CUDA=no is set then only CPU tensors are supported 33 | 34 | ## Run 35 | Configuration variables 36 | | Name | Values | Description | 37 | |------------------------------------|---------------------------|---------------------------------------------------------------------------------------------------------------| 38 | | TORCH_UCC_ALLGATHER_BLOCKING_WAIT | 0 or 1 | Sets behavior of wait function for CUDA Allgather. [Async collective in PyTorch](https://pytorch.org/docs/stable/distributed.html#synchronous-and-asynchronous-collective-operations)| 39 | | TORCH_UCC_ALLREDUCE_BLOCKING_WAIT | 0 or 1 | Sets behavior of wait function for CUDA Allreduce. | 40 | | TORCH_UCC_ALLTOALL_BLOCKING_WAIT | 0 or 1 | Sets behavior of wait function for CUDA Alltoall. | 41 | | TORCH_UCC_BCAST_BLOCKING_WAIT | 0 or 1 | Sets behavior of wait function for CUDA Bcast. | 42 | 43 | ```shell 44 | export LD_LIBRARY_PATH=/lib:/lib:$LD_LIBRARY_PATH 45 | python example.py 46 | ``` 47 | 48 | ```python 49 | import torch 50 | import torch.distributed as dist 51 | import torch_ucc 52 | 53 | .... 54 | dist.init_process_group('ucc', rank=comm_rank, world_size=comm_size) 55 | .... 56 | dist.all_to_all_single(recv_tensor, send_tensor) 57 | 58 | ``` 59 | -------------------------------------------------------------------------------- /hipify.py: -------------------------------------------------------------------------------- 1 | # Meta Platforms, Inc. and affiliates Copyright 2 | 3 | import os 4 | 5 | from torch.utils.hipify import hipify_python 6 | 7 | CUDA_TO_HIP_MAPPINGS = [ 8 | ("UCS_MEMORY_TYPE_CUDA", "UCS_MEMORY_TYPE_ROCM"), 9 | ("UCC_MEMORY_TYPE_CUDA", "UCC_MEMORY_TYPE_ROCM"), 10 | ("UCC_EE_CUDA_STREAM", "UCC_EE_ROCM_STREAM"), 11 | ("nccl", "rccl"), 12 | ] 13 | 14 | # TorchUCC specific hipification 15 | def torch_ucc_hipify_file(src_path, dst_path, verbose=True): 16 | if verbose: 17 | print("Torch-UCC hipification applied to {} -> {}".format(src_path, dst_path)) 18 | with open(src_path, "rt", encoding="utf-8") as fin: 19 | fin.seek(0) 20 | source = fin.read() 21 | for k, v in CUDA_TO_HIP_MAPPINGS: 22 | source = source.replace(k, v) 23 | fin.close() 24 | 25 | with open(dst_path, "wt", encoding="utf-8") as fout: 26 | fout.write(source) 27 | fout.close() 28 | 29 | 30 | # Overwrite each source file for hipification 31 | def torch_ucc_hipify(src_path_list, verbose=True): 32 | for src_path in src_path_list: 33 | torch_ucc_hipify_file(src_path, src_path) 34 | -------------------------------------------------------------------------------- /include/torch_ucc.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | * 4 | * Copyright (c) Facebook, Inc. and its affiliates. 5 | * 6 | * This source code is licensed under the MIT license found in the 7 | * LICENSE file in the root directory of this source tree. 8 | * 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "torch_ucc_comm.hpp" 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #ifdef USE_CUDA 27 | #include 28 | #include 29 | #endif 30 | 31 | namespace c10d { 32 | 33 | #define TORCH_UCC_DEVICE_NOT_SET -2 34 | 35 | #define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \ 36 | ((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \ 37 | (((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \ 38 | (((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET)) 39 | 40 | #define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \ 41 | ((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \ 42 | (((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \ 43 | (((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET)) 44 | 45 | #define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \ 46 | do { \ 47 | (_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \ 48 | } while (0) 49 | 50 | #define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1) 51 | #define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK) 52 | #define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1) 53 | 54 | #define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \ 55 | do { \ 56 | (_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \ 57 | if ((_rank) == TORCH_UCX_ANY_SOURCE) { \ 58 | (_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \ 59 | } else { \ 60 | (_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \ 61 | } \ 62 | } while (0) 63 | 64 | #define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \ 65 | do { \ 66 | (_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \ 67 | } while (0) 68 | 69 | #define TORCH_UCX_MAKE_OOB_RECV_TAG( \ 70 | _ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \ 71 | do { \ 72 | (_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \ 73 | (_ucp_tag_mask) = (uint64_t)-1; \ 74 | } while (0) 75 | 76 | #ifdef USE_CUDA 77 | #define SAVE_TENSORS(_TENSORS, _DATA) \ 78 | do { \ 79 | if ((_TENSORS)[0].device().is_cuda()) { \ 80 | for (const auto i : c10::irange((_TENSORS).size())) { \ 81 | c10::cuda::CUDACachingAllocator::recordStream( \ 82 | (_TENSORS)[i].storage().data_ptr(), (*stream)); \ 83 | } \ 84 | } else { \ 85 | (_DATA) = (_TENSORS); \ 86 | } \ 87 | } while (0) 88 | 89 | #else 90 | #define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS); 91 | #endif 92 | 93 | constexpr const char* UCC_BACKEND_NAME = "ucc"; 94 | 95 | enum torch_ucx_tag_type_t { TORCH_UCX_P2P_TAG, TORCH_UCX_OOB_TAG }; 96 | 97 | struct event_pool_t { 98 | #ifdef USE_CUDA 99 | std::queue> event_pool; 100 | #endif 101 | std::mutex event_pool_mutex; 102 | }; 103 | 104 | class Comm; 105 | 106 | // UCC does not support multiple CUDA devices per process. 107 | class ProcessGroupUCC : public ProcessGroup { 108 | private: 109 | void set_timeout(ucc_coll_args_t &args); 110 | 111 | public: 112 | class WorkData { 113 | public: 114 | std::vector src; 115 | std::vector dst; 116 | std::vector flat; 117 | WorkData() {} 118 | virtual ~WorkData() = default; 119 | }; 120 | class AlltoallWorkData : public WorkData { 121 | public: 122 | AlltoallWorkData(int size) 123 | : send_lengths(size), 124 | send_offsets(size), 125 | recv_lengths(size), 126 | recv_offsets(size) {} 127 | std::vector send_lengths; 128 | std::vector send_offsets; 129 | std::vector recv_lengths; 130 | std::vector recv_offsets; 131 | }; 132 | 133 | class AllgathervWorkData : public WorkData { 134 | public: 135 | AllgathervWorkData(int size) 136 | : recv_lengths(size), 137 | recv_offsets(size) {} 138 | std::vector recv_lengths; 139 | std::vector recv_offsets; 140 | }; 141 | 142 | class ScattervWorkData : public WorkData { 143 | public: 144 | ScattervWorkData(int size) 145 | : send_lengths(size), 146 | send_offsets(size) {} 147 | std::vector send_lengths; 148 | std::vector send_offsets; 149 | }; 150 | 151 | class ProgressEntry { 152 | friend class ProcessGroupUCC; 153 | friend class Comm; 154 | 155 | public: 156 | ProgressEntry( 157 | CommBase* comm, 158 | ucc_coll_req_h request) 159 | : status_(UCC_INPROGRESS), comm_(comm), request_(request) {} 160 | // Finalizes UCC status or exception of collective request. 161 | void finalize(std::exception_ptr eptr = nullptr); 162 | ucc_status_t status_; 163 | CommBase* comm_; 164 | ucc_coll_req_h request_; 165 | std::unique_ptr data; 166 | c10::intrusive_ptr future_; 167 | std::exception_ptr eptr_; 168 | }; 169 | 170 | class WorkUCC : public Work { 171 | friend class ProcessGroupUCC; 172 | friend class Comm; 173 | 174 | public: 175 | WorkUCC( 176 | OpType opType, 177 | const char* prof_title) 178 | : Work(-1, opType, prof_title) {} 179 | WorkUCC( 180 | OpType opType, 181 | const char* prof_title, 182 | const c10::intrusive_ptr& logger) 183 | : Work(-1, opType, prof_title), logger_(logger) {} 184 | ~WorkUCC(); 185 | void setException(); 186 | void setAndThrowException(); 187 | bool isCompleted() override; 188 | bool isSuccess() const override; 189 | bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; 190 | c10::intrusive_ptr getFuture() override; 191 | std::vector result() override; 192 | #ifdef USE_CUDA 193 | std::unique_ptr fence = nullptr; 194 | event_pool_t* ep = nullptr; 195 | #endif 196 | protected: 197 | std::shared_ptr entry_; 198 | c10::intrusive_ptr logger_; 199 | 200 | private: 201 | // The future returned by getFuture. 202 | c10::intrusive_ptr future_; 203 | // Store a reference to collective's outputs, used by result 204 | std::shared_ptr> outputs_; 205 | }; 206 | 207 | explicit ProcessGroupUCC( 208 | const c10::intrusive_ptr& store, 209 | int rank = -1, 210 | int size = -1, 211 | std::chrono::duration timeout = kProcessGroupDefaultTimeout); 212 | 213 | void initComm(c10::Device dev); 214 | 215 | ~ProcessGroupUCC() override; 216 | 217 | const std::string getBackendName() const override { 218 | return std::string(UCC_BACKEND_NAME); 219 | } 220 | 221 | #ifdef USE_CUDA 222 | std::unique_ptr getPooledEvent(); 223 | #endif 224 | 225 | // Performs a health check by initializing dummy UCC & UCX communicators and then 226 | // destroying them. This will help indicate and signal any UCC/UCX-related issues 227 | // prior to the first collective. The actual initialization and subsequent 228 | // destruction is ran on a separate thread and the main thread is signalled 229 | // about timeouts/errors to report to the application. 230 | void runHealthCheck(); 231 | 232 | template 233 | c10::intrusive_ptr collective_post( 234 | OpType opType, 235 | PreProcess preproc, 236 | PostProcess postproc, 237 | ucc_coll_args_t& coll, 238 | std::unique_ptr data, 239 | c10::Device dev, 240 | std::vector& inputTensors, 241 | std::vector& outputTensors, 242 | const char* prof_title); 243 | 244 | c10::intrusive_ptr broadcast( 245 | std::vector& data, 246 | const BroadcastOptions& opts = BroadcastOptions()) override; 247 | 248 | c10::intrusive_ptr allreduce( 249 | std::vector& tensors, 250 | const AllreduceOptions& opts = AllreduceOptions()) override; 251 | 252 | c10::intrusive_ptr allreduce_coalesced( 253 | std::vector& tensors, 254 | const AllreduceCoalescedOptions& opts = 255 | AllreduceCoalescedOptions()) override; 256 | 257 | c10::intrusive_ptr reduce( 258 | std::vector& tensors, 259 | const ReduceOptions& opts = ReduceOptions()) override; 260 | 261 | c10::intrusive_ptr allgather( 262 | std::vector>& outputTensors, 263 | std::vector& inputTensors, 264 | const AllgatherOptions& opts = AllgatherOptions()) override; 265 | 266 | c10::intrusive_ptr _allgather_base( 267 | at::Tensor& outputBuffer, 268 | at::Tensor& inputBuffer, 269 | const AllgatherOptions& opts = AllgatherOptions()) override; 270 | 271 | c10::intrusive_ptr barrier( 272 | const BarrierOptions& opts = BarrierOptions()) override; 273 | 274 | c10::intrusive_ptr gather( 275 | std::vector>& outputTensors, 276 | std::vector& inputTensors, 277 | const GatherOptions& opts = GatherOptions()) override; 278 | 279 | c10::intrusive_ptr scatter( 280 | std::vector& outputTensors, 281 | std::vector>& inputTensors, 282 | const ScatterOptions& opts = ScatterOptions()) override; 283 | 284 | c10::intrusive_ptr reduce_scatter( 285 | std::vector& outputTensors, 286 | std::vector>& inputTensors, 287 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; 288 | 289 | c10::intrusive_ptr alltoall_base( 290 | at::Tensor& outputTensor, 291 | at::Tensor& inputTensor, 292 | std::vector& outputSplitSizes, 293 | std::vector& inputSplitSizes, 294 | const AllToAllOptions& opts = AllToAllOptions()) override; 295 | 296 | c10::intrusive_ptr alltoall( 297 | std::vector& outputTensors, 298 | std::vector& inputTensors, 299 | const AllToAllOptions& opts = AllToAllOptions()) override; 300 | 301 | c10::intrusive_ptr send( 302 | std::vector& tensors, 303 | int dstRank, 304 | int tag) override; 305 | 306 | c10::intrusive_ptr recv( 307 | std::vector& tensors, 308 | int srcRank, 309 | int tag) override; 310 | 311 | c10::intrusive_ptr recvAnysource( 312 | std::vector& tensors, 313 | int tag) override; 314 | 315 | static c10::intrusive_ptr createProcessGroupUCC( 316 | const c10::intrusive_ptr<::c10d::Store>& store, 317 | int rank, 318 | int size, 319 | const std::chrono::duration& timeout); 320 | 321 | protected: 322 | const std::chrono::duration timeout_; 323 | std::shared_ptr oob; 324 | std::shared_ptr comm = {nullptr}; 325 | uint32_t comm_id; 326 | #ifndef USE_ACTIVE_SETS 327 | std::vector eps; 328 | #endif 329 | ucc_team_h team {nullptr}; 330 | ucc_ee_h cuda_ee {nullptr}; 331 | #ifdef USE_CUDA 332 | std::unique_ptr stream = nullptr; 333 | event_pool_t ep; 334 | #endif 335 | c10::intrusive_ptr logger; 336 | }; 337 | 338 | class Comm { 339 | c10::intrusive_ptr logger; 340 | std::shared_ptr oob; 341 | #ifndef USE_ACTIVE_SETS 342 | CommUCX ucx_comm; 343 | #endif 344 | CommUCC ucc_comm; 345 | std::mutex mutex; 346 | std::thread progress_thread; 347 | std::condition_variable queue_produce_cv; 348 | std::condition_variable queue_consume_cv; 349 | std::deque> progress_queue; 350 | bool stop_progress_loop; 351 | bool collective_inprogress; 352 | torch_ucc_phase_t finalize_phase; 353 | 354 | public: 355 | c10::DeviceIndex cuda_device_index; 356 | Comm(const c10::intrusive_ptr& logger, 357 | std::shared_ptr oob, 358 | c10::Device dev, bool is_health_check); 359 | 360 | ~Comm(); 361 | 362 | #ifndef USE_ACTIVE_SETS 363 | // Connects UCX end points. 364 | void ucx_connect_eps( 365 | std::vector& eps, 366 | std::shared_ptr oob); 367 | 368 | // Disconnects UCX end points. 369 | void ucx_disconnect_eps( 370 | std::vector& eps, 371 | std::shared_ptr oob); 372 | #endif 373 | 374 | void ucc_create_team( 375 | ucc_team_h& team, 376 | std::shared_ptr oob); 377 | 378 | void ucc_destroy_team(ucc_team_h& team); 379 | 380 | #ifndef USE_ACTIVE_SETS 381 | c10::intrusive_ptr enqueue_p2p( 382 | OpType opType, 383 | ucc_coll_req_h request, 384 | const char* prof_title); 385 | #endif 386 | 387 | #ifdef USE_CUDA 388 | void enqueue_cuda_collective( 389 | std::unique_ptr data, 390 | c10::intrusive_ptr work, 391 | ucc_coll_args_t& coll, 392 | ucc_team_h team, 393 | ucc_ee_h ee); 394 | #endif 395 | 396 | void enqueue_collective( 397 | std::unique_ptr data, 398 | c10::intrusive_ptr work, 399 | ucc_coll_args_t& coll, 400 | ucc_team_h team); 401 | 402 | static std::shared_ptr get_comm( 403 | uint32_t& id, 404 | c10::Device dev, 405 | std::shared_ptr oob, 406 | const c10::intrusive_ptr& logger, 407 | bool is_health_check = false); 408 | 409 | void progress_loop(); 410 | 411 | #ifndef USE_ACTIVE_SETS 412 | // Only used internally 413 | // Unused when USE_ACTIVE_SETS is ON, thus safe to disable 414 | ucc_coll_req_h send_nb( 415 | ucp_ep_h ep, 416 | void* data, 417 | ucs_memory_type_t mtype, 418 | size_t size, 419 | ucp_tag_t ucp_tag); 420 | 421 | ucc_coll_req_h recv_nb( 422 | void* data, 423 | ucs_memory_type_t mtype, 424 | size_t size, 425 | ucp_tag_t ucp_tag, 426 | ucp_tag_t ucp_tag_mask); 427 | #endif 428 | }; 429 | 430 | } // namespace c10d 431 | -------------------------------------------------------------------------------- /include/torch_ucc_comm.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | * 4 | * Copyright (c) Facebook, Inc. and its affiliates. 5 | * 6 | * This source code is licensed under the MIT license found in the 7 | * LICENSE file in the root directory of this source tree. 8 | * 9 | */ 10 | 11 | #pragma once 12 | 13 | #include 14 | #include 15 | #include 16 | #ifndef USE_ACTIVE_SETS 17 | #include 18 | #endif 19 | 20 | #define TORCH_UCX_COMM_BITS 15 21 | #define TORCH_UCX_RANK_BITS 16 22 | #define TORCH_UCX_TAG_BITS 32 23 | #define TORCH_UCX_OOB_BITS 1 24 | 25 | #define TORCH_UCX_COMM_BITS_OFFSET 0 26 | #define TORCH_UCX_RANK_BITS_OFFSET TORCH_UCX_COMM_BITS 27 | #define TORCH_UCX_TAG_BITS_OFFSET (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS) 28 | #define TORCH_UCX_OOB_BITS_OFFSET \ 29 | (TORCH_UCX_COMM_BITS + TORCH_UCX_RANK_BITS + TORCH_UCX_TAG_BITS) 30 | 31 | #define TORCH_UCX_MAX_COMM ((((uint64_t)1) << TORCH_UCX_COMM_BITS) - 1) 32 | #define TORCH_UCX_MAX_RANK ((((uint64_t)1) << TORCH_UCX_RANK_BITS) - 1) 33 | #define TORCH_UCX_MAX_TAG ((((uint64_t)1) << TORCH_UCX_TAG_BITS) - 1) 34 | #define TORCH_UCX_MAX_OOB ((((uint64_t)1) << TORCH_UCX_OOB_BITS) - 1) 35 | 36 | #define TORCH_UCX_COMM_MASK (TORCH_UCX_MAX_COMM << TORCH_UCX_COMM_BITS_OFFSET) 37 | #define TORCH_UCX_RANK_MASK (TORCH_UCX_MAX_RANK << TORCH_UCX_RANK_BITS_OFFSET) 38 | #define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET) 39 | #define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET) 40 | 41 | namespace c10d { 42 | 43 | // Macro to throw on a non-successful UCC return value. 44 | #define TORCH_UCC_CHECK(_cmd, _error_msg) \ 45 | do { \ 46 | ucc_status_t result = _cmd; \ 47 | if (result != UCC_OK) { \ 48 | std::string err = c10::str( \ 49 | "[", \ 50 | std::string(__FILE__), \ 51 | ":", \ 52 | std::to_string(__LINE__), \ 53 | "] ", \ 54 | logger->getLogPrefix(), \ 55 | _error_msg, \ 56 | ", error code ", \ 57 | result, \ 58 | ": ", \ 59 | ucc_status_string(result), \ 60 | ", system error code ", \ 61 | errno); \ 62 | TORCH_CHECK(false, err); \ 63 | } \ 64 | } while (0) 65 | 66 | #ifndef USE_ACTIVE_SETS 67 | // Macro to throw on a non-successful UCX return value. 68 | #define TORCH_UCX_CHECK(_cmd, _error_msg) \ 69 | do { \ 70 | ucs_status_t result = _cmd; \ 71 | if (result != UCS_OK) { \ 72 | std::string err = c10::str( \ 73 | "[", \ 74 | std::string(__FILE__), \ 75 | ":", \ 76 | std::to_string(__LINE__), \ 77 | "] ", \ 78 | logger->getLogPrefix(), \ 79 | _error_msg, \ 80 | ", error code ", \ 81 | result, \ 82 | ": ", \ 83 | ucs_status_string(result), \ 84 | ", system error code ", \ 85 | errno); \ 86 | TORCH_CHECK(false, err); \ 87 | } \ 88 | } while (0) 89 | #endif 90 | 91 | // Macros to print logs with unified format 92 | #define TORCH_UCC_LOG_ERROR(_phase, _msg) \ 93 | LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg; 94 | #define TORCH_UCC_LOG_INFO(_phase, _msg) \ 95 | LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg; 96 | #define TORCH_UCC_LOG_DEBUG(_phase, _msg) \ 97 | VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg; 98 | 99 | enum torch_ucc_phase_t { 100 | TORCH_UCC_UNKNOWN = -1, 101 | TORCH_UCC_INIT, 102 | TORCH_UCC_HEALTH_CHECK, 103 | TORCH_UCC_READY, 104 | TORCH_UCC_COLL_POST, 105 | TORCH_UCC_COLL_PROGRESS, 106 | TORCH_UCC_FINALIZE, 107 | }; 108 | 109 | const std::map ucc_phase_map = { 110 | {TORCH_UCC_UNKNOWN, "UNKNOWN"}, 111 | {TORCH_UCC_INIT, "INIT"}, 112 | {TORCH_UCC_HEALTH_CHECK, "HEALTH_CHECK"}, 113 | {TORCH_UCC_READY, "READY"}, 114 | {TORCH_UCC_COLL_POST, "COLL_POST"}, 115 | {TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"}, 116 | {TORCH_UCC_FINALIZE, "FINALIZE"}, 117 | }; 118 | 119 | class CommTraceLogger; 120 | 121 | class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder { 122 | public: 123 | ProcessGroupUCCLogger(); 124 | ProcessGroupUCCLogger(std::string log_prefix, torch_ucc_phase_t phase); 125 | 126 | std::string getLogPrefix(torch_ucc_phase_t phase = TORCH_UCC_UNKNOWN); 127 | void setLogPrefix(std::string log_prefix); 128 | inline void setPhase(torch_ucc_phase_t phase) { 129 | local_phase = phase; 130 | } 131 | 132 | void initCommsTracer(); 133 | void flushComms(int rank, int world_size); 134 | std::shared_ptr trace_generator = nullptr; 135 | 136 | protected: 137 | std::string log_prefix; 138 | torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN; 139 | bool initialized_CommTraceLogger = false; 140 | }; 141 | 142 | struct torch_ucc_oob_coll_info_t { 143 | c10::intrusive_ptr store; 144 | uint32_t comm_id; 145 | int rank; 146 | int size; 147 | void* rbuf; 148 | size_t msglen; 149 | std::string getKey(std::string key) { 150 | return std::to_string(comm_id) + key; 151 | } 152 | }; 153 | 154 | class CommBase { 155 | public: 156 | CommBase(const c10::intrusive_ptr& logger_) 157 | : logger(logger_) {} 158 | virtual void progress() = 0; 159 | virtual void free_request(ucc_coll_req_h request) = 0; 160 | virtual ~CommBase() {} 161 | c10::intrusive_ptr logger; 162 | }; 163 | 164 | #ifndef USE_ACTIVE_SETS 165 | class CommUCX : public CommBase { 166 | public: 167 | ucp_context_h context{nullptr}; 168 | ucp_worker_h worker{nullptr}; 169 | 170 | public: 171 | void progress() override; 172 | void free_request(ucc_coll_req_h request) override; 173 | CommUCX( 174 | int comm_size, 175 | const c10::intrusive_ptr& logger); 176 | ~CommUCX(); 177 | }; 178 | #endif 179 | 180 | class CommUCC : public CommBase { 181 | public: 182 | ucc_lib_h lib{nullptr}; 183 | ucc_context_h context{nullptr}; 184 | 185 | public: 186 | void progress() override; 187 | CommUCC( 188 | std::shared_ptr oob, 189 | const c10::intrusive_ptr& logger); 190 | void free_request(ucc_coll_req_h request) override; 191 | ~CommUCC(); 192 | }; 193 | 194 | ucc_status_t oob_allgather( 195 | void* sbuf, 196 | void* rbuf, 197 | size_t msglen, 198 | void* coll_info, 199 | void** req); 200 | 201 | ucc_status_t oob_allgather_test(void* req); 202 | 203 | ucc_status_t oob_allgather_free(void* req); 204 | 205 | } // namespace c10d 206 | -------------------------------------------------------------------------------- /include/torch_ucc_tracing.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | */ 8 | 9 | #pragma once 10 | 11 | #include "torch_ucc_comm.hpp" 12 | 13 | namespace c10d { 14 | 15 | #define RECORD_COMMS_TRACE( \ 16 | _comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \ 17 | do { \ 18 | if (torch_ucc_config.enable_comms_logger) { \ 19 | _comms_tracer->recordComms( \ 20 | opTypeToString(_opType), \ 21 | (uintptr_t)_work.get(), \ 22 | _rank, \ 23 | _comm_size, \ 24 | _inTensors, \ 25 | _outTensors); \ 26 | } \ 27 | } while (0) 28 | 29 | // interfaces to collect communication traces 30 | class TORCH_API CommTraceLogger : public torch::CustomClassHolder { 31 | private: 32 | std::vector comms_trace_; 33 | std::vector curBlocks_; /* unused */ 34 | std::vector curOutSplitSizes_; 35 | std::vector curInSplitSizes_; 36 | int curRoot_ = -1; 37 | unsigned long seqnum = 0; 38 | 39 | public: 40 | void setCurBlock(const std::string& name); /* unused */ 41 | void popBlock(); /* unused */ 42 | // record root info if applicable, e.g., broadcast, gather, scatter 43 | void recordOptionalInfo(int root = -1); 44 | // record input/output splits of Alltoallv 45 | void recordOptionalInfo( 46 | const std::vector& outputSplitSizes = {}, 47 | const std::vector& inputSplitSizes = {}); 48 | // record essential comms information 49 | void recordComms( 50 | const std::string& collName, 51 | const uintptr_t workReq = 0, 52 | const int rank = -1, 53 | const int world_size = -1, 54 | const std::vector& inputTensors = {}, 55 | const std::vector& outputTensor = {}); 56 | // return collected comms traces 57 | std::vector& getCommsTrace() { 58 | return comms_trace_; 59 | } 60 | }; 61 | 62 | } // namespace c10d 63 | -------------------------------------------------------------------------------- /requirements-flake8.txt: -------------------------------------------------------------------------------- 1 | flake8==3.8.2 2 | flake8-bugbear==20.1.4 3 | flake8-comprehensions==3.3.0 4 | flake8-executable==2.0.4 5 | git+https://github.com/malfet/flake8-coding.git 6 | flake8-pyi==20.5.0 7 | mccabe==0.6.1 8 | pycodestyle==2.6.0 9 | pyflakes==2.2.0 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import os 11 | import sys 12 | from setuptools import setup 13 | import torch 14 | from torch.utils import cpp_extension 15 | from hipify import torch_ucc_hipify 16 | 17 | ucc_plugin_dir = os.path.dirname(os.path.abspath(__file__)) 18 | ucx_home = os.environ.get("UCX_HOME") 19 | if ucx_home is None: 20 | print("Couldn't find UCX install dir, please set UCX_HOME env variable") 21 | sys.exit(1) 22 | 23 | ucc_home = os.environ.get("UCC_HOME") 24 | if ucc_home is None: 25 | print("Couldn't find UCC install dir, please set UCC_HOME env variable") 26 | sys.exit(1) 27 | 28 | plugin_compile_args = [] 29 | enable_debug = os.environ.get("ENABLE_DEBUG") 30 | if enable_debug is None or enable_debug == "no": 31 | print("Release build") 32 | else: 33 | print("Debug build") 34 | plugin_compile_args.extend(["-g", "-O0"]) 35 | 36 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 37 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 38 | 39 | def check_if_rocm_pytorch(): 40 | is_rocm_pytorch = False 41 | if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): 42 | from torch.utils.cpp_extension import ROCM_HOME 43 | is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False 44 | 45 | return is_rocm_pytorch 46 | 47 | IS_ROCM_PYTORCH = check_if_rocm_pytorch() 48 | 49 | plugin_sources = ["src/torch_ucc.cpp", 50 | "src/torch_ucc_comm.cpp", 51 | "src/torch_ucc_tracing.cpp"] 52 | plugin_include_dirs = ["{}/include/".format(ucc_plugin_dir), 53 | "{}/include/".format(ucx_home), 54 | "{}/include/".format(ucc_home)] 55 | plugin_library_dirs = ["{}/lib/".format(ucx_home), 56 | "{}/lib/".format(ucc_home)] 57 | plugin_libraries = ["ucp", "uct", "ucm", "ucs", "ucc"] 58 | 59 | if '--oss' in sys.argv: 60 | sys.argv.remove('--oss') 61 | plugin_sources += ["src/torch_ucc_init_oss.cpp"] 62 | else: 63 | plugin_sources += ["src/torch_ucc_init.cpp"] 64 | 65 | if '--active-sets' in sys.argv: 66 | sys.argv.remove('--active-sets') 67 | plugin_compile_args.append("-DUSE_ACTIVE_SETS") 68 | 69 | with_cuda = os.environ.get("WITH_CUDA") 70 | if with_cuda is None or with_cuda == "no": 71 | print("CUDA support is disabled") 72 | module = cpp_extension.CppExtension( 73 | name = "torch_ucc", 74 | sources = plugin_sources, 75 | include_dirs = plugin_include_dirs, 76 | library_dirs = plugin_library_dirs, 77 | libraries = plugin_libraries, 78 | extra_compile_args=plugin_compile_args 79 | ) 80 | else: 81 | print("CUDA support is enabled") 82 | plugin_compile_args.append("-DUSE_CUDA") 83 | module = cpp_extension.CUDAExtension( 84 | name = "torch_ucc", 85 | sources = plugin_sources, 86 | include_dirs = plugin_include_dirs, 87 | library_dirs = plugin_library_dirs, 88 | libraries = plugin_libraries, 89 | extra_compile_args=plugin_compile_args 90 | ) 91 | # Apply Torch-UCC specific hipification after Pytorch hipification 92 | if IS_ROCM_PYTORCH: 93 | torch_ucc_hipify(module.sources) 94 | 95 | setup( 96 | name = "torch-ucc", 97 | version = "1.0.0", 98 | ext_modules = [module], 99 | cmdclass={'build_ext': cpp_extension.BuildExtension} 100 | ) 101 | -------------------------------------------------------------------------------- /src/torch_ucc.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | * 4 | * Copyright (c) Facebook, Inc. and its affiliates. 5 | * 6 | * This source code is licensed under the MIT license found in the 7 | * LICENSE file in the root directory of this source tree. 8 | * 9 | */ 10 | 11 | #include "torch_ucc.hpp" 12 | #include "torch_ucc_comm.hpp" 13 | #include "torch_ucc_tracing.hpp" 14 | #include 15 | #include 16 | 17 | namespace c10d { 18 | 19 | namespace { 20 | constexpr int64_t kBusyWaitMillis = 10; 21 | 22 | #ifndef USE_ACTIVE_SETS 23 | const std::map ucs_mtype_map = { 24 | {c10::kCPU, UCS_MEMORY_TYPE_HOST}, 25 | {c10::kCUDA, UCS_MEMORY_TYPE_CUDA}, 26 | }; 27 | 28 | ucs_memory_type_t to_ucs_memType(c10::DeviceType _c10_type) { 29 | if (ucs_mtype_map.find(_c10_type) != ucs_mtype_map.end()) 30 | return ucs_mtype_map.at(_c10_type); 31 | else 32 | return UCS_MEMORY_TYPE_UNKNOWN; 33 | } 34 | #endif 35 | 36 | const std::map ucc_mtype_map = { 37 | {c10::kCPU, UCC_MEMORY_TYPE_HOST}, 38 | {c10::kCUDA, UCC_MEMORY_TYPE_CUDA}, 39 | }; 40 | 41 | ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) { 42 | if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end()) 43 | return ucc_mtype_map.at(_c10_type); 44 | else 45 | return UCC_MEMORY_TYPE_UNKNOWN; 46 | } 47 | 48 | const std::map ucc_dtype_map = { 49 | {at::kByte, UCC_DT_UINT8}, 50 | {at::kChar, UCC_DT_INT8}, 51 | {at::kHalf, UCC_DT_FLOAT16}, 52 | {at::kBFloat16, UCC_DT_BFLOAT16}, 53 | {at::kDouble, UCC_DT_FLOAT64}, 54 | {at::kFloat, UCC_DT_FLOAT32}, 55 | {at::kInt, UCC_DT_INT32}, 56 | {at::kLong, UCC_DT_INT64}, 57 | {at::kBool, UCC_DT_UINT8}, 58 | }; 59 | 60 | ucc_datatype_t to_ucc_dType(at::Tensor _tensor) { 61 | if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) { 62 | TORCH_CHECK( 63 | false, "Size of Boolean type larger than 1 is not supported in UCC"); 64 | } 65 | try { 66 | return ucc_dtype_map.at(_tensor.scalar_type()); 67 | } catch (const std::out_of_range& e) { 68 | TORCH_CHECK(false, "Not supported data type for UCC"); 69 | } 70 | } 71 | 72 | const std::map ucc_op_map = { 73 | {ReduceOp::SUM, UCC_OP_SUM}, 74 | {ReduceOp::PRODUCT, UCC_OP_PROD}, 75 | {ReduceOp::MIN, UCC_OP_MIN}, 76 | {ReduceOp::MAX, UCC_OP_MAX}, 77 | {ReduceOp::BAND, UCC_OP_BAND}, 78 | {ReduceOp::BOR, UCC_OP_BOR}, 79 | {ReduceOp::BXOR, UCC_OP_BXOR}, 80 | #if TORCH_VERSION_MAJOR > 1 || (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 11) 81 | {ReduceOp::AVG, UCC_OP_AVG}, 82 | #endif 83 | }; 84 | 85 | ucc_reduction_op_t to_ucc_reduceOp( 86 | const ReduceOp _op, 87 | const at::ScalarType _dt) { 88 | if (_dt == at::kBool) { 89 | if (_op == ReduceOp::SUM) { 90 | // bitwise or 91 | return UCC_OP_MAX; 92 | } else if (_op == ReduceOp::PRODUCT) { 93 | // bitwise and 94 | return UCC_OP_MIN; 95 | } 96 | #if TORCH_VERSION_MAJOR > 1 || \ 97 | (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 11) 98 | else if (_op == ReduceOp::AVG) { 99 | TORCH_CHECK( 100 | false, "Cannot use ReduceOp.AVG with boolean inputs"); 101 | } 102 | #endif 103 | } 104 | 105 | try { 106 | return ucc_op_map.at(_op); 107 | } catch (const std::out_of_range& e) { 108 | TORCH_CHECK( 109 | false, "Not supported ReduceOp for UCC"); 110 | } 111 | } 112 | 113 | struct torch_ucc_config_t { 114 | std::once_flag flag; 115 | std::array blocking_wait; 116 | bool enable_profiling; 117 | bool enable_comms_logger; 118 | bool use_future; 119 | // Sharing UCC communicator among multiple PGs to save resource. 120 | bool shared_comm; 121 | // Using allgatherv to achieve allgather, without flattening the list of 122 | // (potentially non-contiguous) tensors. 123 | bool use_allgatherv; 124 | bool enable_health_check; 125 | } torch_ucc_config; 126 | 127 | // TODO: support UCC_BLOCKING_WAIT that applies to all collectives. 128 | std::map torch_ucc_envs_map = { 129 | {"TORCH_UCC_ALLGATHER_BLOCKING_WAIT", "0"}, 130 | {"TORCH_UCC_ALLGATHER_BASE_BLOCKING_WAIT", "0"}, 131 | {"TORCH_UCC_ALLREDUCE_BLOCKING_WAIT", "0"}, 132 | {"TORCH_UCC_ALLTOALL_BLOCKING_WAIT", "0"}, 133 | {"TORCH_UCC_BCAST_BLOCKING_WAIT", "0"}, 134 | {"TORCH_UCC_GATHER_BLOCKING_WAIT", "0"}, 135 | {"TORCH_UCC_REDUCE_BLOCKING_WAIT", "0"}, 136 | {"TORCH_UCC_REDUCE_SCATTER_BLOCKING_WAIT", "0"}, 137 | {"TORCH_UCC_SCATTER_BLOCKING_WAIT", "0"}, 138 | {"TORCH_UCC_SEND_BLOCKING_WAIT", "0"}, 139 | {"TORCH_UCC_RECV_BLOCKING_WAIT", "0"}, 140 | 141 | {"TORCH_UCC_USE_FUTURE", "1"}, 142 | {"TORCH_UCC_PROFILING_ENABLE", "0"}, 143 | {"TORCH_UCC_SHARED_COMM", "1"}, 144 | {"TORCH_UCC_USE_ALLGATHERV", "0"}, 145 | {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"}, 146 | {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"}, 147 | }; 148 | 149 | } // namespace 150 | 151 | void read_confg() { 152 | // default configuration 153 | torch_ucc_config.blocking_wait.fill(true); 154 | torch_ucc_config.enable_profiling = false; 155 | torch_ucc_config.use_future = true; 156 | torch_ucc_config.shared_comm = false; 157 | torch_ucc_config.use_allgatherv = false; 158 | torch_ucc_config.enable_health_check = false; 159 | torch_ucc_config.enable_comms_logger = false; 160 | 161 | // read all torch_ucc env. variables and update the map 162 | char* env; 163 | for (auto& torch_ucc_env : torch_ucc_envs_map) { 164 | env = std::getenv(torch_ucc_env.first.c_str()); 165 | if (env) { 166 | torch_ucc_envs_map[torch_ucc_env.first] = std::string(env); 167 | } 168 | } 169 | 170 | #define BUILD_BLOCKING_CFG(op, str) \ 171 | (torch_ucc_config.blocking_wait[(std::uint8_t)op] = \ 172 | std::stoi(torch_ucc_envs_map.at(str))) 173 | 174 | BUILD_BLOCKING_CFG(OpType::ALLGATHER, "TORCH_UCC_ALLGATHER_BLOCKING_WAIT"); 175 | BUILD_BLOCKING_CFG(OpType::_ALLGATHER_BASE, 176 | "TORCH_UCC_ALLGATHER_BASE_BLOCKING_WAIT"); 177 | BUILD_BLOCKING_CFG(OpType::ALLREDUCE, "TORCH_UCC_ALLREDUCE_BLOCKING_WAIT"); 178 | BUILD_BLOCKING_CFG(OpType::ALLTOALL_BASE, "TORCH_UCC_ALLTOALL_BLOCKING_WAIT"); 179 | BUILD_BLOCKING_CFG(OpType::BROADCAST, "TORCH_UCC_BCAST_BLOCKING_WAIT"); 180 | BUILD_BLOCKING_CFG(OpType::GATHER, "TORCH_UCC_GATHER_BLOCKING_WAIT"); 181 | BUILD_BLOCKING_CFG(OpType::REDUCE, "TORCH_UCC_REDUCE_BLOCKING_WAIT"); 182 | BUILD_BLOCKING_CFG(OpType::REDUCE_SCATTER, 183 | "TORCH_UCC_REDUCE_SCATTER_BLOCKING_WAIT"); 184 | BUILD_BLOCKING_CFG(OpType::SCATTER, "TORCH_UCC_SCATTER_BLOCKING_WAIT"); 185 | BUILD_BLOCKING_CFG(OpType::SEND, "TORCH_UCC_SEND_BLOCKING_WAIT"); 186 | BUILD_BLOCKING_CFG(OpType::RECV, "TORCH_UCC_RECV_BLOCKING_WAIT"); 187 | #undef BUILD_BLOCKING_CFG 188 | 189 | torch_ucc_config.use_future = 190 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE")); 191 | torch_ucc_config.enable_profiling = 192 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_PROFILING_ENABLE")); 193 | torch_ucc_config.shared_comm = 194 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM")); 195 | torch_ucc_config.use_allgatherv = 196 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV")); 197 | torch_ucc_config.enable_health_check = 198 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK")); 199 | torch_ucc_config.enable_comms_logger = 200 | std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER")); 201 | } 202 | 203 | void check_device(c10::Device dev1, c10::Device dev2) { 204 | if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) { 205 | throw std::runtime_error("ProcessGroupUCC multidevice is not supported"); 206 | } 207 | } 208 | 209 | void check_tensor(const std::vector& tensors) { 210 | if (tensors.size() != 1) { 211 | throw std::runtime_error( 212 | "ProcessGroupUCC takes 1 tensor. Got " + 213 | std::to_string(tensors.size()) + ". "); 214 | } 215 | if (!tensors[0].is_contiguous()) { 216 | throw std::runtime_error( 217 | "ProcessGroupUCC input tensor has to be contiguous"); 218 | } 219 | if (tensors[0].is_sparse()) { 220 | throw std::runtime_error("ProcessGroupUCC input tensor has to be dense"); 221 | } 222 | // TODO: check cuda case 223 | } 224 | 225 | ProcessGroupUCC::WorkUCC::~WorkUCC() { 226 | #ifdef USE_CUDA 227 | if (fence && ep) { 228 | std::lock_guard lock(ep->event_pool_mutex); 229 | ep->event_pool.push(std::move(fence)); 230 | } 231 | #endif 232 | } 233 | 234 | void ProcessGroupUCC::WorkUCC::setException() { 235 | if (exception() || !entry_) { 236 | return; 237 | } 238 | exception_ = entry_->eptr_; 239 | } 240 | 241 | void ProcessGroupUCC::WorkUCC::setAndThrowException() { 242 | setException(); 243 | if (exception()) { 244 | std::rethrow_exception(exception()); 245 | } 246 | } 247 | 248 | bool ProcessGroupUCC::WorkUCC::isCompleted() { 249 | if (!entry_) { 250 | return true; 251 | } 252 | setException(); 253 | // status_ <= 0 to avoid listing all possible status codes. The main thread 254 | // needs to be unblocked when UCC (in progress thread) returns success (== 0) 255 | // or any error code (< 0). 256 | return exception() || entry_->status_ <= 0; 257 | } 258 | 259 | bool ProcessGroupUCC::WorkUCC::isSuccess() const { 260 | if (!entry_) { 261 | return true; 262 | } 263 | return !exception() && entry_->status_ == 0; 264 | } 265 | 266 | bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { 267 | if (torch_ucc_config.enable_comms_logger && logger_) { 268 | logger_->trace_generator->recordComms( 269 | "wait", 270 | (uintptr_t) this, 271 | rank_); 272 | } 273 | #ifdef USE_CUDA 274 | if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { 275 | // block user stream 276 | setAndThrowException(); 277 | fence->block(at::cuda::getCurrentCUDAStream()); 278 | return true; 279 | } 280 | #endif 281 | // wait for complete. For blocking case, the main thread will be blocked in 282 | // this loop until the progress thread changes the status of this request. 283 | // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The 284 | // main thread will throw out the exception then. There is no "abort" 285 | // function in UCC currently. 286 | while (!isCompleted()) 287 | ; 288 | setAndThrowException(); 289 | // manually call profiling end callbacks if they are set, 290 | // since progress thread does not own WorkUCC 291 | if (Work::recordFunctionEndCallback_) { 292 | Work::recordFunctionEndCallback_(); 293 | Work::recordFunctionEndCallback_ = nullptr; 294 | } 295 | return true; 296 | } 297 | 298 | c10::intrusive_ptr ProcessGroupUCC::WorkUCC::getFuture() { 299 | return future_; 300 | } 301 | 302 | std::vector ProcessGroupUCC::WorkUCC::result() { 303 | return *outputs_; 304 | } 305 | 306 | void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) { 307 | ucc_status_t status = UCC_OK; 308 | 309 | if (request_ != nullptr) { 310 | status = request_->status; 311 | comm_->free_request(request_); 312 | } 313 | if (eptr) { 314 | eptr_ = eptr; 315 | } else { 316 | status_ = status; 317 | } 318 | if (future_) { 319 | if (eptr) { 320 | future_->setError(eptr); 321 | } else { 322 | future_->markCompleted( 323 | c10::IValue(data ? data->dst : std::vector())); 324 | } 325 | } 326 | } 327 | 328 | Comm::Comm( 329 | const c10::intrusive_ptr& logger_, 330 | std::shared_ptr oob_, 331 | c10::Device dev, 332 | bool is_health_check) 333 | : logger(logger_), 334 | oob(oob_), 335 | #ifndef USE_ACTIVE_SETS 336 | ucx_comm(oob->size, logger), 337 | #endif 338 | ucc_comm(oob, logger), 339 | finalize_phase(is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE), 340 | cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) { 341 | if (dev.is_cuda()) { 342 | cuda_device_index = dev.index(); 343 | } 344 | stop_progress_loop = false; 345 | collective_inprogress = false; 346 | progress_thread = std::thread(&Comm::progress_loop, this); 347 | #ifdef _GNU_SOURCE 348 | pthread_setname_np(progress_thread.native_handle(), "ucc-progress"); 349 | #endif 350 | } 351 | 352 | Comm::~Comm() { 353 | std::unique_lock lock(mutex); 354 | queue_consume_cv.wait( 355 | lock, [&] { return progress_queue.empty() && !collective_inprogress; }); 356 | stop_progress_loop = true; 357 | lock.unlock(); 358 | queue_produce_cv.notify_all(); 359 | progress_thread.join(); 360 | } 361 | 362 | std::shared_ptr Comm::get_comm( 363 | uint32_t& id, 364 | c10::Device dev, 365 | std::shared_ptr oob, 366 | const c10::intrusive_ptr& logger, 367 | bool is_health_check) { 368 | static std::mutex m; 369 | static std::weak_ptr comm; 370 | static uint32_t comm_id; 371 | 372 | std::lock_guard lock(m); 373 | id = (comm_id % TORCH_UCX_MAX_COMM); 374 | 375 | std::string group_id = "group_id"; 376 | if (is_health_check) { 377 | group_id = c10::str(dev.type()) + "/" + group_id; 378 | } 379 | 380 | std::vector remote_comm_id; 381 | oob->store->deleteKey(group_id + std::to_string(0)); 382 | if (oob->rank != 0) { 383 | std::vector val = std::vector( 384 | reinterpret_cast(&id), 385 | reinterpret_cast(&id) + sizeof(id)); 386 | oob->store->set(group_id + std::to_string(oob->rank), val); 387 | } else { 388 | for (int i = 1; i < oob->size; i++) { 389 | remote_comm_id = oob->store->get(group_id + std::to_string(i)); 390 | oob->store->deleteKey(group_id + std::to_string(i)); 391 | // Find the highest id. 392 | id = std::max(id, *(reinterpret_cast(remote_comm_id.data()))); 393 | } 394 | std::vector val = std::vector( 395 | reinterpret_cast(&id), 396 | reinterpret_cast(&id) + sizeof(id)); 397 | oob->store->set(group_id + std::to_string(oob->rank), val); 398 | } 399 | remote_comm_id = oob->store->get(group_id + std::to_string(0)); 400 | oob->comm_id = *(reinterpret_cast(remote_comm_id.data())); 401 | // Prepare comm_id (static variable) to the next id. 402 | comm_id = oob->comm_id + 1; 403 | 404 | if (torch_ucc_config.shared_comm) { 405 | std::shared_ptr shared_comm = comm.lock(); 406 | if (!shared_comm) { 407 | shared_comm = std::make_shared( 408 | logger, oob, dev, is_health_check); 409 | comm = shared_comm; 410 | } else { 411 | if (dev.is_cuda() && !is_health_check) { 412 | if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && 413 | (shared_comm->cuda_device_index != dev.index())) { 414 | TORCH_UCC_LOG_ERROR( 415 | is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT, 416 | "ucc communicator was initialized with different cuda device," 417 | "multi device is not supported"); 418 | throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); 419 | } 420 | shared_comm->cuda_device_index = dev.index(); 421 | } 422 | } 423 | return shared_comm; 424 | } else { 425 | return std::make_shared(logger, oob, dev, is_health_check); 426 | } 427 | } 428 | 429 | #ifndef USE_ACTIVE_SETS 430 | // Only called internally in initComm and runHealthCheck when USE_ACTIVE_SETS is off. 431 | void Comm::ucx_connect_eps( 432 | std::vector& eps, 433 | std::shared_ptr oob) { 434 | ucp_address_t* local_addr; 435 | size_t local_addr_len; 436 | std::vector peer_addr; 437 | 438 | TORCH_UCX_CHECK( 439 | ucp_worker_get_address(ucx_comm.worker, &local_addr, &local_addr_len), 440 | "failed to get worker address"); 441 | 442 | std::vector val = std::vector( 443 | reinterpret_cast(local_addr), 444 | reinterpret_cast(local_addr) + local_addr_len); 445 | oob->store->set(oob->getKey("wa" + std::to_string(oob->rank)), val); 446 | ucp_worker_release_address(ucx_comm.worker, local_addr); 447 | eps.resize(oob->size); 448 | for (int i = 0; i < oob->size; i++) { 449 | peer_addr = oob->store->get(oob->getKey("wa" + std::to_string(i))); 450 | ucp_ep_params_t ep_params; 451 | ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; 452 | ep_params.address = reinterpret_cast(peer_addr.data()); 453 | TORCH_UCX_CHECK( 454 | ucp_ep_create(ucx_comm.worker, &ep_params, &(eps[i])), 455 | c10::str("failed to create endpoint with rank ", i)); 456 | } 457 | } 458 | 459 | // Only called internally in ~ProcessGroupUCC and runHealthCheck when USE_ACTIVE_SETS is off. 460 | void Comm::ucx_disconnect_eps( 461 | std::vector& eps, 462 | std::shared_ptr oob) { 463 | ucs_status_t st; 464 | 465 | for (ucp_ep_h& ep : eps) { 466 | ucs_status_ptr_t close_req = ucp_ep_close_nb(ep, UCP_EP_CLOSE_MODE_FLUSH); 467 | if (UCS_PTR_IS_ERR(close_req)) { 468 | TORCH_UCC_LOG_ERROR( 469 | finalize_phase, 470 | "failed to close endpoint, ignore and continue..."); 471 | return; 472 | } 473 | if (UCS_PTR_IS_PTR(close_req)) { 474 | do { 475 | ucp_worker_progress(ucx_comm.worker); 476 | st = ucp_request_check_status(close_req); 477 | } while (st != UCS_OK); 478 | ucp_request_free(close_req); 479 | } 480 | } 481 | if (!eps.size()) { 482 | return; 483 | } 484 | try { 485 | auto sz = (size_t)oob->store->add(oob->getKey("epclosed"), 1); 486 | while (sz != eps.size()) { 487 | ucp_worker_progress(ucx_comm.worker); 488 | std::this_thread::sleep_for(std::chrono::milliseconds(kBusyWaitMillis)); 489 | sz = (size_t)oob->store->add(oob->getKey("epclosed"), 0); 490 | } 491 | } catch (std::exception& ex) { 492 | LOG(ERROR) << "(disconnect_eps) Caught error in Store Operation .. " 493 | << "[" << ex.what() << "]"; 494 | } 495 | } 496 | 497 | // Only used internally by send when USE_ACTIVE_SETS is off. 498 | ucc_coll_req_h Comm::send_nb( 499 | ucp_ep_h ep, 500 | void* data, 501 | ucs_memory_type_t mtype, 502 | size_t size, 503 | ucp_tag_t ucp_tag) { 504 | ucs_status_ptr_t st; 505 | ucp_request_param_t params; 506 | params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | 507 | UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE; 508 | params.datatype = ucp_dt_make_contig(size); 509 | params.memory_type = mtype; 510 | params.cb.send = [](void* request, ucs_status_t status, void* user_data) { 511 | static_cast(request)->status = UCC_OK; 512 | }; 513 | st = ucp_tag_send_nbx(ep, data, 1, ucp_tag, ¶ms); 514 | if (UCS_PTR_IS_ERR(st)) { 515 | TORCH_UCC_LOG_ERROR( 516 | TORCH_UCC_COLL_POST, 517 | c10::str( 518 | "failed to send message: ", ucs_status_string(UCS_PTR_STATUS(st)))); 519 | throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st))); 520 | } 521 | return reinterpret_cast(st); 522 | } 523 | 524 | // Only used internally by recv and recvAnysource when USE_ACTIVE_SETS is off. 525 | ucc_coll_req_h Comm::recv_nb( 526 | void* data, 527 | ucs_memory_type_t mtype, 528 | size_t size, 529 | ucp_tag_t ucp_tag, 530 | ucp_tag_t ucp_tag_mask) { 531 | ucs_status_ptr_t st; 532 | ucp_request_param_t params; 533 | params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | 534 | UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_MEMORY_TYPE; 535 | params.datatype = ucp_dt_make_contig(size); 536 | params.cb.recv = [](void* request, 537 | ucs_status_t status, 538 | const ucp_tag_recv_info_t* info, 539 | void* user_data) { 540 | static_cast(request)->status = UCC_OK; 541 | }; 542 | params.memory_type = mtype; 543 | st = ucp_tag_recv_nbx( 544 | ucx_comm.worker, data, 1, ucp_tag, ucp_tag_mask, ¶ms); 545 | if (UCS_PTR_IS_ERR(st)) { 546 | TORCH_UCC_LOG_ERROR( 547 | TORCH_UCC_COLL_POST, 548 | c10::str( 549 | "failed to recv message: ", ucs_status_string(UCS_PTR_STATUS(st)))); 550 | throw std::runtime_error(ucs_status_string(UCS_PTR_STATUS(st))); 551 | } 552 | return reinterpret_cast(st); 553 | } 554 | #endif // end of ifndef USE_ACTIVE_SETS 555 | 556 | void Comm::ucc_create_team( 557 | ucc_team_h& team, 558 | std::shared_ptr oob) { 559 | ucc_status_t st; 560 | ucc_team_params_t team_params; 561 | team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | 562 | UCC_TEAM_PARAM_FIELD_OOB; 563 | team_params.oob.allgather = oob_allgather; 564 | team_params.oob.req_test = oob_allgather_test; 565 | team_params.oob.req_free = oob_allgather_free; 566 | team_params.oob.coll_info = oob.get(); 567 | team_params.oob.n_oob_eps = oob->size; 568 | team_params.oob.oob_ep = oob->rank; 569 | team_params.ep = oob->rank; 570 | team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; 571 | TORCH_UCC_CHECK( 572 | ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team), 573 | "failed to post team create"); 574 | do { 575 | st = ucc_team_create_test(team); 576 | ucc_context_progress(ucc_comm.context); 577 | } while (st == UCC_INPROGRESS); 578 | TORCH_UCC_CHECK(st, "failed to create UCC team"); 579 | } 580 | 581 | void Comm::ucc_destroy_team(ucc_team_h& team) { 582 | std::unique_lock lock(mutex); 583 | queue_consume_cv.wait( 584 | lock, [&] { return progress_queue.empty() && !collective_inprogress; }); 585 | 586 | ucc_status_t status; 587 | while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) { 588 | if (UCC_OK != status) { 589 | TORCH_UCC_LOG_ERROR( 590 | finalize_phase, 591 | c10::str("ucc team destroy error: ", ucc_status_string(status))); 592 | break; 593 | } 594 | } 595 | 596 | lock.unlock(); 597 | } 598 | 599 | #ifndef USE_ACTIVE_SETS 600 | c10::intrusive_ptr Comm::enqueue_p2p( 601 | OpType opType, 602 | ucc_coll_req_h request, 603 | const char* prof_title) { 604 | auto work = c10::make_intrusive( 605 | opType, prof_title, logger); 606 | if (torch_ucc_config.use_future) { 607 | work->future_ = c10::make_intrusive( 608 | c10::ListType::create(c10::TensorType::get())); 609 | } 610 | if (request == nullptr) { 611 | // p2p2 request completed immediately don't save it to progress queue 612 | // and mark future completed immediately 613 | if (torch_ucc_config.use_future) { 614 | work->future_->markCompleted(c10::IValue(std::vector())); 615 | } 616 | return work; 617 | } 618 | auto entry = 619 | std::make_shared(&ucx_comm, request); 620 | work->entry_ = entry; 621 | std::unique_lock lock(mutex); 622 | progress_queue.push_back(entry); 623 | lock.unlock(); 624 | queue_produce_cv.notify_one(); 625 | return work; 626 | } 627 | #endif 628 | 629 | void Comm::enqueue_collective( 630 | std::unique_ptr data, 631 | c10::intrusive_ptr work, 632 | ucc_coll_args_t& coll, 633 | ucc_team_h team) { 634 | ucc_coll_req_h request; 635 | TORCH_UCC_CHECK( 636 | ucc_collective_init(&coll, &request, team), "failed to init collective"); 637 | TORCH_UCC_CHECK(ucc_collective_post(request), "failed to post collective"); 638 | 639 | auto entry = 640 | std::make_shared(&ucc_comm, request); 641 | entry->data = std::move(data); 642 | entry->future_ = work->getFuture(); 643 | work->entry_ = entry; 644 | std::unique_lock lock(mutex); 645 | progress_queue.push_back(entry); 646 | lock.unlock(); 647 | queue_produce_cv.notify_one(); 648 | } 649 | 650 | #ifdef USE_CUDA 651 | void Comm::enqueue_cuda_collective( 652 | std::unique_ptr data, 653 | c10::intrusive_ptr work, 654 | ucc_coll_args_t& coll, 655 | ucc_team_h team, 656 | ucc_ee_h ee) { 657 | ucc_coll_req_h request; 658 | TORCH_UCC_CHECK( 659 | ucc_collective_init(&coll, &request, team), 660 | "failed to init cuda collective"); 661 | ucc_ev_t comp_ev, *post_ev; 662 | comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; 663 | comp_ev.ev_context = nullptr; 664 | comp_ev.ev_context_size = 0; 665 | comp_ev.req = request; 666 | TORCH_UCC_CHECK( 667 | ucc_collective_triggered_post(ee, &comp_ev), 668 | "failed to post triggered collective"); 669 | ucc_status_t st = ucc_ee_get_event(ee, &post_ev); 670 | TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); 671 | ucc_ee_ack_event(ee, post_ev); 672 | auto entry = 673 | std::make_shared(&ucc_comm, request); 674 | entry->data = std::move(data); 675 | work->entry_ = entry; 676 | std::unique_lock lock(mutex); 677 | progress_queue.push_back(entry); 678 | lock.unlock(); 679 | queue_produce_cv.notify_one(); 680 | } 681 | #endif 682 | 683 | void Comm::progress_loop() { 684 | std::unique_lock lock(mutex); 685 | #ifdef USE_CUDA 686 | bool device_set = false; 687 | #endif 688 | while (!stop_progress_loop) { 689 | if (progress_queue.empty()) { 690 | queue_produce_cv.wait(lock); 691 | continue; 692 | } 693 | collective_inprogress = true; 694 | auto work = progress_queue.front(); 695 | progress_queue.pop_front(); 696 | lock.unlock(); 697 | #ifdef USE_CUDA 698 | if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { 699 | c10::cuda::set_device(cuda_device_index); 700 | device_set = true; 701 | } 702 | #endif 703 | std::exception_ptr eptr; 704 | try { 705 | while (work->request_->status > 0) { 706 | ucc_comm.progress(); 707 | #ifndef USE_ACTIVE_SETS 708 | ucx_comm.progress(); 709 | #endif 710 | } 711 | if (work->request_->status < 0) { 712 | eptr = std::make_exception_ptr( 713 | std::runtime_error(ucc_status_string(work->request_->status))); 714 | std::string err_log = c10::str( 715 | "Failed to progress communication", // TODO: report exact op type or 716 | // id? 717 | ucc_status_string(work->request_->status)); 718 | TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log); 719 | } 720 | } catch (...) { 721 | eptr = std::current_exception(); 722 | } 723 | work->finalize(eptr); 724 | work = nullptr; 725 | collective_inprogress = false; 726 | queue_consume_cv.notify_one(); 727 | lock.lock(); 728 | } 729 | } 730 | 731 | ProcessGroupUCC::ProcessGroupUCC( 732 | const c10::intrusive_ptr& store, 733 | int rank, 734 | int size, 735 | std::chrono::duration timeout) 736 | : ProcessGroup(rank, size), timeout_(timeout) { 737 | std::call_once(torch_ucc_config.flag, read_confg); 738 | oob = std::make_shared(); 739 | oob->rank = rank; 740 | oob->size = size; 741 | oob->store = store; 742 | comm = nullptr; 743 | cuda_ee = nullptr; 744 | static uint32_t id = 0; 745 | uint32_t pg_id = (id++ % TORCH_UCX_MAX_COMM); 746 | 747 | logger = c10::make_intrusive( 748 | c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"), 749 | TORCH_UCC_INIT); 750 | TORCH_UCC_LOG_INFO( 751 | TORCH_UCC_INIT, 752 | c10::str("Created ProcessGroupUCC with ", size, " ranks, with timeout ", timeout_.count(), " secs")); 753 | std::string envs = ""; 754 | for (auto& torch_ucc_env : torch_ucc_envs_map) { 755 | envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second); 756 | } 757 | TORCH_UCC_LOG_INFO( 758 | TORCH_UCC_INIT, 759 | c10::str( 760 | "Successfully read and set ProcessGroupUCC env. variables as followings", 761 | envs)); 762 | 763 | if (torch_ucc_config.enable_health_check) { 764 | // Perform health check by initializing dummy communicators and destroying 765 | // them. This will help indicate any UCC/UCX-related issues prior to the first 766 | // collective. 767 | // Run it in a separate thread and wait on CV to handle timeouts so that if there 768 | // are hangs, the main thread can still run correctly. 769 | runHealthCheck(); 770 | } 771 | if (torch_ucc_config.enable_comms_logger) { 772 | logger->initCommsTracer(); 773 | } 774 | } 775 | 776 | ProcessGroupUCC::~ProcessGroupUCC() { 777 | if (torch_ucc_config.enable_comms_logger) { 778 | logger->flushComms(this->getRank(), this->getSize()); 779 | } 780 | if (comm) { 781 | logger->setPhase(TORCH_UCC_FINALIZE); 782 | comm->ucc_destroy_team(team); 783 | TORCH_UCC_LOG_INFO( 784 | TORCH_UCC_FINALIZE, "Successfully destroyed UCC library"); 785 | #ifndef USE_ACTIVE_SETS 786 | comm->ucx_disconnect_eps(eps, oob); 787 | TORCH_UCC_LOG_INFO( 788 | TORCH_UCC_FINALIZE, "Successfully destroyed UCX library"); 789 | #endif 790 | try { 791 | if (cuda_ee) { 792 | ucc_ee_destroy(cuda_ee); 793 | } 794 | if ((size_t) oob->store->add(oob->getKey("ucc_pg_closed"), 1) == 795 | this->getSize()) { 796 | std::vector val = {1}; 797 | oob->store->set(oob->getKey("ucc_pg_finished"), val); 798 | } else { 799 | oob->store->wait({oob->getKey("ucc_pg_finished")}); 800 | } 801 | } catch (std::exception& ex) { 802 | TORCH_UCC_LOG_INFO( 803 | TORCH_UCC_FINALIZE, 804 | c10::str( 805 | "(~ProcessGroupUCC) Caught error in Store Operation .. ", 806 | "[", 807 | ex.what(), 808 | "]")); 809 | } 810 | comm = nullptr; 811 | } 812 | } 813 | 814 | #ifdef USE_CUDA 815 | // Return CUDA device with ordinal given by input rank. 816 | c10::Device getCUDADeviceForRank(int rank) { 817 | TORCH_CHECK(rank >= 0, "Invalid rank ", rank); 818 | auto numGPUs = at::cuda::getNumGPUs(); 819 | auto deviceIdx = static_cast(rank % numGPUs); 820 | return c10::Device(c10::DeviceType::CUDA, deviceIdx); 821 | } 822 | #endif 823 | 824 | void ProcessGroupUCC::runHealthCheck() { 825 | // Run health check in a separate thread and wait on CV to handle timeouts. 826 | // This design allows us to handle hangs. 827 | 828 | // When size_ is 1, there is no need to do any communication at all. 829 | if (size_ == 1) return; 830 | 831 | struct HealthCheckData { 832 | std::mutex healthCheckMutex; 833 | std::condition_variable healthCheckCv; 834 | bool ucxHealthCheckSuccess = false; 835 | bool uccHealthCheckSuccess = false; 836 | std::exception_ptr healthCheckException; 837 | } healthCheckData; 838 | 839 | auto t = std::thread([&healthCheckData, this]() { 840 | std::list devices{c10::kCPU}; 841 | #ifdef USE_CUDA 842 | c10::cuda::OptionalCUDAGuard gpuGuard; 843 | if (at::cuda::is_available()) { 844 | devices.emplace_front(getCUDADeviceForRank(rank_)); 845 | } 846 | #endif 847 | for (auto device : devices) { 848 | bool is_last_device = (device == devices.back()); 849 | try { 850 | auto oob = std::make_shared(); 851 | oob->rank = this->oob->rank; 852 | oob->size = this->oob->size; 853 | oob->store = this->oob->store; 854 | 855 | ucc_team_h team = nullptr; 856 | uint32_t comm_id; 857 | #ifdef USE_CUDA 858 | if (device.is_cuda()) { 859 | gpuGuard.set_index(device.index()); 860 | } 861 | #endif 862 | auto comm = Comm::get_comm(comm_id, device, oob, logger, true); 863 | #ifdef USE_ACTIVE_SETS 864 | TORCH_UCC_LOG_INFO( 865 | TORCH_UCC_HEALTH_CHECK, 866 | c10::str( 867 | "Skip UCX health check in UCC when USE_ACTIVE_SETS is set.") 868 | ); 869 | healthCheckData.ucxHealthCheckSuccess = true; 870 | #else 871 | std::vector eps; 872 | comm->ucx_connect_eps(eps, oob); 873 | comm->ucx_disconnect_eps(eps, oob); 874 | TORCH_UCC_LOG_INFO( 875 | TORCH_UCC_HEALTH_CHECK, 876 | c10::str( 877 | "UCX library health check succeed for device ", 878 | c10::DeviceTypeName(device.type())) 879 | ); 880 | // Mark ucx health check as complete. 881 | if (is_last_device) { 882 | std::lock_guard lk(healthCheckData.healthCheckMutex); 883 | healthCheckData.ucxHealthCheckSuccess = true; 884 | } 885 | #endif 886 | comm->ucc_create_team(team, oob); 887 | comm->ucc_destroy_team(team); 888 | TORCH_UCC_LOG_INFO( 889 | TORCH_UCC_HEALTH_CHECK, 890 | c10::str( 891 | "UCC library health check succeed for device ", 892 | c10::DeviceTypeName(device.type())) 893 | ); 894 | // Mark ucc health check as complete. 895 | if (is_last_device) { 896 | std::lock_guard lk(healthCheckData.healthCheckMutex); 897 | healthCheckData.uccHealthCheckSuccess = true; 898 | } 899 | 900 | comm = nullptr; 901 | oob = nullptr; 902 | // Notify main thread the health check is complete. 903 | if (is_last_device) { 904 | healthCheckData.healthCheckCv.notify_one(); 905 | } 906 | } catch (const std::exception& e) { 907 | // Populate exception ptr. 908 | healthCheckData.healthCheckException = std::current_exception(); 909 | // Unblock waiting main thread which will report exception. 910 | healthCheckData.healthCheckCv.notify_one(); 911 | } // Unknown exceptions will just cause the program to terminate. 912 | } 913 | }); 914 | // We don't need to join the thread, just need to verify health check via the 915 | // CV. Hence we detach the thread here. 916 | t.detach(); // NOLINT 917 | TORCH_UCC_LOG_INFO( 918 | TORCH_UCC_HEALTH_CHECK, 919 | c10::str( 920 | "will wait up to ", timeout_.count(), 921 | " msec for UCC health check to complete.") 922 | ); 923 | std::unique_lock lock(healthCheckData.healthCheckMutex); 924 | healthCheckData.healthCheckCv.wait_for( 925 | lock, timeout_, [&healthCheckData]() { 926 | return healthCheckData.ucxHealthCheckSuccess && healthCheckData.uccHealthCheckSuccess; 927 | }); 928 | 929 | if (healthCheckData.healthCheckException) { 930 | std::rethrow_exception(healthCheckData.healthCheckException); 931 | } 932 | // If there is no exception, the likely culprit is a timeout/hang 933 | #ifndef USE_ACTIVE_SETS 934 | TORCH_CHECK( 935 | healthCheckData.ucxHealthCheckSuccess, 936 | "ProcessGroupUCC: Health check failure: Failed to initialize UCX on rank ", 937 | rank_); 938 | #endif 939 | TORCH_CHECK( 940 | healthCheckData.uccHealthCheckSuccess, 941 | "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ", 942 | rank_); 943 | } 944 | 945 | void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) { 946 | args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; 947 | args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT; 948 | args.timeout = timeout_.count(); 949 | } 950 | 951 | #ifdef USE_CUDA 952 | std::unique_ptr ProcessGroupUCC::getPooledEvent() { 953 | std::unique_ptr ev; 954 | std::lock_guard lock(ep.event_pool_mutex); 955 | if (ep.event_pool.empty()) { 956 | ev = std::make_unique(); 957 | } else { 958 | ev = std::move(ep.event_pool.front()); 959 | ep.event_pool.pop(); 960 | } 961 | return ev; 962 | } 963 | #endif 964 | 965 | template 966 | c10::intrusive_ptr ProcessGroupUCC::collective_post( 967 | OpType opType, 968 | PreProcess preproc, 969 | PostProcess postproc, 970 | ucc_coll_args_t& coll, 971 | std::unique_ptr data, 972 | c10::Device dev, 973 | std::vector &inputTensors, 974 | std::vector &outputTensors, 975 | const char* prof_title) { 976 | set_timeout(coll); 977 | auto work = c10::make_intrusive( 978 | opType, torch_ucc_config.enable_profiling ? prof_title : nullptr, logger); 979 | 980 | RECORD_COMMS_TRACE( 981 | logger->trace_generator, 982 | work, 983 | opType, 984 | this->getRank(), 985 | this->getSize(), 986 | inputTensors, 987 | outputTensors); 988 | 989 | // Store references to outputs to be used by result 990 | work->outputs_ = std::make_shared>(outputTensors); 991 | switch (dev.type()) { 992 | case c10::DeviceType::CPU: { 993 | if (torch_ucc_config.use_future) { 994 | work->future_ = c10::make_intrusive( 995 | c10::ListType::create(c10::TensorType::get())); 996 | } 997 | comm->enqueue_collective(std::move(data), work, coll, team); 998 | return work; 999 | } 1000 | #ifdef USE_CUDA 1001 | case c10::DeviceType::CUDA: { 1002 | auto cuda_ev = getPooledEvent(); 1003 | cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index())); 1004 | cuda_ev->block(*stream); 1005 | at::cuda::CUDAStreamGuard guard(*stream); 1006 | preproc(); 1007 | comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee); 1008 | postproc(); 1009 | cuda_ev->record(*stream); 1010 | work->fence = std::move(cuda_ev); 1011 | work->ep = &ep; 1012 | if (torch_ucc_config.use_future) { 1013 | c10::cuda::CUDAMultiStreamGuard streamGuard(*stream); 1014 | std::vector devList{dev}; 1015 | work->future_ = c10::make_intrusive( 1016 | c10::ListType::create(c10::TensorType::get()), devList); 1017 | // Add a callback that runs profiling end callbacks 1018 | if (work->recordFunctionEndCallback_) { 1019 | work->future_->addCallback([work](at::ivalue::Future& /* unused */) { 1020 | work->recordFunctionEndCallback_(); 1021 | }); 1022 | } 1023 | 1024 | work->future_->markCompleted(c10::IValue(outputTensors)); 1025 | } 1026 | return work; 1027 | } 1028 | #endif // #ifdef USE_CUDA 1029 | default: { 1030 | TORCH_UCC_LOG_ERROR( 1031 | TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str())); 1032 | throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); 1033 | } 1034 | } 1035 | } 1036 | 1037 | c10::intrusive_ptr ProcessGroupUCC::allgather( 1038 | std::vector>& outputTensors, 1039 | std::vector& inputTensors, 1040 | const AllgatherOptions& /* unused */) { 1041 | auto& tensor = inputTensors[0]; 1042 | check_device(tensor.device(), outputTensors[0][0].device()); 1043 | initComm(tensor.device()); 1044 | 1045 | if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) { 1046 | AllgathervWorkData* data = new AllgathervWorkData(size_); 1047 | for (int i = 0; i < size_; i++) { 1048 | data->recv_lengths[i] = tensor.element_size() * tensor.numel(); 1049 | data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr(); 1050 | } 1051 | ucc_coll_args_t coll; 1052 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1053 | coll.flags = 1054 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; 1055 | coll.coll_type = UCC_COLL_TYPE_ALLGATHERV; 1056 | coll.src.info.buffer = tensor.data_ptr(); 1057 | coll.src.info.count = tensor.element_size() * tensor.numel(); 1058 | coll.src.info.datatype = UCC_DT_UINT8; 1059 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1060 | coll.dst.info_v.buffer = nullptr; 1061 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); 1062 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); 1063 | coll.dst.info_v.datatype = UCC_DT_UINT8; 1064 | coll.dst.info_v.mem_type = 1065 | to_ucc_memType(outputTensors[0][0].device().type()); 1066 | SAVE_TENSORS(inputTensors, data->src); 1067 | SAVE_TENSORS(outputTensors[0], data->dst); 1068 | 1069 | return collective_post( 1070 | OpType::ALLGATHER, 1071 | []() {}, 1072 | []() {}, 1073 | coll, 1074 | std::unique_ptr(data), 1075 | tensor.device(), 1076 | inputTensors, 1077 | outputTensors[0], 1078 | "ucc:allgatherv"); 1079 | } else { 1080 | WorkData* data = new WorkData(); 1081 | std::vector flat_output(outputTensors.size()); 1082 | for (size_t i = 0; i < outputTensors.size(); i++) { 1083 | TORCH_CHECK(outputTensors[i].size() == outputTensors.size() * size_, 1084 | "Tensor output list is not valid for the number of participants"); 1085 | flat_output[i] = c10d::newLikeFlat(outputTensors, i); 1086 | } 1087 | SAVE_TENSORS(flat_output, data->flat); 1088 | ucc_coll_args_t coll; 1089 | coll.mask = 0; 1090 | coll.flags = 0; 1091 | coll.coll_type = UCC_COLL_TYPE_ALLGATHER; 1092 | coll.src.info.buffer = tensor.data_ptr(); 1093 | coll.src.info.count = tensor.numel(); 1094 | coll.src.info.datatype = to_ucc_dType(tensor); 1095 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1096 | coll.dst.info.buffer = flat_output[0].data_ptr(); 1097 | coll.dst.info.count = flat_output[0].numel(); 1098 | coll.dst.info.datatype = to_ucc_dType(flat_output[0]); 1099 | coll.dst.info.mem_type = 1100 | to_ucc_memType(outputTensors[0][0].device().type()); 1101 | 1102 | auto copy_from_flat = [&] { 1103 | bool asyncCopy = false; 1104 | #ifdef USE_CUDA 1105 | bool isCuda = outputTensors[0][0].device().is_cuda();; 1106 | #endif 1107 | for (size_t i = 0; i < outputTensors.size(); i++) { 1108 | auto inumel = inputTensors[i].numel(); 1109 | for (size_t j = 0; j < outputTensors[i].size(); j++) { 1110 | TORCH_CHECK( 1111 | (outputTensors[i][j].numel() == inumel), 1112 | "Tensor operand counts must be same"); 1113 | #ifdef USE_CUDA 1114 | if (isCuda) { 1115 | c10::cuda::CUDACachingAllocator::recordStream( 1116 | outputTensors[i][j].storage().data_ptr(), (*stream)); 1117 | asyncCopy = true; 1118 | } 1119 | #endif 1120 | outputTensors[i][j].copy_(flat_output[i][j], asyncCopy); 1121 | } 1122 | } 1123 | }; 1124 | return collective_post( 1125 | OpType::ALLGATHER, 1126 | []() {}, 1127 | copy_from_flat, 1128 | coll, 1129 | std::unique_ptr(data), 1130 | tensor.device(), 1131 | inputTensors, 1132 | outputTensors[0], 1133 | "ucc:allgather"); 1134 | } 1135 | } 1136 | 1137 | c10::intrusive_ptr ProcessGroupUCC::_allgather_base( 1138 | at::Tensor& outputTensor, 1139 | at::Tensor& inputTensor, 1140 | const AllgatherOptions& opts) { 1141 | check_tensor({outputTensor}); 1142 | check_tensor({inputTensor}); 1143 | initComm(outputTensor.device()); 1144 | 1145 | WorkData* data = new WorkData(); 1146 | 1147 | ucc_coll_args_t coll; 1148 | coll.mask = 0; 1149 | coll.flags = 0; 1150 | coll.coll_type = UCC_COLL_TYPE_ALLGATHER; 1151 | coll.src.info.buffer = inputTensor.data_ptr(); 1152 | coll.src.info.count = inputTensor.numel(); 1153 | coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type()); 1154 | coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); 1155 | coll.dst.info.buffer = outputTensor.data_ptr(); 1156 | coll.dst.info.count = outputTensor.numel(); 1157 | coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type()); 1158 | coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); 1159 | 1160 | std::vector inputTensors = {inputTensor}; 1161 | std::vector outputTensors = {outputTensor}; 1162 | SAVE_TENSORS(inputTensors, data->src); 1163 | SAVE_TENSORS(outputTensors, data->dst); 1164 | 1165 | return collective_post( 1166 | OpType::_ALLGATHER_BASE, 1167 | []() {}, 1168 | []() {}, 1169 | coll, 1170 | std::unique_ptr(data), 1171 | outputTensor.device(), 1172 | inputTensors, 1173 | outputTensors, 1174 | "ucc:allgather_base"); 1175 | } 1176 | 1177 | c10::intrusive_ptr ProcessGroupUCC::allreduce( 1178 | std::vector& tensors, 1179 | const AllreduceOptions& opts) { 1180 | check_tensor(tensors); 1181 | auto& tensor = tensors[0]; 1182 | initComm(tensor.device()); 1183 | WorkData* data = new WorkData(); 1184 | 1185 | ucc_coll_args_t coll; 1186 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1187 | coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; 1188 | coll.coll_type = UCC_COLL_TYPE_ALLREDUCE; 1189 | coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type()); 1190 | coll.src.info.buffer = nullptr; 1191 | coll.src.info.count = tensor.numel(); 1192 | coll.src.info.datatype = to_ucc_dType(tensor); 1193 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1194 | coll.dst.info.buffer = tensor.data_ptr(); 1195 | coll.dst.info.count = tensor.numel(); 1196 | coll.dst.info.datatype = to_ucc_dType(tensor); 1197 | coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); 1198 | SAVE_TENSORS(tensors, data->dst); 1199 | return collective_post( 1200 | OpType::ALLREDUCE, 1201 | []() {}, 1202 | []() {}, 1203 | coll, 1204 | std::unique_ptr(data), 1205 | tensor.device(), 1206 | tensors, 1207 | tensors, 1208 | "ucc:allreduce"); 1209 | } 1210 | 1211 | c10::intrusive_ptr ProcessGroupUCC::allreduce_coalesced( 1212 | std::vector& /* unused */, 1213 | const AllreduceCoalescedOptions& /* unused */) { 1214 | throw std::runtime_error( 1215 | "ProcessGroupUCC does not support allreduce_coalesced"); 1216 | } 1217 | 1218 | c10::intrusive_ptr ProcessGroupUCC::alltoall( 1219 | std::vector& outputTensors, 1220 | std::vector& inputTensors, 1221 | const AllToAllOptions& /* unused */) { 1222 | auto device = outputTensors[0].device(); 1223 | for (const auto r : c10::irange(outputTensors.size())) { 1224 | TORCH_CHECK( 1225 | device == outputTensors[r].device() && 1226 | device == inputTensors[r].device(), 1227 | "Tensors must be on the same device") 1228 | } 1229 | 1230 | initComm(device); 1231 | ucc_coll_args_t coll; 1232 | AlltoallWorkData* data; 1233 | data = new AlltoallWorkData(size_); 1234 | 1235 | /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as 1236 | follow. 1237 | 1. store addresses of each tensor directly in displacements, keep buffer 1238 | to nullptr, i.e., 0 1239 | 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size 1240 | calculation in UCC layer 1241 | 3. post Alltoallv 1242 | */ 1243 | for (const auto i : c10::irange(size_)) { 1244 | data->send_lengths[i] = 1245 | (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel()); 1246 | data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr(); 1247 | data->recv_lengths[i] = 1248 | (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel()); 1249 | data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr(); 1250 | } 1251 | 1252 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1253 | coll.flags = 1254 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; 1255 | coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; 1256 | coll.src.info_v.buffer = 0; 1257 | coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); 1258 | coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); 1259 | coll.src.info_v.datatype = UCC_DT_UINT8; 1260 | coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type()); 1261 | coll.dst.info_v.buffer = 0; 1262 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); 1263 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); 1264 | coll.dst.info_v.datatype = UCC_DT_UINT8; 1265 | coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type()); 1266 | 1267 | SAVE_TENSORS(inputTensors, data->src); 1268 | SAVE_TENSORS(outputTensors, data->dst); 1269 | 1270 | return collective_post( 1271 | OpType::ALLTOALL, 1272 | []() {}, 1273 | []() {}, 1274 | coll, 1275 | std::unique_ptr(data), 1276 | device, 1277 | inputTensors, 1278 | outputTensors, 1279 | "ucc:alltoall"); 1280 | } 1281 | 1282 | c10::intrusive_ptr ProcessGroupUCC::alltoall_base( 1283 | at::Tensor& outputTensor, 1284 | at::Tensor& inputTensor, 1285 | std::vector& outputSplitSizes, 1286 | std::vector& inputSplitSizes, 1287 | const AllToAllOptions& /* unused */) { 1288 | check_device(inputTensor.device(), outputTensor.device()); 1289 | initComm(inputTensor.device()); 1290 | ucc_coll_args_t coll; 1291 | AlltoallWorkData* data; 1292 | 1293 | if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) { 1294 | data = new AlltoallWorkData(0); 1295 | TORCH_CHECK( 1296 | (outputTensor.size(0) % size_ == 0) && 1297 | (inputTensor.size(0) % size_ == 0), 1298 | "Tensor's dim 0 does not divide equally across group size"); 1299 | coll.mask = 0; 1300 | coll.flags = 0; 1301 | coll.coll_type = UCC_COLL_TYPE_ALLTOALL; 1302 | coll.src.info.buffer = inputTensor.data_ptr(); 1303 | coll.src.info.count = inputTensor.element_size() * inputTensor.numel(); 1304 | coll.src.info.datatype = UCC_DT_UINT8; 1305 | coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); 1306 | coll.dst.info.buffer = outputTensor.data_ptr(); 1307 | coll.dst.info.count = outputTensor.element_size() * outputTensor.numel(); 1308 | coll.dst.info.datatype = UCC_DT_UINT8; 1309 | coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); 1310 | coll.flags = 0; 1311 | } else { 1312 | data = new AlltoallWorkData(size_); 1313 | c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); 1314 | c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); 1315 | computeLengthsAndOffsets( 1316 | outputSplitSizes, 1317 | outputTensor, 1318 | &data->recv_lengths, 1319 | &data->recv_offsets); 1320 | computeLengthsAndOffsets( 1321 | inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets); 1322 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1323 | coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; 1324 | coll.src.info_v.buffer = inputTensor.data_ptr(); 1325 | coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); 1326 | coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); 1327 | coll.src.info_v.datatype = to_ucc_dType(inputTensor); 1328 | coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type()); 1329 | coll.dst.info_v.buffer = outputTensor.data_ptr(); 1330 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); 1331 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); 1332 | coll.dst.info_v.datatype = to_ucc_dType(outputTensor); 1333 | coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type()); 1334 | coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | 1335 | UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | 1336 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | 1337 | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; 1338 | 1339 | if (torch_ucc_config.enable_comms_logger) { 1340 | logger->trace_generator->recordOptionalInfo(outputSplitSizes, inputSplitSizes); 1341 | } 1342 | } 1343 | std::vector inputTensors = {inputTensor}; 1344 | std::vector outputTensors = {outputTensor}; 1345 | SAVE_TENSORS(inputTensors, data->src); 1346 | SAVE_TENSORS(outputTensors, data->dst); 1347 | 1348 | return collective_post( 1349 | OpType::ALLTOALL_BASE, 1350 | []() {}, 1351 | []() {}, 1352 | coll, 1353 | std::unique_ptr(data), 1354 | inputTensor.device(), 1355 | inputTensors, 1356 | outputTensors, 1357 | "ucc:alltoall"); 1358 | } 1359 | 1360 | c10::intrusive_ptr ProcessGroupUCC::barrier( 1361 | const BarrierOptions& opts) { 1362 | c10::Device device = c10::Device(c10::DeviceType::CPU); 1363 | #ifdef USE_CUDA 1364 | auto numGPUs = c10::cuda::device_count(); 1365 | if (!opts.device_ids.empty()) { 1366 | device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front()); 1367 | } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) { 1368 | device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index); 1369 | } else if (numGPUs > 0) { 1370 | int8_t deviceIdx = static_cast(c10::cuda::current_device()); 1371 | // if current device is 0, likely the device is not set, use the best guess 1372 | if (0 == (int)deviceIdx) { 1373 | deviceIdx = static_cast(this->getRank() % numGPUs); 1374 | } 1375 | TORCH_UCC_LOG_INFO( 1376 | TORCH_UCC_COLL_POST, 1377 | c10::str( 1378 | "post barrier before specifying any GPU while there are ", 1379 | numGPUs, 1380 | " GPUs available. ", 1381 | "Not clear if GPU barrier is required, using GPU ", 1382 | (int)deviceIdx, 1383 | " to perform barrier. ", 1384 | "Specify device_ids option in barrier() to force ", 1385 | "use of a particular device")); 1386 | device = c10::Device(c10::DeviceType::CUDA, deviceIdx); 1387 | } 1388 | #endif 1389 | initComm(device); 1390 | 1391 | ucc_coll_args_t coll; 1392 | coll.mask = 0; 1393 | coll.flags = 0; 1394 | coll.coll_type = UCC_COLL_TYPE_BARRIER; 1395 | auto dummy_tensor = std::vector(); 1396 | return collective_post( 1397 | OpType::BARRIER, 1398 | []() {}, 1399 | []() {}, 1400 | coll, 1401 | nullptr, 1402 | device, 1403 | dummy_tensor, 1404 | dummy_tensor, 1405 | "ucc:barrier"); 1406 | } 1407 | 1408 | c10::intrusive_ptr ProcessGroupUCC::broadcast( 1409 | std::vector& tensors, 1410 | const BroadcastOptions& opts) { 1411 | check_tensor(tensors); 1412 | auto& tensor = tensors[0]; 1413 | initComm(tensor.device()); 1414 | WorkData* data = new WorkData(); 1415 | 1416 | ucc_coll_args_t coll; 1417 | coll.mask = 0; 1418 | coll.flags = 0; 1419 | coll.coll_type = UCC_COLL_TYPE_BCAST; 1420 | coll.src.info.buffer = tensor.data_ptr(); 1421 | coll.src.info.count = tensor.numel(); 1422 | coll.src.info.datatype = to_ucc_dType(tensor); 1423 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1424 | coll.root = opts.rootRank; 1425 | SAVE_TENSORS(tensors, data->dst); 1426 | 1427 | if (torch_ucc_config.enable_comms_logger) { 1428 | logger->trace_generator->recordOptionalInfo(opts.rootRank); 1429 | } 1430 | 1431 | return collective_post( 1432 | OpType::BROADCAST, 1433 | []() {}, 1434 | []() {}, 1435 | coll, 1436 | std::unique_ptr(data), 1437 | tensor.device(), 1438 | tensors, 1439 | tensors, 1440 | "ucc:broadcast"); 1441 | } 1442 | 1443 | c10::intrusive_ptr ProcessGroupUCC::gather( 1444 | std::vector>& outputTensors, 1445 | std::vector& inputTensors, 1446 | const GatherOptions& opts) { 1447 | std::vector outputs; 1448 | auto& input = inputTensors[0]; 1449 | initComm(input.device()); 1450 | 1451 | AllgathervWorkData* data = new AllgathervWorkData(size_); 1452 | ucc_coll_args_t coll; 1453 | coll.root = opts.rootRank; 1454 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1455 | coll.flags = 1456 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; 1457 | coll.coll_type = UCC_COLL_TYPE_GATHERV; 1458 | 1459 | /* for non-root ranks, only src is valid */ 1460 | coll.src.info.buffer = input.data_ptr(); 1461 | coll.src.info.count = (uint64_t)(input.element_size() * input.numel()); 1462 | coll.src.info.datatype = UCC_DT_UINT8; 1463 | coll.src.info.mem_type = to_ucc_memType(input.device().type()); 1464 | 1465 | if (getRank() == opts.rootRank) { 1466 | if (outputTensors.size() != 1) { 1467 | TORCH_UCC_LOG_ERROR( 1468 | TORCH_UCC_COLL_POST, 1469 | c10::str( 1470 | "gather requires a single-element output list containing a list with ", 1471 | getSize(), 1472 | " tensors.")); 1473 | } else if (outputTensors[0].size() != static_cast(getSize())) { 1474 | TORCH_UCC_LOG_ERROR( 1475 | TORCH_UCC_COLL_POST, 1476 | c10::str( 1477 | "Incorrect output list size ", 1478 | outputTensors[0].size(), 1479 | ". Output list size should be ", 1480 | getSize(), 1481 | ", same as size of the process group.")); 1482 | } 1483 | outputs = outputTensors[0]; 1484 | 1485 | for (int i = 0; i < size_; i++) { 1486 | data->recv_lengths[i] = 1487 | (uint64_t)(outputs[i].element_size() * outputs[i].numel()); 1488 | data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr(); 1489 | } 1490 | /* use gatherv and store non-contiguous addresses in displacements to avoid 1491 | * flatten outputTensors */ 1492 | coll.dst.info_v.buffer = nullptr; 1493 | coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); 1494 | coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); 1495 | coll.dst.info_v.datatype = UCC_DT_UINT8; 1496 | coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type()); 1497 | 1498 | SAVE_TENSORS(outputs, data->dst); 1499 | } else { 1500 | // for non-root ranks, outputTensors should be an empty list 1501 | if (outputTensors.size() != 0) { 1502 | TORCH_UCC_LOG_ERROR( 1503 | TORCH_UCC_COLL_POST, "requires empty output on non-root"); 1504 | } 1505 | outputs = {}; 1506 | // append a empty tensor to the list to be used by future mark 1507 | outputs.emplace_back(); 1508 | } 1509 | 1510 | SAVE_TENSORS(inputTensors, data->src); 1511 | 1512 | return collective_post( 1513 | OpType::GATHER, 1514 | []() {}, 1515 | []() {}, 1516 | coll, 1517 | std::unique_ptr(data), 1518 | input.device(), 1519 | inputTensors, 1520 | outputs, 1521 | "ucc:gather"); 1522 | } 1523 | 1524 | c10::intrusive_ptr ProcessGroupUCC::reduce( 1525 | std::vector& tensors, 1526 | const ReduceOptions& opts) { 1527 | check_tensor(tensors); 1528 | auto& tensor = tensors[0]; 1529 | initComm(tensor.device()); 1530 | WorkData* data = new WorkData(); 1531 | 1532 | ucc_coll_args_t coll; 1533 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1534 | coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; 1535 | coll.coll_type = UCC_COLL_TYPE_REDUCE; 1536 | coll.op = ucc_op_map.at(opts.reduceOp); 1537 | coll.root = opts.rootRank; 1538 | coll.src.info.buffer = tensor.data_ptr(); 1539 | coll.src.info.count = tensor.numel(); 1540 | coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); 1541 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1542 | coll.dst.info.buffer = tensor.data_ptr(); 1543 | coll.dst.info.count = tensor.numel(); 1544 | coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); 1545 | coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); 1546 | SAVE_TENSORS(tensors, data->dst); 1547 | return collective_post( 1548 | OpType::REDUCE, 1549 | []() {}, 1550 | []() {}, 1551 | coll, 1552 | std::unique_ptr(data), 1553 | tensor.device(), 1554 | tensors, 1555 | tensors, 1556 | "ucc:reduce"); 1557 | } 1558 | 1559 | c10::intrusive_ptr ProcessGroupUCC::reduce_scatter( 1560 | std::vector& outputTensors, 1561 | std::vector>& inputTensors, 1562 | const ReduceScatterOptions& opts) { 1563 | TORCH_CHECK( 1564 | (outputTensors.size() == inputTensors.size()), 1565 | "Tensor input/output list for reduce_scatter must have same size"); 1566 | check_tensor(outputTensors); 1567 | check_device(inputTensors[0][0].device(), outputTensors[0].device()); 1568 | initComm(inputTensors[0][0].device()); 1569 | auto data = std::make_unique(); 1570 | std::vector flat_input(inputTensors.size()); 1571 | for (size_t i = 0; i < inputTensors.size(); i++) { 1572 | TORCH_CHECK(inputTensors[i].size() == inputTensors.size() * size_, 1573 | "Tensor input list is not valid for the number of participants"); 1574 | flat_input[i] = c10d::newLikeFlat(inputTensors, i); 1575 | } 1576 | SAVE_TENSORS(flat_input, data->flat); 1577 | check_tensor(flat_input); 1578 | ucc_coll_args_t coll; 1579 | coll.mask = 0; 1580 | coll.flags = 0; 1581 | coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; 1582 | coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type()); 1583 | 1584 | coll.src.info.buffer = flat_input[0].data_ptr(); 1585 | coll.src.info.count = flat_input[0].numel(); 1586 | coll.src.info.datatype = to_ucc_dType(flat_input[0]); 1587 | coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type()); 1588 | coll.dst.info.buffer = outputTensors[0].data_ptr(); 1589 | coll.dst.info.count = outputTensors[0].numel(); 1590 | coll.dst.info.datatype = to_ucc_dType(outputTensors[0]); 1591 | coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type()); 1592 | 1593 | SAVE_TENSORS(inputTensors[0], data->src); 1594 | SAVE_TENSORS(outputTensors, data->dst); 1595 | 1596 | auto copy_to_flat = [&] { 1597 | bool asyncCopy = false; 1598 | auto isize = inputTensors.size(); 1599 | #ifdef USE_CUDA 1600 | bool isCuda = inputTensors[0][0].device().is_cuda(); 1601 | #endif 1602 | for (size_t i = 0; i < isize; i++) { 1603 | auto onumel = outputTensors[i].numel(); 1604 | for (size_t j = 0; j < inputTensors[i].size(); j++) { 1605 | TORCH_CHECK( 1606 | (inputTensors[i][j].numel() == onumel), 1607 | "Tensor operand counts must be same"); 1608 | #ifdef USE_CUDA 1609 | if (isCuda) { 1610 | c10::cuda::CUDACachingAllocator::recordStream( 1611 | inputTensors[i][j].storage().data_ptr(), (*stream)); 1612 | asyncCopy = true; 1613 | } 1614 | #endif 1615 | flat_input[i][j].copy_(inputTensors[i][j], asyncCopy); 1616 | } 1617 | } 1618 | }; 1619 | 1620 | return collective_post( 1621 | OpType::REDUCE_SCATTER, 1622 | copy_to_flat, 1623 | []() {}, 1624 | coll, 1625 | std::move(data), 1626 | inputTensors[0][0].device(), 1627 | inputTensors[0], 1628 | outputTensors, 1629 | "ucc:reduce_scatter"); 1630 | } 1631 | 1632 | c10::intrusive_ptr ProcessGroupUCC::scatter( 1633 | std::vector& outputTensors, 1634 | std::vector>& inputTensors, 1635 | const ScatterOptions& opts) { 1636 | auto& tensor = outputTensors[0]; 1637 | initComm(tensor.device()); 1638 | 1639 | ScattervWorkData* data = new ScattervWorkData(size_); 1640 | ucc_coll_args_t coll; 1641 | coll.root = opts.rootRank; 1642 | coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; 1643 | coll.flags = 1644 | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; 1645 | coll.coll_type = UCC_COLL_TYPE_SCATTERV; 1646 | 1647 | if (getRank() == opts.rootRank) { 1648 | /* src is only valid at non-root rank */ 1649 | if (inputTensors.size() != 1) { 1650 | TORCH_UCC_LOG_ERROR( 1651 | TORCH_UCC_COLL_POST, 1652 | c10::str( 1653 | "gather requires a single-element output list containing a list with ", 1654 | getSize(), 1655 | " tensors.")); 1656 | } else if (inputTensors[0].size() != static_cast(getSize())) { 1657 | TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_POST, 1658 | c10::str( 1659 | "Incorrect output list size ", inputTensors[0].size(), 1660 | ". Output list size should be ", getSize(), 1661 | ", same as size of the process group.")); 1662 | } 1663 | 1664 | for (int i = 0; i < size_; i++) { 1665 | data->send_lengths[i] = (uint64_t) tensor.element_size() * tensor.numel(); 1666 | data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr(); 1667 | } 1668 | /* use scatter and store non-contiguous addresses in displacements to avoid 1669 | * flatten inputTensors */ 1670 | coll.src.info_v.buffer = nullptr; 1671 | coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); 1672 | coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); 1673 | coll.src.info_v.datatype = UCC_DT_UINT8; 1674 | coll.src.info_v.mem_type = 1675 | to_ucc_memType(inputTensors[0][0].device().type()); 1676 | 1677 | SAVE_TENSORS(inputTensors[0], data->src); 1678 | } else { 1679 | // for non-root ranks, inputTensors should be an empty list 1680 | if (inputTensors.size() != 0) { 1681 | TORCH_UCC_LOG_ERROR( 1682 | TORCH_UCC_COLL_POST, "requires empty output on non-root"); 1683 | } 1684 | } 1685 | 1686 | coll.dst.info.buffer = tensor.data_ptr(); 1687 | coll.dst.info.count = (uint64_t) tensor.element_size() * tensor.numel(); 1688 | coll.dst.info.datatype = UCC_DT_UINT8; 1689 | coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); 1690 | SAVE_TENSORS(outputTensors, data->dst); 1691 | 1692 | return collective_post( 1693 | OpType::SCATTER, 1694 | []() {}, 1695 | []() {}, 1696 | coll, 1697 | std::unique_ptr(data), 1698 | tensor.device(), 1699 | inputTensors[0], 1700 | outputTensors, 1701 | "ucc:scatter"); 1702 | } 1703 | 1704 | c10::intrusive_ptr ProcessGroupUCC::send( 1705 | std::vector& tensors, 1706 | int dstRank, 1707 | int tag) { 1708 | check_tensor(tensors); 1709 | auto& tensor = tensors[0]; 1710 | initComm(tensor.device()); 1711 | 1712 | #ifdef USE_ACTIVE_SETS 1713 | WorkData* data = new WorkData(); 1714 | ucc_coll_args_t coll; 1715 | coll.tag = tag; 1716 | coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; 1717 | coll.flags = 0; 1718 | coll.coll_type = UCC_COLL_TYPE_BCAST; 1719 | coll.src.info.buffer = tensor.data_ptr(); 1720 | coll.src.info.count = tensor.numel(); 1721 | coll.src.info.datatype = to_ucc_dType(tensor); 1722 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1723 | coll.root = getRank(); 1724 | 1725 | coll.active_set.size = 2; 1726 | coll.active_set.start = getRank(); 1727 | coll.active_set.stride = dstRank - getRank(); 1728 | SAVE_TENSORS(tensors, data->dst); 1729 | 1730 | return collective_post( 1731 | OpType::SEND, 1732 | []() {}, 1733 | []() {}, 1734 | coll, 1735 | std::unique_ptr(data), 1736 | tensor.device(), 1737 | tensors, 1738 | tensors, 1739 | "ucc:send"); 1740 | #else 1741 | ucp_tag_t ucp_tag; 1742 | TORCH_UCX_MAKE_SEND_TAG(ucp_tag, tag, rank_, comm_id); 1743 | ucc_coll_req_h request = comm->send_nb( 1744 | eps[dstRank], 1745 | tensor.data_ptr(), 1746 | to_ucs_memType(tensor.device().type()), 1747 | tensor.numel() * tensor.element_size(), 1748 | ucp_tag); 1749 | 1750 | auto work = comm->enqueue_p2p(OpType::SEND, request, "ucc:send"); 1751 | // TODO: record src, dst ranks and tag 1752 | RECORD_COMMS_TRACE( 1753 | logger->trace_generator, 1754 | work, 1755 | OpType::SEND, 1756 | this->getRank(), 1757 | this->getSize(), 1758 | tensors, 1759 | tensors); 1760 | return work; 1761 | #endif 1762 | } 1763 | 1764 | c10::intrusive_ptr ProcessGroupUCC::recv( 1765 | std::vector& tensors, 1766 | int srcRank, 1767 | int tag) { 1768 | check_tensor(tensors); 1769 | auto& tensor = tensors[0]; 1770 | initComm(tensor.device()); 1771 | 1772 | #ifdef USE_ACTIVE_SETS 1773 | WorkData* data = new WorkData(); 1774 | ucc_coll_args_t coll; 1775 | coll.tag = tag; 1776 | coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; 1777 | coll.flags = 0; 1778 | coll.coll_type = UCC_COLL_TYPE_BCAST; 1779 | coll.src.info.buffer = tensor.data_ptr(); 1780 | coll.src.info.count = tensor.numel(); 1781 | coll.src.info.datatype = to_ucc_dType(tensor); 1782 | coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); 1783 | coll.root = srcRank; 1784 | 1785 | coll.active_set.size = 2; 1786 | coll.active_set.start = srcRank; 1787 | coll.active_set.stride = getRank() - srcRank; 1788 | SAVE_TENSORS(tensors, data->dst); 1789 | 1790 | return collective_post( 1791 | OpType::RECV, 1792 | []() {}, 1793 | []() {}, 1794 | coll, 1795 | std::unique_ptr(data), 1796 | tensor.device(), 1797 | tensors, 1798 | tensors, 1799 | "ucc:recv"); 1800 | #else 1801 | ucp_tag_t ucp_tag, ucp_tag_mask; 1802 | TORCH_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, srcRank, comm_id); 1803 | ucc_coll_req_h request = comm->recv_nb( 1804 | tensor.data_ptr(), 1805 | to_ucs_memType(tensor.device().type()), 1806 | tensor.numel() * tensor.element_size(), 1807 | ucp_tag, 1808 | ucp_tag_mask); 1809 | 1810 | auto work = comm->enqueue_p2p(OpType::RECV, request, "ucc:recv"); 1811 | // TODO: record src, dst ranks and tag 1812 | RECORD_COMMS_TRACE( 1813 | logger->trace_generator, 1814 | work, 1815 | OpType::RECV, 1816 | this->getRank(), 1817 | this->getSize(), 1818 | tensors, 1819 | tensors); 1820 | return work; 1821 | #endif 1822 | } 1823 | 1824 | c10::intrusive_ptr ProcessGroupUCC::recvAnysource( 1825 | std::vector& tensors, 1826 | int tag) { 1827 | check_tensor(tensors); 1828 | auto& tensor = tensors[0]; 1829 | initComm(tensor.device()); 1830 | 1831 | #ifdef USE_ACTIVE_SETS 1832 | TORCH_CHECK(false, "recvAnysource is not supported in UCC when USE_ACTIVE_SETS is set"); 1833 | #else 1834 | ucp_tag_t ucp_tag, ucp_tag_mask; 1835 | TORCH_UCX_MAKE_RECV_TAG( 1836 | ucp_tag, ucp_tag_mask, tag, TORCH_UCX_ANY_SOURCE, comm_id); 1837 | ucc_coll_req_h request = comm->recv_nb( 1838 | tensor.data_ptr(), 1839 | to_ucs_memType(tensor.device().type()), 1840 | tensor.numel() * tensor.element_size(), 1841 | ucp_tag, 1842 | ucp_tag_mask); 1843 | 1844 | auto work = comm->enqueue_p2p(OpType::RECVANYSOURCE, request, "ucc:recv"); 1845 | // TODO: record dst rank and tag 1846 | RECORD_COMMS_TRACE( 1847 | logger->trace_generator, 1848 | work, 1849 | OpType::RECVANYSOURCE, 1850 | this->getRank(), 1851 | this->getSize(), 1852 | tensors, 1853 | tensors); 1854 | return work; 1855 | #endif 1856 | } 1857 | 1858 | c10::intrusive_ptr ProcessGroupUCC::createProcessGroupUCC( 1859 | const c10::intrusive_ptr<::c10d::Store>& store, 1860 | int rank, 1861 | int size, 1862 | const std::chrono::duration& timeout) { 1863 | return c10::make_intrusive(store, rank, size, timeout); 1864 | } 1865 | 1866 | void ProcessGroupUCC::initComm(c10::Device dev) { 1867 | if (!comm) { 1868 | #ifdef USE_CUDA 1869 | if (dev.is_cuda()) { 1870 | c10::cuda::set_device(dev.index()); 1871 | } 1872 | #endif 1873 | comm = Comm::get_comm(comm_id, dev, oob, logger); 1874 | #ifndef USE_ACTIVE_SETS 1875 | comm->ucx_connect_eps(eps, oob); 1876 | TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library"); 1877 | #endif 1878 | comm->ucc_create_team(team, oob); 1879 | TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library"); 1880 | logger->setPhase(TORCH_UCC_READY); 1881 | } else { 1882 | if (dev.is_cuda()) { 1883 | if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && 1884 | (comm->cuda_device_index != dev.index())) { 1885 | TORCH_UCC_LOG_ERROR( 1886 | TORCH_UCC_INIT, 1887 | "ucc communicator was initialized with different cuda device," 1888 | "multi device is not supported"); 1889 | throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); 1890 | } 1891 | comm->cuda_device_index = dev.index(); 1892 | } 1893 | } 1894 | #ifdef USE_CUDA 1895 | // Create UCC execution engine. 1896 | if (!cuda_ee && dev.is_cuda()) { 1897 | stream = std::make_unique( 1898 | at::cuda::getStreamFromPool(true, dev.index())); 1899 | ucc_ee_params_t params; 1900 | params.ee_type = UCC_EE_CUDA_STREAM; 1901 | params.ee_context = (void*)stream->stream(); 1902 | params.ee_context_size = sizeof(cudaStream_t); 1903 | TORCH_UCC_CHECK( 1904 | ucc_ee_create(team, ¶ms, &cuda_ee), 1905 | "failed to create UCC execution engine"); 1906 | } 1907 | #endif 1908 | } 1909 | 1910 | } // namespace c10d 1911 | -------------------------------------------------------------------------------- /src/torch_ucc_comm.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | * 4 | * Copyright (c) Facebook, Inc. and its affiliates. 5 | * 6 | * This source code is licensed under the MIT license found in the 7 | * LICENSE file in the root directory of this source tree. 8 | * 9 | */ 10 | 11 | #include "torch_ucc_comm.hpp" 12 | #include "torch_ucc_tracing.hpp" 13 | 14 | namespace c10d { 15 | 16 | namespace { 17 | constexpr char kTeamRank[] = "teamr"; 18 | constexpr char kAllGatherDone[] = "ag_done"; 19 | constexpr char kAllGatherFree[] = "ag_free"; 20 | } // namespace 21 | 22 | #ifndef USE_ACTIVE_SETS 23 | CommUCX::CommUCX( 24 | int comm_size, 25 | const c10::intrusive_ptr& logger) 26 | : CommBase(logger) { 27 | ucp_params_t params; 28 | ucp_config_t* config; 29 | ucs_status_t st; 30 | ucp_worker_params_t worker_params; 31 | ucp_lib_attr_t ucp_attr; 32 | 33 | ucp_attr.field_mask = UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL; 34 | TORCH_UCX_CHECK( 35 | ucp_lib_query(&ucp_attr), "failed to query UCP lib attributes"); 36 | TORCH_CHECK( 37 | ucp_attr.max_thread_level == UCS_THREAD_MODE_MULTI, 38 | "ucx library wasn't initialized with multithreading support, " 39 | "please check ucx build options"); 40 | TORCH_UCX_CHECK( 41 | ucp_config_read("TORCH", nullptr, &config), "failed to read UCP config"); 42 | 43 | memset(¶ms, 0, sizeof(ucp_params_t)); 44 | params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE | 45 | UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_TAG_SENDER_MASK | 46 | UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP; 47 | params.request_size = sizeof(ucc_coll_req_t); 48 | params.features = UCP_FEATURE_TAG; 49 | params.estimated_num_eps = comm_size; 50 | params.tag_sender_mask = TORCH_UCX_RANK_MASK; 51 | params.request_init = [](void* request) { 52 | static_cast(request)->status = UCC_INPROGRESS; 53 | }; 54 | params.request_cleanup = [](void*) {}; 55 | TORCH_UCX_CHECK( 56 | ucp_init(¶ms, config, &context), "failed to init UCP context"); 57 | ucp_config_release(config); 58 | 59 | memset(&worker_params, 0, sizeof(ucp_worker_params_t)); 60 | worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; 61 | worker_params.thread_mode = UCS_THREAD_MODE_MULTI; 62 | st = ucp_worker_create(context, &worker_params, &worker); 63 | if (st != UCS_OK) { 64 | TORCH_UCC_LOG_ERROR( 65 | TORCH_UCC_INIT, 66 | c10::str("UCX failed to create UCP worker:", ucs_status_string(st))); 67 | ucp_cleanup(context); 68 | throw std::runtime_error(ucs_status_string(st)); 69 | } 70 | } 71 | 72 | void CommUCX::progress() { 73 | ucp_worker_progress(worker); 74 | } 75 | 76 | void CommUCX::free_request(ucc_coll_req_h request) { 77 | request->status = UCC_INPROGRESS; 78 | ucp_request_free(request); 79 | } 80 | 81 | CommUCX::~CommUCX() { 82 | if (worker != nullptr) { 83 | ucp_worker_destroy(worker); 84 | } 85 | if (context != nullptr) { 86 | ucp_cleanup(context); 87 | } 88 | worker = nullptr; 89 | context = nullptr; 90 | } 91 | #endif 92 | 93 | ucc_status_t oob_allgather( 94 | void* sbuf, 95 | void* rbuf, 96 | size_t msglen, 97 | void* coll_info, 98 | void** req) { 99 | auto* info = reinterpret_cast(coll_info); 100 | TORCH_CHECK(info != nullptr); 101 | std::vector val = std::vector( 102 | reinterpret_cast(sbuf), 103 | reinterpret_cast(sbuf) + msglen); 104 | try { 105 | info->store->set(info->getKey(kTeamRank + std::to_string(info->rank)), val); 106 | info->rbuf = rbuf; 107 | info->msglen = msglen; 108 | *req = coll_info; 109 | } catch (std::exception& ex) { 110 | LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " 111 | << "[" << ex.what() << "]"; 112 | return UCC_ERR_NO_MESSAGE; 113 | } 114 | return UCC_OK; 115 | } 116 | 117 | ucc_status_t oob_allgather_test(void* req) { 118 | auto* info = reinterpret_cast(req); 119 | TORCH_CHECK(info != nullptr); 120 | 121 | try { 122 | for (int r = 0; r < info->size; r++) { 123 | if (!info->store->check({info->getKey(kTeamRank + std::to_string(r))})) { 124 | return UCC_INPROGRESS; 125 | } 126 | } 127 | for (int r = 0; r < info->size; r++) { 128 | std::vector data = 129 | info->store->get(info->getKey(kTeamRank + std::to_string(r))); 130 | memcpy( 131 | (void*)((ptrdiff_t)info->rbuf + info->msglen * r), 132 | data.data(), 133 | info->msglen); 134 | } 135 | } catch (std::exception& ex) { 136 | LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " 137 | << "[" << ex.what() << "]"; 138 | return UCC_ERR_NO_MESSAGE; 139 | } 140 | return UCC_OK; 141 | } 142 | 143 | ucc_status_t oob_allgather_free(void* req) { 144 | auto* info = reinterpret_cast(req); 145 | TORCH_CHECK(info != nullptr); 146 | try { 147 | int num_done = info->store->add({info->getKey(kAllGatherDone)}, 1); 148 | if (num_done == info->size) { 149 | info->store->deleteKey(info->getKey(kAllGatherDone)); 150 | // Note: to avoid race condition, it's important to remove all keys in 151 | // oob_allgather_free first and only after that signal completion to 152 | // other ranks 153 | for (const auto r : c10::irange(info->size)) { 154 | info->store->deleteKey(info->getKey(kTeamRank + std::to_string(r))); 155 | } 156 | for (const auto r : c10::irange(info->size)) { 157 | info->store->add({info->getKey(kAllGatherFree + std::to_string(r))}, 1); 158 | } 159 | } else { 160 | info->store->wait( 161 | {info->getKey(kAllGatherFree + std::to_string(info->rank))}); 162 | } 163 | info->store->deleteKey( 164 | info->getKey(kAllGatherFree + std::to_string(info->rank))); 165 | } catch (std::exception& ex) { 166 | LOG(ERROR) << "(oob_allgather) Caught exception in Store Operation .. " 167 | << "[" << ex.what() << "]"; 168 | return UCC_ERR_NO_MESSAGE; 169 | } 170 | return UCC_OK; 171 | } 172 | 173 | CommUCC::CommUCC( 174 | std::shared_ptr oob, 175 | const c10::intrusive_ptr& logger) 176 | : CommBase(logger) { 177 | ucc_lib_config_h lib_config; 178 | ucc_context_config_h context_config; 179 | ucc_lib_params_t lib_params; 180 | ucc_context_params_t context_params; 181 | ucc_status_t st; 182 | 183 | TORCH_UCC_CHECK( 184 | ucc_lib_config_read("TORCH", nullptr, &lib_config), 185 | "failed to read UCC lib config"); 186 | memset(&lib_params, 0, sizeof(ucc_lib_params_t)); 187 | lib_params.mask = UCC_LIB_PARAM_FIELD_THREAD_MODE; 188 | lib_params.thread_mode = UCC_THREAD_MULTIPLE; 189 | TORCH_UCC_CHECK( 190 | ucc_init(&lib_params, lib_config, &lib), "failed to init UCC lib"); 191 | ucc_lib_config_release(lib_config); 192 | ucc_lib_attr_t lib_attr; 193 | lib_attr.mask = UCC_LIB_ATTR_FIELD_THREAD_MODE; 194 | TORCH_UCC_CHECK( 195 | ucc_lib_get_attr(lib, &lib_attr), "failed to query for lib attr"); 196 | TORCH_CHECK( 197 | lib_attr.thread_mode == UCC_THREAD_MULTIPLE, 198 | "ucc library wasn't initialized with multithreading support, " 199 | "please check ucc build options"); 200 | st = ucc_context_config_read(lib, NULL, &context_config); 201 | if (st != UCC_OK) { 202 | // FIXME: would this cause deadlock if only one rank fails? 203 | TORCH_UCC_CHECK( 204 | ucc_finalize(lib), 205 | "failed to finalize UCC library when failing to read UCC context config"); 206 | TORCH_UCC_LOG_ERROR( 207 | TORCH_UCC_INIT, 208 | c10::str("failed to read UCC context config: ", ucc_status_string(st))); 209 | throw std::runtime_error(ucc_status_string(st)); 210 | } 211 | st = ucc_context_config_modify( 212 | context_config, 213 | NULL, 214 | "ESTIMATED_NUM_EPS", 215 | std::to_string(oob->size).c_str()); 216 | if (st != UCC_OK) { 217 | ucc_context_config_release(context_config); 218 | ucc_finalize(lib); 219 | TORCH_UCC_LOG_ERROR( 220 | TORCH_UCC_INIT, 221 | c10::str( 222 | "UCC failed to modify UCC context config: ", 223 | ucc_status_string(st))); 224 | throw std::runtime_error(ucc_status_string(st)); 225 | } 226 | memset(&context_params, 0, sizeof(ucc_context_params_t)); 227 | context_params.mask = 228 | UCC_CONTEXT_PARAM_FIELD_TYPE | UCC_CONTEXT_PARAM_FIELD_OOB; 229 | context_params.type = UCC_CONTEXT_SHARED; 230 | context_params.oob.n_oob_eps = oob->size; 231 | context_params.oob.oob_ep = oob->rank; 232 | context_params.oob.allgather = oob_allgather; 233 | context_params.oob.req_test = oob_allgather_test; 234 | context_params.oob.req_free = oob_allgather_free; 235 | context_params.oob.coll_info = oob.get(); 236 | st = ucc_context_create(lib, &context_params, context_config, &context); 237 | ucc_context_config_release(context_config); 238 | if (st != UCC_OK) { 239 | TORCH_UCC_CHECK( 240 | ucc_finalize(lib), 241 | "failed to finalize UCC library when failing to creat UCC context"); 242 | TORCH_UCC_LOG_ERROR( 243 | TORCH_UCC_INIT, 244 | c10::str("UCC failed to create UCC context: ", ucc_status_string(st))); 245 | throw std::runtime_error(ucc_status_string(st)); 246 | } 247 | } 248 | 249 | void CommUCC::progress() { 250 | TORCH_UCC_CHECK( 251 | ucc_context_progress(context), "failed to progress UCC collective"); 252 | } 253 | 254 | void CommUCC::free_request(ucc_coll_req_h request) { 255 | TORCH_UCC_CHECK( 256 | ucc_collective_finalize(request), "failed to release UCC request"); 257 | } 258 | 259 | CommUCC::~CommUCC() { 260 | if (context != nullptr) { 261 | TORCH_UCC_CHECK( 262 | ucc_context_destroy(context), "failed to destory UCC context"); 263 | } 264 | if (lib != nullptr) { 265 | TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library"); 266 | } 267 | context = nullptr; 268 | lib = nullptr; 269 | } 270 | 271 | std::string ProcessGroupUCCLogger::getLogPrefix(torch_ucc_phase_t phase) { 272 | // caller can override the phase stored locally 273 | torch_ucc_phase_t phase_ = 274 | (local_phase != phase && phase != TORCH_UCC_UNKNOWN) ? phase 275 | : local_phase; 276 | return c10::str(log_prefix, "[", ucc_phase_map.at(phase_), "]"); 277 | } 278 | void ProcessGroupUCCLogger::setLogPrefix(std::string log_prefix_) { 279 | log_prefix = log_prefix_; 280 | } 281 | 282 | ProcessGroupUCCLogger::ProcessGroupUCCLogger() { 283 | setLogPrefix("[ProcessGroupUCC]"); 284 | } 285 | ProcessGroupUCCLogger::ProcessGroupUCCLogger( 286 | std::string log_prefix, 287 | torch_ucc_phase_t phase) 288 | : local_phase(phase) { 289 | setLogPrefix(log_prefix); 290 | } 291 | 292 | } // namespace c10d 293 | -------------------------------------------------------------------------------- /src/torch_ucc_init.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | */ 4 | 5 | #include 6 | #include 7 | 8 | #include "torch_ucc.hpp" 9 | 10 | static void __attribute__((constructor)) ProcessGroupUCCConstructor() { 11 | py::object module = py::module::import("torch.distributed"); 12 | py::object register_backend = 13 | module.attr("Backend").attr("register_backend"); 14 | register_backend("ucc", py::cpp_function(c10d::ProcessGroupUCC::createProcessGroupUCC)); 15 | } 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("createProcessGroupUCC", &c10d::ProcessGroupUCC::createProcessGroupUCC); 19 | } 20 | -------------------------------------------------------------------------------- /src/torch_ucc_init_oss.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | */ 4 | 5 | #include 6 | #include "torch_ucc.hpp" 7 | 8 | using namespace c10d; 9 | 10 | extern "C" C10_EXPORT c10::intrusive_ptr createProcessGroupUCC( 11 | const c10::intrusive_ptr& store, 12 | int rank, int size 13 | ) { 14 | return c10::make_intrusive(store, rank, size); 15 | } 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} 18 | -------------------------------------------------------------------------------- /src/torch_ucc_tracing.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | */ 8 | 9 | #include "torch_ucc_tracing.hpp" 10 | #include "torch_ucc_comm.hpp" 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #ifdef FBCODE_CAFFE2 20 | #include "torch_ucc_internal_utils.hpp" 21 | #endif 22 | 23 | namespace c10d { 24 | 25 | void ProcessGroupUCCLogger::initCommsTracer() { 26 | trace_generator = std::make_shared(); 27 | initialized_CommTraceLogger = true; 28 | } 29 | 30 | void ProcessGroupUCCLogger::flushComms(int rank, int world_size) { 31 | if (!initialized_CommTraceLogger || 32 | trace_generator->getCommsTrace().empty()) { 33 | return; 34 | } 35 | 36 | std::string dirname = c10::str("ProcessGroupUCC_trace_np", world_size); 37 | time_t now_ = time(0); 38 | std::tm* ltm = localtime(&now_); 39 | if (ltm) { 40 | dirname += c10::str( 41 | "_", 42 | (1 + ltm->tm_mon), 43 | "_", 44 | ltm->tm_mday, 45 | "_", 46 | (1900 + ltm->tm_year)); 47 | } 48 | 49 | std::string fullpath = "/tmp/" + dirname; 50 | char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR"); 51 | if (user_path) { 52 | fullpath = user_path; 53 | } 54 | std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json"); 55 | std::ofstream _outfile; 56 | if (!_outfile.is_open()) { 57 | if (!mkdir(fullpath.c_str(), 0777)) { 58 | LOG(INFO) << getLogPrefix() << "[INFO] failed to mkdir " << fullpath; 59 | } else if (errno != EEXIST) { 60 | return; 61 | } 62 | _outfile.open(trace_filename, std::ofstream::out | std::ofstream::trunc); 63 | } 64 | // flush the traced comms 65 | if (_outfile.is_open()) { 66 | _outfile << "[" << c10::Join(",", trace_generator->getCommsTrace()) 67 | << "\n]"; 68 | _outfile.flush(); 69 | _outfile.close(); 70 | } 71 | #ifdef FBCODE_CAFFE2 72 | uploadTrace_internal( 73 | trace_filename, dirname, c10::str("rank", rank, ".json")); 74 | #endif 75 | } 76 | 77 | /* unused */ 78 | void CommTraceLogger::setCurBlock(const std::string& name) { 79 | curBlocks_.push_back( 80 | c10::str("\"", name, "\"")); // add quote marks for JSON format 81 | } 82 | 83 | /* unused */ 84 | void CommTraceLogger::popBlock() { 85 | // TODO: remove specific name 86 | curBlocks_.pop_back(); 87 | } 88 | 89 | void CommTraceLogger::recordOptionalInfo(int root) { 90 | curRoot_ = root; 91 | } 92 | 93 | void CommTraceLogger::recordOptionalInfo( 94 | const std::vector& outputSplitSizes, 95 | const std::vector& inputSplitSizes) { 96 | curOutSplitSizes_ = outputSplitSizes; 97 | curInSplitSizes_ = inputSplitSizes; 98 | } 99 | 100 | void CommTraceLogger::recordComms( 101 | const std::string& commName, 102 | const uintptr_t workReq, 103 | const int rank, 104 | const int world_size, 105 | const std::vector& inputTensors, 106 | const std::vector& outputTensors) { 107 | auto inSize = (!inputTensors.empty()) ? inputTensors[0].numel() : 0; 108 | auto outSize = (!outputTensors.empty()) ? outputTensors[0].numel() : 0; 109 | auto dtype = 110 | (!outputTensors.empty()) ? outputTensors[0].scalar_type() : at::kByte; 111 | auto devType = (!outputTensors.empty()) ? outputTensors[0].device().type() 112 | : c10::DeviceType::CPU; 113 | auto now = std::chrono::system_clock::now(); 114 | static auto startTS = now; 115 | int64_t time_since_begin = 116 | std::chrono::duration_cast(now - startTS) 117 | .count(); 118 | 119 | // TODO: get markers from torch profiler if enabled 120 | 121 | // common fields for all operations 122 | std::string cur_trace_ = c10::str( 123 | "\n\t\t\"markers\": [", 124 | curBlocks_, 125 | "]", 126 | ",\n\t\t\"startTime_ns\": ", 127 | time_since_begin, 128 | ",\n\t\t\"comms\": \"", 129 | commName, 130 | "\"", 131 | ",\n\t\t\"req\": ", 132 | workReq, 133 | ",\n\t\t\"seqnum\": ", 134 | seqnum, 135 | ",\n\t\t\"world_size\": ", 136 | world_size); 137 | 138 | if (inSize > 0 || outSize > 0) { 139 | // for most collectives - append msg sizes, data type, device type 140 | cur_trace_ = c10::str( 141 | cur_trace_, 142 | ",\n\t\t\"in_msg_size\": ", 143 | inSize, 144 | ",\n\t\t\"out_msg_size\": ", 145 | outSize, 146 | ",\n\t\t\"dtype\": \"", 147 | at::toString(dtype), 148 | "\",\n\t\t\"devType\": \"", 149 | c10::DeviceTypeName(devType), 150 | "\""); 151 | } 152 | if (curRoot_ != -1) { 153 | // append root rank if applicable, e.g., broadcast, gather, scatter 154 | cur_trace_ = c10::str(cur_trace_, ",\n\t\t\"root\": ", curRoot_); 155 | } 156 | if (!curInSplitSizes_.empty() || !curOutSplitSizes_.empty()) { 157 | // append input and output splits if applicable, e.g., ALLTOALL_BASE 158 | cur_trace_ = c10::str( 159 | cur_trace_, 160 | ",\n\t\t\"in_split\": [", 161 | c10::Join(",", curInSplitSizes_), 162 | "]" 163 | ",\n\t\t\"out_split\": [", 164 | c10::Join(",", curOutSplitSizes_), 165 | "]"); 166 | } 167 | comms_trace_.push_back(c10::str("\n\t{", cur_trace_, "\n\t}")); 168 | 169 | // record the trace to kineto trace if applicable 170 | RECORD_PARAM_COMMS( 171 | static_cast(seqnum), // seq 172 | std::make_tuple("0", ""), // pg_name tuple 173 | rank, 174 | commName.c_str(), 175 | inSize, 176 | outSize, 177 | dtype, 178 | curInSplitSizes_, 179 | curOutSplitSizes_, 180 | 0, 181 | 0, 182 | world_size); 183 | 184 | ++seqnum; 185 | 186 | // reset optional field 187 | curRoot_ = -1; 188 | curInSplitSizes_ = {}; 189 | curOutSplitSizes_ = {}; 190 | } 191 | 192 | } // namespace c10d 193 | -------------------------------------------------------------------------------- /test/blocking_wait_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import argparse 11 | import os 12 | import sys 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | # torch_ucc is required to enable ucc PG 18 | import torch_ucc # noqa: F401 19 | 20 | def init_pg(backend): 21 | global comm_rank, comm_size 22 | try: 23 | comm_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 24 | comm_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 25 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 26 | except: 27 | print("OMPI env variables are not found") 28 | sys.exit(1) 29 | torch.cuda.set_device(local_rank) 30 | 31 | os.environ["MASTER_PORT"] = "32167" 32 | os.environ["MASTER_ADDR"] = "localhost" 33 | os.environ["RANK"] = str(comm_rank) 34 | os.environ["WORLD_SIZE"] = str(comm_size) 35 | dist.init_process_group(backend, rank=comm_rank, world_size=comm_size) 36 | 37 | 38 | def allreduce_test(): 39 | global comm_rank, comm_size 40 | num_iters = 10 41 | dev = torch.device("cuda") 42 | t = torch.ones(100, device=dev) 43 | for i in range(10): 44 | dist.all_reduce(t) 45 | if torch.all(torch.eq(t, comm_size ** num_iters)): 46 | print(f"Rank {comm_rank}: success") 47 | else: 48 | print(f"Rank {comm_rank}: failed") 49 | 50 | 51 | def alltoall_test(): 52 | global comm_rank, comm_size 53 | dev = torch.device("cuda") 54 | t_send = torch.zeros(comm_size, device=dev) + comm_rank 55 | t_recv = torch.zeros(comm_size, device=dev) 56 | dist.all_to_all_single(t_recv, t_send) 57 | t_recv = t_recv + 1 58 | dist.all_reduce(t_recv) 59 | if torch.all( 60 | torch.eq( 61 | t_recv, comm_size * torch.arange(start=1, end=comm_size + 1, device=dev) 62 | ) 63 | ): 64 | print(f"Rank {comm_rank}: success") 65 | else: 66 | print(f"Rank {comm_rank}: failed") 67 | 68 | 69 | if __name__ == "__main__": 70 | if not torch.cuda.is_available(): 71 | print("cuda is not available") 72 | sys.exit(1) 73 | parser = argparse.ArgumentParser(description="PG UCC nonblocking test") 74 | parser.add_argument("--backend", type=str, default="ucc") 75 | parser.add_argument("--test", type=str, default="ucc") 76 | args = parser.parse_args() 77 | 78 | comm_rank = -1 79 | comm_size = -1 80 | init_pg(args.backend) 81 | if args.test == "allreduce": 82 | allreduce_test() 83 | elif args.test == "alltoall": 84 | alltoall_test() 85 | else: 86 | print("Wrong test name") 87 | -------------------------------------------------------------------------------- /test/start_test.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | #!/bin/bash 9 | 10 | size="${TORCH_UCC_TEST_SIZE:-4}" 11 | for i in $(seq 0 $(($size-1))) 12 | do 13 | OMPI_COMM_WORLD_LOCAL_RANK=$i OMPI_COMM_WORLD_RANK=$i OMPI_COMM_WORLD_SIZE=$size python $@ & 14 | processes[${i}]=$! 15 | done 16 | 17 | for p in ${processes[*]}; do 18 | wait $p 19 | done 20 | -------------------------------------------------------------------------------- /test/torch_allgather_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | 11 | import numpy as np 12 | from torch_ucc_test_setup import * 13 | 14 | args = parse_test_args() 15 | pg = init_process_groups(args.backend, args.use_cuda) 16 | 17 | comm_size = dist.get_world_size() 18 | comm_rank = dist.get_rank() 19 | 20 | counts = 2 ** np.arange(24) 21 | print_test_head("Allgather", comm_rank) 22 | for count in counts: 23 | tensor_input = get_tensor(count, args.use_cuda) 24 | tensors_out_ucc = [] 25 | tensors_out_test = [] 26 | for p in range(comm_size): 27 | tensors_out_ucc.append(get_tensor(count, args.use_cuda)) 28 | tensors_out_test.append(get_tensor(count, is_cuda=False)) 29 | dist.all_gather(tensors_out_ucc, tensor_input) 30 | dist.all_gather(tensors_out_test, tensor_input.cpu(), group=pg) 31 | status = check_tensor_list_equal(tensors_out_ucc, tensors_out_test) 32 | dist.all_reduce(status, group=pg) 33 | print_test_result(status, count, comm_rank, comm_size) 34 | if comm_rank == 0: 35 | print("Test allgather: succeeded") 36 | -------------------------------------------------------------------------------- /test/torch_allreduce_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import numpy as np 11 | from torch_ucc_test_setup import * 12 | 13 | args = parse_test_args() 14 | pg = init_process_groups(args.backend, args.use_cuda) 15 | 16 | comm_size = dist.get_world_size() 17 | comm_rank = dist.get_rank() 18 | 19 | counts = 2 ** np.arange(24) 20 | print_test_head("Allreduce", comm_rank) 21 | for count in counts: 22 | tensor_ucc = get_tensor(count, args.use_cuda) 23 | tensor_test = tensor_ucc.cpu() 24 | dist.all_reduce(tensor_ucc) 25 | dist.all_reduce(tensor_test, group=pg) 26 | status = check_tensor_equal(tensor_ucc, tensor_test) 27 | dist.all_reduce(status, group=pg) 28 | print_test_result(status, count, comm_rank, comm_size) 29 | 30 | if comm_rank == 0: 31 | print("Test allreduce: succeeded") 32 | -------------------------------------------------------------------------------- /test/torch_alltoall_bench.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import argparse 11 | import os 12 | import sys 13 | from time import perf_counter 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | # torch_ucc is required to enable ucc PG 19 | import torch_ucc # noqa: F401 20 | 21 | 22 | def get_tensor(size, device, val): 23 | count = size // 4 24 | t = torch.ones([count], dtype=torch.int32, device=device) 25 | t = t + val 26 | return t 27 | 28 | 29 | parser = argparse.ArgumentParser(description="Process Group Alltoall Benchmark") 30 | parser.add_argument("--backend", type=str, default="mpi") 31 | parser.add_argument("--use-cuda", default=False, action="store_true") 32 | parser.add_argument("--min-size", type=int, default=2 ** 5) 33 | parser.add_argument("--max-size", type=int, default=2 ** 15) 34 | parser.add_argument("--skip", type=int, default=500) 35 | parser.add_argument("--iter", type=int, default=100) 36 | args = parser.parse_args() 37 | 38 | try: 39 | comm_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 40 | comm_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 41 | except: 42 | try: 43 | comm_size = int(os.environ["WORLD_SIZE"]) 44 | comm_rank = int(os.environ["RANK"]) 45 | except: 46 | print("OMPI env variables are not found") 47 | sys.exit(1) 48 | 49 | if not os.environ.get("MASTER_PORT", None): 50 | os.environ["MASTER_PORT"] = "32167" 51 | if not os.environ.get("MASTER_ADDR", None): 52 | os.environ["MASTER_ADDR"] = "localhost" 53 | if not os.environ.get("RANK", None): 54 | os.environ["RANK"] = str(comm_rank) 55 | if not os.environ.get("WORLD_SIZE", None): 56 | os.environ["WORLD_SIZE"] = str(comm_size) 57 | 58 | if args.use_cuda and not torch.cuda.is_available(): 59 | print("CUDA is not available") 60 | sys.exit(0) 61 | 62 | if args.backend == "nccl" and not args.use_cuda: 63 | print("NCCL backend doesn't support host buffers") 64 | sys.exit(0) 65 | 66 | if args.use_cuda: 67 | torch.cuda.set_device(comm_rank) 68 | args.device = torch.device("cuda") 69 | else: 70 | args.device = torch.device("cpu") 71 | 72 | if comm_rank == 0: 73 | print("World size {}".format(comm_size)) 74 | print("%-10s %-10s %-10s %-10s" % ("size", "min, us", "avg, us", "max, us")) 75 | 76 | if args.backend != "mpi": 77 | dist.init_process_group(args.backend, rank=comm_rank, world_size=comm_size) 78 | else: 79 | dist.init_process_group(args.backend) 80 | 81 | size = args.min_size 82 | while size <= args.max_size: 83 | bufsize = size * comm_size 84 | send_tensor = get_tensor(bufsize, args.device, comm_rank) 85 | recv_tensor = get_tensor(bufsize, args.device, 0) 86 | time = 0 87 | for i in range(args.iter + args.skip): 88 | start = perf_counter() 89 | req = dist.all_to_all_single(recv_tensor, send_tensor, async_op=True) 90 | # req = dist.all_reduce(send_tensor, op=dist.ReduceOp.SUM, async_op=True) 91 | req.wait() 92 | if args.backend == "nccl": 93 | torch.cuda.synchronize(args.device) 94 | finish = perf_counter() 95 | dist.barrier() 96 | if i > args.skip: 97 | time += finish - start 98 | time = [time / args.iter] 99 | if args.use_cuda: 100 | max_time = torch.tensor([time], device=args.device) 101 | min_time = torch.tensor([time], device=args.device) 102 | avg_time = torch.tensor([time], device=args.device) 103 | else: 104 | max_time = torch.tensor([time]) 105 | min_time = torch.tensor([time]) 106 | avg_time = torch.tensor([time]) 107 | 108 | dist.all_reduce(max_time, op=dist.ReduceOp.MAX) 109 | dist.all_reduce(min_time, op=dist.ReduceOp.MIN) 110 | dist.all_reduce(avg_time, op=dist.ReduceOp.SUM) 111 | if comm_rank == 0: 112 | print( 113 | "%-10i %-10.3f %-10.3f %-10.3f" 114 | % ( 115 | size, 116 | min_time[0] * (10 ** 6), 117 | avg_time[0] * (10 ** 6) / comm_size, 118 | max_time[0] * (10 ** 6), 119 | ) 120 | ) 121 | size = size * 2 122 | -------------------------------------------------------------------------------- /test/torch_alltoall_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | from torch_ucc_test_setup import * 11 | import numpy as np 12 | 13 | args = parse_test_args() 14 | pg = init_process_groups(args.backend, args.use_cuda) 15 | 16 | comm_size = dist.get_world_size() 17 | comm_rank = dist.get_rank() 18 | 19 | counts = 2 ** np.arange(4, 22) 20 | 21 | print_test_head("Alltoall", comm_rank) 22 | for count in counts: 23 | recv_tensor_test = get_tensor(count * comm_size, is_cuda=False) 24 | send_tensor_list = [] 25 | recv_tensor_ucc = [] 26 | for p in range(comm_size): 27 | recv_tensor_ucc.append(get_tensor(count, args.use_cuda)) 28 | send_tensor_list.append(get_tensor(count, args.use_cuda)) 29 | 30 | dist.all_to_all( 31 | recv_tensor_ucc, 32 | send_tensor_list, 33 | ) 34 | # flatten the send_tensor_list and use all_to_all_single as not all PGs support all_to_all primitive 35 | dist.all_to_all_single( 36 | recv_tensor_test, torch.stack(send_tensor_list, dim=0).view(-1).cpu(), group=pg 37 | ) 38 | 39 | status = check_tensor_equal( 40 | torch.stack(recv_tensor_ucc, dim=0).view(-1), recv_tensor_test 41 | ) 42 | dist.all_reduce(status, group=pg) 43 | print_test_result(status, "{} x {}".format(count, comm_size), comm_rank, comm_size) 44 | 45 | if comm_rank == 0: 46 | print("Test alltoall: succeeded") 47 | -------------------------------------------------------------------------------- /test/torch_alltoallv_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | from torch_ucc_test_setup import * 11 | import numpy as np 12 | 13 | args = parse_test_args() 14 | pg = init_process_groups(args.backend, args.use_cuda) 15 | 16 | comm_size = dist.get_world_size() 17 | comm_rank = dist.get_rank() 18 | 19 | counts = 2 ** np.arange(4, 24) 20 | 21 | print_test_head("Alltoallv", comm_rank) 22 | for count in counts: 23 | np.random.seed(3131) 24 | 25 | split = np.random.randint( 26 | low=1, high=2 * count // comm_size, size=(comm_size, comm_size) 27 | ) 28 | input_size = np.sum(split, axis=1) 29 | output_size = np.sum(split, axis=0) 30 | 31 | send_tensor = get_tensor(input_size[comm_rank], args.use_cuda) 32 | recv_tensor = get_tensor(output_size[comm_rank], args.use_cuda) 33 | recv_tensor_test = get_tensor(output_size[comm_rank], is_cuda=False) 34 | dist.all_to_all_single( 35 | recv_tensor, send_tensor, split[:, comm_rank], split[comm_rank, :] 36 | ) 37 | dist.all_to_all_single( 38 | recv_tensor_test, 39 | send_tensor.cpu(), 40 | split[:, comm_rank], 41 | split[comm_rank, :], 42 | group=pg, 43 | ) 44 | status = check_tensor_equal(recv_tensor, recv_tensor_test) 45 | dist.all_reduce(status, group=pg) 46 | print_test_result( 47 | status, "{}({})".format(count, input_size[comm_rank]), comm_rank, comm_size 48 | ) 49 | if comm_rank == 0: 50 | print("Test alltoallv: succeeded") 51 | -------------------------------------------------------------------------------- /test/torch_barrier_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import random 11 | import sys 12 | import time 13 | from torch_ucc_test_setup import * 14 | 15 | args = parse_test_args() 16 | pg = init_process_groups(args.backend, args.use_cuda) 17 | 18 | comm_size = dist.get_world_size() 19 | comm_rank = dist.get_rank() 20 | 21 | for i in range(comm_size): 22 | rand_sleep = random.randint(1, 1000) 23 | time.sleep(rand_sleep / 1000) 24 | if i == comm_rank: 25 | print("rank {} checks in".format(comm_rank)) 26 | sys.stdout.flush() 27 | dist.barrier() 28 | dist.barrier() 29 | if comm_rank == 0: 30 | print("Test barrier: succeeded") 31 | -------------------------------------------------------------------------------- /test/torch_bcast_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import numpy as np 11 | from torch_ucc_test_setup import * 12 | 13 | args = parse_test_args() 14 | pg = init_process_groups(args.backend, args.use_cuda) 15 | 16 | comm_size = dist.get_world_size() 17 | comm_rank = dist.get_rank() 18 | 19 | counts = 2 ** np.arange(24) 20 | print_test_head("Broadcast", comm_rank) 21 | for count in counts: 22 | tensor_ucc = get_tensor(count, args.use_cuda) 23 | tensor_ucc = do_compute(tensor_ucc) 24 | tensor_test = tensor_ucc.cpu() 25 | dist.broadcast(tensor_ucc, 0) 26 | dist.broadcast(tensor_test, 0, group=pg) 27 | status = check_tensor_equal(tensor_ucc, tensor_test) 28 | dist.all_reduce(status, group=pg) 29 | print_test_result(status, count, comm_rank, comm_size) 30 | 31 | if comm_rank == 0: 32 | print("Test Broadcast: succeeded") 33 | -------------------------------------------------------------------------------- /test/torch_gather_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | 11 | import numpy as np 12 | from torch_ucc_test_setup import * 13 | 14 | args = parse_test_args() 15 | pg = init_process_groups(args.backend, args.use_cuda) 16 | 17 | comm_size = dist.get_world_size() 18 | comm_rank = dist.get_rank() 19 | 20 | counts = 2 ** np.arange(24) 21 | print_test_head("Gather", comm_rank) 22 | for count in counts: 23 | tensor_input = get_tensor(count, args.use_cuda) 24 | tensors_out_ucc = None 25 | tensors_out_test = None 26 | if comm_rank == 0: 27 | tensors_out_ucc = [] 28 | tensors_out_test = [] 29 | for p in range(comm_size): 30 | tensors_out_ucc.append(get_tensor(count, args.use_cuda)) 31 | tensors_out_test.append(get_tensor(count, is_cuda=False)) 32 | 33 | dist.gather(gather_list=tensors_out_ucc, tensor=tensor_input, dst=0) 34 | dist.gather(gather_list=tensors_out_test, tensor=tensor_input.cpu(), dst=0, group=pg) 35 | 36 | if comm_rank == 0: 37 | status = check_tensor_list_equal(tensors_out_ucc, tensors_out_test) 38 | else: 39 | status = torch.tensor(1, device=tensor_ucc.device) 40 | 41 | dist.all_reduce(status, group=pg) 42 | print_test_result(status, count, comm_rank, comm_size) 43 | if comm_rank == 0: 44 | print("Test gather: succeeded") 45 | -------------------------------------------------------------------------------- /test/torch_init_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import random 10 | import sys 11 | import time 12 | 13 | import torch.distributed as dist 14 | 15 | # torch_ucc is required to enable ucc PG 16 | import torch_ucc # noqa: F401 17 | 18 | comm_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 19 | comm_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 20 | 21 | os.environ["MASTER_PORT"] = "32167" 22 | os.environ["MASTER_ADDR"] = "localhost" 23 | os.environ["RANK"] = str(comm_rank) 24 | os.environ["WORLD_SIZE"] = str(comm_size) 25 | dist.init_process_group("ucc", rank=comm_rank, world_size=comm_size) 26 | # dist.new_group(ranks=[0, 1], backend='ucc') 27 | for i in range(comm_size): 28 | rand_sleep = random.randint(1, 1000) 29 | time.sleep(rand_sleep / 1000) 30 | if i == comm_rank: 31 | print("rank {} checks in".format(comm_rank)) 32 | sys.stdout.flush() 33 | dist.barrier() 34 | dist.barrier() 35 | if comm_rank == 0: 36 | print("Test barrier: succeeded") 37 | -------------------------------------------------------------------------------- /test/torch_multiple_comms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | from torch_ucc_test_setup import * 4 | 5 | # create 2 UCC PGs 6 | ucc_pg = init_process_groups("ucc", False) 7 | dist.barrier() 8 | ucc_pg.barrier() 9 | -------------------------------------------------------------------------------- /test/torch_pg_ucc_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import argparse 11 | import os 12 | import sys 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | parser = argparse.ArgumentParser(description="Process Group UCC test") 18 | parser.add_argument("--backend", type=str, default="mpi") 19 | parser.add_argument("--op", type=str, default="p2p") 20 | parser.add_argument("--use-cuda", type=bool, default=False) 21 | args = parser.parse_args() 22 | 23 | try: 24 | size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 25 | rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 26 | except: 27 | print("OMPI env variables are not found") 28 | sys.exit(1) 29 | 30 | os.environ["MASTER_PORT"] = "32167" 31 | os.environ["MASTER_ADDR"] = "localhost" 32 | os.environ["RANK"] = str(rank) 33 | os.environ["WORLD_SIZE"] = str(size) 34 | 35 | 36 | torch.cuda.set_device(rank) 37 | print("World size {}, rank {}".format(size, rank)) 38 | dist.init_process_group(args.backend, rank=rank, world_size=size) 39 | 40 | t = torch.zeros([size]) + rank + 1 41 | t1 = torch.zeros([size]) 42 | t2 = torch.zeros([size]) 43 | use_cuda = args.use_cuda and torch.cuda.is_available() 44 | 45 | if (args.backend == "nccl") or use_cuda: 46 | print("Using cuda tensor") 47 | t = t.cuda() 48 | t1 = t1.cuda() 49 | t2 = t2.cuda() 50 | 51 | if args.op == "p2p": 52 | if rank == 0: 53 | dist.send(t, 1) 54 | else: 55 | dist.recv(t, 0) 56 | elif args.op == "broadcast": 57 | dist.broadcast(t, 0) 58 | elif args.op == "allreduce": 59 | dist.all_reduce(t, op=dist.ReduceOp.SUM) 60 | elif args.op == "reduce": 61 | dist.reduce(t, 0, op=dist.ReduceOp.SUM) 62 | elif args.op == "alltoall": 63 | dist.all_to_all_single(t2, t) 64 | elif args.op == "alltoallv": 65 | out_split = [1] * size 66 | in_split = [1] * size 67 | dist.all_to_all_single(t2, t, out_split, in_split) 68 | elif args.op == "allgather": 69 | dist.all_gather([t1, t2], t) 70 | 71 | else: 72 | print("Incorrect operation") 73 | sys.exit(1) 74 | 75 | # dist.barrier() 76 | print("rank ", rank, ":", t, ":", t1, ":", t2) 77 | dist.destroy_process_group() 78 | -------------------------------------------------------------------------------- /test/torch_pt2pt_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from torch_ucc_test_setup import * 10 | 11 | args = parse_test_args() 12 | pg = init_process_groups(args.backend, args.use_cuda) 13 | 14 | comm_size = dist.get_world_size() 15 | comm_rank = dist.get_rank() 16 | 17 | counts = 2 ** np.arange(24) 18 | print_test_head("Point-to-point", comm_rank) 19 | 20 | for count in counts: 21 | tensor_ucc = get_tensor(count, args.use_cuda) 22 | tensor_ucc = do_compute(tensor_ucc) 23 | tensor_test = tensor_ucc.cpu() 24 | # pt2pt-based bcast if more than 2 processes 25 | if comm_rank == 0: 26 | for dst in range(comm_size - 1): 27 | dist.send(tensor_ucc, dst=dst + 1, tag=0) 28 | dist.send(tensor_test, dst=dst + 1, tag=0, group=pg) 29 | status = torch.tensor(1, device=tensor_test.device) 30 | else: 31 | dist.recv(tensor_ucc, src=0, tag=0) 32 | dist.recv(tensor_test, src=0, tag=0, group=pg) 33 | status = check_tensor_equal(tensor_ucc, tensor_test) 34 | dist.all_reduce(status, group=pg) 35 | print_test_result(status, count, comm_rank, comm_size) 36 | 37 | if comm_rank == 0: 38 | print("Test Point-to-point: succeeded") 39 | -------------------------------------------------------------------------------- /test/torch_reduce_scatter_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import numpy as np 11 | from torch_ucc_test_setup import * 12 | 13 | args = parse_test_args() 14 | pg = init_process_groups(args.backend, args.use_cuda) 15 | 16 | comm_size = dist.get_world_size() 17 | comm_rank = dist.get_rank() 18 | 19 | counts = 2 ** np.arange(24) 20 | print_test_head("Reduce_scatter", comm_rank) 21 | for count in counts: 22 | tensors_input = [] 23 | for p in range(comm_size): 24 | tensors_input.append(get_tensor(count, args.use_cuda)) 25 | tensor_ucc = get_tensor(count, args.use_cuda) 26 | tensor_test = tensor_ucc.cpu() 27 | tensors_input[0] = do_compute(tensors_input[0]) 28 | dist.reduce_scatter(tensor_ucc, tensors_input) 29 | dist.reduce_scatter(tensor_test, [t.cpu() for t in tensors_input], group=pg) 30 | status = check_tensor_equal(tensor_ucc, tensor_test) 31 | dist.all_reduce(status, group=pg) 32 | print_test_result(status, count, comm_rank, comm_size) 33 | 34 | if comm_rank == 0: 35 | print("Test reduce_scatter: succeeded") 36 | -------------------------------------------------------------------------------- /test/torch_reduce_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import numpy as np 11 | from torch_ucc_test_setup import * 12 | 13 | args = parse_test_args() 14 | pg = init_process_groups(args.backend, args.use_cuda) 15 | 16 | comm_size = dist.get_world_size() 17 | comm_rank = dist.get_rank() 18 | 19 | counts = 2 ** np.arange(24) 20 | print_test_head("Reduce", comm_rank) 21 | for count in counts: 22 | tensor_ucc = get_tensor(count, args.use_cuda) 23 | tensor_ucc = do_compute(tensor_ucc) 24 | tensor_test = tensor_ucc.cpu() 25 | dist.reduce(tensor_ucc, dst=0) 26 | dist.reduce(tensor_test, dst=0, group=pg) 27 | # only root (i.e., rank 0 here) need to check results 28 | if comm_rank == 0: 29 | status = check_tensor_equal(tensor_ucc, tensor_test) 30 | else: 31 | status = torch.tensor(1, device=tensor_ucc.device) 32 | dist.all_reduce(status, group=pg) 33 | print_test_result(status, count, comm_rank, comm_size) 34 | 35 | if comm_rank == 0: 36 | print("Test Reduce: succeeded") 37 | -------------------------------------------------------------------------------- /test/torch_sendrecv_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import sys 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | # torch_ucc is required to enable ucc PG 15 | import torch_ucc # noqa: F401 16 | 17 | comm_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 18 | comm_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 19 | 20 | if comm_size != 2: 21 | print("sendrecv rest requires exactly 2 ranks") 22 | sys.exit(0) 23 | 24 | os.environ["MASTER_PORT"] = "32167" 25 | os.environ["MASTER_ADDR"] = "localhost" 26 | os.environ["RANK"] = str(comm_rank) 27 | os.environ["WORLD_SIZE"] = str(comm_size) 28 | dist.init_process_group("ucc", rank=comm_rank, world_size=comm_size) 29 | 30 | if comm_rank == 0: 31 | t = torch.full([16], comm_rank + 1) 32 | print("send: ", t) 33 | dist.send(t, 1, tag=128) 34 | if comm_rank == 1: 35 | t = torch.full([16], 0) 36 | print("recv before: ", t) 37 | dist.recv(t, 0, tag=128) 38 | print("recv after: ", t) 39 | -------------------------------------------------------------------------------- /test/torch_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | import os 4 | import sys 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | torch.backends.cuda.matmul.allow_tf32 = False 10 | 11 | if not dist.is_available(): 12 | print("Distributed not available, skipping tests", file=sys.stderr) 13 | sys.exit(0) 14 | 15 | from torch.testing._internal.common_distributed import DistTestCases 16 | from torch.testing._internal.common_utils import ( 17 | run_tests, 18 | sandcastle_skip_if, 19 | ) 20 | 21 | # Sets showing that a collective isn't implemented 22 | DistTestCases.skip_collective["allgather_coalesced"] = {"ucc"} 23 | DistTestCases.skip_collective["gather"] = {"ucc"} 24 | DistTestCases.skip_collective["scatter"] = {"ucc"} 25 | DistTestCases.skip_collective["reduce"] = {"ucc"} 26 | DistTestCases.skip_collective["sendrecv anysource"] = {"ucc"} 27 | DistTestCases.skip_collective["cpu barrier"] = {"ucc"} 28 | 29 | # Sets showing that something is implemented 30 | DistTestCases.backend_feature["gpu"] = {"ucc"} 31 | DistTestCases.backend_feature["cuda"] = {"ucc"} 32 | DistTestCases.backend_feature["ddp"] = {"ucc"} 33 | DistTestCases.backend_feature["subgroup"] = {"ucc"} 34 | DistTestCases.backend_feature["plugin"] = {"ucc"} 35 | 36 | os.environ["MASTER_ADDR"] = "localhost" 37 | 38 | if "MASTER_PORT" not in os.environ: 39 | try: 40 | from caffe2.torch.fb.common.utils import get_free_port 41 | 42 | os.environ["MASTER_PORT"] = str(get_free_port()) 43 | except ImportError: 44 | os.environ["MASTER_PORT"] = "12375" 45 | 46 | os.environ["INIT_METHOD"] = "tcp://localhost:" + os.environ["MASTER_PORT"] 47 | 48 | if "UCX_TLS" not in os.environ: 49 | os.environ["UCX_TLS"] = "sm,tcp" 50 | 51 | import torch_ucc # noqa: F401 52 | 53 | BACKEND = os.environ["BACKEND"] 54 | 55 | # We have to import this after we change the values in DistTestCases 56 | from torch.testing._internal.distributed.distributed_test import ( 57 | TestDistBackend, 58 | DistributedTest, 59 | ) 60 | 61 | 62 | class TestDistBackendWithSpawn(TestDistBackend, DistributedTest._DistTestBase): 63 | port_num = str(os.environ["MASTER_PORT"]) 64 | 65 | def setUp(self): 66 | super().setUp() 67 | self._spawn_processes() 68 | torch.backends.cudnn.flags(allow_tf32=False).__enter__() 69 | 70 | @sandcastle_skip_if( 71 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 72 | ) 73 | def test_ddp_logging_data_cpu(self): 74 | raise Exception("This test fails with UCC, not running it") 75 | 76 | @sandcastle_skip_if( 77 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 78 | ) 79 | def test_DistributedDataParallelCPU(self): 80 | raise Exception("This test fails with UCC, not running it") 81 | 82 | @sandcastle_skip_if( 83 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 84 | ) 85 | def test_DistributedDataParallelCPU_grad_is_view(self): 86 | raise Exception("This test fails with UCC, not running it") 87 | 88 | @sandcastle_skip_if( 89 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 90 | ) 91 | def test_ddp_create_graph(self): 92 | raise Exception("This test fails with UCC, not running it") 93 | 94 | @sandcastle_skip_if( 95 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 96 | ) 97 | def test_destroy_group(self): 98 | raise Exception("This test fails with UCC, not running it") 99 | 100 | @sandcastle_skip_if( 101 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 102 | ) 103 | def test_gather(self): 104 | raise Exception("This test fails with UCC, not running it") 105 | 106 | @sandcastle_skip_if( 107 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 108 | ) 109 | def test_gather_checks(self): 110 | raise Exception("This test fails with UCC, not running it") 111 | 112 | @sandcastle_skip_if( 113 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 114 | ) 115 | def test_gather_group(self): 116 | raise Exception("This test fails with UCC, not running it") 117 | 118 | @sandcastle_skip_if( 119 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 120 | ) 121 | def test_gather_full_group(self): 122 | raise Exception("This test fails with UCC, not running it") 123 | 124 | @sandcastle_skip_if( 125 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 126 | ) 127 | def test_gather_object(self): 128 | raise Exception("This test fails with UCC, not running it") 129 | 130 | @sandcastle_skip_if( 131 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 132 | ) 133 | def test_gather_object_subgroup(self): 134 | raise Exception("This test fails with UCC, not running it") 135 | 136 | @sandcastle_skip_if( 137 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 138 | ) 139 | def test_get_backend(self): 140 | raise Exception("This test fails with UCC, not running it") 141 | 142 | @sandcastle_skip_if( 143 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 144 | ) 145 | def test_get_rank_size_group(self): 146 | raise Exception("This test fails with UCC, not running it") 147 | 148 | @sandcastle_skip_if( 149 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 150 | ) 151 | def test_scatter(self): 152 | raise Exception("This test fails with UCC, not running it") 153 | 154 | @sandcastle_skip_if( 155 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 156 | ) 157 | def test_scatter_group(self): 158 | raise Exception("This test fails with UCC, not running it") 159 | 160 | @sandcastle_skip_if( 161 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 162 | ) 163 | def test_scatter_checks(self): 164 | raise Exception("This test fails with UCC, not running it") 165 | 166 | @sandcastle_skip_if( 167 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 168 | ) 169 | def test_scatter_complex(self): 170 | raise Exception("This test fails with UCC, not running it") 171 | 172 | @sandcastle_skip_if( 173 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 174 | ) 175 | def test_scatter_full_group(self): 176 | raise Exception("This test fails with UCC, not running it") 177 | 178 | @sandcastle_skip_if( 179 | BACKEND == "ucc", "This test fails on UCC, so we are not running it today" 180 | ) 181 | def test_static_graph_api_cpu(self): 182 | raise Exception("This test fails with UCC, not running it") 183 | 184 | 185 | if __name__ == "__main__": 186 | run_tests() 187 | -------------------------------------------------------------------------------- /test/torch_timeout_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import sys 11 | import time 12 | from torch_ucc_test_setup import * 13 | from datetime import timedelta 14 | 15 | args = parse_test_args() 16 | pg = init_process_groups(args.backend, args.use_cuda, timedelta(seconds=10)) 17 | 18 | comm_size = dist.get_world_size() 19 | comm_rank = dist.get_rank() 20 | 21 | dist.barrier() 22 | if comm_rank == 0: 23 | time.sleep(20) 24 | 25 | estr = "" 26 | try: 27 | req = dist.barrier() 28 | except Exception as e: 29 | estr = str(e) 30 | 31 | if comm_rank != 0: 32 | if "Timeout expired" in estr: 33 | print("Test OK") 34 | else: 35 | print("Test Failed") 36 | sys.exit(1) 37 | -------------------------------------------------------------------------------- /test/torch_ucc_test_setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import argparse 11 | import os 12 | import sys 13 | from datetime import timedelta 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | # torch_ucc is required to enable ucc PG 19 | import torch_ucc # noqa: F401 20 | 21 | 22 | def parse_test_args(): 23 | parser = argparse.ArgumentParser(description="PG UCC Test") 24 | parser.add_argument("--backend", type=str, default="mpi") 25 | parser.add_argument("--use-cuda", default=False, action="store_true") 26 | parser.add_argument("--enable-prof", default=False, action="store_true") 27 | args = parser.parse_args() 28 | 29 | if args.use_cuda and not torch.cuda.is_available(): 30 | print("CUDA is not available") 31 | sys.exit(0) 32 | 33 | # Tensor mem type support seems to rely on static definition at https://pytorch.org/docs/stable/distributed.html 34 | valid_bends = ["mpi", "ucc", "gloo"] 35 | if args.backend not in valid_bends: 36 | print( 37 | "The specified backend {} does not support CPU tensors for result validation. Please choose from {}".format( 38 | args.backend, ", ".join(valid_bends) 39 | ) 40 | ) 41 | sys.exit(0) 42 | 43 | return args 44 | 45 | 46 | def get_tensor(count, is_cuda): 47 | dev = torch.device("cuda") if is_cuda else torch.device("cpu") 48 | t = torch.randint(0, 100, (count,), dtype=torch.int, device=dev) 49 | return t 50 | 51 | 52 | def init_process_groups(bend, use_cuda, to=timedelta(seconds=60)): 53 | try: 54 | comm_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 55 | comm_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 56 | local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 57 | except: 58 | print("OMPI env variables are not found") 59 | sys.exit(1) 60 | 61 | if use_cuda: 62 | torch.cuda.set_device(local_rank) 63 | 64 | os.environ["MASTER_PORT"] = "32167" 65 | os.environ["MASTER_ADDR"] = "localhost" 66 | os.environ["RANK"] = str(comm_rank) 67 | os.environ["WORLD_SIZE"] = str(comm_size) 68 | dist.init_process_group("ucc", rank=comm_rank, world_size=comm_size, timeout=to) 69 | pg = dist.new_group(backend=bend) 70 | 71 | return pg 72 | 73 | 74 | # Compare UCC result tensor with the checking PG's result tensor. 75 | # Return check status allocated on PG's device because the result is exchanged by PG 76 | def check_tensor_equal(t_ucc, t_pg): 77 | # Copy to CPU before comparing with PG's resut which is always on CPU 78 | if t_ucc.is_cuda: 79 | t_ucc = t_ucc.cpu() 80 | if torch.all(torch.eq(t_ucc, t_pg)): 81 | return torch.tensor(1, device=t_pg.device) 82 | else: 83 | print("failed on rank {}".format(os.environ["RANK"])) 84 | return torch.tensor(0, device=t_pg.device) 85 | 86 | 87 | # Compare UCC result tensor list with the checking PG's result tensor list. 88 | # Return check status allocated on PG's device because the result is exchanged by PG 89 | def check_tensor_list_equal(t_ucc, t_pg): 90 | num_tensors = len(t_ucc) 91 | for i in range(num_tensors): 92 | # Copy to CPU before comparing with PG's resut which is always on CPU 93 | if t_ucc[i].is_cuda: 94 | t_ucc[i] = t_ucc[i].cpu() 95 | if not torch.all(torch.eq(t_ucc[i], t_pg[i])): 96 | return torch.tensor(0, device=t_pg[i].device) 97 | return torch.tensor(1, device=t_pg[i].device) 98 | 99 | 100 | def print_test_head(test_name, comm_rank): 101 | if comm_rank == 0: 102 | print("{} test".format(test_name)) 103 | print("{0:20} {1}".format("count", "result")) 104 | 105 | 106 | def print_test_result(status, count, comm_rank, comm_size): 107 | if comm_rank == 0: 108 | result = "OK" if status == comm_size else "Failed" 109 | print("{0:20} {1}".format(str(count), result)) 110 | if status != comm_size: 111 | sys.exit(1) 112 | 113 | 114 | def do_compute(t): 115 | return torch.topk(t, t.size()[0])[0] 116 | -------------------------------------------------------------------------------- /test/torch_work_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) Mellanox Technologies Ltd. 2001-2021. 3 | # 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | from torch_ucc_test_setup import * 11 | 12 | 13 | def test_future(obj): 14 | print("Test WorkUCC: succeeded") 15 | 16 | 17 | args = parse_test_args() 18 | pg = init_process_groups(args.backend, args.use_cuda) 19 | 20 | comm_size = dist.get_world_size() 21 | comm_rank = dist.get_rank() 22 | 23 | print_test_head("WorkUCC", comm_rank) 24 | count = 32 25 | tensor_ucc = get_tensor(count, args.use_cuda) 26 | 27 | work = dist.all_reduce(tensor_ucc, async_op=True) 28 | 29 | # test future functionality 30 | fut = work.get_future().then(test_future) 31 | 32 | # test result functionality 33 | work.wait() 34 | opTensor = work.result() 35 | 36 | # test future functionality 37 | fut.wait() 38 | --------------------------------------------------------------------------------