├── BUILD ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── WORKSPACE.bazel ├── bazel ├── cuda │ └── BUILD.bazel └── nccl │ └── BUILD.bazel ├── compat.h ├── net_fastsocket.cc ├── utilities.cc └── utilities.h /BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_pkg//pkg:tar.bzl", "pkg_tar") 2 | load("@rules_pkg//pkg:deb.bzl", "pkg_deb") 3 | load("@rules_license//rules:license.bzl", "license") 4 | 5 | package(default_applicable_licenses = [":license"]) 6 | 7 | license( 8 | name = "license", 9 | package_name = "fastsocket_plugin", 10 | ) 11 | 12 | # Dual-licensed, using the least restrictive per go/thirdpartylicenses#same 13 | licenses(["notice"]) 14 | 15 | exports_files(["LICENSE"]) 16 | 17 | cc_library( 18 | name = "nccl_utilities", 19 | srcs = [ 20 | "utilities.cc", 21 | ], 22 | hdrs = [ 23 | "utilities.h", 24 | ], 25 | visibility = ["//visibility:public"], 26 | deps = [ 27 | "@nccl//:plugin_lib", 28 | ], 29 | ) 30 | 31 | # Faster socket plugin for NCCL applications. 32 | cc_library( 33 | name = "plugin", 34 | srcs = [ 35 | "net_fastsocket.cc", 36 | ], 37 | hdrs = [ 38 | "compat.h", 39 | ], 40 | # Export the symbol containing the NCCL plugin vtable so it can be 41 | # loaded at runtime via dlopen + dlsym. This means we also need to 42 | # always link this library, otherwise it'll be dropped at build time. 43 | linkopts = [ 44 | "-Wl,--export-dynamic-symbol=ncclNetPlugin_v7", 45 | "-Wl,--export-dynamic-symbol=ncclNetPlugin_v6", 46 | "-Wl,--export-dynamic-symbol=ncclNetPlugin_v5", 47 | ], 48 | visibility = ["//visibility:public"], 49 | deps = [ 50 | ":nccl_utilities", 51 | "@nccl//:plugin_lib", 52 | ], 53 | alwayslink = 1, 54 | ) 55 | 56 | cc_library( 57 | name = "collnet_plugin", 58 | srcs = [ 59 | "net_fastsocket.cc", 60 | ], 61 | hdrs = [ 62 | "compat.h", 63 | ], 64 | # Export the symbol containing the NCCL plugin vtable so it can be 65 | # loaded at runtime via dlopen + dlsym. This means we also need to 66 | # always link this library, otherwise it'll be dropped at build time. 67 | linkopts = [ 68 | "-Wl,--export-dynamic-symbol=ncclNetPlugin_v7", 69 | "-Wl,--export-dynamic-symbol=ncclNetPlugin_v6", 70 | "-Wl,--export-dynamic-symbol=ncclNetPlugin_v5", 71 | ], 72 | local_defines = ["CHECK_COLLNET_ENABLE"], 73 | visibility = ["//visibility:public"], 74 | deps = [ 75 | ":nccl_utilities", 76 | "@nccl//:plugin_lib", 77 | ], 78 | alwayslink = 1, 79 | ) 80 | 81 | cc_binary( 82 | name = "libnccl-net.so", 83 | linkshared = True, 84 | linkstatic = True, 85 | deps = [ 86 | ":plugin", 87 | ], 88 | ) 89 | 90 | genrule( 91 | name = "gen_triggers", 92 | outs = ["triggers"], 93 | cmd = "echo 'activate-noawait ldconfig' > $@", 94 | ) 95 | 96 | pkg_deb( 97 | name = "package-deb", 98 | architecture = "amd64", 99 | data = ":tarball", 100 | description = "Fast Socket for NCCL 2", 101 | maintainer = "Chang Lan ", 102 | package = "google-fast-socket", 103 | recommends = ["libnccl2"], 104 | triggers = ":gen_triggers", 105 | version = "0.0.5", 106 | ) 107 | 108 | pkg_tar( 109 | name = "tarball", 110 | extension = "tar.gz", 111 | deps = [ 112 | ":doc", 113 | ":lib", 114 | ], 115 | ) 116 | 117 | pkg_tar( 118 | name = "lib", 119 | srcs = [":libnccl-net.so"], 120 | mode = "0644", 121 | package_dir = "/usr/lib/", 122 | ) 123 | 124 | pkg_tar( 125 | name = "doc", 126 | srcs = [":LICENSE"], 127 | mode = "0644", 128 | package_dir = "/usr/share/doc/google-fast-socket/", 129 | ) 130 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 Google LLC 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Google LLC nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | ------------------ 30 | 31 | Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions 35 | are met: 36 | * Redistributions of source code must retain the above copyright 37 | notice, this list of conditions and the following disclaimer. 38 | * Redistributions in binary form must reproduce the above copyright 39 | notice, this list of conditions and the following disclaimer in the 40 | documentation and/or other materials provided with the distribution. 41 | * Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National 42 | Laboratory, the U.S. Department of Energy, nor the names of their 43 | contributors may be used to endorse or promote products derived 44 | from this software without specific prior written permission. 45 | 46 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 47 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 48 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 49 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 50 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 51 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 52 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 53 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 54 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 55 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 56 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 57 | 58 | The U.S. Department of Energy funded the development of this software 59 | under subcontract 7078610 with Lawrence Berkeley National Laboratory. 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NCCL Fast Socket 2 | 3 | NCCL Fast Socket is a transport layer plugin to improve NCCL collective 4 | communication performance on Google Cloud. 5 | 6 | ## Overview 7 | 8 | Collective communication primitives such as all-reduce and all-gather have been 9 | widely used in distributed training in machine learning. The NVIDIA Collective 10 | Communications Library (NCCL) is a highly optimized implementation of these 11 | multi-GPU and multi-node collective communication primitives that supports 12 | NVIDIA GPUs. 13 | 14 | NCCL Fast Socket is based on TCP/IP communication and uses a number of 15 | techniques to achieve better and more consistent performance, especially with 16 | 100 Gbps networking on Google Cloud. 17 | 18 | ## Getting Started 19 | 20 | ### Dependencies 21 | 22 | Fast Socket requires working installation of CUDA to build. After building the 23 | plugin, it has to be in `LD_LIBRARY_PATH` in order to be loaded by NCCL. 24 | 25 | ### Build 26 | 27 | The plugin uses Bazel to build. You can build the plugin as follows: 28 | 29 | ``` 30 | $ bazel build :all 31 | ``` 32 | 33 | The plugin is located at `bazel-bin/libnccl-net.so` and can be copied into your 34 | `LD_LIBRARY_PATH`. 35 | 36 | ## Getting Help 37 | 38 | Please open an issue if you have any questions or if you think you may have 39 | found any bugs. 40 | 41 | ## Contributing 42 | 43 | Contributions are always welcomed. Please refer to our [contributing guidelines](CONTRIBUTING.md) 44 | to learn how to contriute. 45 | 46 | ## License 47 | 48 | Fast Socket is licensed under the terms of a BSD-style license. 49 | See [LICENSE](LICENSE) for more information. 50 | 51 | This is not an officially supported Google product. 52 | -------------------------------------------------------------------------------- /WORKSPACE.bazel: -------------------------------------------------------------------------------- 1 | workspace(name = "fastsocket") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") 5 | load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") 6 | 7 | # NCCL 8 | maybe( 9 | new_git_repository, 10 | name = "nccl", 11 | build_file = "//:bazel/nccl/BUILD.bazel", 12 | remote = "https://github.com/NVIDIA/nccl.git", 13 | tag = "v2.13.4-1", 14 | ) 15 | 16 | # CUDA 17 | maybe( 18 | new_local_repository, 19 | name = "local_config_cuda", 20 | build_file = "//:bazel/cuda/BUILD.bazel", 21 | path = "/usr/local/cuda", 22 | ) 23 | 24 | # rules_pkg 25 | http_archive( 26 | name = "rules_pkg", 27 | url = "https://github.com/bazelbuild/rules_pkg/archive/main.zip", 28 | strip_prefix = "rules_pkg-main", 29 | ) 30 | 31 | load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") 32 | rules_pkg_dependencies() 33 | -------------------------------------------------------------------------------- /bazel/cuda/BUILD.bazel: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | ) 4 | 5 | cc_library( 6 | name = "cuda_headers", 7 | hdrs = glob(["include/**"]), 8 | includes = ["include"], 9 | ) 10 | -------------------------------------------------------------------------------- /bazel/nccl/BUILD.bazel: -------------------------------------------------------------------------------- 1 | NCCL_MAJOR = 2 2 | NCCL_MINOR = 13 3 | NCCL_PATCH = 4 4 | NCCL_SUFFIX = 1 5 | NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH 6 | 7 | genrule( 8 | name = "gen_nccl_h", 9 | srcs = [ 10 | "src/nccl.h.in", 11 | ], 12 | outs = [ 13 | "src/nccl.h", 14 | ], 15 | cmd = 'sed -e "s/\\$${{nccl:Major}}/{}/g" -e "s/\\$${{nccl:Minor}}/{}/g" -e "s/\\$${{nccl:Patch}}/{}/g" -e "s/\\$${{nccl:Suffix}}/{}/g" -e "s/\\$${{nccl:Version}}/{}/g" $< > $@'.format( 16 | NCCL_MAJOR, 17 | NCCL_MINOR, 18 | NCCL_PATCH, 19 | NCCL_SUFFIX, 20 | NCCL_VERSION, 21 | ), 22 | ) 23 | 24 | cc_library( 25 | name = "src_hdrs", 26 | hdrs = [ 27 | "src/nccl.h", 28 | ], 29 | includes = ["src"], 30 | deps = [ 31 | "@local_config_cuda//:cuda_headers", 32 | ], 33 | ) 34 | 35 | cc_library( 36 | name = "include_hdrs", 37 | hdrs = glob([ 38 | "src/include/*.h", 39 | "src/include/*.hpp", 40 | ]), 41 | includes = ["src/include"], 42 | deps = [ 43 | ":src_hdrs", 44 | ], 45 | ) 46 | 47 | cc_library( 48 | name = "plugin_lib", 49 | srcs = [ 50 | "src/debug.cc", 51 | "src/misc/utils.cc", 52 | ], 53 | visibility = ["//visibility:public"], 54 | deps = [ 55 | ":include_hdrs", 56 | ], 57 | ) 58 | -------------------------------------------------------------------------------- /compat.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Google LLC 2 | // 3 | // Use of this source code is governed by a BSD-style 4 | // license that can be found in the LICENSE file or at 5 | // https://developers.google.com/open-source/licenses/bsd 6 | 7 | #ifndef THIRD_PARTY_GPUS_NCCL_FASTSOCKET_PLUGIN_COMPAT_H_ 8 | #define THIRD_PARTY_GPUS_NCCL_FASTSOCKET_PLUGIN_COMPAT_H_ 9 | 10 | #include "nccl_net.h" 11 | 12 | // ncclNet_v2_t: defined on earliest supported version, removed on 2.6. 13 | #if (NCCL_MAJOR == 2 && NCCL_MINOR >= 6) 14 | 15 | typedef struct { 16 | const char* name; 17 | ncclResult_t (*init)(ncclDebugLogger_t logFunction); 18 | ncclResult_t (*devices)(int* ndev); 19 | ncclResult_t (*pci_path)(int dev, char** path); 20 | ncclResult_t (*ptr_support)(int dev, int* supportedTypes); 21 | ncclResult_t (*listen)(int dev, void* handle, void** listenComm); 22 | ncclResult_t (*connect)(int dev, void* handle, void** sendComm); 23 | ncclResult_t (*accept)(void* listenComm, void** recvComm); 24 | ncclResult_t (*reg_mr)(void* comm, void* data, int size, int type, 25 | void** mhandle); 26 | ncclResult_t (*dereg_mr)(void* comm, void* mhandle); 27 | ncclResult_t (*isend)(void* sendComm, void* data, int size, void* mhandle, 28 | void** request); 29 | ncclResult_t (*irecv)(void* recvComm, void* data, int size, void* mhandle, 30 | void** request); 31 | ncclResult_t (*flush)(void* recvComm, void* data, int size, void* mhandle); 32 | ncclResult_t (*test)(void* request, int* done, int* size); 33 | ncclResult_t (*close_send)(void* sendComm); 34 | ncclResult_t (*close_recv)(void* recvComm); 35 | ncclResult_t (*close_listen)(void* listenComm); 36 | } ncclNet_v2_t; 37 | 38 | #endif 39 | 40 | // ncclNet_v3_t: defined on 2.6, removed on 2.8. 41 | #if (NCCL_MAJOR == 2 && NCCL_MINOR < 6) || (NCCL_MAJOR == 2 && NCCL_MINOR >= 8) 42 | 43 | typedef ncclNetProperties_v6_t ncclNetProperties_v3_t; 44 | typedef struct { 45 | const char* name; 46 | ncclResult_t (*init)(ncclDebugLogger_t logFunction); 47 | ncclResult_t (*devices)(int* ndev); 48 | ncclResult_t (*get_properties)(int dev, ncclNetProperties_v3_t* props); 49 | ncclResult_t (*listen)(int dev, void* handle, void** listenComm); 50 | ncclResult_t (*connect)(int dev, void* handle, void** sendComm); 51 | ncclResult_t (*accept)(void* listenComm, void** recvComm); 52 | ncclResult_t (*reg_mr)(void* comm, void* data, int size, int type, 53 | void** mhandle); 54 | ncclResult_t (*dereg_mr)(void* comm, void* mhandle); 55 | ncclResult_t (*isend)(void* sendComm, void* data, int size, void* mhandle, 56 | void** request); 57 | ncclResult_t (*irecv)(void* recvComm, void* data, int size, void* mhandle, 58 | void** request); 59 | ncclResult_t (*flush)(void* recvComm, void* data, int size, void* mhandle); 60 | ncclResult_t (*test)(void* request, int* done, int* size); 61 | ncclResult_t (*close_send)(void* sendComm); 62 | ncclResult_t (*close_recv)(void* recvComm); 63 | ncclResult_t (*close_listen)(void* listenComm); 64 | } ncclNet_v3_t; 65 | 66 | #endif 67 | 68 | // ncclNet_v4_t: defined on 2.8, removed on 2.19. 69 | #if (NCCL_MAJOR == 2 && NCCL_MINOR < 8) || (NCCL_MAJOR == 2 && NCCL_MINOR >= 19) 70 | 71 | typedef ncclNetProperties_v6_t ncclNetProperties_v4_t; 72 | typedef struct { 73 | const char* name; 74 | ncclResult_t (*init)(ncclDebugLogger_t logFunction); 75 | ncclResult_t (*devices)(int* ndev); 76 | ncclResult_t (*get_properties)(int dev, ncclNetProperties_v4_t* props); 77 | ncclResult_t (*listen)(int dev, void* handle, void** listenComm); 78 | ncclResult_t (*connect)(int dev, void* handle, void** sendComm); 79 | ncclResult_t (*accept)(void* listenComm, void** recvComm); 80 | ncclResult_t (*reg_mr)(void* comm, void* data, int size, int type, 81 | void** mhandle); 82 | ncclResult_t (*dereg_mr)(void* comm, void* mhandle); 83 | ncclResult_t (*isend)(void* sendComm, void* data, int size, void* mhandle, 84 | void** request); 85 | ncclResult_t (*irecv)(void* recvComm, void* data, int size, void* mhandle, 86 | void** request); 87 | ncclResult_t (*iflush)(void* recvComm, void* data, int size, void* mhandle, 88 | void** request); 89 | ncclResult_t (*test)(void* request, int* done, int* size); 90 | ncclResult_t (*close_send)(void* sendComm); 91 | ncclResult_t (*close_recv)(void* recvComm); 92 | ncclResult_t (*close_listen)(void* listenComm); 93 | ncclResult_t (*calloc)(int dev, int size, int type, void** data, 94 | void** mhandle); 95 | ncclResult_t (*free)(void* mhandle); 96 | } ncclNet_v4_t; 97 | 98 | #endif 99 | 100 | // ncclNet_v5_t: defined on 2.12, live until now. 101 | #if NCCL_MAJOR == 2 && NCCL_MINOR < 12 102 | 103 | typedef struct { 104 | char* name; // Used mostly for logging. 105 | char* pciPath; // Path to the PCI device in /sys. 106 | uint64_t guid; // Unique identifier for the NIC chip. Important for 107 | // cards with multiple PCI functions (Physical or virtual). 108 | int ptrSupport; // NCCL_PTR_HOST or NCCL_PTR_HOST|NCCL_PTR_CUDA 109 | int speed; // Port speed in Mbps. 110 | int port; // Port number. 111 | float latency; // Network latency 112 | int maxComms; // Maximum number of comms we can create 113 | int maxRecvs; // Maximum number of grouped receives. 114 | } ncclNetProperties_v5_t; 115 | 116 | typedef struct { 117 | // Name of the network (mainly for logs) 118 | const char* name; 119 | // Initialize the network. 120 | ncclResult_t (*init)(ncclDebugLogger_t logFunction); 121 | // Return the number of adapters. 122 | ncclResult_t (*devices)(int* ndev); 123 | // Get various device properties. 124 | ncclResult_t (*get_properties)(int dev, ncclNetProperties_v5_t* props); 125 | // Create a receiving object and provide a handle to connect to it. The 126 | // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged 127 | // between ranks to create a connection. 128 | ncclResult_t (*listen)(int dev, void* handle, void** listenComm); 129 | // Connect to a handle and return a sending comm object for that peer. 130 | // This call must not block for the connection to be established, and instead 131 | // should return successfully with sendComm == NULL with the expectation that 132 | // it will be called again until sendComm != NULL. 133 | ncclResult_t (*connect)(int dev, void* handle, void** sendComm); 134 | // Finalize connection establishment after remote peer has called connect. 135 | // This call must not block for the connection to be established, and instead 136 | // should return successfully with recvComm == NULL with the expectation that 137 | // it will be called again until recvComm != NULL. 138 | ncclResult_t (*accept)(void* listenComm, void** recvComm); 139 | // Register/Deregister memory. Comm can be either a sendComm or a recvComm. 140 | // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. 141 | ncclResult_t (*reg_mr)(void* comm, void* data, int size, int type, 142 | void** mhandle); 143 | ncclResult_t (*dereg_mr)(void* comm, void* mhandle); 144 | // Asynchronous send to a peer. 145 | // May return request == NULL if the call cannot be performed (or would block) 146 | ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, 147 | void* mhandle, void** request); 148 | // Asynchronous recv from a peer. 149 | // May return request == NULL if the call cannot be performed (or would block) 150 | ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, 151 | int* tags, void** mhandles, void** request); 152 | // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is 153 | // visible to the GPU 154 | ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, 155 | void** mhandles, void** request); 156 | // Test whether a request is complete. If size is not NULL, it returns the 157 | // number of bytes sent/received. 158 | ncclResult_t (*test)(void* request, int* done, int* sizes); 159 | ncclResult_t (*close_send)(void* sendComm); 160 | ncclResult_t (*close_recv)(void* recvComm); 161 | ncclResult_t (*close_listen)(void* listenComm); 162 | } ncclNet_v5_t; 163 | 164 | #endif 165 | 166 | // ncclNet_v6_t: defined on 2.13, live until now. 167 | #if NCCL_MAJOR == 2 && NCCL_MINOR < 13 168 | 169 | typedef ncclNetProperties_v5_t ncclNetProperties_v6_t; 170 | typedef struct { 171 | // Name of the network (mainly for logs) 172 | const char* name; 173 | // Initialize the network. 174 | ncclResult_t (*init)(ncclDebugLogger_t logFunction); 175 | // Return the number of adapters. 176 | ncclResult_t (*devices)(int* ndev); 177 | // Get various device properties. 178 | ncclResult_t (*getProperties)(int dev, ncclNetProperties_v6_t* props); 179 | // Create a receiving object and provide a handle to connect to it. The 180 | // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged 181 | // between ranks to create a connection. 182 | ncclResult_t (*listen)(int dev, void* handle, void** listenComm); 183 | // Connect to a handle and return a sending comm object for that peer. 184 | // This call must not block for the connection to be established, and instead 185 | // should return successfully with sendComm == NULL with the expectation that 186 | // it will be called again until sendComm != NULL. 187 | ncclResult_t (*connect)(int dev, void* handle, void** sendComm); 188 | // Finalize connection establishment after remote peer has called connect. 189 | // This call must not block for the connection to be established, and instead 190 | // should return successfully with recvComm == NULL with the expectation that 191 | // it will be called again until recvComm != NULL. 192 | ncclResult_t (*accept)(void* listenComm, void** recvComm); 193 | // Register/Deregister memory. Comm can be either a sendComm or a recvComm. 194 | // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. 195 | ncclResult_t (*regMr)(void* comm, void* data, int size, int type, 196 | void** mhandle); 197 | /* DMA-BUF support */ 198 | ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, 199 | uint64_t offset, int fd, void** mhandle); 200 | ncclResult_t (*deregMr)(void* comm, void* mhandle); 201 | // Asynchronous send to a peer. 202 | // May return request == NULL if the call cannot be performed (or would block) 203 | ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, 204 | void* mhandle, void** request); 205 | // Asynchronous recv from a peer. 206 | // May return request == NULL if the call cannot be performed (or would block) 207 | ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, 208 | int* tags, void** mhandles, void** request); 209 | // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is 210 | // visible to the GPU 211 | ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, 212 | void** mhandles, void** request); 213 | // Test whether a request is complete. If size is not NULL, it returns the 214 | // number of bytes sent/received. 215 | ncclResult_t (*test)(void* request, int* done, int* sizes); 216 | // Close and free send/recv comm objects 217 | ncclResult_t (*closeSend)(void* sendComm); 218 | ncclResult_t (*closeRecv)(void* recvComm); 219 | ncclResult_t (*closeListen)(void* listenComm); 220 | // Allocates a zero-initialized host buffer to be used by the net subsystem, 221 | // either allocated or registered through CUDA for host memory - For more 222 | // details, refer to cudaHostRegister(). The plugin can keep track of 223 | // allocation-related information by via mhandle. If this is is being used, 224 | // regMR and deregMr must be set to NULL. 225 | ncclResult_t (*calloc)(int dev, int size, int type, void** data, 226 | void** mhandle); 227 | // Frees the buffer identified by mhandle. The buffer must have been allocated 228 | // by a previous alloc() call. It is legal to pass a nullptr to this function, 229 | // in which case it should return success and do nothing. 230 | ncclResult_t (*free)(void* mhandle); 231 | } ncclNet_v6_t; 232 | 233 | #endif 234 | 235 | #endif // THIRD_PARTY_GPUS_NCCL_FASTSOCKET_PLUGIN_COMPAT_H_ 236 | -------------------------------------------------------------------------------- /net_fastsocket.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Google LLC 2 | // 3 | // Use of this source code is governed by a BSD-style 4 | // license that can be found in the LICENSE file or at 5 | // https://developers.google.com/open-source/licenses/bsd 6 | 7 | /************************************************************************* 8 | * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved. 9 | * 10 | * See LICENSE for license information 11 | ************************************************************************/ 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | #include 32 | #include 33 | #include 34 | 35 | #include "nccl_net.h" 36 | #include "compat.h" 37 | #include "utilities.h" 38 | 39 | #define MAX_INLINE_THRESHOLD 2048 40 | #define MAX_SOCKETS 32 41 | #define MAX_THREADS 16 42 | #define MAX_REQUESTS 16 43 | #define MAX_QUEUE_LEN MAX_REQUESTS 44 | #define MAX_TASKS 6 45 | 46 | #define MAX_FLOW_ENGINES 16 47 | #define MAX_CONNECT_RETRY 1000 48 | 49 | #define BUFFERED_CTRL 50 | #define TX_ZCOPY 51 | 52 | #ifndef SO_BUSY_POLL 53 | #define SO_BUSY_POLL 46 54 | #endif 55 | 56 | #ifndef TCP_NOTSENT_LOWAT 57 | #define TCP_NOTSENT_LOWAT 25 58 | #endif 59 | 60 | #ifndef SO_ZEROCOPY 61 | #define SO_ZEROCOPY 60 62 | #endif 63 | 64 | #ifndef MSG_ZEROCOPY 65 | #define MSG_ZEROCOPY 0x4000000 66 | #endif 67 | 68 | #ifndef SO_EE_ORIGIN_ZEROCOPY 69 | #define SO_EE_ORIGIN_ZEROCOPY 5 70 | #endif 71 | 72 | #ifndef SO_EE_CODE_ZEROCOPY_COPIED 73 | #define SO_EE_CODE_ZEROCOPY_COPIED 1 74 | #endif 75 | 76 | #ifndef SO_INCOMING_CPU 77 | #define SO_INCOMING_CPU 49 78 | #endif 79 | 80 | #define NCCL_SOCKET_SEND 0 81 | #define NCCL_SOCKET_RECV 1 82 | 83 | #define HINT_BOTTLENECK 84 | 85 | // Global variables 86 | static int kNcclNetIfs = -1; 87 | struct ncclSocketDev { 88 | union socketAddress addr; 89 | char dev_name[MAX_IF_NAME_SIZE]; 90 | char* pci_path; 91 | }; 92 | static struct ncclSocketDev kNcclSocketDevs[MAX_IFS]; 93 | pthread_mutex_t kNcclFastSocketLock = PTHREAD_MUTEX_INITIALIZER; 94 | static int kEnableSpin = 0; 95 | static int kInlineThreshold = 0; 96 | static int kSockBusyPoll = 0; 97 | static int kSockNotsentLowat = 0; 98 | static int kSockSentbuf = 0; 99 | static int kMinZcopySize = 0; 100 | static int kDynamicChunkSize = 128 * 1024; 101 | static int kEnableFlowPlacement = 0; 102 | static int kNumFlowEngine = 4; 103 | 104 | static int kTxCPUStart = -1; 105 | static int kRxCPUStart = -1; 106 | static int kQueueSkip = 0; 107 | 108 | // Whether to enable the plugin. Default is enabled. 109 | NCCL_PARAM(EnableFastSocket, "FAST_SOCKET_ENABLE", 1); 110 | 111 | // Maximum chunk size in bytes for dynamic loading balancing. 112 | // Default is 128 KB 113 | NCCL_PARAM(DynamicChunkSize, "DYNAMIC_CHUNK_SIZE", 0); 114 | 115 | // Whether to spin the helper thread. Default is disabled. 116 | NCCL_PARAM(EnableThreadSpin, "THREAD_SPIN_ENABLE", 0); 117 | 118 | // Maximum size of data to inline with a control message. 119 | // 0 means disable inlining. 120 | NCCL_PARAM(InlineThreshold, "INLINE_THRESHOLD", 0); 121 | 122 | // Whether to busy poll the control socket. Default is disabled. 123 | NCCL_PARAM(SockBusyPoll, "SOCK_BUSY_POLL", 0); 124 | 125 | // Limit of unsent bytes in sockets: https://lwn.net/Articles/560082/ 126 | // The backpressure mechanism shifts the load to userspace, helping 127 | // the load balancing algorithm better detect the load of each socket. 128 | // 0 means disable backpressure. 129 | NCCL_PARAM(SockNotsentLowat, "SOCK_NOTSENT_LOWAT", 0); 130 | 131 | // Size of socket send buffer in bytes. 0 means the kernel default value. 132 | NCCL_PARAM(SockSendBuf, "SOCK_SEND_BUF", 0); 133 | 134 | // Minimum data size to use zero-copy. 0 means disabled. 135 | NCCL_PARAM(MinZcopySize, "MIN_ZCOPY_SIZE", 0); 136 | 137 | NCCL_PARAM(NsocksPerThread, "NSOCKS_PERTHREAD", -2); 138 | NCCL_PARAM(NThreads, "SOCKET_NTHREADS", -2); 139 | 140 | NCCL_PARAM(EnableFlowPlacement, "FLOW_PLACEMENT_ENABLE", 1); 141 | NCCL_PARAM(NumFlowEngine, "NUM_FLOW_ENGINE", 4); 142 | 143 | NCCL_PARAM(TxCPUStart, "TX_CPU_START", -2); 144 | NCCL_PARAM(RxCPUStart, "RX_CPU_START", -2); 145 | NCCL_PARAM(QueueSkip, "QUEUE_SKIP", 0); 146 | 147 | static ncclResult_t socketSpin(int op, int fd, void* ptr, int size, 148 | int* offset) { 149 | while (*offset < size) 150 | NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 0)); 151 | return ncclSuccess; 152 | } 153 | 154 | // Data Structures 155 | template 157 | struct ncclItemQueue { 158 | // 0: next dequeue slot, NSTATES - 1: next enqueue slot 159 | IndexType idx[NSTATES]; 160 | ItemType items[MAX_ITEMS]; 161 | ncclItemQueue() { 162 | for (int i = 0; i < NSTATES; ++i) idx[i] = 0; 163 | } 164 | 165 | bool empty() { return idx[0] == idx[NSTATES - 1]; } 166 | bool has_free() { return idx[NSTATES - 1] - idx[0] < MAX_ITEMS; } 167 | 168 | template 169 | bool has() { 170 | if (STATE == 0) return has_free(); 171 | return idx[STATE] > idx[STATE - 1]; 172 | } 173 | 174 | template 175 | ItemType* first() { 176 | if (STATE == 0) return items + idx[NSTATES - 1] % MAX_ITEMS; 177 | return items + idx[STATE - 1] % MAX_ITEMS; 178 | } 179 | 180 | template 181 | void advance() { 182 | ++idx[STATE - 1]; 183 | } 184 | 185 | void enqueue() { ++idx[NSTATES - 1]; } 186 | void dequeue() { ++idx[0]; } 187 | 188 | // For cases when we need to iterate through all items in a state (other than 189 | // 0). 190 | template 191 | IndexUnderType get_iterator() { 192 | return idx[STATE - 1]; 193 | } 194 | IndexUnderType next(IndexUnderType it) { return it + 1; } 195 | template 196 | bool is(IndexUnderType it) { 197 | return it < idx[STATE]; 198 | } 199 | ItemType* to_item(IndexUnderType it) { return items + it % MAX_ITEMS; } 200 | }; 201 | 202 | struct ncclCtrl { 203 | uint16_t type; 204 | uint16_t index; 205 | uint32_t size; 206 | uint32_t offset; 207 | uint32_t total; 208 | } __attribute__((__packed__)); 209 | 210 | struct ncclSocketHandle { 211 | union socketAddress connect_addr; 212 | int num_socks; 213 | int num_threads; 214 | }; 215 | 216 | struct ncclSocketRequest { 217 | struct ncclFastSocketComm* comm; 218 | void* data; 219 | int op; 220 | int next_sock_id; 221 | int next_size; 222 | int offset; 223 | int size; 224 | int size_pending; 225 | }; 226 | 227 | struct ncclSocketTask { 228 | int op; 229 | int size; 230 | int offset; 231 | void* data; 232 | struct ncclSocketRequest* r; 233 | #ifdef TX_ZCOPY 234 | uint32_t tx_count; 235 | uint32_t tx_bound; 236 | #endif 237 | ncclResult_t result; 238 | }; 239 | 240 | enum ThreadState { start, stop }; 241 | enum CtrlType { 242 | CTRL_NORMAL = 0, 243 | CTRL_INLINE = 1, 244 | }; 245 | 246 | struct ncclSocketThreadResources { 247 | int id; // thread index 248 | std::atomic_uint next; 249 | enum ThreadState state; 250 | struct ncclFastSocketComm* comm; 251 | pthread_mutex_t thread_lock; 252 | pthread_cond_t thread_cond; 253 | }; 254 | 255 | // Must be identical to ncclSocketListenComm in net_socket.cc 256 | struct ncclSocketListenComm { 257 | int fd; 258 | int num_socks; 259 | int num_threads; 260 | }; 261 | 262 | // Request state transistion: 263 | // FREE->ACTIVE->INACTIVE->FREE 264 | enum { 265 | REQUEST_FREE = 0, 266 | REQUEST_INACTIVE = 1, 267 | REQUEST_ACTIVE = 2, 268 | REQUEST_MAX_STATES = 3, 269 | }; 270 | 271 | struct ncclSocketRequestQueue 272 | : ncclItemQueue { 274 | using Base = ncclItemQueue; 276 | ncclSocketRequestQueue() : Base() {} 277 | bool has_active() { return has(); } 278 | bool has_inactive() { return has(); } 279 | struct ncclSocketRequest* next_free() { return first(); } 280 | struct ncclSocketRequest* next_active() { return first(); } 281 | struct ncclSocketRequest* next_inactive() { 282 | return first(); 283 | } 284 | void mark_inactive() { advance(); } 285 | }; 286 | 287 | // Task state transistion: 288 | // FREE->ACTIVE->INACTIVE->FREE 289 | enum { 290 | TASK_FREE = 0, 291 | TASK_INACTIVE = 1, 292 | TASK_COMPLETING = 2, 293 | TASK_ACTIVE = 3, 294 | TASK_MAX_STATES = 4, 295 | }; 296 | 297 | struct ncclSocketTaskQueue 298 | : ncclItemQueue { 300 | using Base = ncclItemQueue; 302 | ncclSocketTaskQueue() : Base() {} 303 | bool has_active() { return has(); } 304 | bool has_inactive() { return has(); } 305 | bool has_completing() { return has(); } 306 | ncclSocketTask* next_free() { return first(); } 307 | ncclSocketTask* next_active() { return first(); } 308 | ncclSocketTask* next_completing() { return first(); } 309 | ncclSocketTask* next_inactive() { return first(); } 310 | void finish_active() { advance(); } 311 | void finish_completing() { advance(); } 312 | }; 313 | 314 | template 315 | struct ncclBufferedSendSocket { 316 | ncclBufferedSendSocket() : fd(-1), cur(0) {} 317 | void setFd(int fileFd) { fd = fileFd; } 318 | ncclResult_t sync() { 319 | if (cur == 0) return ncclSuccess; 320 | int off = 0; 321 | NCCLCHECK(socketSpin(NCCL_SOCKET_SEND, fd, buf, cur, &off)); 322 | cur = 0; 323 | return ncclSuccess; 324 | } 325 | ncclResult_t send(void* ptr, unsigned s) { 326 | if (s > BUF_SIZE) return ncclInternalError; 327 | if (cur + s > BUF_SIZE) NCCLCHECK(sync()); 328 | memcpy(buf + cur, ptr, s); 329 | cur += s; 330 | return ncclSuccess; 331 | } 332 | 333 | int fd; 334 | int cur; 335 | char buf[BUF_SIZE]; 336 | }; 337 | 338 | template 339 | struct ncclBufferedRecvSocket { 340 | ncclBufferedRecvSocket() : fd(-1), cur(0), end(0) {} 341 | void setFd(int fileFd) { fd = fileFd; } 342 | bool empty() { return cur == end; } 343 | ncclResult_t refill() { 344 | if (!empty()) return ncclSuccess; 345 | cur = end = 0; 346 | return socketProgress(NCCL_SOCKET_RECV, fd, buf, BUF_SIZE, &end); 347 | } 348 | ncclResult_t recv(void* ptr, int s) { 349 | while (s) { 350 | refill(); 351 | int len = std::min(s, end - cur); 352 | memcpy(ptr, buf + cur, len); 353 | cur += len; 354 | ptr = reinterpret_cast(ptr) + len; 355 | s -= len; 356 | } 357 | return ncclSuccess; 358 | } 359 | int brecv(void* ptr, int s) { 360 | int sz = std::min(s, end - cur); 361 | memcpy(ptr, buf + cur, sz); 362 | cur += sz; 363 | return sz; 364 | } 365 | 366 | int fd; 367 | int cur; 368 | int end; 369 | char buf[BUF_SIZE]; 370 | }; 371 | 372 | struct ncclFdData { 373 | int fd; 374 | #ifdef TX_ZCOPY 375 | uint32_t tx_upper; 376 | uint32_t tx_lower; 377 | #endif 378 | bool used; 379 | uint64_t stat; 380 | ncclSocketTaskQueue tasks; 381 | }; 382 | 383 | struct ncclFastSocketComm { 384 | int ctrl_fd; // control socket fd 385 | bool passive; 386 | std::atomic connected; 387 | struct ncclFdData 388 | fd_data[MAX_SOCKETS]; // data socket fd and its auxiliary data 389 | int num_socks; // total number of socket fds per comm 390 | int num_threads; // number of helper threads per comm 391 | int last_fd; // the last enqueued fd idx 392 | ncclSocketRequestQueue rq; // requests queue 393 | 394 | #ifdef BUFFERED_CTRL 395 | #define CTRL_BUFFER_SIZE (sizeof(ncclCtrl) * 8) 396 | ncclBufferedSendSocket ctrl_send; 397 | ncclBufferedRecvSocket ctrl_recv; 398 | #endif 399 | 400 | #ifdef HINT_BOTTLENECK 401 | struct timeval start_time; 402 | #endif 403 | 404 | // helper threads 405 | pthread_t helper_thread[MAX_THREADS]; 406 | pthread_t connect_thread; 407 | // auxiliary data with helper threads 408 | struct ncclSocketThreadResources thread_resource[MAX_THREADS]; 409 | union socketAddress connect_addr; 410 | }; 411 | 412 | // Control Path Functions 413 | static inline void setSockBusyPoll(int fd) { 414 | if (kSockBusyPoll) { 415 | if (setsockopt(fd, SOL_SOCKET, SO_BUSY_POLL, &kSockBusyPoll, 416 | sizeof kSockBusyPoll) < 0) { 417 | WARN("Cannot enable socket busy poll"); 418 | } 419 | } 420 | } 421 | 422 | static inline void setSockNotsentLowat(int fd) { 423 | if (kSockNotsentLowat) { 424 | if (setsockopt(fd, SOL_TCP, TCP_NOTSENT_LOWAT, &kSockNotsentLowat, 425 | sizeof kSockNotsentLowat) < 0) { 426 | WARN("Cannot set socket TCP_NOTSENT_LOWAT"); 427 | } 428 | } 429 | } 430 | 431 | static inline void setSockSendBuf(int fd) { 432 | if (kSockSentbuf) { 433 | if (setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &kSockSentbuf, 434 | sizeof kSockSentbuf) < 0) { 435 | WARN("Cannot set socket SO_SNDBUF"); 436 | } 437 | } 438 | } 439 | 440 | static inline void setSockZcopy(int fd) { 441 | if (kMinZcopySize > 0) { 442 | int one = 1; 443 | if (setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &one, sizeof one) < 0) { 444 | WARN("Cannot set socket to SO_ZEROCOPY"); 445 | kMinZcopySize = 0; 446 | } 447 | } 448 | } 449 | 450 | static ncclResult_t ncclFastSocketGetPciPath(char* devName, char** pciPath) { 451 | char devicePath[PATH_MAX]; 452 | snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", devName); 453 | // May return NULL if the file doesn't exist. 454 | *pciPath = realpath(devicePath, nullptr); 455 | return ncclSuccess; 456 | } 457 | 458 | ncclResult_t ncclFastSocketPciPath(int dev, char** pciPath) { 459 | char devicePath[PATH_MAX]; 460 | snprintf(devicePath, PATH_MAX, "/sys/class/net/%s/device", 461 | kNcclSocketDevs[dev].dev_name); 462 | // May return NULL if the file doesn't exist. 463 | *pciPath = realpath(devicePath, nullptr); 464 | if (*pciPath == nullptr) { 465 | INFO(NCCL_NET | NCCL_INIT, "Could not find real path of %s", devicePath); 466 | return ncclSystemError; 467 | } 468 | return ncclSuccess; 469 | } 470 | 471 | ncclResult_t ncclFastSocketInit(ncclDebugLogger_t logFunction) { 472 | int enable = ncclParamEnableFastSocket(); 473 | nccl_log_func = logFunction; 474 | #ifdef CHECK_COLLNET_ENABLE 475 | char* collnet_enable = getenv("NCCL_COLLNET_ENABLE"); 476 | if (!collnet_enable || strcmp(collnet_enable, "0") == 0) { 477 | enable = 0; 478 | } 479 | #endif 480 | if (!enable) { 481 | INFO(NCCL_NET | NCCL_INIT, "NET/FastSocket disabled"); 482 | return ncclInternalError; 483 | } 484 | 485 | int dcs = ncclParamDynamicChunkSize(); 486 | if (dcs > 0) kDynamicChunkSize = dcs; 487 | 488 | kInlineThreshold = ncclParamInlineThreshold(); 489 | if (kInlineThreshold < 0) kInlineThreshold = 0; 490 | if (kInlineThreshold > MAX_INLINE_THRESHOLD) 491 | kInlineThreshold = MAX_INLINE_THRESHOLD; 492 | 493 | kEnableSpin = ncclParamEnableThreadSpin(); 494 | 495 | kSockBusyPoll = ncclParamSockBusyPoll(); 496 | if (kSockBusyPoll < 0) kSockBusyPoll = 0; 497 | 498 | int snl = ncclParamSockNotsentLowat(); 499 | if (snl > 0) kSockNotsentLowat = snl; 500 | 501 | kSockSentbuf = ncclParamSockSendBuf(); 502 | if (kSockSentbuf < 0) kSockSentbuf = 0; 503 | kMinZcopySize = ncclParamMinZcopySize(); 504 | 505 | kTxCPUStart = ncclParamTxCPUStart(); 506 | kRxCPUStart = ncclParamRxCPUStart(); 507 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Tx CPU start: %d", kTxCPUStart); 508 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Rx CPU start: %d", kRxCPUStart); 509 | 510 | kEnableFlowPlacement = ncclParamEnableFlowPlacement(); 511 | if (kEnableFlowPlacement) { 512 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Flow placement enabled."); 513 | kNumFlowEngine = ncclParamNumFlowEngine(); 514 | if (kNumFlowEngine < 1) kNumFlowEngine = 1; 515 | if (kNumFlowEngine > MAX_FLOW_ENGINES) kNumFlowEngine = MAX_FLOW_ENGINES; 516 | } 517 | 518 | kQueueSkip = ncclParamQueueSkip(); 519 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : queue skip: %d", kQueueSkip); 520 | 521 | if (kNcclNetIfs == -1) { 522 | pthread_mutex_lock(&kNcclFastSocketLock); 523 | if (kNcclNetIfs == -1) { 524 | char names[MAX_IF_NAME_SIZE * MAX_IFS]; 525 | union socketAddress addrs[MAX_IFS]; 526 | kNcclNetIfs = findInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS); 527 | if (kNcclNetIfs <= 0) { 528 | WARN("NET/FastSocket : no interface found"); 529 | pthread_mutex_unlock(&kNcclFastSocketLock); 530 | return ncclInternalError; 531 | } else { 532 | char line[2048]; 533 | char addrline[2048]; 534 | line[0] = '\0'; 535 | for (int i = 0; i < kNcclNetIfs; i++) { 536 | strncpy(kNcclSocketDevs[i].dev_name, names + i * MAX_IF_NAME_SIZE, 537 | MAX_IF_NAME_SIZE); 538 | memcpy(&kNcclSocketDevs[i].addr, addrs + i, 539 | sizeof(union socketAddress)); 540 | NCCLCHECK(ncclFastSocketGetPciPath(kNcclSocketDevs[i].dev_name, 541 | &kNcclSocketDevs[i].pci_path)); 542 | snprintf(line + strlen(line), 2047 - strlen(line), " [%d]%s:%s", i, 543 | names + i * MAX_IF_NAME_SIZE, 544 | socketToString(&addrs[i].sa, addrline)); 545 | } 546 | line[2047] = '\0'; 547 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Using%s", line); 548 | } 549 | } 550 | pthread_mutex_unlock(&kNcclFastSocketLock); 551 | } 552 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket plugin initialized"); 553 | return ncclSuccess; 554 | } 555 | 556 | ncclResult_t ncclFastSocketDevices(int* ndev) { 557 | *ndev = kNcclNetIfs; 558 | return ncclSuccess; 559 | } 560 | 561 | static ncclResult_t ncclFlowPlacementGetNsockNthread(int* ns, int* nt) { 562 | *ns = kNumFlowEngine; 563 | *nt = *ns; 564 | INFO(NCCL_NET, "Flow placement forcing parameters: nthreads %d nsocks %d", 565 | *nt, *ns); 566 | return ncclSuccess; 567 | } 568 | 569 | static ncclResult_t ncclFastSocketGetNsockNthread(int dev, int* ns, int* nt) { 570 | if (kEnableFlowPlacement) { 571 | return ncclFlowPlacementGetNsockNthread(ns, nt); 572 | } 573 | int nSocksPerThread = ncclParamNsocksPerThread(); 574 | int nThreads = ncclParamNThreads(); 575 | if (nThreads > MAX_THREADS) { 576 | WARN( 577 | "NET/Socket : NCCL_SOCKET_NTHREADS is greater than the maximum " 578 | "allowed, setting to %d", 579 | MAX_THREADS); 580 | nThreads = MAX_THREADS; 581 | } 582 | if (nThreads == -2 || nSocksPerThread == -2) { 583 | // Auto-detection 584 | int autoNt = 1, autoNs = 1; 585 | char vendorPath[PATH_MAX]; 586 | snprintf(vendorPath, PATH_MAX, "/sys/class/net/%s/device/vendor", 587 | kNcclSocketDevs[dev].dev_name); 588 | char* rPath = realpath(vendorPath, nullptr); 589 | int fd = open(rPath, O_RDONLY); 590 | free(rPath); 591 | if (fd == -1) { 592 | // Could not find device vendor. This is handled silently so 593 | // we don't want to print an INFO error. 594 | INFO(NCCL_NET, "Open of %s failed : %s\n", vendorPath, strerror(errno)); 595 | goto end; 596 | } 597 | char vendor[7]; 598 | strncpy(vendor, "0x0000", 7); 599 | int len; 600 | SYSCHECKVAL(read(fd, vendor, 6), "read", len); 601 | SYSCHECK(close(fd), "close"); 602 | if (strcmp(vendor, "0x1d0f") == 0) { // AWS 603 | autoNt = 2; 604 | autoNs = 8; 605 | } else if (strcmp(vendor, "0x1ae0") == 0) { // GCP 606 | autoNt = 6; 607 | autoNs = 1; 608 | } 609 | end: 610 | if (nThreads == -2) nThreads = autoNt; 611 | if (nSocksPerThread == -2) nSocksPerThread = autoNs; 612 | } 613 | int nSocks = nSocksPerThread * nThreads; 614 | if (nSocks > MAX_SOCKETS) { 615 | nSocksPerThread = MAX_SOCKETS / nThreads; 616 | WARN( 617 | "NET/Socket : the total number of sockets is greater than the maximum " 618 | "allowed, setting NCCL_NSOCKS_PERTHREAD to %d", 619 | nSocksPerThread); 620 | nSocks = nSocksPerThread * nThreads; 621 | } 622 | *ns = nSocks; 623 | *nt = nThreads; 624 | INFO(NCCL_INIT, "NET/Socket: Using %d threads and %d sockets per thread", 625 | nThreads, nSocksPerThread); 626 | return ncclSuccess; 627 | } 628 | 629 | ncclResult_t ncclFastSocketNewComm(struct ncclFastSocketComm** comm) { 630 | NCCLCHECK(ncclCalloc(comm, 1)); 631 | (*comm)->ctrl_fd = -1; 632 | (*comm)->last_fd = 0; 633 | for (int i = 0; i < MAX_SOCKETS; i++) { 634 | (*comm)->fd_data[i].fd = -1; 635 | (*comm)->fd_data[i].used = false; 636 | (*comm)->fd_data[i].stat = 0; 637 | #ifdef TX_ZCOPY 638 | (*comm)->fd_data[i].tx_upper = 0; 639 | (*comm)->fd_data[i].tx_lower = 0; 640 | #endif 641 | } 642 | gettimeofday(&(*comm)->start_time, nullptr); 643 | return ncclSuccess; 644 | } 645 | 646 | static ncclResult_t ncclSocketNewListenComm( 647 | struct ncclSocketListenComm** comm) { 648 | NCCLCHECK(ncclCalloc(comm, 1)); 649 | (*comm)->fd = -1; 650 | return ncclSuccess; 651 | } 652 | 653 | static ncclResult_t GetSocketAddr(int dev, union socketAddress* addr) { 654 | if (dev >= kNcclNetIfs) return ncclInternalError; 655 | memcpy(addr, &kNcclSocketDevs[dev].addr, sizeof(*addr)); 656 | return ncclSuccess; 657 | } 658 | 659 | ncclResult_t ncclFastSocketListen(int dev, void* opaqueHandle, 660 | void** listenComm) { 661 | if (dev < 0) { // data transfer socket is based on specified dev 662 | return ncclInternalError; 663 | } 664 | struct ncclSocketHandle* handle = 665 | static_cast(opaqueHandle); 666 | static_assert(sizeof(struct ncclSocketHandle) < NCCL_NET_HANDLE_MAXSIZE, 667 | "ncclSocketHandle size too large"); 668 | struct ncclSocketListenComm* comm; 669 | NCCLCHECK(ncclSocketNewListenComm(&comm)); 670 | NCCLCHECK(GetSocketAddr(dev, &handle->connect_addr)); 671 | NCCLCHECK(createListenSocket(&comm->fd, &handle->connect_addr)); 672 | NCCLCHECK( 673 | ncclFastSocketGetNsockNthread(dev, &comm->num_socks, &comm->num_threads)); 674 | handle->num_socks = comm->num_socks; 675 | handle->num_threads = comm->num_threads; 676 | *listenComm = comm; 677 | return ncclSuccess; 678 | } 679 | 680 | static void initCtrlFd(struct ncclFastSocketComm* comm, int fd) { 681 | comm->ctrl_fd = fd; 682 | #ifdef BUFFERED_CTRL 683 | comm->ctrl_send.setFd(fd); 684 | comm->ctrl_recv.setFd(fd); 685 | #endif 686 | } 687 | 688 | void waitConnect(struct ncclFastSocketComm* comm) { 689 | while (!comm->connected) { 690 | pthread_yield(); 691 | } 692 | } 693 | 694 | ncclResult_t ncclSocketAsyncConnectV2(struct ncclFastSocketComm* comm) { 695 | int i = 0; 696 | int retry = 0; 697 | while (i < comm->num_socks + 1) { 698 | int tmpFd, offset; 699 | NCCLCHECK(connectAddress(&tmpFd, &comm->connect_addr)); 700 | if (i == comm->num_socks) { 701 | int ii = 0; 702 | offset = 0; 703 | NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &ii, sizeof(int), &offset)); 704 | initCtrlFd(comm, tmpFd); 705 | } else { 706 | int qid, dqid; 707 | int rqid = 0; 708 | int cpu = 0; 709 | socklen_t opt_len = sizeof cpu; 710 | 711 | if (retry < MAX_CONNECT_RETRY) { 712 | if (getsockopt(tmpFd, SOL_SOCKET, SO_INCOMING_CPU, &cpu, &opt_len) < 713 | 0) { 714 | WARN("Cannot get incoming CPU."); 715 | } 716 | 717 | qid = cpu % kNumFlowEngine; 718 | dqid = cpu % (kNumFlowEngine * 2); 719 | if (cpu < kQueueSkip || dqid >= kNumFlowEngine || 720 | comm->fd_data[qid].used) { 721 | qid = -1; 722 | } 723 | } else { 724 | int j = 0; 725 | while (j < comm->num_socks) { 726 | if (!comm->fd_data[j].used) break; 727 | ++j; 728 | } 729 | if (j == comm->num_socks) { 730 | WARN("Cannot find empty socket for %d.", i); 731 | return ncclInternalError; 732 | } 733 | dqid = j; 734 | qid = j; 735 | if (retry == MAX_CONNECT_RETRY) { 736 | WARN("Maximum retry reached for connect %d.", i); 737 | } 738 | } 739 | 740 | offset = 0; 741 | NCCLCHECK( 742 | socketWait(NCCL_SOCKET_RECV, tmpFd, &rqid, sizeof(int), &offset)); 743 | offset = 0; 744 | NCCLCHECK( 745 | socketWait(NCCL_SOCKET_SEND, tmpFd, &qid, sizeof(int), &offset)); 746 | if (qid < 0 || rqid < 0) { 747 | close(tmpFd); 748 | ++retry; 749 | continue; 750 | } 751 | 752 | INFO(NCCL_INIT | NCCL_NET, "connect incoming cpu: %u", cpu); 753 | INFO(NCCL_INIT | NCCL_NET, "connect qid: %d, rqid: %d", qid, rqid); 754 | 755 | setSockNotsentLowat(tmpFd); 756 | setSockSendBuf(tmpFd); 757 | setSockZcopy(tmpFd); 758 | comm->fd_data[rqid].fd = tmpFd; 759 | comm->fd_data[qid].used = true; // qid, not rqid 760 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Connected after %d retries.", 761 | retry); 762 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Connected data socket %d", 763 | i); 764 | retry = 0; 765 | } 766 | ++i; 767 | } 768 | setSockBusyPoll(comm->ctrl_fd); 769 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Async connect done"); 770 | comm->connected = true; 771 | return ncclSuccess; 772 | } 773 | 774 | void* asyncConnect(void* opaque) { 775 | ncclSocketAsyncConnectV2(static_cast(opaque)); 776 | return nullptr; 777 | } 778 | 779 | ncclResult_t ncclSocketConnectV2(int dev, void* opaqueHandle, void** sendComm) { 780 | if (dev < 0) { // data transfer socket is based on specified dev 781 | return ncclInternalError; 782 | } 783 | struct ncclFastSocketComm* comm; 784 | NCCLCHECK(ncclFastSocketNewComm(&comm)); 785 | struct ncclSocketHandle* handle = 786 | static_cast(opaqueHandle); 787 | comm->num_socks = handle->num_socks; 788 | comm->num_threads = handle->num_threads; 789 | comm->connect_addr = handle->connect_addr; 790 | comm->passive = false; 791 | comm->connected = false; 792 | 793 | pthread_create(&comm->connect_thread, nullptr, asyncConnect, 794 | reinterpret_cast(comm)); 795 | pthread_detach(comm->connect_thread); 796 | *sendComm = comm; 797 | return ncclSuccess; 798 | } 799 | 800 | ncclResult_t ncclSocketAcceptV2(void* listenComm, void** recvComm) { 801 | struct ncclSocketListenComm* lComm = 802 | static_cast(listenComm); 803 | struct ncclFastSocketComm* rComm; 804 | NCCLCHECK(ncclFastSocketNewComm(&rComm)); 805 | rComm->num_socks = lComm->num_socks; 806 | rComm->num_threads = lComm->num_threads; 807 | rComm->passive = true; 808 | int i = 0; 809 | int retry = 0; 810 | while (i < rComm->num_socks + 1) { 811 | int tmpFd, offset; 812 | struct sockaddr_in sockaddr; 813 | socklen_t socklen = sizeof(struct sockaddr_in); 814 | SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), 815 | "accept", tmpFd); 816 | if (i == rComm->num_socks) { 817 | offset = 0; 818 | NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset)); 819 | initCtrlFd(rComm, tmpFd); 820 | } else { 821 | unsigned cpu = 0; 822 | int qid, dqid; 823 | int rqid; 824 | socklen_t opt_len = sizeof cpu; 825 | 826 | if (retry < MAX_CONNECT_RETRY) { 827 | if (getsockopt(tmpFd, SOL_SOCKET, SO_INCOMING_CPU, &cpu, &opt_len) < 828 | 0) { 829 | WARN("Cannot get incoming CPU."); 830 | } 831 | qid = static_cast(cpu) % kNumFlowEngine; 832 | dqid = static_cast(cpu) % (kNumFlowEngine * 2); 833 | if (dqid < kNumFlowEngine || rComm->fd_data[qid].used) { 834 | qid = -1; 835 | } 836 | } else { 837 | int j = 0; 838 | while (j < rComm->num_socks) { 839 | if (!rComm->fd_data[j].used) break; 840 | ++j; 841 | } 842 | if (j == rComm->num_socks) { 843 | WARN("Cannot find empty socket for %d.", i); 844 | return ncclInternalError; 845 | } 846 | qid = j; 847 | dqid = j + kNumFlowEngine; 848 | if (retry == MAX_CONNECT_RETRY) { 849 | WARN("Maximum retry reached for accept %d.", i); 850 | } 851 | } 852 | 853 | offset = 0; 854 | NCCLCHECK( 855 | socketWait(NCCL_SOCKET_SEND, tmpFd, &qid, sizeof(int), &offset)); 856 | rqid = 0; 857 | offset = 0; 858 | NCCLCHECK( 859 | socketWait(NCCL_SOCKET_RECV, tmpFd, &rqid, sizeof(int), &offset)); 860 | if (qid < 0 || rqid < 0) { 861 | close(tmpFd); 862 | ++retry; 863 | continue; 864 | } 865 | 866 | INFO(NCCL_INIT | NCCL_NET, "accept qid: %d, rqid: %d", qid, rqid); 867 | INFO(NCCL_INIT | NCCL_NET, "accept incoming cpu: %u", cpu); 868 | 869 | setSockNotsentLowat(tmpFd); 870 | setSockSendBuf(tmpFd); 871 | setSockZcopy(tmpFd); 872 | rComm->fd_data[qid].fd = tmpFd; 873 | rComm->fd_data[qid].used = true; 874 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Connected after %d retries.", 875 | retry); 876 | INFO(NCCL_INIT | NCCL_NET, "NET/FastSocket : Accepted data socket %d", i); 877 | retry = 0; 878 | } 879 | ++i; 880 | } 881 | setSockBusyPoll(rComm->ctrl_fd); 882 | *recvComm = rComm; 883 | rComm->connected = true; 884 | return ncclSuccess; 885 | } 886 | 887 | ncclResult_t ncclFastSocketConnect(int dev, void* opaqueHandle, 888 | void** sendComm) { 889 | if (dev < 0) { // data transfer socket is based on specified dev 890 | return ncclInternalError; 891 | } 892 | if (kEnableFlowPlacement) { 893 | return ncclSocketConnectV2(dev, opaqueHandle, sendComm); 894 | } 895 | struct ncclFastSocketComm* comm; 896 | NCCLCHECK(ncclFastSocketNewComm(&comm)); 897 | struct ncclSocketHandle* handle = 898 | static_cast(opaqueHandle); 899 | comm->num_socks = handle->num_socks; 900 | comm->num_threads = handle->num_threads; 901 | for (int i = 0; i < comm->num_socks + 1; i++) { 902 | int tmpFd, offset = 0; 903 | NCCLCHECK(connectAddress(&tmpFd, &handle->connect_addr)); 904 | NCCLCHECK(socketWait(NCCL_SOCKET_SEND, tmpFd, &i, sizeof(int), &offset)); 905 | if (i == comm->num_socks) { 906 | initCtrlFd(comm, tmpFd); 907 | } else { 908 | setSockNotsentLowat(tmpFd); 909 | setSockSendBuf(tmpFd); 910 | setSockZcopy(tmpFd); 911 | comm->fd_data[i].fd = tmpFd; 912 | } 913 | } 914 | setSockBusyPoll(comm->ctrl_fd); 915 | *sendComm = comm; 916 | comm->passive = false; 917 | comm->connected = true; 918 | return ncclSuccess; 919 | } 920 | 921 | ncclResult_t ncclFastSocketAccept(void* listenComm, void** recvComm) { 922 | if (kEnableFlowPlacement) { 923 | return ncclSocketAcceptV2(listenComm, recvComm); 924 | } 925 | struct ncclSocketListenComm* lComm = 926 | static_cast(listenComm); 927 | struct ncclFastSocketComm* rComm; 928 | NCCLCHECK(ncclFastSocketNewComm(&rComm)); 929 | rComm->num_socks = lComm->num_socks; 930 | rComm->num_threads = lComm->num_threads; 931 | for (int i = 0; i < rComm->num_socks + 1; i++) { 932 | int tmpFd, sendSockIdx, offset = 0; 933 | struct sockaddr_in sockaddr; 934 | socklen_t socklen = sizeof(struct sockaddr_in); 935 | SYSCHECKVAL(accept(lComm->fd, (struct sockaddr*)&sockaddr, &socklen), 936 | "accept", tmpFd); 937 | NCCLCHECK(socketWait(NCCL_SOCKET_RECV, tmpFd, &sendSockIdx, sizeof(int), 938 | &offset)); 939 | if (sendSockIdx == rComm->num_socks) { 940 | initCtrlFd(rComm, tmpFd); 941 | } else { 942 | setSockNotsentLowat(tmpFd); 943 | setSockSendBuf(tmpFd); 944 | setSockZcopy(tmpFd); 945 | rComm->fd_data[sendSockIdx].fd = tmpFd; 946 | } 947 | } 948 | setSockBusyPoll(rComm->ctrl_fd); 949 | *recvComm = rComm; 950 | rComm->passive = true; 951 | rComm->connected = true; 952 | return ncclSuccess; 953 | } 954 | 955 | ncclResult_t ncclFastSocketClose(void* opaqueComm) { 956 | struct ncclFastSocketComm* comm = 957 | static_cast(opaqueComm); 958 | if (comm) { 959 | for (int i = 0; i < comm->num_threads; i++) { 960 | struct ncclSocketThreadResources* res = comm->thread_resource + i; 961 | if (comm->helper_thread[i]) { 962 | pthread_mutex_lock(&res->thread_lock); 963 | res->state = stop; 964 | pthread_cond_signal(&res->thread_cond); 965 | pthread_mutex_unlock(&res->thread_lock); 966 | pthread_join(comm->helper_thread[i], nullptr); 967 | } 968 | } 969 | if (comm->ctrl_fd != -1) close(comm->ctrl_fd); 970 | uint64_t total = 0; 971 | for (int i = 0; i < comm->num_socks; i++) { 972 | if (comm->fd_data[i].fd != -1) close(comm->fd_data[i].fd); 973 | if (comm->fd_data[i].stat) { 974 | INFO(NCCL_NET, "Socket %i total bytes: %lu, passive = %d", i, 975 | comm->fd_data[i].stat, (int)comm->passive); 976 | total += comm->fd_data[i].stat; 977 | } 978 | } 979 | INFO(NCCL_NET, "All bytes: %lu", total); 980 | #ifdef HINT_BOTTLENECK 981 | struct timeval current_time; 982 | gettimeofday(¤t_time, nullptr); 983 | timersub(¤t_time, &comm->start_time, ¤t_time); 984 | double avg_throughput_mb = 985 | (double)total / (1e6 * current_time.tv_sec + current_time.tv_usec); 986 | if (avg_throughput_mb > 1000) { 987 | INFO(NCCL_INIT, "Average throughput: %f MB/s", avg_throughput_mb); 988 | INFO(NCCL_INIT, 989 | "This training job might be network bound. Reduction Server boosts " 990 | "performance of network bound training jobs. " 991 | "More details at " 992 | "https://cloud.google.com/blog/products/ai-machine-learning/" 993 | "faster-distributed-training-with-google-clouds-reduction-server."); 994 | } 995 | #endif 996 | free(comm); 997 | } 998 | return ncclSuccess; 999 | } 1000 | 1001 | // Data-path functions 1002 | #ifdef TX_ZCOPY 1003 | static int taskProgress(int fd, struct ncclSocketTask* t) { 1004 | int bytes = 0; 1005 | char* data = reinterpret_cast(t->data); 1006 | int count = 0; 1007 | do { 1008 | int s = t->size - t->offset; 1009 | int flags = MSG_DONTWAIT; 1010 | int op = t->op; 1011 | if (op == NCCL_SOCKET_SEND && kMinZcopySize > 0 && s >= kMinZcopySize) 1012 | flags |= MSG_ZEROCOPY; 1013 | if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data + t->offset, s, flags); 1014 | if (op == NCCL_SOCKET_SEND) bytes = send(fd, data + t->offset, s, flags); 1015 | 1016 | if (op == NCCL_SOCKET_RECV && bytes == 0) { 1017 | WARN("Net : Connection closed by remote peer"); 1018 | return -1; 1019 | } 1020 | if (bytes == -1) { 1021 | if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { 1022 | WARN("Call to socket op %d flags %x failed : %s", op, flags, 1023 | strerror(errno)); 1024 | if (flags & MSG_ZEROCOPY) { 1025 | WARN("Turning off TX zero copy"); 1026 | kMinZcopySize = 0; 1027 | bytes = 0; 1028 | } else { 1029 | return -1; 1030 | } 1031 | } else { 1032 | bytes = 0; 1033 | } 1034 | } 1035 | t->offset += bytes; 1036 | if (bytes && (flags & MSG_ZEROCOPY)) ++count; 1037 | } while (bytes > 0 && t->offset < t->size); 1038 | return count; 1039 | } 1040 | 1041 | static int readNotification(struct msghdr* msg, uint32_t* lower, 1042 | uint32_t* upper) { 1043 | struct sock_extended_err* serr; 1044 | struct cmsghdr* cm; 1045 | cm = CMSG_FIRSTHDR(msg); 1046 | if (cm->cmsg_level != SOL_IP && cm->cmsg_type != IP_RECVERR) { 1047 | WARN("Invalid message level %d or type %d from errorqueue!", 1048 | (int)cm->cmsg_level, (int)cm->cmsg_type); 1049 | return -1; 1050 | } 1051 | serr = reinterpret_cast(CMSG_DATA(cm)); 1052 | if (serr->ee_errno != 0 || serr->ee_origin != SO_EE_ORIGIN_ZEROCOPY) { 1053 | WARN("Invalid message errno %d or origin %d from errorqueue!", 1054 | (int)serr->ee_errno, (int)serr->ee_origin); 1055 | return -1; 1056 | } 1057 | *lower = serr->ee_info; 1058 | *upper = serr->ee_data + 1; 1059 | return 0; 1060 | } 1061 | 1062 | static int readErrqueue(int fd, uint32_t* lower, uint32_t* upper) { 1063 | char control[100]; 1064 | struct msghdr msg = {}; 1065 | msg.msg_control = control; 1066 | msg.msg_controllen = sizeof control; 1067 | int ret = recvmsg(fd, &msg, MSG_ERRQUEUE); 1068 | if (ret < 0 && errno == EAGAIN) return 0; 1069 | if (ret < 0) { 1070 | WARN("Read error from errqueue: %d", errno); 1071 | return -errno; 1072 | } 1073 | ret = readNotification(&msg, lower, upper); 1074 | if (ret < 0) return ret; 1075 | return *upper - *lower; 1076 | } 1077 | 1078 | void processCompletion(ncclSocketTaskQueue* tasks, uint32_t clower, 1079 | uint32_t lower, uint32_t upper) { 1080 | auto it = tasks->get_iterator(); 1081 | while (lower < upper && tasks->is(it)) { 1082 | ncclSocketTask* r = tasks->to_item(it); 1083 | uint32_t cupper = r->tx_bound; 1084 | uint32_t left = std::max(clower, lower); 1085 | uint32_t right = std::min(cupper, upper); 1086 | if (right > left) { 1087 | r->tx_count += right - left; 1088 | } 1089 | lower = std::max(lower, cupper); 1090 | clower = cupper; 1091 | it = tasks->next(it); 1092 | } 1093 | if (lower < upper && tasks->is(it)) { 1094 | ncclSocketTask* r = tasks->to_item(it); 1095 | r->tx_count += upper - lower; 1096 | } 1097 | } 1098 | #endif 1099 | 1100 | static void* persistentSocketThread(void* args_) { 1101 | struct ncclSocketThreadResources* resource = 1102 | static_cast(args_); 1103 | struct ncclFastSocketComm* comm = resource->comm; 1104 | volatile enum ThreadState* state = &resource->state; 1105 | int nSocksPerThread = comm->num_socks / comm->num_threads; 1106 | int tid = resource->id; 1107 | unsigned int mark = 0; 1108 | int core = comm->passive ? kRxCPUStart : kTxCPUStart; 1109 | 1110 | INFO(NCCL_INIT | NCCL_NET, "Comm %p thread %d started", comm, tid); 1111 | if (core >= 0) { 1112 | cpu_set_t my_set; 1113 | core += tid; 1114 | CPU_ZERO(&my_set); 1115 | CPU_SET(core, &my_set); 1116 | sched_setaffinity(0, sizeof my_set, &my_set); 1117 | } 1118 | INFO(NCCL_INIT | NCCL_NET, "Comm %p thread %d binding to core %d", comm, tid, 1119 | core); 1120 | while (true) { 1121 | int idle = 1; 1122 | // iterate all the sockets associate with the current thread 1123 | for (int i = 0; i < nSocksPerThread; ++i) { 1124 | // int idx = i + tid * nSocksPerThread; // sequential access 1125 | int idx = tid + i * comm->num_threads; // strided access 1126 | #ifdef TX_ZCOPY 1127 | struct ncclFdData* fd_data = comm->fd_data + idx; 1128 | ncclSocketTaskQueue* tasks = &(fd_data->tasks); 1129 | if (tasks->has_active()) { 1130 | struct ncclSocketTask* r = fd_data->tasks.next_active(); 1131 | int old_offset = r->offset; 1132 | int cnt = taskProgress(fd_data->fd, r); 1133 | if (cnt < 0) return nullptr; 1134 | fd_data->tx_upper += cnt; 1135 | if (r->op == NCCL_SOCKET_SEND) fd_data->stat += r->offset - old_offset; 1136 | if (r->offset == r->size) { 1137 | r->tx_bound = fd_data->tx_upper; 1138 | tasks->finish_active(); 1139 | } 1140 | idle = 0; 1141 | } 1142 | 1143 | // poll errqueue for send completion 1144 | if (fd_data->tx_upper > fd_data->tx_lower) { 1145 | uint32_t lower, upper; 1146 | while (true) { 1147 | int ret = readErrqueue(fd_data->fd, &lower, &upper); 1148 | if (ret == 0) break; 1149 | if (ret < 0) return nullptr; 1150 | processCompletion(tasks, fd_data->tx_lower, lower, upper); 1151 | } 1152 | idle = 0; 1153 | } 1154 | 1155 | if (tasks->has_completing()) { 1156 | struct ncclSocketTask* r = tasks->next_completing(); 1157 | if (r->tx_count == r->tx_bound - fd_data->tx_lower) { 1158 | fd_data->tx_lower = r->tx_bound; 1159 | tasks->finish_completing(); 1160 | } 1161 | idle = 0; 1162 | } 1163 | #else 1164 | if (!comm->fdData[idx].tasks.has_active()) continue; 1165 | struct ncclSocketTask* r = comm->fdData[idx].tasks.next_active(); 1166 | int fd = comm->fdData[idx].fd; 1167 | if (r->offset < r->size) { 1168 | int old_offset = r->offset; 1169 | r->result = socketProgress(r->op, fd, r->data, r->size, &r->offset); 1170 | if (r->result != ncclSuccess) { 1171 | WARN("NET/Socket : socket progress error"); 1172 | return NULL; 1173 | } 1174 | if (r->op == NCCL_SOCKET_SEND) 1175 | comm->fdData[idx].stat += r->offset - old_offset; 1176 | if (r->offset == r->size) { 1177 | comm->fdData[idx].tasks.finish_active(); 1178 | comm->fdData[idx].tasks.finish_completing(); 1179 | } 1180 | idle = 0; 1181 | } 1182 | #endif 1183 | } 1184 | if (kEnableSpin) idle = 0; 1185 | if (idle) { 1186 | pthread_mutex_lock(&resource->thread_lock); 1187 | while (mark == resource->next && *state != stop) { // no new tasks, wait 1188 | pthread_cond_wait(&resource->thread_cond, &resource->thread_lock); 1189 | } 1190 | mark = resource->next; 1191 | pthread_mutex_unlock(&resource->thread_lock); 1192 | } 1193 | if (*state == stop) return nullptr; 1194 | } 1195 | } 1196 | 1197 | static ncclResult_t ncclFastSocketGetRequest(struct ncclFastSocketComm* comm, 1198 | int op, void* data, int size, 1199 | struct ncclSocketRequest** req) { 1200 | if (!comm->rq.has_free()) { 1201 | WARN("NET/Socket : unable to allocate requests"); 1202 | return ncclInternalError; 1203 | } 1204 | struct ncclSocketRequest* r = comm->rq.next_free(); 1205 | r->op = op; 1206 | r->next_sock_id = -1; 1207 | r->next_size = 0; 1208 | r->data = data; 1209 | r->offset = 0; 1210 | r->size = size; 1211 | if (op == NCCL_SOCKET_SEND) 1212 | r->size_pending = size; 1213 | else 1214 | r->size_pending = -1; 1215 | r->comm = comm; 1216 | *req = r; 1217 | comm->rq.enqueue(); 1218 | return ncclSuccess; 1219 | } 1220 | 1221 | #define CTRL_DONE(r) ((r)->next_sock_id >= 0) 1222 | #define RESET_CTRL(r) ((r)->next_sock_id = -1) 1223 | 1224 | #define REQUEST_DONE(r) \ 1225 | (((r)->size == 0 && CTRL_DONE(r)) || ((r)->size && (r)->size_pending == 0)) 1226 | #define REQUEST_INACTIVE(r) ((r)->size == (r)->offset) 1227 | 1228 | #ifndef BUFFERED_CTRL 1229 | static ncclResult_t ncclProcessCtrl(struct ncclFastSocketComm* comm, 1230 | struct ncclSocketRequest* r, 1231 | struct ncclCtrl* ctrl) { 1232 | int s = 0; 1233 | NCCLCHECK(socketSpin(r->op, comm->ctrl_fd, ctrl, sizeof *ctrl, &s)); 1234 | if (s == 0) return ncclSuccess; 1235 | if (s < sizeof *ctrl) { 1236 | NCCLCHECK(socketSpin(r->op, comm->ctrl_fd, ctrl, sizeof *ctrl, &s)); 1237 | } 1238 | if (s) { 1239 | // save control information to request 1240 | r->next_sock_id = ctrl->index; 1241 | r->next_size = ctrl->size; 1242 | if (r->size_pending < 0) { 1243 | r->size_pending = r->size = ctrl->total; 1244 | } 1245 | } 1246 | return ncclSuccess; 1247 | } 1248 | #endif 1249 | 1250 | static ncclResult_t ncclCtrlRecv(struct ncclFastSocketComm* comm, 1251 | struct ncclSocketRequest* r, 1252 | struct ncclCtrl* ctrl) { 1253 | #ifdef BUFFERED_CTRL 1254 | NCCLCHECK(comm->ctrl_recv.refill()); 1255 | if (comm->ctrl_recv.empty()) return ncclSuccess; 1256 | NCCLCHECK(comm->ctrl_recv.recv(ctrl, sizeof *ctrl)); 1257 | // save control information to request 1258 | r->next_sock_id = ctrl->index; 1259 | r->next_size = ctrl->size; 1260 | if (r->size_pending < 0) { 1261 | r->size_pending = r->size = ctrl->total; 1262 | } 1263 | return ncclSuccess; 1264 | #else 1265 | return ncclProcessCtrl(comm, r, ctrl); 1266 | #endif 1267 | } 1268 | 1269 | static inline ncclResult_t ncclCtrlSendSync(struct ncclFastSocketComm* comm) { 1270 | #ifdef BUFFERED_CTRL 1271 | NCCLCHECK(comm->ctrl_send.sync()); 1272 | #endif 1273 | return ncclSuccess; 1274 | } 1275 | 1276 | static inline ncclResult_t ncclCtrlSend(struct ncclFastSocketComm* comm, 1277 | struct ncclSocketRequest* r, 1278 | struct ncclCtrl* ctrl) { 1279 | #ifdef BUFFERED_CTRL 1280 | NCCLCHECK(comm->ctrl_send.send(ctrl, sizeof *ctrl)); 1281 | r->next_sock_id = ctrl->index; 1282 | r->next_size = ctrl->size; 1283 | return ncclSuccess; 1284 | #else 1285 | return ncclProcessCtrl(comm, r, ctrl); 1286 | #endif 1287 | } 1288 | 1289 | static void enqueueTask(struct ncclFastSocketComm* comm, 1290 | struct ncclSocketRequest* r) { 1291 | int sockId = r->next_sock_id; 1292 | RESET_CTRL(r); 1293 | int sz = r->next_size; 1294 | struct ncclSocketTask* task = comm->fd_data[sockId].tasks.next_free(); 1295 | task->op = r->op; 1296 | task->data = reinterpret_cast(r->data) + r->offset; 1297 | task->r = r; 1298 | task->result = ncclSuccess; 1299 | task->offset = 0; 1300 | task->size = sz; 1301 | #ifdef TX_ZCOPY 1302 | task->tx_count = 0; 1303 | #endif 1304 | comm->fd_data[sockId].tasks.enqueue(); 1305 | 1306 | r->offset += sz; 1307 | if (REQUEST_INACTIVE(r)) { 1308 | comm->rq.mark_inactive(); 1309 | } 1310 | 1311 | // notify thread 1312 | // int tid = sockId * comm->nThreads / comm->nSocks; 1313 | int tid = sockId % comm->num_threads; 1314 | struct ncclSocketThreadResources* res = comm->thread_resource + tid; 1315 | if (res->comm == nullptr) { 1316 | res->id = tid; 1317 | res->next = 0; 1318 | res->comm = comm; 1319 | res->state = start; 1320 | waitConnect(comm); 1321 | pthread_mutex_init(&res->thread_lock, nullptr); 1322 | pthread_cond_init(&res->thread_cond, nullptr); 1323 | pthread_create(comm->helper_thread + tid, nullptr, persistentSocketThread, 1324 | res); 1325 | } else { 1326 | if (kEnableSpin) { 1327 | ++res->next; 1328 | } else { 1329 | pthread_mutex_lock(&res->thread_lock); 1330 | ++res->next; 1331 | pthread_cond_signal(&res->thread_cond); 1332 | pthread_mutex_unlock(&res->thread_lock); 1333 | } 1334 | } 1335 | } 1336 | 1337 | static ncclResult_t ncclCommProgress(struct ncclFastSocketComm* comm) { 1338 | int empty_tasks[MAX_SOCKETS]; 1339 | int num_empty = 0; 1340 | 1341 | // no more requests 1342 | if (comm->rq.empty()) return ncclSuccess; 1343 | 1344 | for (int i = 0; i < comm->num_socks; ++i) { 1345 | int idx = comm->last_fd - i; 1346 | if (idx < 0) idx += comm->num_socks; 1347 | ncclSocketTaskQueue* tasks = &(comm->fd_data[idx].tasks); 1348 | if (tasks->has_inactive()) { 1349 | ncclSocketTask* task = tasks->next_inactive(); 1350 | task->r->size_pending -= task->size; 1351 | tasks->dequeue(); // inactive -> free 1352 | } 1353 | if (tasks->has_free()) { 1354 | // socket fd_idx has room for more tasks 1355 | empty_tasks[num_empty++] = idx; 1356 | } 1357 | } 1358 | 1359 | // no active requests or no socket has room for new tasks 1360 | if (!comm->rq.has_active() || num_empty == 0) return ncclSuccess; 1361 | 1362 | ncclSocketRequest* ar = comm->rq.next_active(); 1363 | if (ar->op == NCCL_SOCKET_SEND) { 1364 | // small enough to send via control socket 1365 | if (ar->size <= kInlineThreshold) { 1366 | ncclCtrl ctrl = {CTRL_INLINE, 0, static_cast(ar->size), 0, 1367 | static_cast(ar->size)}; 1368 | NCCLCHECK(ncclCtrlSend(comm, ar, &ctrl)); 1369 | NCCLCHECK(ncclCtrlSendSync(comm)); 1370 | if (CTRL_DONE(ar)) { 1371 | if (ar->size > 0) { 1372 | int off = 0; 1373 | // send data through control socket 1374 | NCCLCHECK(socketSpin(NCCL_SOCKET_SEND, comm->ctrl_fd, ar->data, 1375 | ar->size, &off)); 1376 | ar->offset = ar->size; 1377 | ar->size_pending = 0; 1378 | } 1379 | comm->rq.mark_inactive(); 1380 | } 1381 | 1382 | return ncclSuccess; 1383 | } 1384 | // there are pending requests and we have available sockets 1385 | while (ar->offset < ar->size && num_empty) { 1386 | --num_empty; 1387 | uint32_t send_size = std::min(kDynamicChunkSize, ar->size - ar->offset); 1388 | ncclCtrl ctrl = { 1389 | CTRL_NORMAL, static_cast(empty_tasks[num_empty]), send_size, 1390 | static_cast(ar->offset), static_cast(ar->size)}; 1391 | NCCLCHECK(ncclCtrlSend(comm, ar, &ctrl)); 1392 | if (!CTRL_DONE(ar)) { 1393 | break; 1394 | } 1395 | enqueueTask(comm, ar); 1396 | comm->last_fd = empty_tasks[num_empty]; 1397 | } 1398 | NCCLCHECK(ncclCtrlSendSync(comm)); 1399 | } else { 1400 | do { 1401 | ncclCtrl ctrl; 1402 | if (!CTRL_DONE(ar)) { 1403 | NCCLCHECK(ncclCtrlRecv(comm, ar, &ctrl)); 1404 | if (!CTRL_DONE(ar)) break; 1405 | if (ctrl.type == CTRL_INLINE) { 1406 | if (ar->size) { 1407 | #ifdef BUFFERED_CTRL 1408 | ar->offset = comm->ctrl_recv.brecv(ar->data, ar->size); 1409 | #endif 1410 | NCCLCHECK(socketSpin(NCCL_SOCKET_RECV, comm->ctrl_fd, ar->data, 1411 | ar->size, &ar->offset)); 1412 | ar->size_pending = 0; 1413 | } 1414 | comm->rq.mark_inactive(); 1415 | break; 1416 | } 1417 | } 1418 | if (!comm->fd_data[ar->next_sock_id].tasks.has_free()) { 1419 | break; 1420 | WARN("No free space for recv task"); 1421 | } 1422 | // uint32_t recv_size = std::min(dynamic_chunk_size, ar->size - 1423 | // ar->offset); 1424 | enqueueTask(comm, ar); 1425 | } while (ar->offset < ar->size); 1426 | } 1427 | 1428 | return ncclSuccess; 1429 | } 1430 | 1431 | // Called by netSendProxy and netRecvProxy from the proxy thread 1432 | ncclResult_t ncclFastSocketTest(void* request, int* done, int* size) { 1433 | *done = 0; 1434 | struct ncclSocketRequest* r = static_cast(request); 1435 | if (r == nullptr) { 1436 | WARN("NET/FastSocket : test called with NULL request"); 1437 | return ncclInternalError; 1438 | } 1439 | NCCLCHECK(ncclCommProgress(r->comm)); 1440 | if (r->comm->rq.has_inactive()) { 1441 | if (r != r->comm->rq.next_inactive()) { 1442 | WARN("NET/FastSocket : test called with invalid request"); 1443 | return ncclInternalError; 1444 | } 1445 | if (REQUEST_DONE(r)) { 1446 | r->comm->rq.dequeue(); 1447 | *done = 1; 1448 | } 1449 | } 1450 | return ncclSuccess; 1451 | } 1452 | 1453 | static ncclResult_t ncclSocketGetSpeed(char* devName, int* speed) { 1454 | *speed = 0; 1455 | char speedPath[PATH_MAX]; 1456 | snprintf(speedPath, PATH_MAX, "/sys/class/net/%s/speed", devName); 1457 | int fd = open(speedPath, O_RDONLY); 1458 | if (fd != -1) { 1459 | char speedStr[] = " "; 1460 | if (read(fd, speedStr, sizeof(speedStr) - 1) > 0) { 1461 | *speed = strtol(speedStr, nullptr, 0); 1462 | } 1463 | close(fd); 1464 | } 1465 | if (*speed <= 0) { 1466 | INFO(NCCL_NET, "Could not get speed from %s. Defaulting to 10 Gbps.", 1467 | speedPath); 1468 | *speed = 10000; 1469 | } 1470 | return ncclSuccess; 1471 | } 1472 | 1473 | template 1474 | ncclResult_t ncclFastSocketGetProperties(int dev, T* props) { 1475 | props->name = kNcclSocketDevs[dev].dev_name; 1476 | props->pciPath = kNcclSocketDevs[dev].pci_path; 1477 | props->guid = dev; 1478 | props->ptrSupport = NCCL_PTR_HOST; 1479 | NCCLCHECK(ncclSocketGetSpeed(props->name, &props->speed)); 1480 | props->port = 0; 1481 | props->maxComms = 65536; 1482 | if constexpr (std::is_same::value) { 1483 | props->latency = 0; 1484 | props->maxRecvs = 1; 1485 | } 1486 | return ncclSuccess; 1487 | } 1488 | 1489 | ncclResult_t ncclFastSocketRegMr(void* comm, void* data, int size, int type, 1490 | void** mhandle) { 1491 | return (type != NCCL_PTR_HOST) ? ncclInternalError : ncclSuccess; 1492 | } 1493 | 1494 | ncclResult_t ncclFastSocketDeregMr(void* comm, void* mhandle) { 1495 | return ncclSuccess; 1496 | } 1497 | 1498 | ncclResult_t ncclFastSocketFlush_v2(void* recvComm, void* data, int size, 1499 | void* mhandle) { 1500 | // We don't support CUDA pointers, so we don't need a flush operation 1501 | return ncclInternalError; 1502 | } 1503 | 1504 | ncclResult_t ncclFastSocketIsend_v2(void* sendComm, void* data, int size, 1505 | void* mhandle, void** request) { 1506 | struct ncclFastSocketComm* comm = 1507 | static_cast(sendComm); 1508 | NCCLCHECK(ncclFastSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, 1509 | (struct ncclSocketRequest**)request)); 1510 | return ncclSuccess; 1511 | } 1512 | 1513 | ncclResult_t ncclFastSocketIrecv_v2(void* recvComm, void* data, int size, 1514 | void* mhandle, void** request) { 1515 | struct ncclFastSocketComm* comm = 1516 | static_cast(recvComm); 1517 | NCCLCHECK(ncclFastSocketGetRequest(comm, NCCL_SOCKET_RECV, data, size, 1518 | (struct ncclSocketRequest**)request)); 1519 | return ncclSuccess; 1520 | } 1521 | 1522 | ncclResult_t ncclFastSocketIflush_v4(void* recvComm, void* data, int size, 1523 | void* mhandle, void** request) { 1524 | // We don't support CUDA pointers, so we don't need a flush operation 1525 | return ncclInternalError; 1526 | } 1527 | 1528 | ncclResult_t ncclFastSocketIsend_v5(void* sendComm, void* data, int size, 1529 | int tag, void* mhandle, void** request) { 1530 | struct ncclFastSocketComm* comm = 1531 | static_cast(sendComm); 1532 | NCCLCHECK(ncclFastSocketGetRequest(comm, NCCL_SOCKET_SEND, data, size, 1533 | (struct ncclSocketRequest**)request)); 1534 | return ncclSuccess; 1535 | } 1536 | 1537 | ncclResult_t ncclFastSocketIrecv_v5(void* recvComm, int n, void** data, 1538 | int* sizes, int* tags, void** mhandles, 1539 | void** request) { 1540 | struct ncclFastSocketComm* comm = 1541 | static_cast(recvComm); 1542 | if (n != 1) return ncclInternalError; 1543 | NCCLCHECK(ncclFastSocketGetRequest(comm, NCCL_SOCKET_RECV, data[0], sizes[0], 1544 | (struct ncclSocketRequest**)request)); 1545 | return ncclSuccess; 1546 | } 1547 | 1548 | ncclResult_t ncclFastSocketIflush_v5(void* recvComm, int n, void** data, 1549 | int* sizes, void** mhandle, 1550 | void** request) { 1551 | // We don't support CUDA pointers, so we don't need a flush operation 1552 | return ncclInternalError; 1553 | } 1554 | 1555 | ncclResult_t ncclFastSocketCloseListen(void* opaqueComm) { 1556 | struct ncclSocketListenComm* comm = (struct ncclSocketListenComm*)opaqueComm; 1557 | if (comm) { 1558 | if (comm->fd != -1) close(comm->fd); 1559 | free(comm); 1560 | } 1561 | return ncclSuccess; 1562 | } 1563 | 1564 | ncclResult_t ncclFastSocketPtrSupport(int dev, int* supportedTypes) { 1565 | *supportedTypes = NCCL_PTR_HOST; 1566 | return ncclSuccess; 1567 | } 1568 | 1569 | ncclResult_t ncclFastSocketFlush(void* recvComm, void* data, int size, 1570 | void* mhandle) { 1571 | // We don't support CUDA pointers, so we don't need a flush operation 1572 | return ncclInternalError; 1573 | } 1574 | 1575 | 1576 | volatile ncclNet_v2_t ncclNetPlugin_v2 = { 1577 | "FastSocket", ncclFastSocketInit, ncclFastSocketDevices, 1578 | ncclFastSocketPciPath, ncclFastSocketPtrSupport, ncclFastSocketListen, 1579 | ncclFastSocketConnect, ncclFastSocketAccept, ncclFastSocketRegMr, 1580 | ncclFastSocketDeregMr, ncclFastSocketIsend_v2, ncclFastSocketIrecv_v2, 1581 | ncclFastSocketFlush, ncclFastSocketTest, ncclFastSocketClose, 1582 | ncclFastSocketClose, ncclFastSocketCloseListen}; 1583 | 1584 | volatile ncclNet_v3_t ncclNetPlugin_v3 = { 1585 | "FastSocket", ncclFastSocketInit, 1586 | ncclFastSocketDevices, ncclFastSocketGetProperties, 1587 | ncclFastSocketListen, ncclFastSocketConnect, 1588 | ncclFastSocketAccept, ncclFastSocketRegMr, 1589 | ncclFastSocketDeregMr, ncclFastSocketIsend_v2, 1590 | ncclFastSocketIrecv_v2, ncclFastSocketFlush, 1591 | ncclFastSocketTest, ncclFastSocketClose, 1592 | ncclFastSocketClose, ncclFastSocketCloseListen}; 1593 | 1594 | volatile ncclNet_v4_t ncclNetPlugin_v4 = { 1595 | "FastSocket", ncclFastSocketInit, 1596 | ncclFastSocketDevices, ncclFastSocketGetProperties, 1597 | ncclFastSocketListen, ncclFastSocketConnect, 1598 | ncclFastSocketAccept, ncclFastSocketRegMr, 1599 | ncclFastSocketDeregMr, ncclFastSocketIsend_v2, 1600 | ncclFastSocketIrecv_v2, ncclFastSocketIflush_v4, 1601 | ncclFastSocketTest, ncclFastSocketClose, 1602 | ncclFastSocketClose, ncclFastSocketCloseListen}; 1603 | 1604 | volatile ncclNet_v5_t ncclNetPlugin_v5 = { 1605 | "FastSocket", ncclFastSocketInit, 1606 | ncclFastSocketDevices, ncclFastSocketGetProperties, 1607 | ncclFastSocketListen, ncclFastSocketConnect, 1608 | ncclFastSocketAccept, ncclFastSocketRegMr, 1609 | ncclFastSocketDeregMr, ncclFastSocketIsend_v5, 1610 | ncclFastSocketIrecv_v5, ncclFastSocketIflush_v5, 1611 | ncclFastSocketTest, ncclFastSocketClose, 1612 | ncclFastSocketClose, ncclFastSocketCloseListen}; 1613 | 1614 | volatile ncclNet_v6_t ncclNetPlugin_v6 = { 1615 | "FastSocket", 1616 | ncclFastSocketInit, 1617 | ncclFastSocketDevices, 1618 | ncclFastSocketGetProperties, 1619 | ncclFastSocketListen, 1620 | ncclFastSocketConnect, 1621 | ncclFastSocketAccept, 1622 | ncclFastSocketRegMr, 1623 | nullptr, // No DMA-BUF support 1624 | ncclFastSocketDeregMr, 1625 | ncclFastSocketIsend_v5, 1626 | ncclFastSocketIrecv_v5, 1627 | ncclFastSocketIflush_v5, 1628 | ncclFastSocketTest, 1629 | ncclFastSocketClose, 1630 | ncclFastSocketClose, 1631 | ncclFastSocketCloseListen}; 1632 | -------------------------------------------------------------------------------- /utilities.cc: -------------------------------------------------------------------------------- 1 | #include "utilities.h" 2 | 3 | static void dummyDebugLog(ncclDebugLogLevel level, uint64_t flags, 4 | const char* filefunc, int line, const char* fmt, 5 | ...) {} 6 | 7 | ncclDebugLogger_t nccl_log_func = dummyDebugLog; 8 | -------------------------------------------------------------------------------- /utilities.h: -------------------------------------------------------------------------------- 1 | // Copyright 2021 Google LLC 2 | // 3 | // Use of this source code is governed by a BSD-style 4 | // license that can be found in the LICENSE file or at 5 | // https://developers.google.com/open-source/licenses/bsd 6 | 7 | /************************************************************************* 8 | * Copyright (c) 2016-2020, NVIDIA CORPORATION. All rights reserved. 9 | * 10 | * See LICENSE for license information 11 | ************************************************************************/ 12 | 13 | #ifndef THIRD_PARTY_GPUS_NCCL_FASTSOCKET_PLUGIN_UTIL_H_ 14 | #define THIRD_PARTY_GPUS_NCCL_FASTSOCKET_PLUGIN_UTIL_H_ 15 | 16 | #include "nccl_net.h" 17 | 18 | extern ncclDebugLogger_t nccl_log_func; 19 | 20 | #define WARN(fmt, ...) \ 21 | (*nccl_log_func)(NCCL_LOG_WARN, NCCL_ALL, __PRETTY_FUNCTION__, \ 22 | __LINE__, fmt, ##__VA_ARGS__) 23 | 24 | #define INFO(flags, fmt, ...) \ 25 | (*nccl_log_func)(NCCL_LOG_INFO, flags, \ 26 | __PRETTY_FUNCTION__, __LINE__, fmt, \ 27 | ##__VA_ARGS__) 28 | 29 | #define TRACE(flags, fmt, ...) \ 30 | (*nccl_log_func)(NCCL_LOG_TRACE, flags, \ 31 | __PRETTY_FUNCTION__, __LINE__, fmt, \ 32 | ##__VA_ARGS__) 33 | 34 | 35 | #define NCCL_PARAM(name, env, default_value) \ 36 | pthread_mutex_t ncclParamMutex##name = PTHREAD_MUTEX_INITIALIZER; \ 37 | int64_t ncclParam##name() { \ 38 | static_assert(default_value != -1LL, "default value cannot be -1"); \ 39 | static int64_t value = -1LL; \ 40 | pthread_mutex_lock(&ncclParamMutex##name); \ 41 | if (value == -1LL) { \ 42 | value = default_value; \ 43 | char* str = getenv("NCCL_" env); \ 44 | if (str && strlen(str) > 0) { \ 45 | errno = 0; \ 46 | int64_t v = strtoll(str, NULL, 0); \ 47 | if (errno) { \ 48 | INFO(NCCL_ALL,"Invalid value %s for %s, using default %lu.", str, "NCCL_" env, value); \ 49 | } else { \ 50 | value = v; \ 51 | INFO(NCCL_ALL,"%s set by environment to %lu.", "NCCL_" env, value); \ 52 | } \ 53 | } \ 54 | } \ 55 | pthread_mutex_unlock(&ncclParamMutex##name); \ 56 | return value; \ 57 | } 58 | 59 | #include 60 | // Check system calls 61 | #define SYSCHECK(call, name) do { \ 62 | int retval; \ 63 | SYSCHECKVAL(call, name, retval); \ 64 | } while (false) 65 | 66 | #define SYSCHECKVAL(call, name, retval) do { \ 67 | SYSCHECKSYNC(call, name, retval); \ 68 | if (retval == -1) { \ 69 | WARN("Call to " name " failed : %s", strerror(errno)); \ 70 | return ncclSystemError; \ 71 | } \ 72 | } while (false) 73 | 74 | #define SYSCHECKSYNC(call, name, retval) do { \ 75 | retval = call; \ 76 | if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ 77 | INFO(NCCL_ALL,"Call to " name " returned %s, retrying", strerror(errno)); \ 78 | } else { \ 79 | break; \ 80 | } \ 81 | } while(true) 82 | 83 | // Propagate errors up 84 | #define NCCLCHECK(call) do { \ 85 | ncclResult_t res = call; \ 86 | if (res != ncclSuccess) { \ 87 | /* Print the back trace*/ \ 88 | INFO(NCCL_ALL,"%s:%d -> %d", __FILE__, __LINE__, res); \ 89 | return res; \ 90 | } \ 91 | } while (0); 92 | 93 | static __inline__ int ncclTypeSize(ncclDataType_t type) { 94 | switch (type) { 95 | case ncclInt8: 96 | case ncclUint8: 97 | return 1; 98 | case ncclFloat16: 99 | #if defined(__CUDA_BF16_TYPES_EXIST__) 100 | case ncclBfloat16: 101 | #endif 102 | return 2; 103 | case ncclInt32: 104 | case ncclUint32: 105 | case ncclFloat32: 106 | return 4; 107 | case ncclInt64: 108 | case ncclUint64: 109 | case ncclFloat64: 110 | return 8; 111 | default: 112 | return -1; 113 | } 114 | } 115 | 116 | #define DIVUP(x, y) (((x) + (y)-1) / (y)) 117 | 118 | #include 119 | 120 | template 121 | static ncclResult_t ncclCalloc(T** ptr, size_t nelem) { 122 | void* p = malloc(nelem*sizeof(T)); 123 | if (p == NULL) { 124 | WARN("Failed to malloc %ld bytes", nelem*sizeof(T)); 125 | return ncclSystemError; 126 | } 127 | memset(p, 0, nelem*sizeof(T)); 128 | *ptr = (T*)p; 129 | return ncclSuccess; 130 | } 131 | 132 | #include 133 | #include 134 | #include 135 | #include 136 | #include 137 | #include 138 | #include 139 | 140 | #define MAX_IFS 16 141 | #define MAX_IF_NAME_SIZE 16 142 | #define SLEEP_INT 1000 // connection retry sleep interval in usec 143 | #define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) 144 | #define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) 145 | #define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV) 146 | 147 | struct netIf { 148 | char prefix[64]; 149 | int port; 150 | }; 151 | 152 | #include 153 | 154 | static int parseStringList(const char* string, struct netIf* ifList, int maxList) { 155 | if (!string) return 0; 156 | 157 | const char* ptr = string; 158 | 159 | int ifNum = 0; 160 | int ifC = 0; 161 | char c; 162 | do { 163 | c = *ptr; 164 | if (c == ':') { 165 | if (ifC > 0) { 166 | ifList[ifNum].prefix[ifC] = '\0'; 167 | ifList[ifNum].port = atoi(ptr+1); 168 | ifNum++; ifC = 0; 169 | } 170 | while (c != ',' && c != '\0') c = *(++ptr); 171 | } else if (c == ',' || c == '\0') { 172 | if (ifC > 0) { 173 | ifList[ifNum].prefix[ifC] = '\0'; 174 | ifList[ifNum].port = -1; 175 | ifNum++; ifC = 0; 176 | } 177 | } else { 178 | ifList[ifNum].prefix[ifC] = c; 179 | ifC++; 180 | } 181 | ptr++; 182 | } while (ifNum < maxList && c); 183 | return ifNum; 184 | } 185 | 186 | static bool matchIf(const char* string, const char* ref, bool matchExact) { 187 | // Make sure to include '\0' in the exact case 188 | int matchLen = matchExact ? strlen(string) + 1 : strlen(ref); 189 | return strncmp(string, ref, matchLen) == 0; 190 | } 191 | 192 | static bool matchPort(const int port1, const int port2) { 193 | if (port1 == -1) return true; 194 | if (port2 == -1) return true; 195 | if (port1 == port2) return true; 196 | return false; 197 | } 198 | 199 | static bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) { 200 | // Make an exception for the case where no user list is defined 201 | if (listSize == 0) return true; 202 | 203 | for (int i=0; i" 222 | */ 223 | static inline const char *socketToString(struct sockaddr *saddr, char *buf) { 224 | if (buf == NULL || saddr == NULL) return NULL; 225 | if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; } 226 | char host[NI_MAXHOST], service[NI_MAXSERV]; 227 | (void) getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, service, NI_MAXSERV, NI_NUMERICHOST|NI_NUMERICSERV); 228 | sprintf(buf, "%s<%s>", host, service); 229 | return buf; 230 | } 231 | 232 | static inline const char* socketToString(union socketAddress* saddr, 233 | char* buf) { 234 | return socketToString(&saddr->sa, buf); 235 | } 236 | 237 | static inline uint16_t socketToPort(struct sockaddr *saddr) { 238 | return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port); 239 | } 240 | 241 | /* Allow the user to force the IPv4/IPv6 interface selection */ 242 | static inline int envSocketFamily(void) { 243 | int family = -1; // Family selection is not forced, will use first one found 244 | char* env = getenv("NCCL_SOCKET_FAMILY"); 245 | if (env == NULL) 246 | return family; 247 | 248 | INFO(NCCL_ENV, "NCCL_SOCKET_FAMILY set by environment to %s", env); 249 | 250 | if (strcmp(env, "AF_INET") == 0) 251 | family = AF_INET; // IPv4 252 | else if (strcmp(env, "AF_INET6") == 0) 253 | family = AF_INET6; // IPv6 254 | return family; 255 | } 256 | 257 | static int findInterfaces(const char* prefixList, char* names, union socketAddress *addrs, int sock_family, int maxIfNameSize, int maxIfs) { 258 | char line[SOCKET_NAME_MAXLEN+1]; 259 | struct netIf userIfs[MAX_IFS]; 260 | bool searchNot = prefixList && prefixList[0] == '^'; 261 | if (searchNot) prefixList++; 262 | bool searchExact = prefixList && prefixList[0] == '='; 263 | if (searchExact) prefixList++; 264 | int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS); 265 | 266 | int found = 0; 267 | struct ifaddrs *interfaces, *interface; 268 | getifaddrs(&interfaces); 269 | for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) { 270 | if (interface->ifa_addr == NULL) continue; 271 | 272 | /* We only support IPv4 & IPv6 */ 273 | int family = interface->ifa_addr->sa_family; 274 | if (family != AF_INET && family != AF_INET6) 275 | continue; 276 | 277 | TRACE(NCCL_INIT|NCCL_NET,"Found interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line)); 278 | 279 | /* Allow the caller to force the socket family type */ 280 | if (sock_family != -1 && family != sock_family) 281 | continue; 282 | 283 | /* We also need to skip IPv6 loopback interfaces */ 284 | if (family == AF_INET6) { 285 | struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr); 286 | if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue; 287 | } 288 | 289 | // check against user specified interfaces 290 | if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) { 291 | continue; 292 | } 293 | 294 | // Check that this interface has not already been saved 295 | // getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link 296 | bool duplicate = false; 297 | for (int i = 0; i < found; i++) { 298 | if (strcmp(interface->ifa_name, names+i*maxIfNameSize) == 0) { duplicate = true; break; } 299 | } 300 | 301 | if (!duplicate) { 302 | // Store the interface name 303 | strncpy(names+found*maxIfNameSize, interface->ifa_name, maxIfNameSize); 304 | // Store the IP address 305 | int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); 306 | memcpy(addrs+found, interface->ifa_addr, salen); 307 | found++; 308 | } 309 | } 310 | 311 | freeifaddrs(interfaces); 312 | return found; 313 | } 314 | 315 | static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) { 316 | /* Check family first */ 317 | int family = local_if.ifa_addr->sa_family; 318 | if (family != remote->sa.sa_family) { 319 | return false; 320 | } 321 | 322 | if (family == AF_INET) { 323 | struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr); 324 | struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask); 325 | struct sockaddr_in& remote_addr = remote->sin; 326 | struct in_addr local_subnet, remote_subnet; 327 | local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr; 328 | remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr; 329 | return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true; 330 | } else if (family == AF_INET6) { 331 | struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr); 332 | struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask); 333 | struct sockaddr_in6& remote_addr = remote->sin6; 334 | struct in6_addr& local_in6 = local_addr->sin6_addr; 335 | struct in6_addr& mask_in6 = mask->sin6_addr; 336 | struct in6_addr& remote_in6 = remote_addr.sin6_addr; 337 | bool same = true; 338 | int len = 16; //IPv6 address is 16 unsigned char 339 | for (int c = 0; c < len; c++) { //Network byte order is big-endian 340 | char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c]; 341 | char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c]; 342 | if (c1 ^ c2) { 343 | same = false; 344 | break; 345 | } 346 | } 347 | // At last, we need to compare scope id 348 | // Two Link-type addresses can have the same subnet address even though they are not in the same scope 349 | // For Global type, this field is 0, so a comparison wouldn't matter 350 | same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id); 351 | return same; 352 | } else { 353 | WARN("Net : Unsupported address family type"); 354 | return false; 355 | } 356 | } 357 | 358 | static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize, int maxIfs) { 359 | char line[SOCKET_NAME_MAXLEN+1]; 360 | char line_a[SOCKET_NAME_MAXLEN+1]; 361 | int found = 0; 362 | struct ifaddrs *interfaces, *interface; 363 | getifaddrs(&interfaces); 364 | for (interface = interfaces; interface && !found; interface = interface->ifa_next) { 365 | if (interface->ifa_addr == NULL) continue; 366 | 367 | /* We only support IPv4 & IPv6 */ 368 | int family = interface->ifa_addr->sa_family; 369 | if (family != AF_INET && family != AF_INET6) 370 | continue; 371 | 372 | // check against user specified interfaces 373 | if (!matchSubnet(*interface, remoteAddr)) { 374 | continue; 375 | } 376 | 377 | // Store the local IP address 378 | int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); 379 | memcpy(localAddrs+found, interface->ifa_addr, salen); 380 | 381 | // Store the interface name 382 | strncpy(ifNames+found*ifNameMaxSize, interface->ifa_name, ifNameMaxSize); 383 | 384 | TRACE(NCCL_INIT|NCCL_NET,"NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name, socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a)); 385 | found++; 386 | if (found == maxIfs) break; 387 | } 388 | 389 | if (found == 0) { 390 | WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a)); 391 | } 392 | freeifaddrs(interfaces); 393 | return found; 394 | } 395 | 396 | static ncclResult_t GetSocketAddrFromString(union socketAddress* ua, const char* ip_port_pair) { 397 | if (!(ip_port_pair && strlen(ip_port_pair) > 1)) { 398 | WARN("Net : string is null"); 399 | return ncclInvalidArgument; 400 | } 401 | 402 | bool ipv6 = ip_port_pair[0] == '['; 403 | /* Construct the sockaddress structure */ 404 | if (!ipv6) { 405 | struct netIf ni; 406 | // parse : string, expect one pair 407 | if (parseStringList(ip_port_pair, &ni, 1) != 1) { 408 | WARN("Net : No valid : pair found"); 409 | return ncclInvalidArgument; 410 | } 411 | 412 | struct addrinfo hints, *p; 413 | int rv; 414 | memset(&hints, 0, sizeof(hints)); 415 | hints.ai_family = AF_UNSPEC; 416 | hints.ai_socktype = SOCK_STREAM; 417 | 418 | if ( (rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) { 419 | WARN("Net : error encountered when getting address info : %s", gai_strerror(rv)); 420 | return ncclInvalidArgument; 421 | } 422 | 423 | // use the first 424 | if (p->ai_family == AF_INET) { 425 | struct sockaddr_in& sin = ua->sin; 426 | memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in)); 427 | sin.sin_family = AF_INET; // IPv4 428 | //inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address 429 | sin.sin_port = htons(ni.port); // port 430 | } else if (p->ai_family == AF_INET6) { 431 | struct sockaddr_in6& sin6 = ua->sin6; 432 | memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6)); 433 | sin6.sin6_family = AF_INET6; // IPv6 434 | sin6.sin6_port = htons(ni.port); // port 435 | sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete 436 | sin6.sin6_scope_id = 0; // should be global scope, set to 0 437 | } else { 438 | WARN("Net : unsupported IP family"); 439 | return ncclInvalidArgument; 440 | } 441 | 442 | freeaddrinfo(p); // all done with this structure 443 | 444 | } else { 445 | int i, j = -1, len = strlen(ip_port_pair); 446 | for (i = 1; i < len; i++) { 447 | if (ip_port_pair[i] == '%') j = i; 448 | if (ip_port_pair[i] == ']') break; 449 | } 450 | if (i == len) { 451 | WARN("Net : No valid [IPv6]:port pair found"); 452 | return ncclInvalidArgument; 453 | } 454 | bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope 455 | 456 | char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ]; 457 | memset(ip_str, '\0', sizeof(ip_str)); 458 | memset(port_str, '\0', sizeof(port_str)); 459 | memset(if_name, '\0', sizeof(if_name)); 460 | strncpy(ip_str, ip_port_pair+1, global_scope ? i-1 : j-1); 461 | strncpy(port_str, ip_port_pair+i+2, len-i-1); 462 | int port = atoi(port_str); 463 | if (!global_scope) strncpy(if_name, ip_port_pair+j+1, i-j-1); // If not global scope, we need the intf name 464 | 465 | struct sockaddr_in6& sin6 = ua->sin6; 466 | sin6.sin6_family = AF_INET6; // IPv6 467 | inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address 468 | sin6.sin6_port = htons(port); // port 469 | sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete 470 | sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope 471 | } 472 | return ncclSuccess; 473 | } 474 | 475 | static int findInterfaces(char* ifNames, union socketAddress *ifAddrs, int ifNameMaxSize, int maxIfs) { 476 | static int shownIfName = 0; 477 | int nIfs = 0; 478 | // Allow user to force the INET socket family selection 479 | int sock_family = envSocketFamily(); 480 | // User specified interface 481 | char* env = getenv("NCCL_SOCKET_IFNAME"); 482 | if (env && strlen(env) > 1) { 483 | INFO(NCCL_ENV, "NCCL_SOCKET_IFNAME set by environment to %s", env); 484 | // Specified by user : find or fail 485 | if (shownIfName++ == 0) INFO(NCCL_NET, "NCCL_SOCKET_IFNAME set to %s", env); 486 | nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); 487 | } else { 488 | // Try to automatically pick the right one 489 | // Start with IB 490 | nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); 491 | // else see if we can get some hint from COMM ID 492 | if (nIfs == 0) { 493 | char* commId = getenv("NCCL_COMM_ID"); 494 | if (commId && strlen(commId) > 1) { 495 | INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId); 496 | // Try to find interface that is in the same subnet as the IP in comm id 497 | union socketAddress idAddr; 498 | GetSocketAddrFromString(&idAddr, commId); 499 | nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs); 500 | } 501 | } 502 | // Then look for anything else (but not docker or lo) 503 | if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); 504 | // Finally look for docker, then lo. 505 | if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); 506 | if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs); 507 | } 508 | return nIfs; 509 | } 510 | 511 | static ncclResult_t createListenSocket(int *fd, union socketAddress *localAddr) { 512 | /* IPv4/IPv6 support */ 513 | int family = localAddr->sa.sa_family; 514 | int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); 515 | 516 | /* Create socket and bind it to a port */ 517 | int sockfd = socket(family, SOCK_STREAM, 0); 518 | if (sockfd == -1) { 519 | WARN("Net : Socket creation failed : %s", strerror(errno)); 520 | return ncclSystemError; 521 | } 522 | 523 | if (socketToPort(&localAddr->sa)) { 524 | // Port is forced by env. Make sure we get the port. 525 | int opt = 1; 526 | #if defined(SO_REUSEPORT) 527 | SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); 528 | #else 529 | SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); 530 | #endif 531 | } 532 | 533 | // localAddr port should be 0 (Any port) 534 | SYSCHECK(bind(sockfd, &localAddr->sa, salen), "bind"); 535 | 536 | /* Get the assigned Port */ 537 | socklen_t size = salen; 538 | SYSCHECK(getsockname(sockfd, &localAddr->sa, &size), "getsockname"); 539 | 540 | char line[SOCKET_NAME_MAXLEN+1]; 541 | TRACE(NCCL_INIT|NCCL_NET,"Listening on socket %s", socketToString(&localAddr->sa, line)); 542 | 543 | /* Put the socket in listen mode 544 | * NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn 545 | */ 546 | SYSCHECK(listen(sockfd, 16384), "listen"); 547 | *fd = sockfd; 548 | return ncclSuccess; 549 | } 550 | 551 | static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) { 552 | /* IPv4/IPv6 support */ 553 | int family = remoteAddr->sa.sa_family; 554 | if (family != AF_INET && family != AF_INET6) { 555 | WARN("Error : connecting to address with family %d is neither AF_INET(%d) nor AF_INET6(%d)", family, AF_INET, AF_INET6); 556 | return ncclInternalError; 557 | } 558 | int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6); 559 | 560 | /* Connect to a hostname / port */ 561 | *fd = socket(family, SOCK_STREAM, 0); 562 | if (*fd == -1) { 563 | WARN("Net : Socket creation failed : %s", strerror(errno)); 564 | return ncclSystemError; 565 | } 566 | 567 | const int one = 1; 568 | SYSCHECK(setsockopt(*fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); 569 | 570 | /* const int bufsize = 128*1024; 571 | SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt"); 572 | SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/ 573 | 574 | char line[SOCKET_NAME_MAXLEN+1]; 575 | TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", socketToString(&remoteAddr->sa, line)); 576 | 577 | int ret; 578 | int timedout_retries = 0; 579 | int refused_retries = 0; 580 | retry: 581 | SYSCHECKSYNC(connect(*fd, &remoteAddr->sa, salen), "connect", ret); 582 | if (ret == 0) return ncclSuccess; 583 | if ((errno == ECONNREFUSED || errno == ETIMEDOUT)) { 584 | if ((errno == ECONNREFUSED && ++refused_retries < RETRY_REFUSED_TIMES) || 585 | (errno == ETIMEDOUT && ++timedout_retries < RETRY_TIMEDOUT_TIMES)) { 586 | if (refused_retries % 1000 == 0) INFO(NCCL_ALL,"Call to connect returned %s, retrying", strerror(errno)); 587 | usleep(SLEEP_INT); 588 | goto retry; 589 | } 590 | } 591 | WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno)); 592 | return ncclSystemError; 593 | } 594 | 595 | #define NCCL_SOCKET_SEND 0 596 | #define NCCL_SOCKET_RECV 1 597 | static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) { 598 | int bytes = 0; 599 | char* data = (char*)ptr; 600 | do { 601 | if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); 602 | if (op == NCCL_SOCKET_SEND) bytes = send(fd, data+(*offset), size-(*offset), block ? 0 : MSG_DONTWAIT); 603 | if (op == NCCL_SOCKET_RECV && bytes == 0) { 604 | WARN("Net : Connection closed by remote peer"); 605 | return ncclSystemError; 606 | } 607 | if (bytes == -1) { 608 | if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { 609 | WARN("Call to recv failed : %s", strerror(errno)); 610 | return ncclSystemError; 611 | } else { 612 | bytes = 0; 613 | } 614 | } 615 | (*offset) += bytes; 616 | } while (bytes > 0 && (*offset) < size); 617 | return ncclSuccess; 618 | } 619 | 620 | static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, 621 | int* offset) { 622 | return socketProgressOpt(op, fd, ptr, size, offset, 0); 623 | } 624 | 625 | static ncclResult_t socketProgress(int op, int fd, union socketAddress* addr, 626 | void* ptr, int size, int* offset) { 627 | return socketProgress(op, fd, ptr, size, offset); 628 | } 629 | 630 | static ncclResult_t socketWait(int op, int fd, void* ptr, int size, 631 | int* offset) { 632 | while (*offset < size) 633 | NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1)); 634 | return ncclSuccess; 635 | } 636 | 637 | static ncclResult_t socketWait(int op, int fd, union socketAddress* addr, 638 | void* ptr, int size, int* offset) { 639 | return socketWait(op, fd, ptr, size, offset); 640 | } 641 | 642 | static ncclResult_t socketSend(int fd, void* ptr, int size) { 643 | int offset = 0; 644 | NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, ptr, size, &offset)); 645 | return ncclSuccess; 646 | } 647 | 648 | static ncclResult_t socketSend(int fd, union socketAddress* addr, void* ptr, 649 | int size) { 650 | return socketSend(fd, ptr, size); 651 | } 652 | 653 | static ncclResult_t socketRecv(int fd, void* ptr, int size) { 654 | int offset = 0; 655 | NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset)); 656 | return ncclSuccess; 657 | } 658 | 659 | static ncclResult_t socketRecv(int fd, union socketAddress* addr, void* ptr, 660 | int size) { 661 | return socketRecv(fd, ptr, size); 662 | } 663 | 664 | #endif // THIRD_PARTY_GPUS_NCCL_FASTSOCKET_PLUGIN_UTIL_H_ 665 | --------------------------------------------------------------------------------