├── third_party ├── tensorflow │ ├── BUILD │ ├── BUILD.tpl │ └── configure.py ├── BUILD ├── toolchains │ └── preconfig │ │ └── ubuntu16.04 │ │ └── gcc7_manylinux2014 │ │ ├── tools │ │ └── cpp │ │ │ └── empty.cc │ │ ├── WORKSPACE │ │ ├── cc_wrapper.sh │ │ ├── dummy_toolchain.bzl │ │ └── BUILD ├── cares │ ├── BUILD │ └── cares.BUILD ├── readerwriterqueue.BUILD ├── mpi │ ├── BUILD │ └── mpi.bzl ├── lz4.BUILD ├── uuid.BUILD ├── yaml-cpp.BUILD ├── systemtap-sdt.BUILD ├── fmtlib.BUILD ├── cryptopp.BUILD ├── xfs.BUILD ├── ci_build │ ├── devtoolset │ │ ├── fixlinks.sh │ │ ├── rpm-patch.sh │ │ └── build_devtoolset.sh │ └── Dockerfile.rbe.ubuntu16.04-manylinux2014 ├── gnutls.BUILD ├── common.bzl ├── sctp.BUILD ├── ragel.BUILD └── seastar.BUILD ├── MANIFEST.in ├── tensorflow_networking ├── verbs │ ├── design_diagram.png │ ├── verbs_with_0_copies.png │ ├── verbs_with_0_copies_phase1_protocol.jpg │ ├── docker_howto.txt │ ├── verbs_util.h │ ├── verbs_util.cc │ ├── rdma_mgr.h │ ├── grpc_verbs_client.cc │ ├── grpc_verbs_client.h │ ├── verbs_service.proto │ ├── Dockerfile │ ├── verbs_server_lib.h │ ├── verbs_with_0_copies_phase1_protocol.xml │ ├── rdma_rendezvous_mgr.h │ ├── verbs_with_0_copies.xml │ ├── grpc_verbs_service_impl.cc │ ├── grpc_verbs_service.h │ ├── grpc_verbs_service_impl.h │ ├── rdma_rendezvous_mgr.cc │ └── BUILD ├── .clang-format ├── __init__.py ├── mpi_collectives │ ├── README.md │ ├── Dockerfile │ ├── mpi_message.proto │ ├── BUILD │ ├── ring.cc │ ├── kernels │ │ ├── ring.cc │ │ └── ring.cu.cc │ ├── mpi_allgather_test.py │ ├── ops │ │ └── mpi_ops.cc │ ├── ring.cu.cc │ └── python │ │ └── ops │ │ └── mpi_ops.py ├── gdr │ ├── gdr.proto │ ├── Dockerfile │ ├── gdr_rendezvous_mgr.h │ ├── gdr_server_lib.h │ ├── gdr_worker.h │ ├── BUILD │ ├── gdr_memory_manager.h │ └── gdr_collective_executor_mgr.h ├── mpi │ ├── mpi_msg.proto │ ├── mpi_server_lib.h │ ├── BUILD │ ├── mpi_utils.h │ ├── Dockerfile │ ├── mpi_utils.cc │ └── mpi_server_lib.cc ├── BUILD ├── seastar │ ├── seastar_worker_service_method.h │ ├── seastar_remote_worker.h │ ├── seastar_cpuset.h │ ├── seastar_worker_interface.h │ ├── seastar_worker_cache.h │ ├── seastar_rendezvous_mgr.h │ ├── seastar_tensor_coding.cc │ ├── seastar_tag_factory.h │ ├── seastar_client.h │ ├── seastar_message.cc │ ├── Dockerfile │ ├── seastar_message.h │ ├── seastar_tag_factory.cc │ ├── seastar_server_lib.h │ ├── seastar_tensor_coding.h │ ├── seastar_channel_cache.h │ ├── seastar_engine.h │ ├── seastar_server_tag.h │ ├── seastar_worker_service.h │ ├── seastar_client_tag.h │ ├── README │ ├── seastar_worker_cache.cc │ ├── BUILD │ └── seastar_cpuset.cc └── repo.bzl ├── configure ├── .gitignore ├── setup.py ├── README.md └── .bazelrc /third_party/tensorflow/BUILD: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include tensorflow_networking/ *.so 2 | -------------------------------------------------------------------------------- /third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2014/tools/cpp/empty.cc: -------------------------------------------------------------------------------- 1 | int main() {} 2 | -------------------------------------------------------------------------------- /third_party/cares/BUILD: -------------------------------------------------------------------------------- 1 | exports_files([ 2 | "ares_build.h", 3 | "ares_config.h", 4 | "cares.BUILD", 5 | ]) 6 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/design_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/networking/HEAD/tensorflow_networking/verbs/design_diagram.png -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_with_0_copies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/networking/HEAD/tensorflow_networking/verbs/verbs_with_0_copies.png -------------------------------------------------------------------------------- /tensorflow_networking/.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | DerivePointerAlignment: false 5 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_with_0_copies_phase1_protocol.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorflow/networking/HEAD/tensorflow_networking/verbs/verbs_with_0_copies_phase1_protocol.jpg -------------------------------------------------------------------------------- /third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2014/WORKSPACE: -------------------------------------------------------------------------------- 1 | # DO NOT EDIT: automatically generated WORKSPACE file for cc_autoconf rule 2 | workspace(name = "local_config_cc") 3 | -------------------------------------------------------------------------------- /tensorflow_networking/__init__.py: -------------------------------------------------------------------------------- 1 | """tensorflow_networking""" 2 | 3 | import os 4 | import tensorflow as tf 5 | 6 | dirname = os.path.dirname(__file__) 7 | tf.load_library(os.path.join(dirname, 'libtensorflow_networking.so')) 8 | -------------------------------------------------------------------------------- /third_party/readerwriterqueue.BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "readerwriterqueue", 5 | hdrs = [ 6 | "atomicops.h", 7 | "readerwriterqueue.h", 8 | ], 9 | ) 10 | -------------------------------------------------------------------------------- /third_party/mpi/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["restricted"]) 2 | 3 | load("//third_party/mpi:mpi.bzl", "mpi_hdr") 4 | 5 | cc_library( 6 | name = "mpi", 7 | srcs = ["libmpi.so"], 8 | hdrs = mpi_hdr(), 9 | visibility = ["//visibility:public"], 10 | ) 11 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/README.md: -------------------------------------------------------------------------------- 1 | # MPI TensorFlow integration 2 | 3 | Tensorflow MPI integration allows communicating between different TensorFlow 4 | processes using MPI. This enables training across multiple nodes and GPUs 5 | using high-speed interconnects. 6 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/gdr.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | 6 | message RemoteMemoryRegion { 7 | string host = 1; 8 | string port = 2; 9 | uint64 addr = 3; 10 | uint32 rkey = 4; 11 | uint32 tensor_key = 5; 12 | } 13 | -------------------------------------------------------------------------------- /third_party/lz4.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # BSD 2-clause 2 | 3 | exports_files(["LICENSE"]) 4 | 5 | cc_library( 6 | name = "lz4", 7 | srcs = glob(["*.c"]), 8 | hdrs = glob(["*.h"]), 9 | includes = ["."], 10 | textual_hdrs = ["lz4.c"], 11 | visibility = ["//visibility:public"], 12 | ) 13 | -------------------------------------------------------------------------------- /third_party/uuid.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # BSD-3-clause 2 | 3 | exports_files(["licenses/COPYING.BSD-3-clause"]) 4 | 5 | cc_library( 6 | name = "uuid", 7 | hdrs = ["libuuid/src/uuid.h"], 8 | include_prefix = "uuid", 9 | strip_include_prefix = "libuuid/src", 10 | visibility = ["//visibility:public"], 11 | ) 12 | -------------------------------------------------------------------------------- /configure: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -o pipefail 5 | 6 | if [ -z "$PYTHON_BIN_PATH" ]; then 7 | PYTHON_BIN_PATH=$(which python || which python3 || true) 8 | fi 9 | 10 | # Set all env variables 11 | CONFIGURE_DIR=$(dirname "$0") 12 | "$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" 13 | 14 | echo "Configuration finished" 15 | 16 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/mpi_msg.proto: -------------------------------------------------------------------------------- 1 | 2 | syntax = "proto3"; 3 | 4 | package tensorflow; 5 | option cc_enable_arenas = true; 6 | 7 | import "tensorflow/core/protobuf/worker.proto"; 8 | 9 | 10 | message MPIRecvTensorResponse { 11 | RecvTensorResponse response = 1; 12 | bool singleSend = 2; 13 | string key = 3; 14 | int64 step_id = 4; 15 | uint64 checksum = 5; 16 | } 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /tensorflow_networking/BUILD: -------------------------------------------------------------------------------- 1 | config_setting( 2 | name = "mpi_library_is_openmpi_based", 3 | values = {"define": "mpi_library_is_openmpi_based=true"}, 4 | visibility = ["//visibility:public"], 5 | ) 6 | 7 | cc_binary( 8 | name = "libtensorflow_networking.so", 9 | linkopts = ["-z defs"], 10 | linkshared = 1, 11 | deps = [ 12 | "//tensorflow_networking/seastar:seastar_server_lib", 13 | "@local_config_tf//:_pywrap_tensorflow_internal", 14 | ], 15 | ) 16 | -------------------------------------------------------------------------------- /third_party/mpi/mpi.bzl: -------------------------------------------------------------------------------- 1 | # OpenMPI and Mvapich/mpich require different headers 2 | # based on the configuration options return one or the other 3 | 4 | def mpi_hdr(): 5 | return if_openmpi( 6 | ["mpi.h", "mpi_portable_platform.h"], 7 | ["mpi.h", "mpio.h", "mpicxx.h"], 8 | ) 9 | 10 | def if_openmpi(if_true, if_false = []): 11 | return select({ 12 | "//tensorflow_networking:mpi_library_is_openmpi_based": if_true, 13 | "//conditions:default": if_false, 14 | }) 15 | -------------------------------------------------------------------------------- /third_party/yaml-cpp.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # MIT 2 | 3 | exports_files(["LICENSE"]) 4 | 5 | cc_library( 6 | name = "yaml-cpp_internal", 7 | hdrs = glob(["src/**/*.h"]), 8 | strip_include_prefix = "src", 9 | ) 10 | 11 | cc_library( 12 | name = "yaml-cpp", 13 | srcs = glob([ 14 | "src/**/*.cpp", 15 | "src/**/*.h", 16 | ]), 17 | hdrs = glob(["include/**/*.h"]), 18 | strip_include_prefix = "include", 19 | visibility = ["//visibility:public"], 20 | ) 21 | -------------------------------------------------------------------------------- /third_party/systemtap-sdt.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["unencumbered"]) # CC0 1.0 Public Domain 2 | 3 | load("@//third_party:common.bzl", "template_rule") 4 | 5 | cc_library( 6 | name = "systemtap-sdt", 7 | hdrs = [ 8 | "includes/sys/sdt.h", 9 | "includes/sys/sdt-config.h", 10 | ], 11 | strip_include_prefix = "includes", 12 | visibility = ["//visibility:public"], 13 | ) 14 | 15 | template_rule( 16 | name = "sdt_config_h", 17 | src = "includes/sys/sdt-config.h.in", 18 | out = "includes/sys/sdt-config.h", 19 | substitutions = { 20 | "#define _SDT_ASM_SECTION_AUTOGROUP_SUPPORT @support_section_question@": "#define _SDT_ASM_SECTION_AUTOGROUP_SUPPORT 1", 21 | }, 22 | ) 23 | -------------------------------------------------------------------------------- /third_party/fmtlib.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # MIT 2 | 3 | exports_files(["LICENSE.rst"]) 4 | 5 | cc_library( 6 | name = "fmtlib", 7 | srcs = [ 8 | "src/format.cc", 9 | "src/posix.cc", 10 | ], 11 | hdrs = [ 12 | "include/fmt/color.h", 13 | "include/fmt/core.h", 14 | "include/fmt/format.h", 15 | "include/fmt/format-inl.h", 16 | "include/fmt/ostream.h", 17 | "include/fmt/posix.h", 18 | "include/fmt/printf.h", 19 | "include/fmt/ranges.h", 20 | "include/fmt/time.h", 21 | ], 22 | defines = ["FMT_HEADER_ONLY"], 23 | includes = [ 24 | "include", 25 | ], 26 | visibility = ["//visibility:public"], 27 | ) 28 | -------------------------------------------------------------------------------- /third_party/cryptopp.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Boost 2 | 3 | exports_files(["License.txt"]) 4 | 5 | cc_library( 6 | name = "cryptopp_internal", 7 | srcs = glob(["*.cpp"]) + glob(["*.h"]), 8 | copts = [ 9 | "-fopenmp", 10 | "-msha", 11 | "-maes", 12 | "-mavx2", 13 | "-mpclmul", 14 | ], 15 | textual_hdrs = [ 16 | "algebra.cpp", 17 | "strciphr.cpp", 18 | "eprecomp.cpp", 19 | "polynomi.cpp", 20 | "eccrypto.cpp", 21 | ], 22 | ) 23 | 24 | cc_library( 25 | name = "cryptopp", 26 | hdrs = glob(["*.h"]), 27 | include_prefix = "cryptopp", 28 | visibility = ["//visibility:public"], 29 | deps = [":cryptopp_internal"], 30 | ) 31 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_worker_service_method.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_WORKER_SERVICE_METHOD_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_WORKER_SERVICE_METHOD_H_ 3 | 4 | namespace tensorflow { 5 | 6 | enum class SeastarWorkerServiceMethod { 7 | kGetStatus = 0, 8 | kCreateWorkerSession, 9 | kDeleteWorkerSession, 10 | kRegisterGraph, 11 | kDeregisterGraph, 12 | kRunGraph, 13 | kCleanupGraph, 14 | kCleanupAll, 15 | kRecvTensor, 16 | kFuseRecvTensor, 17 | kLogging, 18 | kTracing, 19 | kRecvBuf, 20 | kCompleteGroup, 21 | kCompleteInstance, 22 | kGetStepSequence, 23 | }; 24 | 25 | } // namespace tensorflow 26 | 27 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_WORKER_SERVICE_METHOD_H_ 28 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_remote_worker.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_REMOTE_WORKER_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_REMOTE_WORKER_H_ 3 | 4 | #include "seastar/core/channel.hh" 5 | 6 | #include "tensorflow/core/distributed_runtime/worker_cache_logger.h" 7 | #include "tensorflow/core/distributed_runtime/worker_env.h" 8 | #include "tensorflow/core/distributed_runtime/worker_interface.h" 9 | 10 | namespace tensorflow { 11 | 12 | WorkerInterface* NewSeastarRemoteWorker(seastar::channel* seastar_channel, 13 | WorkerCacheLogger* logger, 14 | WorkerEnv* env); 15 | } // namespace tensorflow 16 | 17 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SEASTAR_SEASTAR_REMOTE_WORKER_H_ 18 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_cpuset.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CPUSET_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CPUSET_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace tensorflow { 8 | 9 | class CpusetAllocator { 10 | public: 11 | virtual ~CpusetAllocator() {} 12 | std::string GetCpuset(size_t core_number); 13 | 14 | private: 15 | bool ExistDir(); 16 | void CreateDir(); 17 | void CreateFiles(); 18 | 19 | std::vector LockFiles(size_t core_number); 20 | std::string ToCpuset(const std::vector& locked_files); 21 | 22 | private: 23 | std::string root_dir_; 24 | std::vector files_; 25 | }; 26 | 27 | } // namespace tensorflow 28 | 29 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CPUSET_H_ 30 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/docker_howto.txt: -------------------------------------------------------------------------------- 1 | Building a networking component may require libraries not available or installable in your normal environment. 2 | As of late 2018 the networking contrib extensions to TensorFlow 1.x can be built in a docker container, as follows. 3 | 4 | 1. Ensure docker is installed. 5 | 2. Invoke a docker container with the latest nightly development build: 6 | $ docker run -it -w /tensorflow -v $PWD:/mnt -e HOST_PERMS"$(id -u):$(id -g)" tensorflow/tensorflow:nightly-devel bash 7 | 3. Configure for bazel build 8 | $ ./configure 9 | 4. Install any necessary additional packages, e.g. 10 | $ apt-get update 11 | $ apt-get install libibverbs-dev 12 | 5. Build with the desired extension 13 | $ bazel build --config=verbs //tensorflow/tools/pip_package:build_pip_package 14 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_worker_interface.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_WOKRER_INTERFACE_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_WOKRER_INTERFACE_H_ 3 | 4 | #include "tensorflow_networking/seastar/seastar_tensor_coding.h" 5 | 6 | #include "tensorflow/core/distributed_runtime/call_options.h" 7 | #include "tensorflow/core/distributed_runtime/worker_cache.h" 8 | 9 | namespace tensorflow { 10 | 11 | class SeastarWorkerInterface { 12 | public: 13 | virtual void RecvTensorAsync(CallOptions* call_opts, 14 | const RecvTensorRequest* request, 15 | SeastarTensorResponse* response, 16 | StatusCallback done) = 0; 17 | }; 18 | 19 | } // namespace tensorflow 20 | 21 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_WOKRER_INTERFACE_H_ 22 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_worker_cache.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_WORKER_CACHE_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_WORKER_CACHE_H_ 3 | 4 | #include "tensorflow_networking/seastar/seastar_channel_cache.h" 5 | 6 | #include "tensorflow/core/distributed_runtime/worker_cache.h" 7 | #include "tensorflow/core/distributed_runtime/worker_env.h" 8 | #include "tensorflow/core/distributed_runtime/worker_interface.h" 9 | 10 | namespace tensorflow { 11 | 12 | WorkerCacheInterface* NewSeastarWorkerCache(SeastarChannelCache* channel_cache, 13 | WorkerEnv* env); 14 | 15 | WorkerCacheInterface* NewSeastarWorkerCacheWithLocalWorker( 16 | SeastarChannelCache* channel_cache, WorkerInterface* local_worker, 17 | const string& local_target, WorkerEnv* env); 18 | 19 | } // namespace tensorflow 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_rendezvous_mgr.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_RENDEZVOUS_MGR_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_RENDEZVOUS_MGR_H_ 3 | 4 | #include "tensorflow/core/common_runtime/device_mgr.h" 5 | #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" 6 | #include "tensorflow/core/distributed_runtime/worker_env.h" 7 | #include "tensorflow/core/platform/macros.h" 8 | 9 | namespace tensorflow { 10 | 11 | class SeastarRendezvousMgr : public BaseRendezvousMgr { 12 | public: 13 | explicit SeastarRendezvousMgr(const WorkerEnv* env); 14 | 15 | protected: 16 | BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env); 17 | 18 | private: 19 | TF_DISALLOW_COPY_AND_ASSIGN(SeastarRendezvousMgr); 20 | }; 21 | } // namespace tensorflow 22 | 23 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_RENDEZVOUS_MGR_H_ 24 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_tensor_coding.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_networking/seastar/seastar_tensor_coding.h" 2 | 3 | #include "tensorflow/core/common_runtime/device.h" 4 | 5 | namespace tensorflow { 6 | 7 | void SeastarTensorResponse::InitAlloc(Device* d, 8 | const AllocatorAttributes& aa) { 9 | Clear(); 10 | device_ = d; 11 | alloc_attrs_ = aa; 12 | const DeviceAttributes& da = d->attributes(); 13 | if (alloc_attrs_.on_host() || da.device_type() == "CPU") { 14 | on_host_ = true; 15 | } 16 | allocator_ = device_->GetAllocator(alloc_attrs_); 17 | } 18 | 19 | void SeastarTensorResponse::Clear() { 20 | on_host_ = false; 21 | device_ = nullptr; 22 | alloc_attrs_ = AllocatorAttributes(); 23 | allocator_ = nullptr; 24 | tensor_ = Tensor(); 25 | tensor_proto_ = TensorProto(); 26 | } 27 | 28 | } // namespace tensorflow 29 | -------------------------------------------------------------------------------- /third_party/xfs.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["permissive"]) # LGPL headers only 2 | 3 | exports_files(["LICENSES/LGPL-2.1"]) 4 | 5 | cc_library( 6 | name = "libxfs", 7 | hdrs = [ 8 | "libxfs/xfs_da_format.h", 9 | "libxfs/xfs_format.h", 10 | "libxfs/xfs_fs.h", 11 | "libxfs/xfs_log_format.h", 12 | "libxfs/xfs_types.h", 13 | ], 14 | include_prefix = "xfs", 15 | strip_include_prefix = "libxfs", 16 | ) 17 | 18 | cc_library( 19 | name = "xfs", 20 | hdrs = [ 21 | "include/handle.h", 22 | "include/jdm.h", 23 | "include/linux.h", 24 | "include/xfs.h", 25 | "include/xfs_arch.h", 26 | "include/xqm.h", 27 | ], 28 | include_prefix = "xfs", 29 | strip_include_prefix = "include", 30 | visibility = ["//visibility:public"], 31 | deps = [ 32 | ":libxfs", 33 | "@uuid", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2014/cc_wrapper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2015 The Bazel Authors. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Ship the environment to the C++ action 18 | # 19 | set -eu 20 | 21 | # Set-up the environment 22 | 23 | 24 | # Call the C++ compiler 25 | /dt7/usr/bin/gcc "$@" 26 | -------------------------------------------------------------------------------- /third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2014/dummy_toolchain.bzl: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2017 The Bazel Authors. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Skylark rule that stubs a toolchain.""" 17 | 18 | def _dummy_toolchain_impl(ctx): 19 | ctx = ctx # unused argument 20 | toolchain = platform_common.ToolchainInfo() 21 | return [toolchain] 22 | 23 | dummy_toolchain = rule(_dummy_toolchain_impl, attrs = {}) 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | node_modules 4 | /.bazelrc.user 5 | /.tf_configure.bazelrc 6 | /.tf_networking_configure.bazelrc 7 | /bazel-* 8 | /bazel_pip 9 | /tools/python_bin_path.sh 10 | /tensorflow/tools/git/gen 11 | /pip_test 12 | /_python_build 13 | *.pyc 14 | __pycache__ 15 | *.swp 16 | .vscode/ 17 | cmake_build/ 18 | tensorflow/contrib/cmake/_build/ 19 | .idea/** 20 | /build/ 21 | [Bb]uild/ 22 | /tensorflow/core/util/version_info.cc 23 | /tensorflow/python/framework/fast_tensor_util.cpp 24 | Pods 25 | Podfile.lock 26 | *.pbxproj 27 | *.xcworkspacedata 28 | /tensorflow/lite/tools/make/downloads/** 29 | /tensorflow/lite/gen/** 30 | /tensorflow/lite/examples/ios/simple/data/*.txt 31 | /tensorflow/lite/examples/ios/simple/data/*.tflite 32 | xcuserdata/** 33 | /api_init_files_list.txt 34 | /estimator_api_init_files_list.txt 35 | *.whl 36 | 37 | # Android 38 | .gradle 39 | .idea 40 | *.iml 41 | local.properties 42 | gradleBuild 43 | 44 | # MPI 45 | third_party/mpi/libmpi.so 46 | third_party/mpi/mpi.h 47 | third_party/mpi/mpi_portable_platform.h 48 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:nightly-py3 2 | 3 | RUN apt-get update && \ 4 | apt-get install -yq --no-install-recommends \ 5 | g++ \ 6 | libibverbs-dev \ 7 | librdmacm-dev \ 8 | openjdk-8-jdk \ 9 | unzip \ 10 | zip \ 11 | && \ 12 | rm -rf '/var/lib/apt/lists/*' 13 | 14 | # Install bazel 15 | ARG BAZEL_VERSION=1.1.0 16 | ARG BAZEL_INSTALLER="bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" 17 | RUN curl -L -O "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/${BAZEL_INSTALLER}" && \ 18 | chmod +x ${BAZEL_INSTALLER} && \ 19 | ./${BAZEL_INSTALLER} && \ 20 | rm ${BAZEL_INSTALLER} 21 | 22 | ADD . /tf_networking 23 | RUN cd /tf_networking && \ 24 | python3 third_party/tensorflow/configure.py && \ 25 | bazel build -c opt //tensorflow_networking:libtensorflow_networking.so && \ 26 | cp bazel-bin/tensorflow_networking/libtensorflow_networking.so tensorflow_networking && \ 27 | python3 setup.py bdist_wheel && \ 28 | pip3 install dist/*.whl 29 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_tag_factory.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_TAG_FACTORY_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_TAG_FACTORY_H_ 3 | 4 | #include "seastar/core/channel.hh" 5 | #include "tensorflow_networking/seastar/seastar_client_tag.h" 6 | #include "tensorflow_networking/seastar/seastar_server_tag.h" 7 | #include "tensorflow_networking/seastar/seastar_worker_service.h" 8 | 9 | #include "seastar/core/temporary_buffer.hh" 10 | 11 | namespace tensorflow { 12 | 13 | class SeastarTagFactory { 14 | public: 15 | explicit SeastarTagFactory(SeastarWorkerService* worker_service); 16 | virtual ~SeastarTagFactory() {} 17 | 18 | SeastarClientTag* CreateSeastarClientTag( 19 | seastar::temporary_buffer& header); 20 | 21 | SeastarServerTag* CreateSeastarServerTag( 22 | seastar::temporary_buffer& header, 23 | seastar::channel* seastar_channel); 24 | 25 | private: 26 | SeastarWorkerService* worker_service_; 27 | }; 28 | 29 | } // namespace tensorflow 30 | 31 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_TAG_FACTORY_H_ 32 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_client.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CLIENT_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CLIENT_H_ 3 | 4 | #include "tensorflow_networking/seastar/seastar_tag_factory.h" 5 | 6 | #include "seastar/core/channel.hh" 7 | #include "seastar/core/future-util.hh" 8 | #include "seastar/net/api.hh" 9 | 10 | namespace tensorflow { 11 | 12 | class SeastarClient { 13 | public: 14 | void Connect(seastar::ipv4_addr server_addr, std::string s, 15 | seastar::channel* chan, SeastarTagFactory* tag_factory); 16 | 17 | private: 18 | struct Connection { 19 | seastar::connected_socket fd_; 20 | seastar::input_stream read_buf_; 21 | seastar::channel* channel_; 22 | SeastarTagFactory* tag_factory_; 23 | seastar::socket_address addr_; 24 | Connection(seastar::connected_socket&& fd, seastar::channel* chan, 25 | SeastarTagFactory* tag_factory, seastar::socket_address addr); 26 | seastar::future<> Read(); 27 | }; 28 | }; 29 | 30 | } // namespace tensorflow 31 | 32 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CLIENT_H_ 33 | -------------------------------------------------------------------------------- /third_party/ci_build/devtoolset/fixlinks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # Re-direct all links in $1 that point to /lib... to point to $1/lib... instead. 18 | 19 | BASE="$1" 20 | find "${BASE}" -type l | \ 21 | while read l ; do 22 | if [[ "$(readlink "$l")" == /lib* ]]; then 23 | ORIG="$(readlink "$l")"; 24 | rm "$l"; 25 | ln -s "${BASE}${ORIG}" "$l" 26 | fi 27 | done 28 | 29 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_message.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_networking/seastar/seastar_message.h" 2 | 3 | namespace tensorflow { 4 | 5 | void SeastarMessage::DeserializeMessage(SeastarMessage* sm, 6 | const char* message) { 7 | memcpy(&sm->is_dead_, &message[kIsDeadStartIndex], sizeof(sm->is_dead_)); 8 | memcpy(&sm->data_type_, &message[kDataTypeStartIndex], 9 | sizeof(sm->data_type_)); 10 | memcpy(&sm->tensor_shape_, &message[kTensorShapeStartIndex], 11 | sizeof(sm->tensor_shape_)); 12 | memcpy(&sm->tensor_bytes_, &message[kTensorBytesStartIndex], 13 | sizeof(sm->tensor_bytes_)); 14 | } 15 | 16 | void SeastarMessage::SerializeMessage(const SeastarMessage& sm, char* message) { 17 | memcpy(&message[kIsDeadStartIndex], &sm.is_dead_, sizeof(sm.is_dead_)); 18 | memcpy(&message[kDataTypeStartIndex], &sm.data_type_, sizeof(sm.data_type_)); 19 | memcpy(&message[kTensorShapeStartIndex], &sm.tensor_shape_, 20 | sizeof(sm.tensor_shape_)); 21 | memcpy(&message[kTensorBytesStartIndex], &sm.tensor_bytes_, 22 | sizeof(sm.tensor_bytes_)); 23 | } 24 | 25 | } // namespace tensorflow 26 | -------------------------------------------------------------------------------- /third_party/gnutls.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["permissive"]) # LGPL headers only 2 | 3 | load("@//third_party:common.bzl", "template_rule") 4 | 5 | exports_files(["LICENSE"]) 6 | 7 | cc_library( 8 | name = "gnutls", 9 | hdrs = glob([ 10 | "lib/includes/gnutls/*.h", 11 | ]) + [ 12 | "lib/includes/gnutls/gnutls.h", 13 | ], 14 | strip_include_prefix = "lib/includes", 15 | visibility = ["//visibility:public"], 16 | ) 17 | 18 | template_rule( 19 | name = "gnutls_h", 20 | src = "lib/includes/gnutls/gnutls.h.in", 21 | out = "lib/includes/gnutls/gnutls.h", 22 | substitutions = { 23 | "#define GNUTLS_VERSION \"@VERSION@\"": "#define GNUTLS_VERSION \"3.6.12\"", 24 | "#define GNUTLS_VERSION_MAJOR @MAJOR_VERSION@": "#define GNUTLS_VERSION_MAJOR 3", 25 | "#define GNUTLS_VERSION_MINOR @MINOR_VERSION@": "#define GNUTLS_VERSION_MINOR 6", 26 | "#define GNUTLS_VERSION_PATCH @PATCH_VERSION@": "#define GNUTLS_VERSION_PATCH 12", 27 | "#define GNUTLS_VERSION_NUMBER @NUMBER_VERSION@": "#define GNUTLS_VERSION_NUMBER 0x03060c", 28 | "@DEFINE_IOVEC_T@": "#include \ntypedef struct iovec giovec_t;", 29 | }, 30 | ) 31 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/Dockerfile: -------------------------------------------------------------------------------- 1 | #FROM tensorflow/tensorflow:nightly-custom-op-ubuntu16 2 | FROM byronyi/tensorflow:ubuntu16.04-manylinux2014 3 | 4 | RUN apt-get update && \ 5 | apt-get install -yq --no-install-recommends \ 6 | libibverbs-dev \ 7 | librdmacm-dev \ 8 | && \ 9 | rm -rf '/var/lib/apt/lists/*' 10 | 11 | # Install bazel 12 | ARG BAZEL_VERSION=1.2.1 13 | ARG BAZEL_INSTALLER="bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" 14 | RUN curl -L -O "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/${BAZEL_INSTALLER}" && \ 15 | chmod +x ${BAZEL_INSTALLER} && \ 16 | ./${BAZEL_INSTALLER} && \ 17 | rm ${BAZEL_INSTALLER} 18 | 19 | ADD . /tf_networking 20 | RUN cd /tf_networking && \ 21 | python3 third_party/tensorflow/configure.py && \ 22 | bazel build \ 23 | -c opt \ 24 | --cxxopt=-std=gnu++14 \ 25 | --crosstool_top=@//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain \ 26 | //tensorflow_networking:libtensorflow_networking.so && \ 27 | cp bazel-bin/tensorflow_networking/libtensorflow_networking.so tensorflow_networking && \ 28 | python3.6 setup.py bdist_wheel && \ 29 | pip3.6 install dist/*.whl 30 | -------------------------------------------------------------------------------- /third_party/ci_build/devtoolset/rpm-patch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # Given an RPM spec file $1, apply its patches. 18 | 19 | SPEC="$1" 20 | grep '%patch' "${SPEC}" |while read cmd ; do 21 | N=$(echo "${cmd}" |sed 's,%patch\([0-9]\+\).*,\1,') 22 | file=$(grep "Patch$N:" "${SPEC}" |sed 's,.*: ,,') 23 | parg=$(echo "${cmd}" |sed 's,.*\(-p[0-9]\).*,\1,') 24 | if [[ ! "${file}" =~ doxygen && "${cmd}" != \#* ]]; then 25 | echo "patch ${parg} -s < ${file}" 26 | patch ${parg} -s < "${file}" 27 | fi 28 | done 29 | -------------------------------------------------------------------------------- /third_party/tensorflow/BUILD.tpl: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "grpc_header_lib", 5 | hdrs = [":grpc_header_include"], 6 | includes = ["include"], 7 | visibility = ["//visibility:public"], 8 | ) 9 | 10 | cc_library( 11 | name = "farmhash_header_lib", 12 | hdrs = [":farmhash_header_include"], 13 | includes = ["src"], 14 | visibility = ["//visibility:public"], 15 | ) 16 | 17 | cc_library( 18 | name = "tf_header_lib", 19 | hdrs = [":tf_header_include"], 20 | includes = ["include"], 21 | deps = [ 22 | ":grpc_header_lib", 23 | ":farmhash_header_lib", 24 | ], 25 | visibility = ["//visibility:public"], 26 | ) 27 | 28 | cc_library( 29 | name = "libtensorflow_framework", 30 | srcs = [":libtensorflow_framework.so"], 31 | visibility = ["//visibility:public"], 32 | ) 33 | 34 | cc_library( 35 | name = "_pywrap_tensorflow_internal", 36 | srcs = [":_pywrap_tensorflow_internal.so"], 37 | deps = [":libtensorflow_framework"], 38 | visibility = ["//visibility:public"], 39 | ) 40 | 41 | %{FARMHASH_HEADER_GENRULE} 42 | %{GRPC_HEADER_GENRULE} 43 | %{TF_HEADER_GENRULE} 44 | %{TF_SHARED_LIBRARY_GENRULE} 45 | %{TF_PYWRAP_INTERNAL_LIBRARY_GENRULE} 46 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_util.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_ 18 | 19 | #include 20 | 21 | #include "tensorflow/core/framework/types.h" 22 | 23 | namespace tensorflow { 24 | 25 | class VerbsUtil { 26 | public: 27 | static string AppendStepidToKey(const string& key, int64 step_id); 28 | static void GetKeyAndStepId(const string& key_with_step_id, string& key, 29 | int64& step_id); 30 | }; 31 | 32 | } // namespace tensorflow 33 | #endif // TENSORFLOW_CONTRIB_VERBS_VERBS_UTIL_H_ 34 | -------------------------------------------------------------------------------- /third_party/common.bzl: -------------------------------------------------------------------------------- 1 | # Rule for simple expansion of template files. This performs a simple 2 | # search over the template file for the keys in substitutions, 3 | # and replaces them with the corresponding values. 4 | # 5 | # Typical usage: 6 | # load("/tools/build_rules/template_rule", "expand_header_template") 7 | # template_rule( 8 | # name = "ExpandMyTemplate", 9 | # src = "my.template", 10 | # out = "my.txt", 11 | # substitutions = { 12 | # "$VAR1": "foo", 13 | # "$VAR2": "bar", 14 | # } 15 | # ) 16 | # 17 | # Args: 18 | # name: The name of the rule. 19 | # template: The template file to expand 20 | # out: The destination of the expanded file 21 | # substitutions: A dictionary mapping strings to their substitutions 22 | 23 | def template_rule_impl(ctx): 24 | ctx.actions.expand_template( 25 | template = ctx.file.src, 26 | output = ctx.outputs.out, 27 | substitutions = ctx.attr.substitutions, 28 | ) 29 | 30 | template_rule = rule( 31 | attrs = { 32 | "src": attr.label( 33 | mandatory = True, 34 | allow_single_file = True, 35 | ), 36 | "substitutions": attr.string_dict(mandatory = True), 37 | "out": attr.output(mandatory = True), 38 | }, 39 | # output_to_genfiles is required for header files. 40 | output_to_genfiles = True, 41 | implementation = template_rule_impl, 42 | ) 43 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_message.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_MESSAGE_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_MESSAGE_H_ 3 | 4 | #include "tensorflow/core/framework/tensor_shape.h" 5 | 6 | namespace tensorflow { 7 | 8 | // message for recv tensor response 9 | struct SeastarMessage { 10 | bool is_dead_; 11 | DataType data_type_; 12 | TensorShape tensor_shape_; 13 | uint64_t tensor_bytes_; 14 | 15 | // |is_dead|... 16 | // | 1B |... 17 | // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer 18 | // ...| XB | XB | 8B |... 19 | 20 | static const size_t kIsDeadStartIndex = 0; 21 | static const size_t kDataTypeStartIndex = 22 | kIsDeadStartIndex + sizeof(is_dead_); 23 | static const size_t kTensorShapeStartIndex = 24 | kDataTypeStartIndex + sizeof(data_type_); 25 | static const size_t kTensorBytesStartIndex = 26 | kTensorShapeStartIndex + sizeof(TensorShape); 27 | static const size_t kTensorBufferStartIndex = 28 | kTensorBytesStartIndex + sizeof(tensor_bytes_); 29 | static const size_t kMessageTotalBytes = kTensorBufferStartIndex; 30 | static const size_t kSeastarMessageBufferSize = kMessageTotalBytes; 31 | static void SerializeMessage(const SeastarMessage& rm, char* data); 32 | static void DeserializeMessage(SeastarMessage* rm, const char* data); 33 | }; 34 | 35 | } // namespace tensorflow 36 | 37 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_MESSAGE_H_ 38 | -------------------------------------------------------------------------------- /third_party/sctp.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["permissive"]) # LGPL headers only 2 | 3 | load("@//third_party:common.bzl", "template_rule") 4 | 5 | exports_files(["COPYING.lib"]) 6 | 7 | cc_library( 8 | name = "sctp", 9 | hdrs = ["src/include/netinet/sctp.h"], 10 | strip_include_prefix = "src/include", 11 | visibility = ["//visibility:public"], 12 | ) 13 | 14 | template_rule( 15 | name = "sctp_h", 16 | src = "src/include/netinet/sctp.h.in", 17 | out = "src/include/netinet/sctp.h", 18 | substitutions = { 19 | "#undef HAVE_SCTP_STREAM_RESET_EVENT": "#define HAVE_SCTP_STREAM_RESET_EVENT 1", 20 | "#undef HAVE_SCTP_ASSOC_RESET_EVENT": "/* #undef HAVE_SCTP_ASSOC_RESET_EVENT */", 21 | "#undef HAVE_SCTP_STREAM_CHANGE_EVENT": "/* #undef HAVE_SCTP_STREAM_CHANGE_EVENT */", 22 | "#undef HAVE_SCTP_STREAM_RECONFIG": "#define HAVE_SCTP_STREAM_RECONFIG 1", 23 | "#undef HAVE_SCTP_PEELOFF_FLAGS": "#define HAVE_SCTP_PEELOFF_FLAGS 1", 24 | "#undef HAVE_SCTP_PDAPI_EVENT_PDAPI_STREAM": "#define HAVE_SCTP_PDAPI_EVENT_PDAPI_STREAM 1", 25 | "#undef HAVE_SCTP_PDAPI_EVENT_PDAPI_SEQ": "#define HAVE_SCTP_PDAPI_EVENT_PDAPI_SEQ 1", 26 | "#undef HAVE_SCTP_SENDV": "/* #undef HAVE_SCTP_SENDV */", 27 | "#undef HAVE_SCTP_AUTH_NO_AUTH": "#define HAVE_SCTP_AUTH_NO_AUTH 1", 28 | "#undef HAVE_SCTP_SPP_IPV6_FLOWLABEL": "#define HAVE_SCTP_SPP_IPV6_FLOWLABEL 1", 29 | "#undef HAVE_SCTP_SPP_DSCP": "#define HAVE_SCTP_SPP_DSCP 1", 30 | }, 31 | ) 32 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_tag_factory.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_networking/seastar/seastar_tag_factory.h" 2 | 3 | #include "tensorflow_networking/seastar/seastar_message.h" 4 | #include "tensorflow_networking/seastar/seastar_tensor_coding.h" 5 | 6 | namespace tensorflow { 7 | 8 | SeastarTagFactory::SeastarTagFactory(SeastarWorkerService* worker_service) 9 | : worker_service_(worker_service) {} 10 | 11 | SeastarClientTag* SeastarTagFactory::CreateSeastarClientTag( 12 | seastar::temporary_buffer& header) { 13 | char* p = const_cast(header.get()); 14 | SeastarClientTag* tag = nullptr; 15 | memcpy(&tag, p + 8, 8); 16 | // ignore the method segment 4B 17 | memcpy(&tag->status_, p + 20, 2); 18 | memcpy(&tag->resp_err_msg_len_, p + 22, 2); 19 | 20 | if (!tag->IsRecvTensor()) { 21 | memcpy(&tag->resp_body_buf_.len_, p + 24, 8); 22 | tag->resp_body_buf_.data_ = new char[tag->resp_body_buf_.len_]; 23 | } 24 | return tag; 25 | } 26 | 27 | SeastarServerTag* SeastarTagFactory::CreateSeastarServerTag( 28 | seastar::temporary_buffer& header, 29 | seastar::channel* seastar_channel) { 30 | char* p = const_cast(header.get()); 31 | SeastarServerTag* tag = 32 | new SeastarServerTag(seastar_channel, worker_service_); 33 | memcpy(&tag->client_tag_id_, p + 8, 8); 34 | memcpy(&tag->method_, p + 16, 4); 35 | // ignore the status segment 2B 36 | memcpy(&(tag->req_body_buf_.len_), p + 24, 8); 37 | tag->req_body_buf_.data_ = new char[tag->req_body_buf_.len_]; 38 | 39 | return tag; 40 | } 41 | 42 | } // namespace tensorflow 43 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_server_lib.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_SERVER_LIB_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_SERVER_LIB_H_ 3 | 4 | #include "tensorflow_networking/seastar/seastar_channel_cache.h" 5 | #include "tensorflow_networking/seastar/seastar_engine.h" 6 | #include "tensorflow_networking/seastar/seastar_worker_service.h" 7 | 8 | #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 9 | #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 10 | #include "tensorflow/core/platform/env.h" 11 | 12 | namespace tensorflow { 13 | 14 | class SeastarPortMgr; 15 | 16 | class SeastarServer : public GrpcServer { 17 | protected: 18 | SeastarServer(const ServerDef& server_def, Env* env); 19 | 20 | public: 21 | static Status Create(const ServerDef& server_def, Env* env, 22 | std::unique_ptr* out_server); 23 | virtual ~SeastarServer(); 24 | Status Init(); 25 | 26 | protected: 27 | Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, 28 | SeastarChannelSpec* channel_spec); 29 | Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, 30 | WorkerCacheInterface** worker_cache) override; 31 | 32 | private: 33 | int seastar_bound_port_ = 0; 34 | std::unique_ptr seastar_worker_impl_; 35 | SeastarWorkerService* seastar_worker_service_ = nullptr; 36 | SeastarEngine* seastar_engine_ = nullptr; 37 | SeastarPortMgr* seastar_port_mgr_ = nullptr; 38 | }; 39 | 40 | } // namespace tensorflow 41 | 42 | #endif 43 | -------------------------------------------------------------------------------- /third_party/ragel.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # MIT 2 | 3 | exports_files(["COPYING"]) 4 | 5 | cc_library( 6 | name = "config_h", 7 | hdrs = glob(["stub-config/*.h"]), 8 | includes = ["stub-config"], 9 | visibility = ["//bin:__pkg__"], 10 | ) 11 | 12 | cc_library( 13 | name = "ragel_aapl", 14 | hdrs = glob(["aapl/*.h"]), 15 | strip_include_prefix = "aapl", 16 | ) 17 | 18 | genrule( 19 | name = "gen_rlreduce", 20 | srcs = [ 21 | "src/rlparse.lm", 22 | "src/ragel.lm", 23 | "src/reducer.lm", 24 | ], 25 | outs = [ 26 | "src/parse.c", 27 | "src/rlreduce.cc", 28 | ], 29 | cmd = """ 30 | $(location @colm) -c -b rl_parse \ 31 | -o $(location :src/parse.c) \ 32 | -m $(location :src/rlreduce.cc) \ 33 | -I $$(dirname $(location :src/rlparse.lm)) \ 34 | $(location :src/rlparse.lm) 35 | 36 | # http://www.colm.net/pipermail/colm-users/2018-October/000204.html 37 | # https://trac.macports.org/ticket/57242 38 | sed -i.bak 's/#include //' $(location :src/rlreduce.cc) 39 | """, 40 | tools = ["@colm"], 41 | ) 42 | 43 | cc_library( 44 | name = "ragel_lib", 45 | srcs = glob([ 46 | "src/*.cc", 47 | "src/*.h", 48 | ]) + [ 49 | "src/parse.c", 50 | "src/rlreduce.cc", 51 | ], 52 | copts = ['-DBINDIR=""'], 53 | features = ["no_copts_tokenization"], 54 | includes = ["src"], 55 | visibility = ["//bin:__pkg__"], 56 | deps = [ 57 | ":config_h", 58 | ":ragel_aapl", 59 | "@colm//:runtime", 60 | ], 61 | ) 62 | 63 | cc_binary( 64 | name = "ragelc", 65 | visibility = ["//visibility:public"], 66 | deps = ["//:ragel_lib"], 67 | ) 68 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/gdr_rendezvous_mgr.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_GDR_GDR_RENDEZVOUS_MGR_H_ 17 | #define TENSORFLOW_CONTRIB_GDR_GDR_RENDEZVOUS_MGR_H_ 18 | 19 | #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" 20 | #include "tensorflow/core/distributed_runtime/worker_env.h" 21 | #include "tensorflow/core/platform/macros.h" 22 | 23 | #include "tensorflow_networking/gdr/gdr_memory_manager.h" 24 | 25 | namespace tensorflow { 26 | 27 | class GdrRendezvousMgr : public BaseRendezvousMgr { 28 | public: 29 | explicit GdrRendezvousMgr(const WorkerEnv* env, 30 | RemoteMemoryManager* remote_memory_manager); 31 | 32 | protected: 33 | BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env); 34 | 35 | private: 36 | RemoteMemoryManager* remote_memory_manager_; // Not owned 37 | 38 | TF_DISALLOW_COPY_AND_ASSIGN(GdrRendezvousMgr); 39 | }; 40 | 41 | } // end namespace tensorflow 42 | 43 | #endif // TENSORFLOW_CONTRIB_GDR_GDR_RENDEZVOUS_MGR_H_ 44 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/gdr_server_lib.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_GDR_GDR_SERVER_LIB_H_ 17 | #define TENSORFLOW_CONTRIB_GDR_GDR_SERVER_LIB_H_ 18 | 19 | #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 20 | #include "tensorflow_networking/gdr/gdr_memory_manager.h" 21 | 22 | namespace tensorflow { 23 | 24 | class GdrServer : public GrpcServer { 25 | protected: 26 | GdrServer(const ServerDef& server_def, Env* env); 27 | 28 | public: 29 | static Status Create(const ServerDef& server_def, Env* env, 30 | std::unique_ptr* out_server); 31 | 32 | virtual ~GdrServer() override; 33 | 34 | virtual Status Start() override; 35 | 36 | virtual Status Stop() override; 37 | 38 | virtual Status Join() override; 39 | 40 | protected: 41 | Status Init(); 42 | 43 | private: 44 | mutex mu_; 45 | 46 | std::unique_ptr remote_memory_manager_; 47 | std::unique_ptr gdr_thread_ TF_GUARDED_BY(mu_); 48 | }; 49 | 50 | } // namespace tensorflow 51 | 52 | #endif // TENSORFLOW_CONTRIB_GDR_GDR_SERVER_LIB_H_ 53 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_tensor_coding.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_TENSOR_CODING_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_TENSOR_CODING_H_ 3 | 4 | #include "tensorflow/core/common_runtime/device.h" 5 | #include "tensorflow/core/framework/allocator.h" 6 | #include "tensorflow/core/framework/tensor.h" 7 | #include "tensorflow/core/lib/core/status.h" 8 | #include "tensorflow/core/platform/types.h" 9 | 10 | namespace tensorflow { 11 | 12 | struct SeastarBuf { 13 | uint64_t len_ = 0; 14 | char* data_ = nullptr; 15 | bool owned_ = true; 16 | }; 17 | 18 | class SeastarTensorResponse { 19 | public: 20 | virtual ~SeastarTensorResponse() {} 21 | 22 | void SetIsDead(bool is_dead) { is_dead_ = is_dead; } 23 | bool GetIsDead() const { return is_dead_; } 24 | 25 | // for dst device 26 | void InitAlloc(Device* d, const AllocatorAttributes& aa); 27 | Allocator* GetAlloc() { return allocator_; } 28 | AllocatorAttributes GetAllocAttributes() { return alloc_attrs_; } 29 | Device* GetDevice() const { return device_; } 30 | bool GetOnHost() const { return on_host_; } 31 | 32 | void SetTensor(const Tensor& tensor) { tensor_ = tensor; } 33 | const Tensor& GetTensor() const { return tensor_; } 34 | 35 | TensorProto& GetTensorProto() { return tensor_proto_; } 36 | 37 | void Clear(); 38 | 39 | void SetDataType(DataType data_type) { data_type_ = data_type; } 40 | DataType GetDataType() { return data_type_; } 41 | 42 | private: 43 | bool is_dead_ = false; 44 | bool on_host_ = false; 45 | 46 | // for dst device 47 | Device* device_ = nullptr; 48 | AllocatorAttributes alloc_attrs_; 49 | Allocator* allocator_ = nullptr; 50 | 51 | Tensor tensor_; 52 | TensorProto tensor_proto_; 53 | DataType data_type_; 54 | }; 55 | 56 | } // namespace tensorflow 57 | 58 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_TENSOR_CODING_H_ 59 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_channel_cache.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CHANNEL_CACHE_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CHANNEL_CACHE_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "seastar/core/channel.hh" 10 | #include "tensorflow_networking/seastar/seastar_engine.h" 11 | 12 | #include "tensorflow/core/lib/core/status.h" 13 | 14 | namespace tensorflow { 15 | 16 | class SeastarChannelSpec { 17 | public: 18 | struct HostPortsJob { 19 | HostPortsJob(const std::string& job_id, 20 | const std::unordered_map& host_ports) 21 | : job_id(job_id), host_ports(host_ports) {} 22 | const std::string job_id; 23 | const std::unordered_map host_ports; 24 | }; 25 | virtual ~SeastarChannelSpec() {} 26 | 27 | Status AddHostPortsJob( 28 | const std::string& job_id, 29 | const std::unordered_map& host_ports); 30 | 31 | const std::vector& host_ports_jobs() const { 32 | return host_ports_jobs_; 33 | } 34 | 35 | private: 36 | std::vector host_ports_jobs_; 37 | std::set job_ids_; 38 | }; 39 | 40 | class SeastarChannelCache { 41 | public: 42 | virtual ~SeastarChannelCache() {} 43 | 44 | virtual void ListWorkers(std::vector* workers) const = 0; 45 | virtual void ListWorkersInJob(const string& job_name, 46 | std::vector* workers) = 0; 47 | virtual seastar::channel* FindWorkerChannel(const std::string& target) = 0; 48 | virtual std::string TranslateTask(const std::string& task) = 0; 49 | }; 50 | 51 | SeastarChannelCache* NewSeastarChannelCache( 52 | SeastarEngine* engine, const SeastarChannelSpec& channel_spec); 53 | 54 | } // namespace tensorflow 55 | 56 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CHANNEL_CACHE_H_ 57 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/mpi_server_lib.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_MPI_MPI_SERVER_LIB_H_ 17 | #define TENSORFLOW_CONTRIB_MPI_MPI_SERVER_LIB_H_ 18 | 19 | #include 20 | 21 | #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 22 | #include "tensorflow_networking/mpi/mpi_rendezvous_mgr.h" 23 | 24 | namespace tensorflow { 25 | 26 | class MPIServer : public GrpcServer { 27 | protected: 28 | MPIServer(const ServerDef& server_def, Env* env); 29 | 30 | public: 31 | static Status Create(const ServerDef& server_def, Env* env, 32 | std::unique_ptr* out_server); 33 | 34 | // Destruction is only supported in the factory method. Clean 35 | // shutdown is not currently implemented for this server type. 36 | ~MPIServer() override; 37 | 38 | // Implementations of ServerInterface methods. 39 | Status Start() override; 40 | Status Join() override; 41 | 42 | protected: 43 | Status Init(ServiceInitFunction service_func, 44 | RendezvousMgrCreationFunction rendezvous_mgr_func); 45 | Status ChannelCacheFactory(const ServerDef& server_def, 46 | GrpcChannelCache** channel_cache); 47 | }; 48 | 49 | } // namespace tensorflow 50 | 51 | #endif // TENSORFLOW_CONTRIB_MPI_MPI_SERVER_LIB_H_ 52 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_util.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_networking/verbs/verbs_util.h" 17 | 18 | #include 19 | 20 | #include "tensorflow/core/lib/core/stringpiece.h" 21 | #include "tensorflow/core/lib/strings/numbers.h" 22 | #include "tensorflow/core/lib/strings/str_util.h" 23 | #include "tensorflow/core/lib/strings/strcat.h" 24 | 25 | namespace tensorflow { 26 | 27 | // static 28 | string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) { 29 | return strings::StrCat(key, ";", step_id); 30 | } 31 | 32 | // static 33 | void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key, 34 | int64& step_id) { 35 | StringPiece s(key_with_step_id); 36 | // a key (with step_id) has exact 6 parts if split by ";" 37 | // part 1: src_device; 38 | // part 2: src_incarnation; 39 | // part 3: dst_device; 40 | // part 4: name; 41 | // part 5: frame_iter.frame_id:frame_iter.iter_id 42 | // part 6: step_id 43 | std::vector parts = str_util::Split(s, ';'); 44 | CHECK(parts.size() == 6) << "Key with step_id must have 6 parts"; 45 | strings::safe_strto64(parts[5], &step_id); 46 | parts.pop_back(); // remove step_id 47 | key.assign(str_util::Join(parts, ";")); // stitch them together 48 | } 49 | 50 | } // namespace tensorflow 51 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/rdma_mgr.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ 18 | 19 | #include 20 | #include 21 | 22 | #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 23 | #include "tensorflow/core/distributed_runtime/worker_env.h" 24 | #include "tensorflow_networking/verbs/rdma.h" 25 | 26 | namespace tensorflow { 27 | 28 | class RdmaMgr { 29 | friend class RdmaChannel; 30 | friend class RdmaAdapter; 31 | 32 | public: 33 | explicit RdmaMgr(const WorkerEnv* const worker_env, 34 | GrpcChannelCache* const channel_cache); 35 | ~RdmaMgr(); 36 | RdmaChannel* FindChannel(const string& key); 37 | void SetupChannels(); 38 | bool ConnectivityCheck(); 39 | void InitAllocators(); 40 | static void RegMemVisitors(); 41 | const string& local_worker() { return local_worker_; } 42 | 43 | private: 44 | string local_worker_; 45 | size_t num_remote_workers_; 46 | const WorkerEnv* const worker_env_; 47 | GrpcChannelCache* const channel_cache_; 48 | RdmaAdapter* rdma_adapter_; 49 | typedef std::unordered_map ChannelTable; 50 | ChannelTable channel_table_; 51 | TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr); 52 | }; 53 | 54 | } // namespace tensorflow 55 | 56 | #endif // TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_ 57 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/grpc_verbs_client.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_networking/verbs/grpc_verbs_client.h" 17 | 18 | #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 19 | #include "tensorflow/core/lib/core/errors.h" 20 | #include "tensorflow/core/lib/core/status.h" 21 | 22 | namespace tensorflow { 23 | 24 | Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options, 25 | const GetRemoteAddressRequest* request, 26 | GetRemoteAddressResponse* response) { 27 | ::grpc::ClientContext ctx; 28 | ctx.set_fail_fast(false); 29 | SetDeadline(&ctx, call_options->GetTimeout()); 30 | return FromGrpcStatus(stub_->GetRemoteAddress(&ctx, *request, response)); 31 | } 32 | 33 | Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request, 34 | GetRemoteAddressResponse* response) { 35 | CallOptions call_options; 36 | call_options.SetTimeout(-1); // no time out 37 | return GetRemoteAddress(&call_options, request, response); 38 | } 39 | 40 | void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx, 41 | int64 time_in_ms) { 42 | if (time_in_ms > 0) { 43 | ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN)); 44 | } 45 | } 46 | 47 | } // namespace tensorflow 48 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_engine.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_ENGINE_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_ENGINE_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "tensorflow_networking/seastar/seastar_client.h" 8 | #include "tensorflow_networking/seastar/seastar_tag_factory.h" 9 | #include "tensorflow_networking/seastar/seastar_worker_service.h" 10 | 11 | #include "seastar/core/channel.hh" 12 | #include "seastar/core/app-template.hh" 13 | #include "seastar/core/distributed.hh" 14 | 15 | namespace tensorflow { 16 | 17 | class SeastarEngine { 18 | public: 19 | SeastarEngine(uint16_t local, SeastarWorkerService* worker_service); 20 | virtual ~SeastarEngine(); 21 | 22 | seastar::channel* GetChannel(const std::string& server_ip); 23 | 24 | private: 25 | class Server { 26 | public: 27 | // Used by Seastar template class distributed<> 28 | void start(uint16_t port, SeastarTagFactory* tag_factory); 29 | seastar::future<> stop(); 30 | 31 | private: 32 | struct Connection { 33 | seastar::connected_socket fd_; 34 | seastar::input_stream read_buf_; 35 | seastar::channel* channel_; 36 | SeastarTagFactory* tag_factory_; 37 | seastar::socket_address addr_; 38 | 39 | Connection(seastar::connected_socket&& fd, SeastarTagFactory* tag_factory, 40 | seastar::socket_address addr); 41 | seastar::future<> Read(); 42 | ~Connection(); 43 | }; 44 | 45 | seastar::lw_shared_ptr listener_; 46 | }; 47 | 48 | void AsyncStartServer(); 49 | void ConstructArgs(int* argc, char*** argv); 50 | void GetCpuset(char**); 51 | seastar::channel* AsyncConnect(const std::string& ip); 52 | 53 | seastar::distributed server_; 54 | SeastarClient* client_; 55 | SeastarTagFactory* tag_factory_; 56 | 57 | std::thread thread_; 58 | std::string cpuset_; 59 | uint16_t local_; 60 | std::atomic_size_t core_id_; 61 | std::atomic is_server_ready_; 62 | size_t core_number_; 63 | }; 64 | 65 | } // namespace tensorflow 66 | 67 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_ENGINE_H_ 68 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/gdr_worker.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_GDR_GDR_WORKER_H_ 17 | #define TENSORFLOW_CONTRIB_GDR_GDR_WORKER_H_ 18 | 19 | #include "tensorflow/core/distributed_runtime/recent_request_ids.h" 20 | #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" 21 | 22 | #include "tensorflow_networking/gdr/gdr_memory_manager.h" 23 | 24 | namespace tensorflow { 25 | 26 | class GdrWorker : public GrpcWorker { 27 | public: 28 | GdrWorker(WorkerEnv* env, const ConfigProto& config, 29 | RemoteMemoryManager* remote_memory_manager); 30 | 31 | // Serve the RecvTensorRequest but omit the tensor content and transmit it 32 | // out-of-band using GPU Direct RDMA whenever possible. 33 | // If it's not possible, it falls back to gRPC in-band tensor transport by 34 | // encoding the tensor content into the grpc::ByteBuffer. 35 | // The RecvTensorResponse will carry the necessary information for RDMA. 36 | virtual void GrpcRecvTensorAsync(CallOptions* opts, 37 | const RecvTensorRequest* request, 38 | ::grpc::ByteBuffer* response, 39 | StatusCallback done) override; 40 | 41 | private: 42 | RemoteMemoryManager* remote_memory_manager_; // Not owned 43 | RecentRequestIds recv_tensor_recent_request_ids_; 44 | }; 45 | 46 | } // namespace tensorflow 47 | 48 | #endif // TENSORFLOW_CONTRIB_GDR_GDR_WORKER_H_ 49 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/grpc_verbs_client.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_ 18 | 19 | #include "tensorflow/core/distributed_runtime/call_options.h" 20 | #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 21 | #include "tensorflow/core/lib/core/status.h" 22 | #include "tensorflow_networking/verbs/grpc_verbs_service_impl.h" 23 | #include "tensorflow_networking/verbs/verbs_service.pb.h" 24 | 25 | namespace tensorflow { 26 | 27 | // GrpcVerbsClient is a client that uses gRPC to talk to the Verbs service. 28 | class GrpcVerbsClient { 29 | public: 30 | explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel) 31 | : stub_(grpc::VerbsService::NewStub(client_channel)) {} 32 | ~GrpcVerbsClient() {} 33 | 34 | Status GetRemoteAddress(CallOptions* call_options, 35 | const GetRemoteAddressRequest* request, 36 | GetRemoteAddressResponse* response); 37 | Status GetRemoteAddress(const GetRemoteAddressRequest* request, 38 | GetRemoteAddressResponse* response); 39 | 40 | private: 41 | std::unique_ptr stub_; 42 | 43 | void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms); 44 | 45 | TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient); 46 | }; 47 | 48 | } // namespace tensorflow 49 | 50 | #endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_ 51 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_service.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorflow; 19 | option java_outer_classname = "VerbsServiceProtos"; 20 | option java_multiple_files = true; 21 | option java_package = "org.tensorflow.contrib.verbs"; 22 | 23 | //////////////////////////////////////////////////////////////////////////////// 24 | // 25 | // GRPC Helper messages used to exchange RDMA information. 26 | // 27 | //////////////////////////////////////////////////////////////////////////////// 28 | 29 | message Channel { 30 | int32 lid = 1; 31 | int32 qpn = 2; 32 | int32 psn = 3; 33 | uint64 snp = 4; 34 | uint64 iid = 5; 35 | } 36 | 37 | message MemoryRegion { 38 | uint64 remote_addr = 1; 39 | uint32 rkey = 2; 40 | } 41 | message GetRemoteAddressRequest { 42 | string host_name = 1; 43 | Channel channel = 2; 44 | repeated MemoryRegion mr = 3; 45 | } 46 | 47 | message GetRemoteAddressResponse { 48 | string host_name = 1; 49 | Channel channel = 2; 50 | repeated MemoryRegion mr = 3; 51 | } 52 | 53 | message ErrorStatusProto { 54 | int32 error_code = 1; 55 | string error_message = 2; 56 | string error_details = 3; 57 | } 58 | 59 | //////////////////////////////////////////////////////////////////////////////// 60 | // 61 | // VerbsService 62 | // 63 | //////////////////////////////////////////////////////////////////////////////// 64 | 65 | service VerbsService { 66 | rpc GetRemoteAddress(GetRemoteAddressRequest) 67 | returns (GetRemoteAddressResponse); 68 | } 69 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # MPI based communication interfaces and implementations for TensorFlow. 3 | 4 | licenses(["notice"]) # Apache 2.0 5 | 6 | exports_files(["LICENSE"]) 7 | 8 | filegroup( 9 | name = "c_srcs", 10 | data = glob([ 11 | "**/*.cc", 12 | "**/*.h", 13 | ]), 14 | ) 15 | 16 | # For platform specific build config 17 | load( 18 | "@org_tensorflow//tensorflow/core:platform/default/build_config.bzl", 19 | "tf_proto_library_cc", 20 | ) 21 | 22 | tf_proto_library_cc( 23 | name = "mpi_msg_proto", 24 | srcs = ["mpi_msg.proto"], 25 | cc_api_version = 2, 26 | protodeps = ["@org_tensorflow//tensorflow/core:worker_proto"], 27 | visibility = [ 28 | "//tensorflow_networking:__subpackages__", 29 | ], 30 | ) 31 | 32 | cc_library( 33 | name = "mpi_utils", 34 | srcs = ["mpi_utils.cc"], 35 | hdrs = ["mpi_utils.h"], 36 | deps = [ 37 | "//third_party/mpi", 38 | "@org_tensorflow//tensorflow/core", 39 | ], 40 | ) 41 | 42 | cc_library( 43 | name = "mpi_rendezvous_mgr", 44 | srcs = ["mpi_rendezvous_mgr.cc"], 45 | hdrs = ["mpi_rendezvous_mgr.h"], 46 | deps = [ 47 | ":mpi_msg_proto_cc", 48 | ":mpi_utils", 49 | "//third_party/mpi", 50 | "@org_tensorflow//tensorflow/core", 51 | "@org_tensorflow//tensorflow/core/distributed_runtime:base_rendezvous_mgr", 52 | "@org_tensorflow//tensorflow/core/distributed_runtime:recent_request_ids", 53 | "@org_tensorflow//tensorflow/core/distributed_runtime:request_id", 54 | "@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr", 55 | "@org_tensorflow//tensorflow/core/distributed_runtime:tensor_coding", 56 | "@org_tensorflow//tensorflow/core/distributed_runtime:worker_env", 57 | ], 58 | ) 59 | 60 | cc_library( 61 | name = "mpi_server_lib", 62 | srcs = ["mpi_server_lib.cc"], 63 | hdrs = ["mpi_server_lib.h"], 64 | linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel 65 | deps = [ 66 | ":mpi_rendezvous_mgr", 67 | ":mpi_utils", 68 | "//third_party/mpi", 69 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", 70 | ], 71 | alwayslink = 1, 72 | ) 73 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # GPU Direct RDMA Out-of-Band Tensor transport for TensorFlow. 3 | 4 | package( 5 | default_visibility = [ 6 | "//tensorflow:__subpackages__", 7 | ], 8 | licenses = ["notice"], # Apache 2.0 9 | ) 10 | 11 | exports_files(["LICENSE"]) 12 | 13 | filegroup( 14 | name = "c_srcs", 15 | data = glob([ 16 | "**/*.cc", 17 | "**/*.h", 18 | ]), 19 | ) 20 | 21 | proto_library( 22 | name = "gdr_proto", 23 | srcs = ["gdr.proto"], 24 | deps = ["@com_google_protobuf//:any_proto"], 25 | ) 26 | 27 | cc_proto_library( 28 | name = "gdr_proto_cc", 29 | deps = [":gdr_proto"], 30 | ) 31 | 32 | cc_library( 33 | name = "gdr_memory_manager", 34 | srcs = ["gdr_memory_manager.cc"], 35 | hdrs = ["gdr_memory_manager.h"], 36 | linkopts = [ 37 | "-libverbs", 38 | "-lrdmacm", 39 | ], 40 | deps = [ 41 | ":gdr_proto_cc", 42 | "@local_config_tf//:tf_header_lib", 43 | ], 44 | ) 45 | 46 | cc_library( 47 | name = "gdr_worker", 48 | srcs = ["gdr_worker.cc"], 49 | hdrs = ["gdr_worker.h"], 50 | deps = [ 51 | ":gdr_memory_manager", 52 | "@local_config_tf//:tf_header_lib", 53 | ], 54 | ) 55 | 56 | cc_library( 57 | name = "gdr_rendezvous_mgr", 58 | srcs = ["gdr_rendezvous_mgr.cc"], 59 | hdrs = ["gdr_rendezvous_mgr.h"], 60 | deps = [ 61 | ":gdr_memory_manager", 62 | "@local_config_tf//:tf_header_lib", 63 | ], 64 | ) 65 | 66 | cc_library( 67 | name = "gdr_collective_executor_mgr", 68 | srcs = ["gdr_collective_executor_mgr.cc"], 69 | hdrs = ["gdr_collective_executor_mgr.h"], 70 | deps = [ 71 | ":gdr_memory_manager", 72 | "@local_config_tf//:tf_header_lib", 73 | ], 74 | ) 75 | 76 | cc_library( 77 | name = "gdr_server_lib", 78 | srcs = ["gdr_server_lib.cc"], 79 | hdrs = ["gdr_server_lib.h"], 80 | linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel 81 | visibility = [ 82 | "//tensorflow_networking:__subpackages__", 83 | ], 84 | deps = [ 85 | ":gdr_collective_executor_mgr", 86 | ":gdr_memory_manager", 87 | ":gdr_rendezvous_mgr", 88 | ":gdr_worker", 89 | "@local_config_tf//:tf_header_lib", 90 | ], 91 | alwayslink = 1, 92 | ) 93 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG UBUNTU_VERSION=16.04 2 | 3 | FROM ubuntu:${UBUNTU_VERSION} AS base 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | build-essential \ 7 | curl \ 8 | git \ 9 | libcurl3-dev \ 10 | libfreetype6-dev \ 11 | libhdf5-serial-dev \ 12 | libpng12-dev \ 13 | libzmq3-dev \ 14 | pkg-config \ 15 | rsync \ 16 | software-properties-common \ 17 | unzip \ 18 | zip \ 19 | zlib1g-dev \ 20 | openjdk-8-jdk \ 21 | openjdk-8-jre-headless \ 22 | libibverbs-dev \ 23 | && \ 24 | apt-get clean && \ 25 | rm -rf /var/lib/apt/lists/* 26 | 27 | ENV CI_BUILD_PYTHON python 28 | 29 | ARG USE_PYTHON_3_NOT_2 30 | ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} 31 | ARG PYTHON=python${_PY_SUFFIX} 32 | ARG PIP=pip${_PY_SUFFIX} 33 | 34 | # See http://bugs.python.org/issue19846 35 | ENV LANG C.UTF-8 36 | 37 | RUN apt-get update && apt-get install -y \ 38 | ${PYTHON} \ 39 | ${PYTHON}-pip 40 | 41 | RUN ${PIP} --no-cache-dir install --upgrade \ 42 | pip \ 43 | setuptools 44 | 45 | # Some TF tools expect a "python" binary 46 | RUN ln -s $(which ${PYTHON}) /usr/local/bin/python 47 | 48 | RUN apt-get update && apt-get install -y \ 49 | build-essential \ 50 | curl \ 51 | git \ 52 | wget \ 53 | openjdk-8-jdk \ 54 | ${PYTHON}-dev \ 55 | swig 56 | 57 | RUN ${PIP} --no-cache-dir install \ 58 | Pillow \ 59 | h5py \ 60 | keras_applications \ 61 | keras_preprocessing \ 62 | matplotlib \ 63 | mock \ 64 | numpy \ 65 | scipy \ 66 | sklearn \ 67 | pandas \ 68 | && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ 69 | enum34 70 | 71 | # Install bazel 72 | ARG BAZEL_VERSION=0.24.1 73 | RUN mkdir /bazel && \ 74 | wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ 75 | wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ 76 | chmod +x /bazel/installer.sh && \ 77 | /bazel/installer.sh && \ 78 | rm -f /bazel/installer.sh 79 | 80 | ADD . /tf_networking 81 | WORKDIR /tf_networking 82 | RUN bazel build -c opt //tensorflow_networking/verbs:verbs_server_lib 83 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_server_lib.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ 18 | 19 | #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 20 | #include "tensorflow_networking/verbs/grpc_verbs_service.h" 21 | #include "tensorflow_networking/verbs/rdma_mgr.h" 22 | 23 | namespace tensorflow { 24 | 25 | class VerbsServer : public GrpcServer { 26 | protected: 27 | VerbsServer(const ServerDef& server_def, Env* env); 28 | 29 | public: 30 | static Status Create(const ServerDef& server_def, Env* env, 31 | std::unique_ptr* out_server); 32 | 33 | // Destruction is only supported in the factory method. Clean 34 | // shutdown is not currently implemented for this server type. 35 | virtual ~VerbsServer() override; 36 | 37 | // Implementations of ServerInterface methods. 38 | Status Start() override; 39 | Status Join() override; 40 | 41 | protected: 42 | Status Init(ServiceInitFunction service_func, 43 | RendezvousMgrCreationFunction rendezvous_mgr_func); 44 | Status ChannelCacheFactory(const ServerDef& server_def, 45 | GrpcChannelCache** channel_cache); 46 | 47 | private: 48 | RdmaMgr* rdma_mgr_; 49 | 50 | // Guards state transitions. 51 | mutex mu_; 52 | 53 | enum State { DISCONNECTED, CONNECTED }; 54 | State verbs_state_ TF_GUARDED_BY(mu_); 55 | 56 | GrpcVerbsService* verbs_service_ = nullptr; 57 | std::unique_ptr verbs_thread_ TF_GUARDED_BY(mu_); 58 | GrpcChannelCache* channel_cache_ = nullptr; 59 | }; 60 | 61 | } // namespace tensorflow 62 | 63 | #endif // TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_ 64 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/mpi_utils.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_MPI_MPI_UTILS_H_ 17 | #define TENSORFLOW_CONTRIB_MPI_MPI_UTILS_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | 23 | #include "tensorflow/core/lib/strings/str_util.h" 24 | #include "tensorflow/core/platform/logging.h" 25 | 26 | // Skip MPI C++ bindings support, this matches the usage in other places 27 | #define OMPI_SKIP_MPICXX 28 | #include "third_party/mpi/mpi.h" 29 | #define MPI_CHECK(cmd) \ 30 | do { \ 31 | int mpi_errno = cmd; \ 32 | if (MPI_SUCCESS != mpi_errno) { \ 33 | fprintf(stderr, "[%s:%d] MPI call failed with %d \n", __FILE__, \ 34 | __LINE__, mpi_errno); \ 35 | exit(EXIT_FAILURE); \ 36 | } \ 37 | assert(MPI_SUCCESS == mpi_errno); \ 38 | } while (false) 39 | 40 | namespace tensorflow { 41 | class MPIUtils { 42 | public: 43 | explicit MPIUtils(const std::string& worker_name); 44 | 45 | const int GetSourceID(const std::string& task_id) const { 46 | auto it = name_to_id_.find(task_id); 47 | if (it == name_to_id_.end()) { 48 | LOG(FATAL) << "Failed to convert worker name to MPI index: " << task_id; 49 | } 50 | return it->second; 51 | } 52 | 53 | private: 54 | void InitMPI(); 55 | 56 | std::map name_to_id_; 57 | }; 58 | } // namespace tensorflow 59 | 60 | #endif // TENSORFLOW_CONTRIB_MPI_MPI_UTILS_H_ 61 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_with_0_copies_phase1_protocol.xml: -------------------------------------------------------------------------------- 1 | 7Vxbc5s4FP41nuk+pMMd8ujYTrYzTZqNk9nmyaOAbNhgRIWc2P31K4G4yzZ1AHtaZzJjOBLS8TnnOzeRDNTRcn2DQejeIgf6A0Vy1gN1PFAUWVMM+sEom4RiGpcJYYE9h0/KCVPvJ+REiVNXngOj0kSCkE+8sEy0URBAm5RoAGP0Xp42R3551xAsYI0wtYFfp/7rOcTlVFmS8oG/obdw+daWzgdegP26wGgV8P0GijqPf5LhJUjX4vMjFzjovUBSJwN1hBEiydVyPYI+k20qtuS56y2jGd8YBqTJA1bywBvwVzDl2PDpo1eO98b4IxsuE+PHijF1ReCaXADfWwQDdUhn+HBO8lF6teCf8SpRCIKUNiUAk09/pUOUqeJogRxvXaa2z01Ke8ECDnpkDCxDehG8ROzjgs4c+j6yAYGPMIgQjkoC64WBKQycT4+Tu+m3h9nD5J+nyfSxax62a6K0m1LaRYlxBpklS3T43fUInIbAZqPv1C9RmkuWPr2Ts6ffIKZMbQWLnEGQujaIlpDgDZ3CH7A4aLlTkw1+/567CCX1EG7BO2RuA3C3tMiWzqFJLzg6xUhNPUbrUH2A9ltia7eQgDEgoHOTa6juNnZjBj2GoE9M1UHc/X4xZq+erq8nDLPT+29308nWTU8LRqrSJ4xkQwCjikCgQ5MBfoswcdECBcCf5NSrssgK4vkPErLh+QxYEURJ+QpfEQpLYmQb7RYi5QutsJ3O4rzSQLqA6TRNLGwMfUC8t/L6H5Kc0pEDivHiON/wU+hQyDzAKERBBLsHDfN8XylM/WG0Cezu9/syj/XyY+VhZjvDPgP7gKGsJRL7LiMUbm7unx7R6F6UXT2dUp7XqSCmEHuUsZ/wqN/45HMnQztq8qQfw8lT0eDN9+LNM1vss85u1x75Xrp75hsdGBq0emhIqvAPhAb+6D3y6M6ZKi+VsipNo6KhhAf+VK6kIcZgU5gWsglR831Us1LL7plvWLvnW9rO+fQi4Ti3sEyGzQKmVguY1x6OyKO3hMyZmCP2W/MyBbeQYDdNy0cuCBbQoX5GvW6KchctX1bRfoQ7NCTZxEPU2YypWTFEdoH6nnO9y/25XuSCkF3Ofbgess5RDFWHX45tH0SRZ5eFNfd8f4R8hOMl0gZPAe9SHe+HodoS5HuKWOIFieoCgaa0D2K/mrwrVewn3NewX1tI1aXPpiwZpiXLlqFrplFeV27mUw6AZWoEfVlFi/5cGhxR9bpx+VmxLlVFtixZ1XSlpDCtqrCmhrB7WbVhbDnEDtSaHTzDqGYKLA8rKzoiGL3CVNUBCmBF+5zEk7exTfUMKf2KuVKPlRt8YOk5TpxpiJxzOfvowherdV+sCcxHaSP/qofCO/T7itqSjihqUYOjq+KKFUD3KBSW7H2VQBfbUqgiAw/jW7bCO6bap5+fkr7cID5BIlSrv8r4iVVThsDAusurVH1/BO2zvOI1VJZwHRx0FVMQdLsposyqBrVmgW57EfWRUGjWFLqlIXdidq/12kVQ6xlDTSKnXU+kAaZk4KZY5P1klXKloNDMA/PI6kJ6VeMtdSVq+8jtdg3UBgcUnRiZoEl1oJEZdSNTj2pku2sMU+2kdEnbSR2ULmrdX7d9FjxK8qLf6ShY0Fo78FSmvyOWETtiubl/aqSHftgaw0h45LGbq/hJWqzte+TEEm3rmHl2muwIYO7KjYDA62ERziF1p7ggdbbiFqEfXpfTUsr2ggUl6PndY5zBXyjbNIgoY3M/jmQuLdth0E2JXlLsZUO9VrP0g9QqakC2olb2FsifrNRqedCrVkXFAQ9vVSc3R3EaYWfpWK5ImphJEuOxBtnx7XB2O5lOhzeTWfntvILCk5VrLvWlAzM4sZ5b1mNLD5gtgfJFOWYbTTet3t/sTvnZa15n5W9Tvqr1qXxRO6xz5Sfv+J21L9C+1iv0t/fbW9F+loLHK1T71tloVe1na8hydr1Pa+iqMm+54E4JX5bLZH4TE2UGyv6Spboq902/ZH27akDRAUzL3/vag74Tlb96kUGyCeFAGfHOAIzIzKNWuk5IAVjywYjAcEZ1z2cuEYEz4DiYE17hJrmidgRmDiBgb/F7wOHTPuRSzTnGi6EbA1EXULHtE8SwXMawInhvSBVl8nobGO76j6I6wrAIZlJt9p8LdKF8dgL9DNuPwVYVJGLdwVb0tt8Ztn8gbD8Un7e+kXvG/i9hX+8zZKdrnLFf3boCj9P3AA2PAs+424I7Q9D0bgt39Db/1wTJuXX+/x/Uyf8= 2 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/gdr_memory_manager.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_GDR_GDR_MEMORY_MANAGER_H_ 17 | #define TENSORFLOW_CONTRIB_GDR_GDR_MEMORY_MANAGER_H_ 18 | 19 | #include "google/protobuf/any.pb.h" 20 | #include "tensorflow/core/lib/core/status.h" 21 | 22 | namespace tensorflow { 23 | 24 | class Device; 25 | class DeviceContext; 26 | class Tensor; 27 | 28 | // Abstract interface that handles out-of-band tensor transport. 29 | // 30 | // The transport options are encoded into a protocol buffer and transmitted via 31 | // some other communication channels like RPC. 32 | // See RecvTensorRequest in tensorflow/core/protobuf/worker.proto 33 | class RemoteMemoryManager { 34 | public: 35 | virtual ~RemoteMemoryManager() {} 36 | virtual Status Init() = 0; 37 | virtual void Run() = 0; 38 | virtual void Stop() = 0; 39 | 40 | // Encodes the tensor information to an arbitrary protocol buffer 41 | // The protocol buffer needs to be transmitted via some other channel 42 | virtual void TransportOptionsFromTensor( 43 | ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, 44 | Device* device, DeviceContext* device_context, bool on_host, 45 | StatusCallback done) = 0; 46 | 47 | // Retrieve the tensor from the encoded protocol buffer 48 | // Note that the tensor has to be allocated, but not initialized 49 | virtual void TensorFromTransportOptions( 50 | Tensor* tensor, const ::google::protobuf::Any& transport_options, 51 | Device* device, DeviceContext* device_context, bool on_host, 52 | StatusCallback done) = 0; 53 | }; 54 | 55 | RemoteMemoryManager* CreateRemoteMemoryManager(const string& host, 56 | const string& port); 57 | 58 | } // namespace tensorflow 59 | 60 | #endif // TENSORFLOW_CONTRIB_GDR_GDR_MEMORY_MANAGER_H_ 61 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG UBUNTU_VERSION=16.04 2 | 3 | FROM ubuntu:${UBUNTU_VERSION} AS base 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | build-essential \ 7 | curl \ 8 | git \ 9 | libcurl3-dev \ 10 | libfreetype6-dev \ 11 | libhdf5-serial-dev \ 12 | libpng12-dev \ 13 | libzmq3-dev \ 14 | pkg-config \ 15 | rsync \ 16 | software-properties-common \ 17 | unzip \ 18 | zip \ 19 | zlib1g-dev \ 20 | openjdk-8-jdk \ 21 | openjdk-8-jre-headless \ 22 | libibverbs-dev \ 23 | librdmacm-dev \ 24 | && \ 25 | apt-get clean && \ 26 | rm -rf /var/lib/apt/lists/* 27 | 28 | ENV CI_BUILD_PYTHON python 29 | 30 | ARG USE_PYTHON_3_NOT_2 31 | ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} 32 | ARG PYTHON=python${_PY_SUFFIX} 33 | ARG PIP=pip${_PY_SUFFIX} 34 | 35 | # See http://bugs.python.org/issue19846 36 | ENV LANG C.UTF-8 37 | 38 | RUN apt-get update && apt-get install -y \ 39 | ${PYTHON} \ 40 | ${PYTHON}-pip 41 | 42 | RUN ${PIP} --no-cache-dir install --upgrade \ 43 | pip \ 44 | setuptools 45 | 46 | # Some TF tools expect a "python" binary 47 | RUN ln -s $(which ${PYTHON}) /usr/local/bin/python 48 | 49 | RUN apt-get update && apt-get install -y \ 50 | build-essential \ 51 | curl \ 52 | git \ 53 | wget \ 54 | openjdk-8-jdk \ 55 | ${PYTHON}-dev \ 56 | swig 57 | 58 | RUN ${PIP} --no-cache-dir install \ 59 | Pillow \ 60 | h5py \ 61 | keras_applications \ 62 | keras_preprocessing \ 63 | matplotlib \ 64 | mock \ 65 | numpy \ 66 | scipy \ 67 | sklearn \ 68 | pandas \ 69 | && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ 70 | enum34 71 | 72 | #RUN apt-get install -y libopenmpi-dev # Test using OpenMPI library 73 | RUN apt-get install -y mpich # Test using the MPICH library 74 | 75 | # Install bazel 76 | ARG BAZEL_VERSION=0.24.1 77 | RUN mkdir /bazel && \ 78 | wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ 79 | wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ 80 | chmod +x /bazel/installer.sh && \ 81 | /bazel/installer.sh && \ 82 | rm -f /bazel/installer.sh 83 | 84 | ADD . /tf_networking 85 | WORKDIR /tf_networking 86 | RUN TF_NEED_MPI=1 ./configure && bazel build -c opt //tensorflow_networking/mpi:all 87 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_server_tag.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_SERVER_TAG_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_SERVER_TAG_H_ 3 | 4 | #include 5 | 6 | #include "tensorflow_networking/seastar/seastar_tensor_coding.h" 7 | #include "tensorflow_networking/seastar/seastar_worker_service_method.h" 8 | 9 | #include "seastar/core/channel.hh" 10 | #include "seastar/core/packet_queue.hh" 11 | #include "seastar/core/temporary_buffer.hh" 12 | #include "tensorflow/core/distributed_runtime/worker_cache.h" 13 | #include "tensorflow/core/lib/core/status.h" 14 | #include "tensorflow/core/protobuf/worker.pb.h" 15 | 16 | namespace tensorflow { 17 | 18 | // Required for break circular dependency 19 | class SeastarWorkerService; 20 | 21 | class SeastarServerTag { 22 | public: 23 | // Server Header struct 32B: 24 | // |ID:8B|tag_id:8B|method:4B|status:2B|err_msg_len:2B|body_len:8B|err_msg...| 25 | static const uint64_t HEADER_SIZE = 32; 26 | 27 | SeastarServerTag(seastar::channel* seastar_channel, 28 | SeastarWorkerService* seastar_worker_service); 29 | 30 | virtual ~SeastarServerTag(); 31 | 32 | // Called by seastar engine, call the handler. 33 | void RecvReqDone(Status s); 34 | 35 | // Called by seastar engine. 36 | void SendRespDone(); 37 | 38 | void ProcessDone(Status s); 39 | 40 | uint64_t GetRequestBodySize(); 41 | 42 | char* GetRequestBodyBuffer(); 43 | 44 | void StartResp(); 45 | void StartRespWithTensor(); 46 | 47 | private: 48 | seastar::user_packet* ToUserPacket(); 49 | seastar::user_packet* ToUserPacketWithTensor(); 50 | 51 | public: 52 | SeastarBuf req_body_buf_; 53 | SeastarBuf resp_header_buf_; 54 | SeastarBuf resp_body_buf_; 55 | SeastarBuf resp_message_buf_; 56 | SeastarBuf resp_tensor_buf_; 57 | 58 | SeastarWorkerServiceMethod method_; 59 | 60 | seastar::channel* seastar_channel_; 61 | int64_t client_tag_id_; 62 | 63 | // Used to serialize and send response data. 64 | StatusCallback send_resp_; 65 | StatusCallback clear_; 66 | int16_t status_; 67 | SeastarWorkerService* seastar_worker_service_; 68 | }; 69 | 70 | void InitSeastarServerTag(protobuf::Message* request, 71 | protobuf::Message* response, SeastarServerTag* tag); 72 | 73 | void InitSeastarServerTag(protobuf::Message* request, 74 | SeastarTensorResponse* response, 75 | SeastarServerTag* tag, StatusCallback clear); 76 | 77 | } // namespace tensorflow 78 | 79 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_SERVER_TAG_H_ 80 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/rdma_rendezvous_mgr.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ 18 | 19 | #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" 20 | #include "tensorflow/core/distributed_runtime/worker_env.h" 21 | #include "tensorflow/core/platform/macros.h" 22 | #include "tensorflow_networking/verbs/rdma_mgr.h" 23 | 24 | namespace tensorflow { 25 | 26 | // RendezvousMgr keeps track of a set of local rendezvous instances. 27 | // All tensors sent by this worker are buffered in a RendezvousMgr 28 | // until the tensor is received. Each global unique "step_id" 29 | // corresponds to one local rendezvous instance managed by a 30 | // RendezvousMgr. 31 | // 32 | // E.g., 33 | // Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); 34 | // fork execution of an graph executor using "rendez" on thread 1; 35 | // fork execution of another graph executor using "rendez" on thread 2; 36 | // ... 37 | // join threads 1 and 2; 38 | // 39 | // In the example above, execution in thread 1 and 2 communicates with 40 | // each other by send/recv operations through the "rend". 41 | // 42 | // Tensors sent and recved through rendezvous managed by this 43 | // RendezvousMgr must have keys generated by Rendezvous::CreateKey. 44 | class RdmaRendezvousMgr : public BaseRendezvousMgr { 45 | public: 46 | explicit RdmaRendezvousMgr(const WorkerEnv* env); 47 | void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; } 48 | 49 | protected: 50 | BaseRemoteRendezvous* Create(int64 step_id, 51 | const WorkerEnv* worker_env) override; 52 | 53 | private: 54 | RdmaMgr* rdma_mgr_; 55 | TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr); 56 | }; 57 | 58 | } // end namespace tensorflow 59 | 60 | #endif // TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_ 61 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG UBUNTU_VERSION=16.04 2 | 3 | FROM ubuntu:${UBUNTU_VERSION} AS base 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends \ 6 | build-essential \ 7 | curl \ 8 | git \ 9 | libcurl3-dev \ 10 | libfreetype6-dev \ 11 | libhdf5-serial-dev \ 12 | libpng12-dev \ 13 | libzmq3-dev \ 14 | pkg-config \ 15 | rsync \ 16 | software-properties-common \ 17 | unzip \ 18 | zip \ 19 | zlib1g-dev \ 20 | openjdk-8-jdk \ 21 | openjdk-8-jre-headless \ 22 | libibverbs-dev \ 23 | librdmacm-dev \ 24 | && \ 25 | apt-get clean && \ 26 | rm -rf /var/lib/apt/lists/* 27 | 28 | ENV CI_BUILD_PYTHON python 29 | 30 | ARG USE_PYTHON_3_NOT_2 31 | ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3} 32 | ARG PYTHON=python${_PY_SUFFIX} 33 | ARG PIP=pip${_PY_SUFFIX} 34 | 35 | # See http://bugs.python.org/issue19846 36 | ENV LANG C.UTF-8 37 | 38 | RUN apt-get update && apt-get install -y \ 39 | ${PYTHON} \ 40 | ${PYTHON}-pip 41 | 42 | RUN ${PIP} --no-cache-dir install --upgrade \ 43 | pip \ 44 | setuptools 45 | 46 | # Some TF tools expect a "python" binary 47 | RUN ln -s $(which ${PYTHON}) /usr/local/bin/python 48 | 49 | RUN apt-get update && apt-get install -y \ 50 | build-essential \ 51 | curl \ 52 | git \ 53 | wget \ 54 | openjdk-8-jdk \ 55 | ${PYTHON}-dev \ 56 | swig 57 | 58 | RUN ${PIP} --no-cache-dir install \ 59 | Pillow \ 60 | h5py \ 61 | keras_applications \ 62 | keras_preprocessing \ 63 | matplotlib \ 64 | mock \ 65 | numpy \ 66 | scipy \ 67 | sklearn \ 68 | pandas \ 69 | && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \ 70 | enum34 71 | 72 | #RUN apt-get install -y libopenmpi-dev # Test using OpenMPI library 73 | RUN apt-get install -y mpich # Test using the MPICH library 74 | 75 | # Install bazel 76 | ARG BAZEL_VERSION=0.24.1 77 | RUN mkdir /bazel && \ 78 | wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ 79 | wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \ 80 | chmod +x /bazel/installer.sh && \ 81 | /bazel/installer.sh && \ 82 | rm -f /bazel/installer.sh 83 | 84 | ADD . /tf_networking 85 | WORKDIR /tf_networking 86 | RUN TF_NEED_MPI=1 ./configure && bazel build -c opt //tensorflow_networking/mpi_collectives:all 87 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/mpi_message.proto: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | syntax = "proto3"; 17 | 18 | package tensorflow.contrib.mpi_collectives; 19 | 20 | import "tensorflow/core/framework/tensor_shape.proto"; 21 | import "tensorflow/core/framework/types.proto"; 22 | 23 | // An MPIRequest is a message sent from a rank greater than zero to the 24 | // coordinator (rank zero), informing the coordinator of an operation that 25 | // the rank wants to do and the tensor that it wants to apply the operation to. 26 | message MPIRequest { 27 | enum RequestType { 28 | ALLREDUCE = 0; 29 | ALLGATHER = 1; 30 | } 31 | 32 | // The request rank is necessary to create a consistent ordering of results, 33 | // for example in the allgather where the order of outputs should be sorted 34 | // by rank. 35 | int32 request_rank = 1; 36 | RequestType request_type = 2; 37 | DataType tensor_type = 3; 38 | string tensor_name = 4; 39 | TensorShapeProto tensor_shape = 5; 40 | }; 41 | 42 | // An MPIResponse is a message sent from the coordinator (rank zero) to a rank 43 | // greater than zero, informing the rank of an operation should be performed 44 | // now. If the operation requested would result in an error (for example, due 45 | // to a type or shape mismatch), then the MPIResponse can contain an error and 46 | // an error message instead. Finally, an MPIResponse can be a DONE message (if 47 | // there are no more tensors to reduce on this tick of the background loop) or 48 | // SHUTDOWN if all MPI processes should shut down. 49 | message MPIResponse { 50 | enum ResponseType { 51 | ALLREDUCE = 0; 52 | ALLGATHER = 1; 53 | ERROR = 2; 54 | DONE = 3; 55 | SHUTDOWN = 4; 56 | } 57 | 58 | // Empty if the type is DONE or SHUTDOWN. 59 | ResponseType response_type = 1; 60 | string tensor_name = 2; 61 | 62 | // Empty unless response_type is ERROR. 63 | string error_message = 3; 64 | }; 65 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/verbs_with_0_copies.xml: -------------------------------------------------------------------------------- 1 | 7Vxtc9o4EP41zKQfmsGW3/hIgPQ60/RyIZ1rPzHClsFXY1FZEOivP8mW8ZsAB2yXtHQ6jb2SJXl3n0e7K6cdMFhsPhC4nD9gB/kdtetsOmDYUVVFUw32g0u2scRUu7FgRjxHdEoFY+8nEsKk28pzUJjrSDH2qbfMC20cBMimORkkBL/ku7nYz8+6hDNUEoxt6Jel/3oOnQup0u2mDX8hbzYXU1u6aJhC+/uM4FUg5uuowI3+xM0LmIwl+odz6OCXjAiMOmBAMKbx1WIzQD7XbaK2+Ln7Pa27dRMU0CoP6CB+Yg39FUqWHC2MbhNlRK+D+APdDrh7mXsUjZfQ5q0vzPxMNqcLn90p7NL1fH+AfUzYfYAD1ulOzIAIRZu9y1R2L8+cCuEFomTLumx2mo8fEf5kiduX1DhWIptn7GIkQigcYrYbOlUKuxB6kevIkqjI8NkMd463zqnK+LHihrtjL0rfQ9+bBR3QZz185NK0lV3NxM9olHAJg0Q2ppDQm3dJE1tatjUjjqbOS+tfTSKbEskK2lqYcsua+r6PbUgRJwIUhJiEt03OqfI5xyhwbp5Hn8d/P02eRv98GY2f37VhAam2c7PVBVDGTg5Elmtzu1OCv6NMi2FbaOru5iuBVQLp/fgFefwqRhnAalcC4B3jngPghGxrR/ATstfPkTs+IAqHkMKbC/GQ2oD3ZenEsGNah+/ZNeTbLrTnqHkAPiHYLuxlHAjKVCBjg8q0+XsBGahtAlkxjkcryGGRnLjFhM7xDAfQH6XSu7yWMxr9D1G6FcEoXFHMROkInzBein579RjiFbGTEFIsjW3oM5R002IZX+NBbRPkQ+qt89HoWZpTG6LAiCQGBMUgJShc4iBshRvs9SfGDX4/3AZ2s7QbUcAAL7dRGsKvH79wyCPisWd/8hf/6EZv/2PlEeTsffs3B3dT+aX7tv4r0M20RbZfszff+GC3Or/dePRr7u6bmKgaJ2hlTkhS5fo4QTz6iD22lJ0pe728KU2jYKF4UeKp1Eh9QuA2023JO4QH5jELO4RZyECP9E9SttRH4hWkHrPTSTXm00rMd++RkD57Cw7cjjngf1UDLjjggmm4LPJGjN4kwhvMYTBDDmMccF8V53O8mK7C4xh3GHvY1MOMjYbMbzjCLgP3LW/zvePbfDiHS37p+mjT5xWfCKyOuBzaPgxDzz5Ioa5lI9vOIr5bRnxJzVNL1/TKiLckQYBaEfAZXesSVSeyM3kBFCI6tcgL8euUeKE0kNbt3jK/MUxLUSxD10wjP65SjW9OgHjiHm35y5k+Id0FuhflFIbSu1WtHlAVy1KApqs5U2pFU1Z1kcPDgoo70ikeUi5zfkNhyUl4OJh3gbypRUFTUuMUMeTQZoZHTH7Hudbj4aloWHiOE8Unsi0gH7M0wd+gzN+axH3UOqK28ob7Gf++qraK8U6bqpblw00VQqJM75GFyfI0r61yMM/+5MFadgVPwwc+xAvxorz0Jq7SdaITI8qM/e61S3/zqZsh8cvmQjigH9+S28zlRMK2i+2q7tWqKdmrW8rYrKIFtWr74/6M7Zwd1CwZdFcaztDBm4eJ1mqFA1Q4fn0jmU6yn+WQYlZESjtRrVpIdTTzxDi2OJBe9IWaaimleZR6ayOgQj29EfeTlNbOdT+j7H7gstzvcPZjgkaSqtIXEPUlVaC8JdR9rDqIo7Xf6VRVUlqMI2uCN9soQOXnDPcOyh4veJWOF0oDyyLj6PTkY7BmYGMXDs+p+IGu7/NPl45Fxa9S0ZEz1aHkdJf7aeBE77rAayReGoX0tEzjzQfxxePWdoP4JN6UAHyuVIKAyNFlCEeMSEnGUHzEPXY6vVbAXMX2ghkT6Ondc5QevFf3GZ05HnH9aHebe46DgibKBoqp5yzbKxt299Fb1rDFHOAku6pN2ZV/J/EnW9U0jlq115RRZamEYO0iL7o4CiDsHWmldgSu2+1y8ihBdvjQnzyMxuP+h9Ek/1Vcxt7xyCUWnjbgBRdWBwRWnqoVyTeq0uiyjkKgVq65Nmf8h9FzfzLss3+eRuPHvz+PR1cHkDgA0Np0AFm9rXH0XwnggP21Vglg/0lALfYv1s+vFpdY3NDbtHhT2XfNSXUi+LhYxP2ruCF3Qv47M8VRBQO9yvsOJYiyVLLm9773kO+E+VfPLpBulyzPHaSp7sRjXrqJRQFciMaQouWE2V70XGCKJtBxiBB8R9v4in+mPYk/0z6SGZ9w6nUMuTwvNqaGbnTKNUDXVaMa4KVh2MhjWJV86QRkGbZeB4ab/tWihjAsg1m31PvPBbpUPweBfoXtebAFkmCrOdjKvk+8wvYPhO3r9ucrsk9Att7mhpyMcUV2ceqC818+viueVF1xtwd3hqR2XRfu2G36XxzEJ8/p/yMBRv8D 2 | -------------------------------------------------------------------------------- /tensorflow_networking/gdr/gdr_collective_executor_mgr.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | #ifndef TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ 16 | #define TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ 17 | 18 | #include "tensorflow_networking/gdr/gdr_memory_manager.h" 19 | 20 | #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" 21 | #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" 22 | #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" 23 | #include "tensorflow/core/framework/collective.h" 24 | 25 | namespace tensorflow { 26 | class ConfigProto; 27 | class DeviceMgr; 28 | class WorkerCacheInterface; 29 | class StepSequenceRequest; 30 | class StepSequenceResponse; 31 | 32 | // An implementation of CollectiveExecutorMgr for a distributed environment 33 | // that uses WorkerInterface::RecvBufAsync to route data transfers over RDMA. 34 | class GdrCollectiveExecutorMgr : public RpcCollectiveExecutorMgr { 35 | public: 36 | GdrCollectiveExecutorMgr( 37 | const ConfigProto& config, const DeviceMgr* dev_mgr, 38 | std::unique_ptr dev_resolver, 39 | std::unique_ptr param_resolver, 40 | WorkerCacheInterface* worker_cache, const string& task_name, 41 | RemoteMemoryManager* remote_memory_manager) 42 | : RpcCollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver), 43 | std::move(param_resolver), worker_cache, 44 | task_name), 45 | remote_memory_manager_(remote_memory_manager) {} 46 | 47 | ~GdrCollectiveExecutorMgr() override {} 48 | 49 | protected: 50 | virtual CollectiveExecutor* Create(int64 step_id) override; 51 | 52 | private: 53 | RemoteMemoryManager* remote_memory_manager_; // Not owned. 54 | }; 55 | 56 | } // namespace tensorflow 57 | #endif // TENSORFLOW_CONTRIB_GDR_GDR_COLLECTIVE_EXECUTOR_MGR_H_ 58 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup for pip package.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from setuptools import find_packages 21 | from setuptools import setup 22 | from setuptools.dist import Distribution 23 | 24 | 25 | __version__ = '0.1.0' 26 | REQUIRED_PACKAGES = [ 27 | 'tf-nightly >= 2.1.0.dev20191206' 28 | ] 29 | project_name = 'tensorflow-networking' 30 | 31 | 32 | class BinaryDistribution(Distribution): 33 | """This class is needed in order to create OS specific wheels.""" 34 | 35 | def has_ext_modules(self): 36 | return True 37 | 38 | 39 | setup( 40 | name=project_name, 41 | version=__version__, 42 | description=('TensorFlow Networking'), 43 | author='Google Inc.', 44 | author_email='opensource@google.com', 45 | # Contained modules and scripts. 46 | packages=find_packages(), 47 | install_requires=REQUIRED_PACKAGES, 48 | # Add in any packaged data. 49 | include_package_data=True, 50 | zip_safe=False, 51 | distclass=BinaryDistribution, 52 | # PyPI package information. 53 | classifiers=[ 54 | 'Development Status :: 4 - Beta', 55 | 'Intended Audience :: Developers', 56 | 'Intended Audience :: Education', 57 | 'Intended Audience :: Science/Research', 58 | 'License :: OSI Approved :: Apache Software License', 59 | 'Programming Language :: Python :: 3.5', 60 | 'Programming Language :: Python :: 3.6', 61 | 'Programming Language :: Python :: 3.7', 62 | 'Topic :: Scientific/Engineering', 63 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 64 | 'Topic :: Software Development', 65 | 'Topic :: Software Development :: Libraries', 66 | 'Topic :: Software Development :: Libraries :: Python Modules', 67 | ], 68 | license='Apache 2.0', 69 | keywords='tensorflow networking machine learning', 70 | ) 71 | -------------------------------------------------------------------------------- /third_party/seastar.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | exports_files(["LICENSE"]) 4 | 5 | proto_library( 6 | name = "metrics2_proto", 7 | srcs = ["src/proto/metrics2.proto"], 8 | ) 9 | 10 | cc_proto_library( 11 | name = "metrics2_cc_proto", 12 | deps = [":metrics2_proto"], 13 | ) 14 | 15 | cc_library( 16 | name = "seastar", 17 | srcs = glob( 18 | ["src/**/*.cc"], 19 | exclude = [ 20 | "src/testing/*.cc", 21 | ], 22 | ) + glob( 23 | ["src/**/*.hh"], 24 | ), 25 | hdrs = glob( 26 | ["include/seastar/**/*.hh"], 27 | exclude = [ 28 | "include/seastar/testing/*.hh", 29 | ], 30 | ) + [ 31 | "include/seastar/http/request_parser.hh", 32 | "include/seastar/http/response_parser.hh", 33 | ], 34 | copts = [ 35 | "-DSEASTAR_NO_EXCEPTION_HACK", 36 | "-DNO_EXCEPTION_INTERCEPT", 37 | "-DSEASTAR_DEFAULT_ALLOCATOR", 38 | "-DSEASTAR_HAVE_NUMA", 39 | ], 40 | includes = [ 41 | "src", 42 | ], 43 | linkopts = [ 44 | "-ldl", 45 | "-lm", 46 | "-lrt", 47 | "-lstdc++fs", 48 | ], 49 | strip_include_prefix = "include", 50 | visibility = ["//visibility:public"], 51 | deps = [ 52 | ":metrics2_cc_proto", 53 | "@boost//:asio", 54 | "@boost//:filesystem", 55 | "@boost//:fusion", 56 | "@boost//:lockfree", 57 | "@boost//:program_options", 58 | "@boost//:system", 59 | "@boost//:thread", 60 | "@boost//:variant", 61 | "@cares", 62 | "@cryptopp", 63 | "@fmtlib", 64 | "@gnutls", 65 | "@lz4", 66 | "@org_lzma_lzma//:lzma", 67 | "@readerwriterqueue", 68 | "@sctp", 69 | "@systemtap-sdt", 70 | "@xfs", 71 | "@yaml-cpp", 72 | "@readerwriterqueue", 73 | ], 74 | ) 75 | 76 | genrule( 77 | name = "generate_http_request_parser", 78 | srcs = ["src/http/request_parser.rl"], 79 | outs = ["include/seastar/http/request_parser.hh"], 80 | cmd = "\n".join([ 81 | "$(location @ragel//:ragelc) -G2 -o $@ $<", 82 | "sed -i -e '1h;2,$$H;$$!d;g' -re 's/static const char _nfa[^;]*;//g' $@", 83 | ]), 84 | tools = ["@ragel//:ragelc"], 85 | ) 86 | 87 | genrule( 88 | name = "generate_http_response_parser", 89 | srcs = ["src/http/response_parser.rl"], 90 | outs = ["include/seastar/http/response_parser.hh"], 91 | cmd = "\n".join([ 92 | "$(location @ragel//:ragelc) -G2 -o $@ $<", 93 | "sed -i -e '1h;2,$$H;$$!d;g' -re 's/static const char _nfa[^;]*;//g' $@", 94 | ]), 95 | tools = ["@ragel//:ragelc"], 96 | ) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Networking 2 | 3 | 4 | This repository is for platform-specific networking extensions to core TensorFlow and related 5 | utilities (e.g. testing). 6 | 7 | The design goal is to work towards separately compilable plugins, but initially we'll just be porting the 8 | networking related contrib directories since TensorFlow 2.0 will be dropping contrib. 9 | 10 | ## Building 11 | 12 | Currently support building GDR, VERBS, and MPI extensions: 13 | 14 | #### GDR 15 | 16 | Using Bazel: 17 | 18 | ```bash 19 | bazel build -c opt //tensorflow_networking/gdr:gdr_server_lib 20 | ``` 21 | 22 | Using Docker: 23 | 24 | ```bash 25 | docker build -t tf_networking -f tensorflow_networking/gdr/Dockerfile . 26 | ``` 27 | 28 | #### VERBS 29 | 30 | Using Bazel: 31 | 32 | ```bash 33 | bazel build -c opt //tensorflow_networking/verbs:verbs_server_lib 34 | ``` 35 | 36 | Using Docker: 37 | 38 | ```bash 39 | docker build -t tf_networking -f tensorflow_networking/verbs/Dockerfile . 40 | ``` 41 | 42 | #### MPI 43 | 44 | 45 | For the MPI extensions the location to the MPI library has to be configured. The `configure` script is used to setup this configuration. The script will attempt to find the location of the `mpirun` binary and from there deduce the include and library paths. You can use the `MPI_HOME` environment variable if `mpirun` is not installed in your PATH or you want to use another base path for the MPI library. The configure script will create symbolic links inside the `third_party/mpi` folder to the relevant MPI header and library files. Furthermore the script will determine if your MPI installation is based on `OpenMPI` or on `MPICH` and sets this in the `.tf_networking_configure.bazelrc` file. 46 | 47 | ##### `grpc+mpi` extension 48 | 49 | Using Bazel: 50 | 51 | By manually answering the relevant configuration questions 52 | ```bash 53 | ./configure 54 | bazel build -c opt //tensorflow_networking/mpi:mpi_server_lib 55 | ``` 56 | or by preset answers to the configuration questions 57 | ```bash 58 | MPI_HOME= TF_NEED_MPI=1 ./configure 59 | bazel build -c opt //tensorflow_networking/mpi:mpi_server_lib 60 | ``` 61 | 62 | Using Docker: 63 | 64 | ```bash 65 | docker build -t tf_networking -f tensorflow_networking/mpi/Dockerfile . 66 | ``` 67 | 68 | 69 | ##### `MPI collectives` extension 70 | 71 | Using Bazel: 72 | 73 | By manually answering the relevant configuration questions 74 | ```bash 75 | ./configure 76 | bazel build -c opt //tensorflow_networking/mpi_collectives:all 77 | ``` 78 | 79 | Using Docker: 80 | 81 | ```bash 82 | docker build -t tf_networking -f tensorflow_networking/mpi_collectives/Dockerfile . 83 | ``` 84 | 85 | ##### `grpc+seastar` extension 86 | 87 | Using Bazel: 88 | 89 | ```bash 90 | bazel build -c opt --copt='-std=gnu++14' //tensorflow_networking:libtensorflow_networking.so 91 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/grpc_verbs_service_impl.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_networking/verbs/grpc_verbs_service_impl.h" 17 | 18 | #include "grpcpp/impl/codegen/async_stream.h" 19 | #include "grpcpp/impl/codegen/async_unary_call.h" 20 | #include "grpcpp/impl/codegen/channel_interface.h" 21 | #include "grpcpp/impl/codegen/client_unary_call.h" 22 | #include "grpcpp/impl/codegen/method_handler_impl.h" 23 | #include "grpcpp/impl/codegen/rpc_service_method.h" 24 | #include "grpcpp/impl/codegen/service_type.h" 25 | #include "grpcpp/impl/codegen/sync_stream.h" 26 | 27 | namespace tensorflow { 28 | 29 | namespace grpc { 30 | 31 | static const char* grpcVerbsService_method_names[] = { 32 | "/tensorflow.VerbsService/GetRemoteAddress", 33 | }; 34 | 35 | std::unique_ptr VerbsService::NewStub( 36 | const std::shared_ptr< ::grpc::ChannelInterface>& channel, 37 | const ::grpc::StubOptions& options) { 38 | std::unique_ptr stub(new VerbsService::Stub(channel)); 39 | return stub; 40 | } 41 | 42 | VerbsService::Stub::Stub( 43 | const std::shared_ptr< ::grpc::ChannelInterface>& channel) 44 | : channel_(channel), 45 | rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0], 46 | ::grpc::internal::RpcMethod::NORMAL_RPC, 47 | channel) {} 48 | 49 | ::grpc::Status VerbsService::Stub::GetRemoteAddress( 50 | ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, 51 | GetRemoteAddressResponse* response) { 52 | return ::grpc::internal::BlockingUnaryCall( 53 | channel_.get(), rpcmethod_GetRemoteAddress_, context, request, response); 54 | } 55 | 56 | VerbsService::AsyncService::AsyncService() { 57 | for (int i = 0; i < 1; ++i) { 58 | AddMethod(new ::grpc::internal::RpcServiceMethod( 59 | grpcVerbsService_method_names[i], 60 | ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); 61 | ::grpc::Service::MarkMethodAsync(i); 62 | } 63 | } 64 | 65 | VerbsService::AsyncService::~AsyncService() {} 66 | 67 | } // namespace grpc 68 | 69 | } // namespace tensorflow 70 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/grpc_verbs_service.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ 18 | 19 | #include "grpcpp/alarm.h" 20 | #include "grpcpp/grpcpp.h" 21 | #include "grpcpp/server_builder.h" 22 | #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" 23 | #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" 24 | #include "tensorflow/core/lib/core/refcount.h" 25 | #include "tensorflow_networking/verbs/grpc_verbs_service_impl.h" 26 | #include "tensorflow_networking/verbs/rdma_mgr.h" 27 | #include "tensorflow_networking/verbs/verbs_service.pb.h" 28 | 29 | namespace tensorflow { 30 | 31 | class GrpcVerbsService : public AsyncServiceInterface { 32 | public: 33 | GrpcVerbsService(const WorkerEnv* worker_env, ::grpc::ServerBuilder* builder); 34 | ~GrpcVerbsService(); 35 | void HandleRPCsLoop() override; 36 | void Shutdown() override; 37 | void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; } 38 | 39 | private: 40 | template 41 | using WorkerCall = Call; 43 | void GetRemoteAddressHandler( 44 | WorkerCall* call); 45 | Status GetRemoteAddressSync(const GetRemoteAddressRequest* request, 46 | GetRemoteAddressResponse* response); 47 | 48 | ::grpc::ServerCompletionQueue* cq_; 49 | grpc::VerbsService::AsyncService verbs_service_; 50 | mutex shutdown_mu_; 51 | bool is_shutdown_ TF_GUARDED_BY(shutdown_mu_); 52 | ::grpc::Alarm* shutdown_alarm_; 53 | // not owned 54 | RdmaMgr* rdma_mgr_; 55 | const WorkerEnv* const worker_env_; 56 | 57 | TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService); 58 | }; 59 | 60 | // Create a GrpcVerbsService, then assign it to a given handle. 61 | void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env, 62 | ::grpc::ServerBuilder* builder); 63 | 64 | } // namespace tensorflow 65 | 66 | #endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_ 67 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_worker_service.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SEASTAR_SEASTAR_WORKER_SERVICE_H_ 2 | #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SEASTAR_SEASTAR_WORKER_SERVICE_H_ 3 | 4 | #include 5 | 6 | #include "tensorflow_networking/seastar/seastar_tensor_coding.h" 7 | #include "tensorflow_networking/seastar/seastar_worker_interface.h" 8 | #include "tensorflow_networking/seastar/seastar_worker_service_method.h" 9 | 10 | #include "tensorflow/core/distributed_runtime/call_options.h" 11 | #include "tensorflow/core/distributed_runtime/worker.h" 12 | #include "tensorflow/core/distributed_runtime/worker_env.h" 13 | #include "tensorflow/core/lib/core/status.h" 14 | 15 | namespace tensorflow { 16 | 17 | // Required for break circular dependency 18 | class SeastarServerTag; 19 | 20 | class SeastarWorker : public Worker, public SeastarWorkerInterface { 21 | public: 22 | typedef std::function StatusCallback; 23 | explicit SeastarWorker(WorkerEnv* worker_env); 24 | virtual ~SeastarWorker() {} 25 | 26 | // Specialized version of RecvTensor for seastar. 27 | void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, 28 | SeastarTensorResponse* response, StatusCallback done); 29 | WorkerEnv* env(); 30 | }; 31 | 32 | class SeastarWorkerService { 33 | public: 34 | using HandleRequestFunction = 35 | void (SeastarWorkerService::*)(SeastarServerTag*); 36 | 37 | explicit SeastarWorkerService(SeastarWorker* worker); 38 | virtual ~SeastarWorkerService() {} 39 | 40 | HandleRequestFunction GetHandler(SeastarWorkerServiceMethod methodId); 41 | 42 | void RunGraphHandler(SeastarServerTag* tag); 43 | void GetStatusHandler(SeastarServerTag* tag); 44 | void CreateWorkerSessionHandler(SeastarServerTag* tag); 45 | void DeleteWorkerSessionHandler(SeastarServerTag* tag); 46 | void CleanupAllHandler(SeastarServerTag* tag); 47 | void RegisterGraphHandler(SeastarServerTag* tag); 48 | void DeregisterGraphHandler(SeastarServerTag* tag); 49 | void CleanupGraphHandler(SeastarServerTag* tag); 50 | void LoggingHandler(SeastarServerTag* tag); 51 | void TracingHandler(SeastarServerTag* tag); 52 | void RecvTensorHandlerRaw(SeastarServerTag* tag); 53 | void RecvBufHandler(SeastarServerTag* tag); 54 | void CompleteGroupHandler(SeastarServerTag* tag); 55 | void CompleteInstanceHandler(SeastarServerTag* tag); 56 | void GetStepSequenceHandler(SeastarServerTag* tag); 57 | 58 | private: 59 | void Schedule(std::function f); 60 | 61 | std::unordered_map 62 | handler_map_; 63 | SeastarWorker* worker_; 64 | }; 65 | 66 | std::unique_ptr NewSeastarWorker(WorkerEnv* worker_env); 67 | std::unique_ptr NewSeastarWorkerService( 68 | SeastarWorker* worker); 69 | 70 | } // namespace tensorflow 71 | 72 | #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SEASTAR_SEASTAR_WORKER_SERVICE_H_ 73 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/mpi_utils.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_networking/mpi/mpi_utils.h" 17 | namespace tensorflow { 18 | 19 | #define max_worker_name_length 128 20 | 21 | MPIUtils::MPIUtils(const std::string& worker_name) { 22 | InitMPI(); 23 | // Connect the MPI process IDs to the worker names that are used by TF. 24 | // Gather the names of all the active processes (name can't be longer than 25 | // 128 bytes) 26 | int proc_id = 0, number_of_procs = 1; 27 | char my_name[max_worker_name_length]; 28 | MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &proc_id)); 29 | MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs)); 30 | 31 | CHECK(worker_name.size() < max_worker_name_length) 32 | << "Specified worker name is too long."; 33 | snprintf(my_name, max_worker_name_length, worker_name.c_str()); 34 | std::vector worker_names(number_of_procs * max_worker_name_length); 35 | MPI_CHECK(MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, 36 | &worker_names[0], max_worker_name_length, MPI_CHAR, 37 | MPI_COMM_WORLD)); 38 | 39 | if (proc_id == 0) LOG(INFO) << "MPI process-ID to gRPC server name map: \n"; 40 | for (int i = 0; i < number_of_procs; i++) { 41 | name_to_id_[std::string(&worker_names[i * 128])] = i; 42 | if (proc_id == 0) 43 | LOG(INFO) << "Process: " << i 44 | << "\tgRPC-name: " << std::string(&worker_names[i * 128]) 45 | << std::endl; 46 | } 47 | } 48 | 49 | void MPIUtils::InitMPI() { 50 | // Initialize the MPI environment if that hasn't been done 51 | int flag = 0; 52 | MPI_CHECK(MPI_Initialized(&flag)); 53 | if (!flag) { 54 | int proc_id = 0, number_of_procs = 1, len = -1; 55 | char my_host_name[max_worker_name_length]; 56 | // MPI_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &flag)); 57 | MPI_CHECK(MPI_Init(0, 0)); 58 | MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &proc_id)); 59 | MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs)); 60 | MPI_CHECK(MPI_Get_processor_name(my_host_name, &len)); 61 | fprintf(stderr, 62 | "MPI Environment initialized. Process id: %d Total processes: %d " 63 | "|| Hostname: %s \n", 64 | proc_id, number_of_procs, my_host_name); 65 | } 66 | } 67 | 68 | } // namespace tensorflow 69 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_client_tag.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CLIENT_TAG_H_ 2 | #define TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CLIENT_TAG_H_ 3 | 4 | #include "tensorflow_networking/seastar/seastar_tensor_coding.h" 5 | #include "tensorflow_networking/seastar/seastar_worker_service.h" 6 | #include "tensorflow_networking/seastar/seastar_worker_service_method.h" 7 | 8 | #include "seastar/core/channel.hh" 9 | #include "seastar/core/packet_queue.hh" 10 | #include "seastar/core/temporary_buffer.hh" 11 | 12 | #include "tensorflow/core/distributed_runtime/call_options.h" 13 | #include "tensorflow/core/distributed_runtime/worker_cache.h" 14 | #include "tensorflow/core/distributed_runtime/worker_env.h" 15 | #include "tensorflow/core/lib/core/status.h" 16 | #include "tensorflow/core/protobuf/worker.pb.h" 17 | 18 | namespace tensorflow { 19 | 20 | typedef std::function ParseMessageCallback; 21 | 22 | class SeastarClientTag { 23 | public: 24 | // Client Header 32B: 25 | // |ID:8B|tag:8B|method:4B|reserve:4B|body_len:8B| 26 | static const uint64_t HEADER_SIZE = 32; 27 | SeastarClientTag(tensorflow::SeastarWorkerServiceMethod method, 28 | WorkerEnv* env); 29 | virtual ~SeastarClientTag(); 30 | 31 | // Called by seastar remote worker, notify seastar engine to send request. 32 | void StartReq(seastar::channel* seastar_channel); 33 | 34 | bool IsRecvTensor(); 35 | Status ParseMessage(); 36 | 37 | // Called by seastar engine, handle the upper layer callback, ex. callback of 38 | // 'RecvOp'. 39 | void RecvRespDone(Status s); 40 | 41 | uint64_t GetResponseBodySize(); 42 | char* GetResponseBodyBuffer(); 43 | 44 | uint64_t GetResponseMessageSize(); 45 | char* GetResponseMessageBuffer(); 46 | 47 | uint64_t GetResponseTensorSize(); 48 | char* GetResponseTensorBuffer(); 49 | 50 | private: 51 | friend class SeastarTagFactory; 52 | seastar::user_packet* ToUserPacket(); 53 | void Schedule(std::function f); 54 | 55 | public: 56 | // Used to handle the upper layer call back when resp recevied. 57 | StatusCallback done_; 58 | SeastarWorkerServiceMethod method_; 59 | WorkerEnv* env_; 60 | int16_t status_; 61 | uint16_t resp_err_msg_len_; 62 | SeastarBuf req_header_buf_; 63 | SeastarBuf req_body_buf_; 64 | SeastarBuf resp_body_buf_; 65 | SeastarBuf resp_message_buf_; 66 | SeastarBuf resp_tensor_buf_; 67 | ParseMessageCallback parse_message_; 68 | CallOptions* call_opts_; 69 | int timeout_in_ms_; 70 | }; 71 | 72 | void InitSeastarClientTag(protobuf::Message* request, 73 | protobuf::Message* response, StatusCallback done, 74 | SeastarClientTag* tag, CallOptions* call_opts); 75 | 76 | void InitSeastarClientTag(protobuf::Message* request, 77 | SeastarTensorResponse* response, StatusCallback done, 78 | SeastarClientTag* tag, CallOptions* call_opts); 79 | 80 | } // namespace tensorflow 81 | 82 | #endif // TENSORFLOW_CONTRIB_SEASTAR_SEASTAR_CLIENT_TAG_H_ 83 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/BUILD: -------------------------------------------------------------------------------- 1 | # Ops that communicate with other processes via MPI. 2 | 3 | #package(default_visibility = [ 4 | # "//tensorflow:__subpackages__", 5 | #]) 6 | 7 | licenses(["notice"]) # Apache 2.0 8 | 9 | load( 10 | "@org_tensorflow//tensorflow/core:platform/default/build_config.bzl", 11 | "tf_proto_library_cc", 12 | ) 13 | 14 | tf_proto_library_cc( 15 | name = "mpi_message_proto", 16 | srcs = ["mpi_message.proto"], 17 | cc_api_version = 2, 18 | protodeps = ["@org_tensorflow//tensorflow/core:protos_all"], 19 | visibility = [ 20 | "//tensorflow_networking:__subpackages__", 21 | ], 22 | ) 23 | 24 | load( 25 | "@org_tensorflow//tensorflow:tensorflow.bzl", 26 | "tf_custom_op_library", 27 | "tf_custom_op_py_library", 28 | "tf_gen_op_libs", 29 | "tf_gen_op_wrapper_py", 30 | "tf_kernel_library", 31 | "tf_py_test", 32 | ) 33 | 34 | tf_custom_op_library( 35 | name = "python/ops/_mpi_ops.so", 36 | srcs = [ 37 | "kernels/mpi_ops.cc", 38 | "kernels/ring.cc", 39 | "kernels/ring.h", 40 | "ops/mpi_ops.cc", 41 | ], 42 | gpu_srcs = [ 43 | "kernels/ring.cu.cc", 44 | "kernels/ring.h", 45 | ], 46 | deps = [ 47 | ":mpi_message_proto_cc", 48 | "//third_party/mpi", 49 | "@org_tensorflow//tensorflow/core:stream_executor_headers_lib", 50 | ], 51 | ) 52 | 53 | tf_kernel_library( 54 | name = "mpi_ops_kernels", 55 | srcs = [ 56 | "kernels/mpi_ops.cc", 57 | "kernels/ring.cc", 58 | ], 59 | hdrs = [ 60 | "kernels/ring.h", 61 | ], 62 | gpu_srcs = [ 63 | "kernels/ring.cu.cc", 64 | ], 65 | deps = [ 66 | ":mpi_message_proto_cc", 67 | "//third_party/mpi", 68 | "@org_tensorflow//tensorflow/core", 69 | ], 70 | # TODO: Include? alwayslink = 1, 71 | ) 72 | 73 | tf_gen_op_libs( 74 | op_lib_names = ["mpi_ops"], 75 | ) 76 | 77 | tf_gen_op_wrapper_py( 78 | name = "mpi_ops", 79 | deps = [":mpi_ops_op_lib"], 80 | ) 81 | 82 | tf_custom_op_py_library( 83 | name = "mpi_collectives_py", 84 | srcs = [ 85 | "__init__.py", 86 | "python/ops/mpi_ops.py", 87 | ], 88 | dso = [ 89 | ":python/ops/_mpi_ops.so", 90 | ], 91 | kernels = [ 92 | ":mpi_ops_kernels", 93 | ":mpi_ops_op_lib", 94 | ], 95 | srcs_version = "PY2AND3", 96 | visibility = ["//visibility:public"], 97 | deps = [ 98 | ":mpi_ops", 99 | "@org_tensorflow//tensorflow/python:platform", 100 | ], 101 | ) 102 | 103 | tf_py_test( 104 | name = "mpi_ops_test", 105 | srcs = ["mpi_ops_test.py"], 106 | additional_deps = [ 107 | "@org_tensorflow//tensorflow:tensorflow_py", 108 | "@org_tensorflow//tensorflow/python:platform", 109 | ], 110 | data = [ 111 | ":python/ops/_mpi_ops.so", 112 | ], 113 | tags = ["manual"], 114 | ) 115 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/README: -------------------------------------------------------------------------------- 1 | # BUILD 2 | ## Build Docker 3 | Dockerfile: tensorflow_networking/seastar/Dockerfile 4 | 5 | ## Build Network whl with seastar: 6 | 1. python3 third_party/tensorflow/configure.py 7 | 2. bazel build -c opt --cxxopt=-std=gnu++14 --crosstool_top=@//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain //tensorflow_networking:libtensorflow_networking.so 8 | 3. cp bazel-bin/tensorflow_networking/libtensorflow_networking.so tensorflow_networking 9 | 4. python3.6 setup.py bdist_wheel 10 | 5. pip3.6 install --upgrade dist/tensorflow_networking-0.1.0-cp36-cp36m-linux_x86_64.whl 11 | 12 | # RUN 13 | Run with grpc+seastar: 14 | Use grpc+seastar you have to create .endpoint_map file in launch dir. 15 | (42353 and 42354 are grpc ports, 46068 and 47079 are seastar ports) 16 | 127.0.0.1:42353=127.0.0.1:46068 17 | 127.0.0.1:42354=127.0.0.1:47079 18 | 19 | grpc+seastar's Environment Variables: 20 | SEASTAR_CORE_NUMBER configure seastar threads' number, by default is 4. 21 | 22 | # Seastar-based RPC for TF Worker Service 23 | 24 | ## User API 25 | To use it in low level API, specify protocol=grpc+seastar in tf.train.Server. 26 | To use it in TF 2.0 higher API, specify rpc_layer=grpc+seastar in distribution 27 | strategy or TF_CONFIG environment variable. 28 | 29 | ## Design Goal 30 | The grpc+seastar RPC protocol is designed to improve the performance of 31 | distributed sparse model training, typically with hundreds or more worker 32 | nodes and large-scale embedding variables in the model. In particular, the 33 | protocol is proposed to address two performance bottlenecks induced by gRPC: 34 | (1) low QPS for small RPC message (2) extra copying of large RPC message. 35 | 36 | Some of the existing networking plugins (such as grpc+verbs) could partially 37 | address the second bottleneck, but they do not resolve the first one, which 38 | is the main obstacle for us to scale when our parameter synchronization 39 | process occurs, once a couple of milliseconds with hundreds of nodes. 40 | 41 | ## Architecture 42 | The grpc+seastar RPC protocol has three components: 43 | Seastar: an Apache 2.0 licensed high-performance server-side application 44 | framework. It uses kernel-bypass networking and userspace networking 45 | protocol stack to avoid frequent threading, context switching, 46 | and memory copying. 47 | 48 | TF2Seastar: wrapper classes of Seastar servers and clients to initialize RPC 49 | handlers to process RPC requests and responses. Seastar itself is not an RPC 50 | library like gRPC, so we wrap it to act like an RPC system. We also design 51 | a zero-copy wire format so that we could avoid extra memory copy with 52 | transmitting tensor data. 53 | 54 | Distributed runtime: TF WorkerService stubs implemented using TF2Seastar. 55 | Some RPC methods in WorkerServices uses the zero-copy wire format 56 | above (such as RecvTensor and RunGraph), while others still use Protocol 57 | Buffer. We keep the original gRPC implementation in TF MasterService stubs. 58 | 59 | ## Design Details 60 | (https://docs.google.com/document/d/1f1m-98rbH33WE0qNb3tP0yt9Jjbb-rprvweLobRbTCA/edit?usp=sharing) 61 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_worker_cache.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_networking/seastar/seastar_worker_cache.h" 2 | 3 | #include "tensorflow_networking/seastar/seastar_remote_worker.h" 4 | 5 | #include "tensorflow/core/distributed_runtime/worker_cache_logger.h" 6 | #include "tensorflow/core/distributed_runtime/worker_cache_partial.h" 7 | 8 | namespace tensorflow { 9 | 10 | namespace { 11 | 12 | class SeastarWorkerCache : public WorkerCachePartial { 13 | public: 14 | explicit SeastarWorkerCache(SeastarChannelCache* channel_cache, 15 | WorkerInterface* local_worker, 16 | const string& local_target, WorkerEnv* env) 17 | : local_target_(local_target), 18 | local_worker_(local_worker), 19 | channel_cache_(channel_cache), 20 | env_(env) {} 21 | 22 | virtual ~SeastarWorkerCache() {} 23 | 24 | void ListWorkers(std::vector* workers) const override { 25 | channel_cache_->ListWorkers(workers); 26 | } 27 | 28 | void ListWorkersInJob(const string& job_name, 29 | std::vector* workers) const override { 30 | channel_cache_->ListWorkersInJob(job_name, workers); 31 | } 32 | 33 | WorkerInterface* GetOrCreateWorker(const string& target) override { 34 | if (target == local_target_) { 35 | return local_worker_; 36 | } else { 37 | seastar::channel* chan = channel_cache_->FindWorkerChannel(target); 38 | if (!chan) return nullptr; 39 | return NewSeastarRemoteWorker(chan, &logger_, env_); 40 | } 41 | } 42 | 43 | void ReleaseWorker(const string& target, WorkerInterface* worker) { 44 | if (target == local_target_) { 45 | CHECK_EQ(worker, local_worker_) 46 | << "Releasing a worker that was not returned by this WorkerCache"; 47 | } else { 48 | WorkerCacheInterface::ReleaseWorker(target, worker); 49 | } 50 | } 51 | 52 | Status GetEagerClientCache( 53 | std::unique_ptr* eager_client_cache) override { 54 | return errors::Unimplemented( 55 | "Eager client not yet implemented for this protocol"); 56 | } 57 | 58 | void SetLogging(bool v) override { logger_.SetLogging(v); } 59 | void ClearLogs() override { logger_.ClearLogs(); } 60 | bool RetrieveLogs(int64 step_id, StepStats* ss) override { 61 | return logger_.RetrieveLogs(step_id, ss); 62 | } 63 | 64 | private: 65 | const string local_target_; 66 | WorkerInterface* const local_worker_; 67 | WorkerCacheLogger logger_; 68 | SeastarChannelCache* channel_cache_; 69 | WorkerEnv* env_; 70 | }; 71 | 72 | } // namespace 73 | 74 | WorkerCacheInterface* NewSeastarWorkerCache(SeastarChannelCache* channel_cache, 75 | WorkerEnv* env) { 76 | return new SeastarWorkerCache(channel_cache, nullptr, "", env); 77 | } 78 | 79 | WorkerCacheInterface* NewSeastarWorkerCacheWithLocalWorker( 80 | SeastarChannelCache* channel_cache, WorkerInterface* local_worker, 81 | const string& local_target, WorkerEnv* env) { 82 | return new SeastarWorkerCache(channel_cache, local_worker, local_target, env); 83 | } 84 | 85 | } // namespace tensorflow 86 | -------------------------------------------------------------------------------- /third_party/cares/cares.BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # MIT 2 | 3 | exports_files(["LICENSE.md"]) 4 | 5 | cc_library( 6 | name = "cares", 7 | srcs = [ 8 | "ares__close_sockets.c", 9 | "ares__get_hostent.c", 10 | "ares__read_line.c", 11 | "ares__timeval.c", 12 | "ares_cancel.c", 13 | "ares_create_query.c", 14 | "ares_data.c", 15 | "ares_destroy.c", 16 | "ares_expand_name.c", 17 | "ares_expand_string.c", 18 | "ares_fds.c", 19 | "ares_free_hostent.c", 20 | "ares_free_string.c", 21 | "ares_getenv.c", 22 | "ares_gethostbyaddr.c", 23 | "ares_gethostbyname.c", 24 | "ares_getnameinfo.c", 25 | "ares_getopt.c", 26 | "ares_getsock.c", 27 | "ares_init.c", 28 | "ares_library_init.c", 29 | "ares_llist.c", 30 | "ares_mkquery.c", 31 | "ares_nowarn.c", 32 | "ares_options.c", 33 | "ares_parse_a_reply.c", 34 | "ares_parse_aaaa_reply.c", 35 | "ares_parse_mx_reply.c", 36 | "ares_parse_naptr_reply.c", 37 | "ares_parse_ns_reply.c", 38 | "ares_parse_ptr_reply.c", 39 | "ares_parse_soa_reply.c", 40 | "ares_parse_srv_reply.c", 41 | "ares_parse_txt_reply.c", 42 | "ares_platform.c", 43 | "ares_process.c", 44 | "ares_query.c", 45 | "ares_search.c", 46 | "ares_send.c", 47 | "ares_strcasecmp.c", 48 | "ares_strdup.c", 49 | "ares_strerror.c", 50 | "ares_timeout.c", 51 | "ares_version.c", 52 | "ares_writev.c", 53 | "bitncmp.c", 54 | "inet_net_pton.c", 55 | "inet_ntop.c", 56 | "windows_port.c", 57 | ], 58 | hdrs = [ 59 | "ares.h", 60 | "ares_build.h", 61 | "ares_config.h", 62 | "ares_data.h", 63 | "ares_dns.h", 64 | "ares_getenv.h", 65 | "ares_getopt.h", 66 | "ares_inet_net_pton.h", 67 | "ares_iphlpapi.h", 68 | "ares_ipv6.h", 69 | "ares_library_init.h", 70 | "ares_llist.h", 71 | "ares_nowarn.h", 72 | "ares_platform.h", 73 | "ares_private.h", 74 | "ares_rules.h", 75 | "ares_setup.h", 76 | "ares_strcasecmp.h", 77 | "ares_strdup.h", 78 | "ares_version.h", 79 | "ares_writev.h", 80 | "bitncmp.h", 81 | "config-win32.h", 82 | "nameser.h", 83 | "setup_once.h", 84 | ], 85 | copts = [ 86 | "-D_GNU_SOURCE", 87 | "-D_HAS_EXCEPTIONS=0", 88 | "-DHAVE_CONFIG_H", 89 | ], 90 | defines = ["CARES_STATICLIB"], 91 | includes = ["."], 92 | visibility = ["//visibility:public"], 93 | ) 94 | 95 | genrule( 96 | name = "ares_build_h", 97 | srcs = ["@//third_party/cares:ares_build.h"], 98 | outs = ["ares_build.h"], 99 | cmd = "cat $< > $@", 100 | ) 101 | 102 | genrule( 103 | name = "ares_config_h", 104 | srcs = ["@//third_party/cares:ares_config.h"], 105 | outs = ["ares_config.h"], 106 | cmd = "cat $< > $@", 107 | ) 108 | -------------------------------------------------------------------------------- /third_party/tensorflow/configure.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Config Utility to write .bazelrc based on tensorflow.""" 16 | from __future__ import print_function 17 | import re 18 | import sys 19 | import tensorflow as tf 20 | 21 | 22 | def write_config(): 23 | """Retrive compile and link information from tensorflow and write to .bazelrc.""" 24 | 25 | cflags = tf.sysconfig.get_compile_flags() 26 | 27 | inc_regex = re.compile("^-I") 28 | opt_regex = re.compile("^-D") 29 | 30 | include_list = [] 31 | opt_list = [] 32 | 33 | for arg in cflags: 34 | if inc_regex.match(arg): 35 | include_list.append(arg) 36 | elif opt_regex.match(arg): 37 | opt_list.append(arg) 38 | else: 39 | print("WARNING: Unexpected cflag item {}".format(arg)) 40 | 41 | if len(include_list) != 1: 42 | print("ERROR: Expected a single include directory in " + 43 | "tf.sysconfig.get_compile_flags()") 44 | exit(1) 45 | 46 | library_regex = re.compile("^-l") 47 | libdir_regex = re.compile("^-L") 48 | 49 | library_list = [] 50 | libdir_list = [] 51 | 52 | lib = tf.sysconfig.get_link_flags() 53 | 54 | for arg in lib: 55 | if library_regex.match(arg): 56 | library_list.append(arg) 57 | elif libdir_regex.match(arg): 58 | libdir_list.append(arg) 59 | else: 60 | print("WARNING: Unexpected link flag item {}".format(arg)) 61 | 62 | if len(library_list) != 1 or len(libdir_list) != 1: 63 | print("ERROR: Expected exactly one lib and one libdir in" + 64 | "tf.sysconfig.get_link_flags()") 65 | exit(1) 66 | 67 | try: 68 | 69 | with open(".bazelrc", "w") as bazel_rc: 70 | for opt in opt_list: 71 | bazel_rc.write('build --copt="{}"\n'.format(opt)) 72 | 73 | bazel_rc.write('build --action_env TF_HEADER_DIR="{}"\n' 74 | .format(include_list[0][2:])) 75 | 76 | bazel_rc.write('build --action_env TF_SHARED_LIBRARY_DIR="{}"\n' 77 | .format(libdir_list[0][2:])) 78 | library_name = library_list[0][2:] 79 | if library_name.startswith(":"): 80 | library_name = library_name[1:] 81 | elif sys.platform == "darwin": 82 | library_name = "lib" + library_name + ".dylib" 83 | else: 84 | library_name = "lib" + library_name + ".so" 85 | bazel_rc.write('build --action_env TF_SHARED_LIBRARY_NAME="{}"\n' 86 | .format(library_name)) 87 | bazel_rc.close() 88 | except OSError: 89 | print("ERROR: Writing .bazelrc") 90 | exit(1) 91 | 92 | 93 | if __name__ == '__main__': 94 | write_config() 95 | -------------------------------------------------------------------------------- /third_party/ci_build/Dockerfile.rbe.ubuntu16.04-manylinux2014: -------------------------------------------------------------------------------- 1 | # Dockerfile to build a manylinux 2014 compliant cross-compiler. 2 | # 3 | # Builds a devtoolset gcc/libstdc++ that targets manylinux 2014 compatible 4 | # glibc (2.17) and system libstdc++ (4.8). 5 | 6 | FROM ubuntu:16.04 as devtoolset 7 | 8 | ENV DEBIAN_FRONTEND=noninteractive 9 | RUN apt-get update && \ 10 | apt-get install -y --no-install-recommends \ 11 | bzip2 \ 12 | cpio \ 13 | curl \ 14 | file \ 15 | flex \ 16 | g++ \ 17 | make \ 18 | patch \ 19 | rpm2cpio \ 20 | unar \ 21 | tar \ 22 | xz-utils \ 23 | && \ 24 | rm -rf /var/lib/apt/lists/* 25 | 26 | ADD devtoolset/fixlinks.sh fixlinks.sh 27 | ADD devtoolset/build_devtoolset.sh build_devtoolset.sh 28 | ADD devtoolset/rpm-patch.sh rpm-patch.sh 29 | 30 | # Set up a sysroot for glibc 2.17 / libstdc++ 4.8 / devtoolset-7 in /dt7. 31 | RUN /build_devtoolset.sh devtoolset-7 /dt7 32 | # Set up a sysroot for glibc 2.17 / libstdc++ 4.8 / devtoolset-8 in /dt8. 33 | RUN /build_devtoolset.sh devtoolset-8 /dt8 34 | 35 | FROM ubuntu:16.04 36 | COPY --from=devtoolset /dt7 /dt7 37 | COPY --from=devtoolset /dt8 /dt8 38 | 39 | ARG DEBIAN_FRONTEND=noninteractive 40 | 41 | # Install python 3.5/3.6/3.7/3.8. 42 | RUN apt-get update && \ 43 | apt-get install --no-install-recommends -yq software-properties-common && \ 44 | add-apt-repository ppa:deadsnakes/ppa && \ 45 | apt-get update && \ 46 | apt-get install --no-install-recommends -yq \ 47 | curl \ 48 | python3.5-dev \ 49 | python3.6-dev \ 50 | python3.7-dev \ 51 | python3.8-dev \ 52 | python3.8-distutils \ 53 | && \ 54 | rm -rf /var/lib/apt/lists/* 55 | 56 | RUN curl -O https://bootstrap.pypa.io/get-pip.py &&\ 57 | python3.5 get-pip.py && \ 58 | python3.6 get-pip.py && \ 59 | python3.7 get-pip.py && \ 60 | python3.8 get-pip.py && \ 61 | rm get-pip.py 62 | 63 | RUN mkdir -p "/dt7/usr/include/x86_64-linux-gnu" && \ 64 | ln -s "/usr/include/x86_64-linux-gnu/python3.5m" "/dt7/usr/include/x86_64-linux-gnu/python3.5m" && \ 65 | ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt7/usr/include/x86_64-linux-gnu/python3.6m" && \ 66 | ln -s "/usr/include/x86_64-linux-gnu/python3.7m" "/dt7/usr/include/x86_64-linux-gnu/python3.7m" && \ 67 | ln -s "/usr/include/x86_64-linux-gnu/python3.8" "/dt7/usr/include/x86_64-linux-gnu/python3.8" 68 | 69 | RUN mkdir -p "/dt8/usr/include/x86_64-linux-gnu" && \ 70 | ln -s "/usr/include/x86_64-linux-gnu/python3.5m" "/dt8/usr/include/x86_64-linux-gnu/python3.5m" && \ 71 | ln -s "/usr/include/x86_64-linux-gnu/python3.6m" "/dt8/usr/include/x86_64-linux-gnu/python3.6m" && \ 72 | ln -s "/usr/include/x86_64-linux-gnu/python3.7m" "/dt8/usr/include/x86_64-linux-gnu/python3.7m" && \ 73 | ln -s "/usr/include/x86_64-linux-gnu/python3.8" "/dt8/usr/include/x86_64-linux-gnu/python3.8" 74 | 75 | # TensorFlow dependencies 76 | RUN apt-get update && \ 77 | apt-get install -y --no-install-recommends \ 78 | build-essential \ 79 | default-jdk-headless \ 80 | git \ 81 | patchelf \ 82 | pkg-config \ 83 | python \ 84 | unzip \ 85 | zip \ 86 | zlib1g-dev \ 87 | && \ 88 | rm -rf /var/lib/apt/lists/* 89 | 90 | # Install auditwheel 91 | RUN pip3 install --no-cache-dir -U auditwheel 92 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/grpc_verbs_service_impl.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_ 17 | #define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_ 18 | 19 | #include "grpcpp/impl/codegen/async_stream.h" 20 | #include "grpcpp/impl/codegen/async_unary_call.h" 21 | #include "grpcpp/impl/codegen/proto_utils.h" 22 | #include "grpcpp/impl/codegen/rpc_method.h" 23 | #include "grpcpp/impl/codegen/service_type.h" 24 | #include "grpcpp/impl/codegen/status.h" 25 | #include "grpcpp/impl/codegen/stub_options.h" 26 | #include "grpcpp/impl/codegen/sync_stream.h" 27 | 28 | #include "tensorflow_networking/verbs/verbs_service.pb.h" 29 | 30 | namespace tensorflow { 31 | 32 | namespace grpc { 33 | 34 | // Implementation of `tensorflow.VerbsService`, based on the 35 | // definition in "//tensorflow_networking/verbs/verbs_service.proto", 36 | // and the gRPC generated stub and service classes. 37 | // See the proto file for the definition of methods and messages. 38 | class VerbsService GRPC_FINAL { 39 | public: 40 | class StubInterface { 41 | public: 42 | virtual ~StubInterface() {} 43 | virtual ::grpc::Status GetRemoteAddress( 44 | ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, 45 | GetRemoteAddressResponse* response) = 0; 46 | }; 47 | class Stub GRPC_FINAL : public StubInterface { 48 | public: 49 | Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); 50 | ::grpc::Status GetRemoteAddress( 51 | ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, 52 | GetRemoteAddressResponse* response) GRPC_OVERRIDE; 53 | 54 | private: 55 | std::shared_ptr< ::grpc::ChannelInterface> channel_; 56 | const ::grpc::internal::RpcMethod rpcmethod_GetRemoteAddress_; 57 | }; 58 | static std::unique_ptr NewStub( 59 | const std::shared_ptr< ::grpc::ChannelInterface>& channel, 60 | const ::grpc::StubOptions& options = ::grpc::StubOptions()); 61 | 62 | class AsyncService : public ::grpc::Service { 63 | public: 64 | AsyncService(); 65 | virtual ~AsyncService(); 66 | void RequestGetRemoteAddress( 67 | ::grpc::ServerContext* context, GetRemoteAddressRequest* request, 68 | ::grpc::ServerAsyncResponseWriter* response, 69 | ::grpc::CompletionQueue* new_call_cq, 70 | ::grpc::ServerCompletionQueue* notification_cq, void* tag) { 71 | ::grpc::Service::RequestAsyncUnary(0, context, request, response, 72 | new_call_cq, notification_cq, tag); 73 | } 74 | }; 75 | }; 76 | 77 | } // namespace grpc 78 | 79 | } // namespace tensorflow 80 | 81 | #endif // TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_IMPL_H_ 82 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/rdma_rendezvous_mgr.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "tensorflow_networking/verbs/rdma_rendezvous_mgr.h" 17 | #include 18 | #include "tensorflow/core/common_runtime/device.h" 19 | #include "tensorflow/core/common_runtime/device_mgr.h" 20 | #include "tensorflow/core/common_runtime/dma_helper.h" 21 | #include "tensorflow/core/lib/core/errors.h" 22 | #include "tensorflow/core/lib/strings/numbers.h" 23 | #include "tensorflow/core/lib/strings/str_util.h" 24 | #include "tensorflow_networking/verbs/verbs_util.h" 25 | 26 | namespace tensorflow { 27 | 28 | class RdmaRemoteRendezvous : public BaseRemoteRendezvous { 29 | public: 30 | RdmaRemoteRendezvous(const WorkerEnv* env, int64 step_id, RdmaMgr* rdma_mgr) 31 | : BaseRemoteRendezvous(env, step_id), rdma_mgr_(rdma_mgr) {} 32 | 33 | protected: 34 | void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, 35 | const Rendezvous::Args& args, 36 | DoneCallback done) override; 37 | 38 | private: 39 | ~RdmaRemoteRendezvous() override {} 40 | RdmaMgr* rdma_mgr_; 41 | 42 | TF_DISALLOW_COPY_AND_ASSIGN(RdmaRemoteRendezvous); 43 | }; 44 | 45 | void RdmaRemoteRendezvous::RecvFromRemoteAsync( 46 | const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, 47 | DoneCallback done) { 48 | Status s; 49 | // parse src_name and dst_name 50 | string src_name, dst_name, unused; 51 | if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name, 52 | &unused) || 53 | !DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name, 54 | &unused)) { 55 | s = errors::Internal("Could not parse src or dst name."); 56 | } 57 | if (!s.ok()) { 58 | LOG(ERROR) << "s is not ok, error code " << s.error_message(); 59 | done(s, Args(), recv_args, Tensor{}, false); 60 | return; 61 | } 62 | CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0); 63 | RdmaChannel* rc = rdma_mgr_->FindChannel(src_name); 64 | string key(parsed.FullKey()); 65 | string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); 66 | 67 | Device* dst_dev; 68 | s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev); 69 | CHECK(s.ok()) << "s is not ok, error code " << s.error_message(); 70 | if (!s.ok()) { 71 | done(s, Args(), recv_args, Tensor(), true); 72 | return; 73 | } 74 | 75 | RdmaTensorRequest* request = 76 | rc->InsertTensorRequest(key, step_id_, dst_dev, recv_args, done); 77 | request->Start(); 78 | } 79 | 80 | RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env) 81 | : BaseRendezvousMgr(env) {} 82 | 83 | BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id, 84 | const WorkerEnv* worker_env) { 85 | return new RdmaRemoteRendezvous(worker_env, step_id, rdma_mgr_); 86 | } 87 | 88 | } // end namespace tensorflow 89 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/ring.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifdef TENSORFLOW_USE_MPI 17 | 18 | #define EIGEN_USE_THREADS 19 | 20 | #include "tensorflow_networking/mpi_collectives/ring.h" 21 | 22 | namespace tensorflow { 23 | namespace contrib { 24 | namespace mpi { 25 | 26 | using CPUDevice = Eigen::ThreadPoolDevice; 27 | 28 | extern template MPI_Datatype MPIType(); 29 | extern template MPI_Datatype MPIType(); 30 | extern template MPI_Datatype MPIType(); 31 | extern template DataType TensorFlowDataType(); 32 | extern template DataType TensorFlowDataType(); 33 | extern template DataType TensorFlowDataType(); 34 | 35 | // Generate all necessary specializations for RingAllreduce. 36 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 37 | Tensor*, Tensor*); 38 | template Status RingAllreduce(OpKernelContext*, 39 | const Tensor*, Tensor*, 40 | Tensor*); 41 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 42 | Tensor*, Tensor*); 43 | 44 | // Generate all necessary specializations for RingAllgather. 45 | template Status RingAllgather(OpKernelContext*, const Tensor*, 46 | const std::vector&, 47 | Tensor*); 48 | template Status RingAllgather(OpKernelContext*, 49 | const Tensor*, 50 | const std::vector&, 51 | Tensor*); 52 | template Status RingAllgather(OpKernelContext*, const Tensor*, 53 | const std::vector&, 54 | Tensor*); 55 | 56 | // Copy data on a CPU using a straight-forward memcpy. 57 | template <> 58 | void CopyTensorData(void* dst, void* src, size_t size) { 59 | std::memcpy(dst, src, size); 60 | }; 61 | 62 | // Accumulate values on a CPU. 63 | #define GENERATE_ACCUMULATE(type) \ 64 | template <> \ 65 | void AccumulateTensorData(type * dst, type * src, \ 66 | size_t size) { \ 67 | for (unsigned int i = 0; i < size; i++) { \ 68 | dst[i] += src[i]; \ 69 | } \ 70 | }; 71 | GENERATE_ACCUMULATE(int); 72 | GENERATE_ACCUMULATE(long long); 73 | GENERATE_ACCUMULATE(float); 74 | #undef GENERATE_ACCUMULATE 75 | 76 | } // namespace mpi 77 | } // namespace contrib 78 | } // namespace tensorflow 79 | 80 | #endif // TENSORFLOW_USE_MPI 81 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/kernels/ring.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifdef TENSORFLOW_USE_MPI 17 | 18 | #define EIGEN_USE_THREADS 19 | 20 | #include "tensorflow_networking/mpi_collectives/kernels/ring.h" 21 | 22 | namespace tensorflow { 23 | namespace contrib { 24 | namespace mpi_collectives { 25 | 26 | using CPUDevice = Eigen::ThreadPoolDevice; 27 | 28 | extern template MPI_Datatype MPIType(); 29 | extern template MPI_Datatype MPIType(); 30 | extern template MPI_Datatype MPIType(); 31 | extern template DataType TensorFlowDataType(); 32 | extern template DataType TensorFlowDataType(); 33 | extern template DataType TensorFlowDataType(); 34 | 35 | // Generate all necessary specializations for RingAllreduce. 36 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 37 | Tensor*, Tensor*); 38 | template Status RingAllreduce(OpKernelContext*, 39 | const Tensor*, Tensor*, 40 | Tensor*); 41 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 42 | Tensor*, Tensor*); 43 | 44 | // Generate all necessary specializations for RingAllgather. 45 | template Status RingAllgather(OpKernelContext*, const Tensor*, 46 | const std::vector&, 47 | Tensor*); 48 | template Status RingAllgather(OpKernelContext*, 49 | const Tensor*, 50 | const std::vector&, 51 | Tensor*); 52 | template Status RingAllgather(OpKernelContext*, const Tensor*, 53 | const std::vector&, 54 | Tensor*); 55 | 56 | // Copy data on a CPU using a straight-forward memcpy. 57 | template <> 58 | void CopyTensorData(void* dst, void* src, size_t size) { 59 | std::memcpy(dst, src, size); 60 | }; 61 | 62 | // Accumulate values on a CPU. 63 | #define GENERATE_ACCUMULATE(type) \ 64 | template <> \ 65 | void AccumulateTensorData(type * dst, type * src, \ 66 | size_t size) { \ 67 | for (unsigned int i = 0; i < size; i++) { \ 68 | dst[i] += src[i]; \ 69 | } \ 70 | }; 71 | GENERATE_ACCUMULATE(int); 72 | GENERATE_ACCUMULATE(long long); 73 | GENERATE_ACCUMULATE(float); 74 | #undef GENERATE_ACCUMULATE 75 | 76 | } // namespace mpi_collectives 77 | } // namespace contrib 78 | } // namespace tensorflow 79 | 80 | #endif // TENSORFLOW_USE_MPI 81 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | exports_files(["LICENSE"]) 4 | 5 | filegroup( 6 | name = "c_srcs", 7 | data = glob([ 8 | "**/*.cc", 9 | "**/*.h", 10 | ]), 11 | ) 12 | 13 | COMMON_COPTS = [ 14 | "-DFMT_HEADER_ONLY", 15 | "-DNO_EXCEPTION_HACK", 16 | "-DNO_EXCEPTION_INTERCEPT", 17 | "-DHAVE_LZ4_COMPRESS_DEFAULT", 18 | ] 19 | 20 | cc_library( 21 | name = "seastar_tensor_coding", 22 | srcs = [ 23 | "seastar_message.cc", 24 | "seastar_tensor_coding.cc", 25 | ], 26 | hdrs = [ 27 | "seastar_message.h", 28 | "seastar_tensor_coding.h", 29 | ], 30 | deps = [ 31 | "@local_config_tf//:tf_header_lib", 32 | ], 33 | ) 34 | 35 | cc_library( 36 | name = "seastar_worker_service", 37 | srcs = [ 38 | "seastar_client_tag.cc", 39 | "seastar_server_tag.cc", 40 | "seastar_tag_factory.cc", 41 | "seastar_worker_service.cc", 42 | ], 43 | hdrs = [ 44 | "seastar_client_tag.h", 45 | "seastar_server_tag.h", 46 | "seastar_tag_factory.h", 47 | "seastar_worker_interface.h", 48 | "seastar_worker_service.h", 49 | "seastar_worker_service_method.h", 50 | ], 51 | copts = COMMON_COPTS, 52 | linkstatic = 1, 53 | deps = [ 54 | ":seastar_tensor_coding", 55 | "@local_config_tf//:tf_header_lib", 56 | "@seastar", 57 | ], 58 | alwayslink = 1, 59 | ) 60 | 61 | cc_library( 62 | name = "seastar_cpuset", 63 | srcs = ["seastar_cpuset.cc"], 64 | hdrs = ["seastar_cpuset.h"], 65 | linkstatic = 1, 66 | deps = [ 67 | "@local_config_tf//:tf_header_lib", 68 | ], 69 | alwayslink = 1, 70 | ) 71 | 72 | cc_library( 73 | name = "seastar_engine", 74 | srcs = [ 75 | "seastar_client.cc", 76 | "seastar_engine.cc", 77 | ], 78 | hdrs = [ 79 | "seastar_client.h", 80 | "seastar_engine.h", 81 | ], 82 | copts = COMMON_COPTS, 83 | linkstatic = 1, 84 | deps = [ 85 | ":seastar_cpuset", 86 | ":seastar_worker_service", 87 | ], 88 | alwayslink = 1, 89 | ) 90 | 91 | cc_library( 92 | name = "seastar_remote_worker", 93 | srcs = ["seastar_remote_worker.cc"], 94 | hdrs = ["seastar_remote_worker.h"], 95 | copts = COMMON_COPTS, 96 | linkstatic = 1, 97 | deps = [ 98 | ":seastar_worker_service", 99 | "@local_config_tf//:tf_header_lib", 100 | ], 101 | alwayslink = 1, 102 | ) 103 | 104 | cc_library( 105 | name = "seastar_worker_cache", 106 | srcs = [ 107 | "seastar_channel_cache.cc", 108 | "seastar_worker_cache.cc", 109 | ], 110 | hdrs = [ 111 | "seastar_channel_cache.h", 112 | "seastar_worker_cache.h", 113 | ], 114 | copts = COMMON_COPTS, 115 | linkstatic = 1, 116 | deps = [ 117 | ":seastar_engine", 118 | ":seastar_remote_worker", 119 | "@local_config_tf//:tf_header_lib", 120 | ], 121 | alwayslink = 1, 122 | ) 123 | 124 | cc_library( 125 | name = "seastar_rendezvous_mgr", 126 | srcs = ["seastar_rendezvous_mgr.cc"], 127 | hdrs = ["seastar_rendezvous_mgr.h"], 128 | copts = COMMON_COPTS, 129 | linkstatic = 1, 130 | deps = [ 131 | ":seastar_worker_cache", 132 | "@local_config_tf//:tf_header_lib", 133 | ], 134 | alwayslink = 1, 135 | ) 136 | 137 | cc_library( 138 | name = "seastar_server_lib", 139 | srcs = ["seastar_server_lib.cc"], 140 | hdrs = ["seastar_server_lib.h"], 141 | copts = COMMON_COPTS, 142 | linkstatic = 1, 143 | deps = [ 144 | ":seastar_rendezvous_mgr", 145 | "@local_config_tf//:tf_header_lib", 146 | ], 147 | alwayslink = 1, 148 | ) 149 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi/mpi_server_lib.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include 17 | #include 18 | 19 | #include "grpc/support/alloc.h" 20 | 21 | #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" 22 | #include "tensorflow/core/distributed_runtime/server_lib.h" 23 | #include "tensorflow/core/lib/core/status.h" 24 | #include "tensorflow/core/platform/env.h" 25 | 26 | #include "tensorflow_networking/mpi/mpi_server_lib.h" 27 | 28 | namespace tensorflow { 29 | 30 | namespace { 31 | // static utility function 32 | RendezvousMgrInterface* NewMPIRendezvousMgr(const WorkerEnv* env) { 33 | // Runtime check to disable the MPI path 34 | const char* mpienv = getenv("MPI_DISABLED"); 35 | if (mpienv && mpienv[0] == '1') { 36 | LOG(INFO) << "MPI path disabled by environment variable\n"; 37 | return new RpcRendezvousMgr(env); 38 | } else { 39 | return new MPIRendezvousMgr(env); 40 | } 41 | } 42 | 43 | } // namespace 44 | 45 | MPIServer::MPIServer(const ServerDef& server_def, Env* env) 46 | : GrpcServer(server_def, env) {} 47 | 48 | MPIServer::~MPIServer() { 49 | TF_CHECK_OK(Stop()); 50 | TF_CHECK_OK(Join()); 51 | } 52 | 53 | Status MPIServer::Init(ServiceInitFunction service_func, 54 | RendezvousMgrCreationFunction rendezvous_mgr_func) { 55 | GrpcServerOptions opts; 56 | opts.service_func = service_func; 57 | opts.rendezvous_mgr_func = rendezvous_mgr_func; 58 | Status s = GrpcServer::Init(opts); 59 | return s; 60 | } 61 | 62 | Status MPIServer::Start() { 63 | Status s = GrpcServer::Start(); 64 | return s; 65 | } 66 | 67 | Status MPIServer::Join() { 68 | Status s = GrpcServer::Join(); 69 | return s; 70 | } 71 | 72 | /* static */ 73 | Status MPIServer::Create(const ServerDef& server_def, Env* env, 74 | std::unique_ptr* out_server) { 75 | std::unique_ptr ret(new MPIServer(server_def, Env::Default())); 76 | ServiceInitFunction service_func = nullptr; 77 | TF_RETURN_IF_ERROR(ret->Init(service_func, NewMPIRendezvousMgr)); 78 | *out_server = std::move(ret); 79 | return Status::OK(); 80 | } 81 | 82 | namespace { 83 | 84 | class MPIServerFactory : public ServerFactory { 85 | public: 86 | bool AcceptsOptions(const ServerDef& server_def) override { 87 | return server_def.protocol() == "grpc+mpi"; 88 | } 89 | 90 | Status NewServer(const ServerDef& server_def, 91 | std::unique_ptr* out_server) override { 92 | return MPIServer::Create(server_def, Env::Default(), out_server); 93 | } 94 | }; 95 | 96 | // Registers a `ServerFactory` for `MPIServer` instances. 97 | class MPIServerRegistrar { 98 | public: 99 | MPIServerRegistrar() { 100 | gpr_allocation_functions alloc_fns; 101 | alloc_fns.malloc_fn = port::Malloc; 102 | alloc_fns.realloc_fn = port::Realloc; 103 | alloc_fns.free_fn = port::Free; 104 | gpr_set_allocation_functions(alloc_fns); 105 | ServerFactory::Register("MPI_SERVER", new MPIServerFactory()); 106 | } 107 | }; 108 | static MPIServerRegistrar registrar; 109 | 110 | } // namespace 111 | } // namespace tensorflow 112 | -------------------------------------------------------------------------------- /third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2014/BUILD: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The Bazel Authors. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This becomes the BUILD file for @local_config_cc// under non-FreeBSD unixes. 16 | 17 | package(default_visibility = ["//visibility:public"]) 18 | 19 | load(":cc_toolchain_config.bzl", "cc_toolchain_config") 20 | 21 | licenses(["notice"]) # Apache 2.0 22 | 23 | cc_library( 24 | name = "malloc", 25 | ) 26 | 27 | filegroup( 28 | name = "empty", 29 | srcs = [], 30 | ) 31 | 32 | filegroup( 33 | name = "cc_wrapper", 34 | srcs = ["cc_wrapper.sh"], 35 | ) 36 | 37 | filegroup( 38 | name = "compiler_deps", 39 | srcs = glob(["extra_tools/**"]) + [":empty"], 40 | ) 41 | 42 | # This is the entry point for --crosstool_top. Toolchains are found 43 | # by lopping off the name of --crosstool_top and searching for 44 | # the "${CPU}" entry in the toolchains attribute. 45 | cc_toolchain_suite( 46 | name = "toolchain", 47 | toolchains = { 48 | "k8|/dt7/usr/bin/gcc": ":cc-compiler-k8", 49 | "k8": ":cc-compiler-k8", 50 | "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a", 51 | "armeabi-v7a": ":cc-compiler-armeabi-v7a", 52 | }, 53 | ) 54 | 55 | cc_toolchain( 56 | name = "cc-compiler-k8", 57 | all_files = ":compiler_deps", 58 | ar_files = ":empty", 59 | as_files = ":empty", 60 | compiler_files = ":compiler_deps", 61 | dwp_files = ":empty", 62 | linker_files = ":compiler_deps", 63 | objcopy_files = ":empty", 64 | strip_files = ":empty", 65 | supports_param_files = 1, 66 | toolchain_config = ":linux_gnu_x86", 67 | toolchain_identifier = "linux_gnu_x86", 68 | ) 69 | 70 | cc_toolchain_config( 71 | name = "linux_gnu_x86", 72 | compiler = "/dt7/usr/bin/gcc", 73 | cpu = "k8", 74 | ) 75 | 76 | toolchain( 77 | name = "cc-toolchain-k8", 78 | exec_compatible_with = [ 79 | # TODO(katre): add autodiscovered constraints for host CPU and OS. 80 | ], 81 | target_compatible_with = [ 82 | # TODO(katre): add autodiscovered constraints for host CPU and OS. 83 | ], 84 | toolchain = ":cc-compiler-k8", 85 | toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", 86 | ) 87 | 88 | # Android tooling requires a default toolchain for the armeabi-v7a cpu. 89 | cc_toolchain( 90 | name = "cc-compiler-armeabi-v7a", 91 | all_files = ":empty", 92 | ar_files = ":empty", 93 | as_files = ":empty", 94 | compiler_files = ":empty", 95 | dwp_files = ":empty", 96 | linker_files = ":empty", 97 | objcopy_files = ":empty", 98 | strip_files = ":empty", 99 | supports_param_files = 1, 100 | toolchain_config = ":stub_armeabi-v7a", 101 | toolchain_identifier = "stub_armeabi-v7a", 102 | ) 103 | 104 | cc_toolchain_config( 105 | name = "stub_armeabi-v7a", 106 | compiler = "compiler", 107 | cpu = "armeabi-v7a", 108 | ) 109 | 110 | toolchain( 111 | name = "cc-toolchain-armeabi-v7a", 112 | exec_compatible_with = [ 113 | # TODO(katre): add autodiscovered constraints for host CPU and OS. 114 | ], 115 | target_compatible_with = [ 116 | "@bazel_tools//platforms:arm", 117 | "@bazel_tools//platforms:android", 118 | ], 119 | toolchain = ":cc-compiler-armabi-v7a", 120 | toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", 121 | ) 122 | -------------------------------------------------------------------------------- /tensorflow_networking/seastar/seastar_cpuset.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow_networking/seastar/seastar_cpuset.h" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "tensorflow/core/lib/strings/str_util.h" 11 | #include "tensorflow/core/lib/strings/strcat.h" 12 | #include "tensorflow/core/platform/cpu_info.h" 13 | #include "tensorflow/core/platform/logging.h" 14 | 15 | namespace tensorflow { 16 | 17 | namespace { 18 | 19 | const char* ROOT_PATH = "/tmp_tf"; 20 | const char* DEFAULT_ROOT_PATH = "/tmp"; 21 | const char* CPUSET_FILE_PATH = "/cpuset"; 22 | const size_t CORES_PER_FILE = 1; 23 | const size_t INIT_CPU_ID = 0; 24 | 25 | } // namespace 26 | 27 | class FileLocker { 28 | public: 29 | FileLocker(const std::string& rd) : root_dir_(rd) {} 30 | virtual ~FileLocker() {} 31 | 32 | bool Lock(const std::string& file_name) { 33 | return LockerOpImpl(file_name, LOCK_EX | LOCK_NB); 34 | } 35 | 36 | void Unlock(const std::string& file_name) { 37 | LockerOpImpl(file_name, LOCK_UN | LOCK_NB); 38 | } 39 | 40 | private: 41 | bool LockerOpImpl(const std::string& file_name, int lock_type) { 42 | std::string file_path; 43 | file_path += root_dir_ + std::string("/") + file_name; 44 | int fd = open(file_path.c_str(), O_RDWR | O_CREAT, 0777); 45 | if (fd < 0) { 46 | VLOG(2) << "can't open file:" << file_path; 47 | return false; 48 | } 49 | 50 | int stat = flock(fd, lock_type); 51 | return (stat == 0); 52 | } 53 | 54 | private: 55 | const std::string root_dir_; 56 | }; 57 | 58 | std::string CpusetAllocator::GetCpuset(size_t core_number) { 59 | // critical section: semphore to lock this function 60 | if (!ExistDir()) { 61 | CreateDir(); 62 | } 63 | CreateFiles(); 64 | auto locked_files = LockFiles(core_number); 65 | return ToCpuset(locked_files); 66 | } 67 | 68 | bool CpusetAllocator::ExistDir() { 69 | if (opendir(ROOT_PATH) != nullptr) { 70 | root_dir_ = ROOT_PATH; 71 | } else if (opendir(DEFAULT_ROOT_PATH) != nullptr) { 72 | root_dir_ = DEFAULT_ROOT_PATH; 73 | } else { 74 | return false; 75 | } 76 | root_dir_ += CPUSET_FILE_PATH; 77 | 78 | return opendir(root_dir_.c_str()) != nullptr; 79 | } 80 | 81 | void CpusetAllocator::CreateDir() { 82 | int flag = mkdir(root_dir_.c_str(), 0777); 83 | if (flag != 0) { 84 | LOG(FATAL) << "Seastar: create cpuset dir failure"; 85 | } 86 | } 87 | 88 | void CpusetAllocator::CreateFiles() { 89 | // todo: port::NumAllCPUs(), all phsical core should be available in docker 90 | // or this would bug here, k8s could be a candidate to allocator cpu cores. 91 | for (auto i = INIT_CPU_ID; i < port::NumTotalCPUs(); ++i) { 92 | auto file_name = std::to_string(i); 93 | 94 | std::string file_path; 95 | file_path += root_dir_ + std::string("/") + file_name; 96 | int fd = open(file_path.c_str(), O_RDWR | O_CREAT, 0777); 97 | if (fd < 0) { 98 | LOG(FATAL) << "Seastar error: can't create lock files for cpuset," 99 | << ", please try other protocol, filepath:" << file_path; 100 | } 101 | close(fd); 102 | 103 | files_.emplace_back(file_name); 104 | } 105 | } 106 | 107 | std::vector CpusetAllocator::LockFiles(size_t core_number) { 108 | std::vector locked_files; 109 | FileLocker locker(root_dir_); 110 | for (auto file : files_) { 111 | if (core_number <= 0) break; 112 | if (locker.Lock(file)) { 113 | core_number -= CORES_PER_FILE; 114 | locked_files.emplace_back(file); 115 | } 116 | } 117 | if (core_number > 0) { 118 | LOG(WARNING) << "Seastar: allocate cpuset by file lock failure," 119 | << "please try other protocol"; 120 | for (auto file : locked_files) { 121 | locker.Unlock(file); 122 | } 123 | return std::vector(); 124 | } 125 | return locked_files; 126 | } 127 | 128 | std::string CpusetAllocator::ToCpuset( 129 | const std::vector& locked_files) { 130 | if (locked_files.empty()) return std::string(); 131 | const std::string& cpuset = 132 | strings::StrCat("--cpuset=", str_util::Join(locked_files, ",")); 133 | return cpuset.substr(0, cpuset.size() - 1); 134 | } 135 | 136 | } // namespace tensorflow 137 | -------------------------------------------------------------------------------- /tensorflow_networking/repo.bzl: -------------------------------------------------------------------------------- 1 | """ TensorFlow Http Archive 2 | 3 | Modified http_archive that allows us to override the TensorFlow commit that is 4 | downloaded by setting an environment variable. This override is to be used for 5 | testing purposes. 6 | 7 | Add the following to your Bazel build command in order to override the 8 | TensorFlow revision. 9 | 10 | build: --action_env TF_REVISION="" 11 | 12 | * `TF_REVISION`: tensorflow revision override (git commit hash) 13 | """ 14 | 15 | _TF_REVISION = "TF_REVISION" 16 | 17 | def _get_env_var(ctx, name): 18 | if name in ctx.os.environ: 19 | return ctx.os.environ[name] 20 | else: 21 | return None 22 | 23 | # Checks if we should use the system lib instead of the bundled one 24 | def _use_system_lib(ctx, name): 25 | syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS") 26 | if syslibenv: 27 | for n in syslibenv.strip().split(","): 28 | if n.strip() == name: 29 | return True 30 | return False 31 | 32 | # Executes specified command with arguments and calls 'fail' if it exited with 33 | # non-zero code 34 | def _execute_and_check_ret_code(repo_ctx, cmd_and_args): 35 | result = repo_ctx.execute(cmd_and_args, timeout = 60) 36 | if result.return_code != 0: 37 | fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" + 38 | "Stderr: {3}").format( 39 | " ".join(cmd_and_args), 40 | result.return_code, 41 | result.stdout, 42 | result.stderr, 43 | )) 44 | 45 | def _repos_are_siblings(): 46 | return Label("@foo//bar").workspace_root.startswith("../") 47 | 48 | # Apply a patch_file to the repository root directory 49 | # Runs 'patch -p1' 50 | def _apply_patch(ctx, patch_file): 51 | if not ctx.which("patch"): 52 | fail("patch command is not found, please install it") 53 | cmd = ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)] 54 | _execute_and_check_ret_code(ctx, cmd) 55 | 56 | def _tensorflow_http_archive(ctx): 57 | use_syslib = _use_system_lib(ctx, ctx.attr.name) 58 | 59 | # Work around the bazel bug that redownloads the whole library. 60 | # Remove this after https://github.com/bazelbuild/bazel/issues/10515 is fixed. 61 | if ctx.attr.additional_build_files: 62 | for internal_src in ctx.attr.additional_build_files: 63 | _ = ctx.path(Label(internal_src)) 64 | 65 | # End of workaround. 66 | 67 | if not use_syslib: 68 | ctx.download_and_extract( 69 | ctx.attr.urls, 70 | "", 71 | ctx.attr.sha256, 72 | ctx.attr.type, 73 | ctx.attr.strip_prefix, 74 | ) 75 | if ctx.attr.patch_file != None: 76 | _apply_patch(ctx, ctx.attr.patch_file) 77 | 78 | if use_syslib and ctx.attr.system_build_file != None: 79 | # Use BUILD.bazel to avoid conflict with third party projects with 80 | # BUILD or build (directory) underneath. 81 | ctx.template("BUILD.bazel", ctx.attr.system_build_file, { 82 | "%prefix%": ".." if _repos_are_siblings() else "external", 83 | }, False) 84 | 85 | elif ctx.attr.build_file != None: 86 | # Use BUILD.bazel to avoid conflict with third party projects with 87 | # BUILD or build (directory) underneath. 88 | ctx.template("BUILD.bazel", ctx.attr.build_file, { 89 | "%prefix%": ".." if _repos_are_siblings() else "external", 90 | }, False) 91 | 92 | if use_syslib: 93 | for internal_src, external_dest in ctx.attr.system_link_files.items(): 94 | ctx.symlink(Label(internal_src), ctx.path(external_dest)) 95 | 96 | if ctx.attr.additional_build_files: 97 | for internal_src, external_dest in ctx.attr.additional_build_files.items(): 98 | ctx.symlink(Label(internal_src), ctx.path(external_dest)) 99 | 100 | 101 | tensorflow_http_archive = repository_rule( 102 | attrs = { 103 | "sha256": attr.string(mandatory = True), 104 | "urls": attr.string_list( 105 | mandatory = True, 106 | allow_empty = False, 107 | ), 108 | "strip_prefix": attr.string(), 109 | "type": attr.string(), 110 | "patch_file": attr.label(), 111 | "build_file": attr.label(), 112 | "system_build_file": attr.label(), 113 | "system_link_files": attr.string_dict(), 114 | "additional_build_files": attr.string_dict(), 115 | }, 116 | environ = [ 117 | "TF_SYSTEM_LIBS", 118 | ], 119 | implementation = _tensorflow_http_archive, 120 | ) 121 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/mpi_allgather_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import numpy as np 23 | import tensorflow as tf 24 | import tensorflow.contrib.mpi_collectives as mpi 25 | from tensorflow.python.platform import test 26 | 27 | 28 | average_allgather = False 29 | 30 | 31 | class AllgatherTest(test.TestCase): 32 | def checkAllgather(self, num_ranks, all_gathered, local_gathered): 33 | # Ensure that indices match. 34 | all_gat_ind = np.sort(all_gathered.indices) 35 | loc_gat_ind = np.sort(local_gathered.indices) 36 | assert(len(loc_gat_ind) == len(all_gat_ind)) 37 | for i in range(len(loc_gat_ind)): 38 | assert(loc_gat_ind[i] == all_gat_ind[i]) 39 | 40 | # For each index, verify same values. 41 | local_checked = [] 42 | for i in range(len(local_gathered.indices)): 43 | local_checked.append(False) 44 | for i in range(len(all_gathered.indices)): 45 | all_index = all_gathered.indices[i] 46 | # TODO(jthestness): Make this lookup quicker using sorting. 47 | loc_index = -1 48 | for j in range(len(local_gathered.indices)): 49 | if local_gathered.indices[j] == all_index and not local_checked[j]: 50 | loc_index = j 51 | local_checked[j] = True 52 | break 53 | assert(loc_index >= 0) 54 | correct_output = local_gathered.values[loc_index][0] 55 | if average_allgather: 56 | correct_output = correct_output / float(num_ranks) 57 | assert(all_gathered.values[i][0] == correct_output) 58 | 59 | 60 | def test_mpi_allgather(self): 61 | # Get MPI rank 62 | my_rank = int(os.environ['PMI_RANK']) 63 | num_ranks = int(os.environ['PMI_SIZE']) 64 | 65 | indices_per_rank = 100 66 | tensor_width = 10 67 | 68 | # Create IndexedSlices for each rank, some with overlapping indices. 69 | to_gather_indices = [] 70 | to_gather_values = [] 71 | to_gather = [] 72 | for rank_id in range(num_ranks): 73 | indices = [] 74 | values = [] 75 | my_multiple = rank_id + 1 76 | current_index = my_multiple 77 | for i in range(indices_per_rank): 78 | indices.append(current_index) 79 | ones_tensor = tf.ones([tensor_width]) 80 | values.append(tf.multiply(ones_tensor, 81 | tf.fill(ones_tensor.get_shape(), 82 | float(current_index)))) 83 | current_index += my_multiple 84 | concat_ind = tf.stack(indices) 85 | concat_vals = tf.stack(values) 86 | to_gather_indices.append(concat_ind) 87 | to_gather_values.append(concat_vals) 88 | to_gather.append(tf.IndexedSlices(concat_vals, concat_ind)) 89 | 90 | # Collect the local IndexedSlices (indices and values) to create 91 | # correct IndexedSlices output. 92 | correct_gather_indices = tf.concat(to_gather_indices, 0) 93 | correct_gather_values = tf.concat(to_gather_values, 0) 94 | correct_gather = tf.IndexedSlices(correct_gather_values, 95 | correct_gather_indices) 96 | 97 | all_gather = mpi.allreduce(to_gather[my_rank], average_allgather) 98 | 99 | # NOTE: This assumes that device IDs are numbered the same as ranks. 100 | gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) 101 | config = tf.ConfigProto(gpu_options=gpu_options) 102 | 103 | # MPI Session to test allgather. 104 | with mpi.Session(config=config) as sess: 105 | sess.run(tf.global_variables_initializer()) 106 | 107 | all_gathered, local_gathered = sess.run([all_gather, correct_gather]) 108 | 109 | # Compare all_gathered with local_gathered. 110 | self.checkAllgather(num_ranks, all_gathered, local_gathered) 111 | 112 | 113 | if __name__ == '__main__': 114 | test.main() 115 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/ops/mpi_ops.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifdef TENSORFLOW_USE_MPI 17 | 18 | #include "tensorflow/core/framework/op.h" 19 | #include "tensorflow/core/framework/shape_inference.h" 20 | 21 | namespace tensorflow { 22 | namespace contrib { 23 | namespace mpi_collectives { 24 | 25 | REGISTER_OP("MPIInit").Doc(R"doc( 26 | Initialize MPI for the current process. 27 | 28 | If this is run on a GPU, then that GPU must be used for all future MPI 29 | operations. If it is run on CPU, then all future MPI operations must also 30 | run on CPU. 31 | )doc"); 32 | 33 | REGISTER_OP("MPISize") 34 | .Output("size: int32") 35 | .SetShapeFn([](shape_inference::InferenceContext* c) { 36 | c->set_output(0, c->Scalar()); 37 | return Status::OK(); 38 | }) 39 | .Doc(R"doc( 40 | Returns the number of running MPI processes. 41 | 42 | More precisely, returns the number of MPI processes in the group associated 43 | with the MPI_COMM_WORLD communicator. 44 | 45 | size: Size of the MPI group. 46 | )doc"); 47 | 48 | REGISTER_OP("MPIRank") 49 | .Output("rank: int32") 50 | .SetShapeFn([](shape_inference::InferenceContext* c) { 51 | c->set_output(0, c->Scalar()); 52 | return Status::OK(); 53 | }) 54 | .Doc(R"doc( 55 | Returns the index of the current process in the MPI group. 56 | 57 | More precisely, returns the rank of the calling process in the MPI_COMM_WORLD 58 | communicator. 59 | 60 | rank: Rank of the calling process. 61 | )doc"); 62 | 63 | REGISTER_OP("MPILocalRank") 64 | .Output("rank: int32") 65 | .SetShapeFn([](shape_inference::InferenceContext* c) { 66 | c->set_output(0, c->Scalar()); 67 | return Status::OK(); 68 | }) 69 | .Doc(R"doc( 70 | Returns the index of the current process in the node it is on. 71 | 72 | More precisely, returns the rank of the calling process in communicator that 73 | only spans the MPI processes running on that node. 74 | 75 | rank: Rank of the calling process on the node it is on. 76 | )doc"); 77 | 78 | REGISTER_OP("MPIAllreduce") 79 | .Attr("T: {int32, int64, float32}") 80 | .Input("tensor: T") 81 | .Output("sum: T") 82 | .SetShapeFn([](shape_inference::InferenceContext* c) { 83 | c->set_output(0, c->input(0)); 84 | return Status::OK(); 85 | }) 86 | .Doc(R"doc( 87 | Perform an MPI Allreduce on a tensor. All other processes that do a reduction 88 | on a tensor with the same name must have the same dimension for that tensor. 89 | Tensors are reduced with other tensors that have the same node name for the 90 | allreduce. 91 | 92 | Arguments 93 | tensor: A tensor to reduce. 94 | 95 | Output 96 | sum: A tensor with the same shape as `tensor`, summed across all 97 | MPI processes. 98 | )doc"); 99 | 100 | REGISTER_OP("MPIAllgather") 101 | .Attr("T: {int32, int64, float32}") 102 | .Attr("S: {int64}") 103 | .Input("tensor: T") 104 | .Input("sizes: S") 105 | .Output("gathered: T") 106 | .SetShapeFn([](shape_inference::InferenceContext* c) { 107 | shape_inference::ShapeHandle output; 108 | TF_RETURN_IF_ERROR( 109 | c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); 110 | c->set_output(0, output); 111 | return Status::OK(); 112 | }) 113 | .Doc(R"doc( 114 | Perform an MPI Allgather on a tensor. All other processes that do a gather on a 115 | tensor with the same name must have the same rank for that tensor, and have the 116 | same dimension on all but the first dimension. 117 | 118 | Arguments 119 | tensor: A tensor to gather. 120 | sizes: A tensor containing the first-dimension sizes of tensors to be 121 | gathered from other ranks 122 | 123 | Output 124 | gathered: A tensor with the same shape as `tensor` except for the first 125 | dimension, which is the sum of dimensions in `sizes`. 126 | )doc"); 127 | 128 | } // namespace mpi_collectives 129 | } // namespace contrib 130 | } // namespace tensorflow 131 | 132 | #endif // TENSORFLOW_USE_MPI 133 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/kernels/ring.cu.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifdef TENSORFLOW_USE_MPI 17 | 18 | #if GOOGLE_CUDA 19 | 20 | #define EIGEN_USE_GPU 21 | 22 | #include "tensorflow_networking/mpi_collectives/kernels/ring.h" 23 | 24 | namespace tensorflow { 25 | namespace contrib { 26 | namespace mpi_collectives { 27 | 28 | using CPUDevice = Eigen::ThreadPoolDevice; 29 | 30 | template <> 31 | MPI_Datatype MPIType() { 32 | return MPI_FLOAT; 33 | }; 34 | template <> 35 | MPI_Datatype MPIType() { 36 | return MPI_INT; 37 | }; 38 | template <> 39 | MPI_Datatype MPIType() { 40 | return MPI_LONG_LONG; 41 | }; 42 | 43 | template <> 44 | DataType TensorFlowDataType() { 45 | return DT_FLOAT; 46 | }; 47 | template <> 48 | DataType TensorFlowDataType() { 49 | return DT_INT32; 50 | }; 51 | template <> 52 | DataType TensorFlowDataType() { 53 | return DT_INT64; 54 | }; 55 | 56 | // Generate all necessary specializations for RingAllreduce. 57 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 58 | Tensor*, Tensor*); 59 | template Status RingAllreduce(OpKernelContext*, 60 | const Tensor*, Tensor*, 61 | Tensor*); 62 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 63 | Tensor*, Tensor*); 64 | 65 | // Generate all necessary specializations for RingAllgather. 66 | template Status RingAllgather(OpKernelContext*, const Tensor*, 67 | const std::vector&, 68 | Tensor*); 69 | template Status RingAllgather(OpKernelContext*, 70 | const Tensor*, 71 | const std::vector&, 72 | Tensor*); 73 | template Status RingAllgather(OpKernelContext*, const Tensor*, 74 | const std::vector&, 75 | Tensor*); 76 | 77 | // Synchronously copy data on the GPU, using a different stream than the default 78 | // and than TensorFlow to avoid synchronizing on operations unrelated to the 79 | // allreduce. 80 | template <> 81 | void CopyTensorData(void* dst, void* src, size_t size) { 82 | auto stream = CudaStreamForMPI(); 83 | cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); 84 | cudaStreamSynchronize(stream); 85 | }; 86 | 87 | // Elementwise accumulation kernel for GPU. 88 | template 89 | __global__ void elemwise_accum(T* out, const T* in, const size_t N) { 90 | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N; 91 | i += blockDim.x * gridDim.x) { 92 | out[i] += in[i]; 93 | } 94 | } 95 | 96 | // Synchronously accumulate tensors on the GPU, using a different stream than 97 | // the default and than TensorFlow to avoid synchronizing on operations 98 | // unrelated to the allreduce. 99 | #define GENERATE_ACCUMULATE(type) \ 100 | template <> \ 101 | void AccumulateTensorData(type * dst, type * src, \ 102 | size_t size) { \ 103 | auto stream = CudaStreamForMPI(); \ 104 | elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ 105 | cudaStreamSynchronize(stream); \ 106 | }; 107 | GENERATE_ACCUMULATE(int); 108 | GENERATE_ACCUMULATE(long long); 109 | GENERATE_ACCUMULATE(float); 110 | #undef GENERATE_ACCUMULATE 111 | 112 | } // namespace mpi_collectives 113 | } // namespace contrib 114 | } // namespace tensorflow 115 | #endif // GOOGLE_CUDA 116 | 117 | #endif // TENSORFLOW_USE_MPI 118 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/ring.cu.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifdef TENSORFLOW_USE_MPI 17 | 18 | #if GOOGLE_CUDA 19 | 20 | #define EIGEN_USE_GPU 21 | 22 | #include "tensorflow_networking/mpi_collectives/ring.h" 23 | 24 | #include "tensorflow/core/util/cuda_launch_config.h" 25 | 26 | namespace tensorflow { 27 | namespace contrib { 28 | namespace mpi { 29 | 30 | using CPUDevice = Eigen::ThreadPoolDevice; 31 | 32 | template <> 33 | MPI_Datatype MPIType() { 34 | return MPI_FLOAT; 35 | }; 36 | template <> 37 | MPI_Datatype MPIType() { 38 | return MPI_INT; 39 | }; 40 | template <> 41 | MPI_Datatype MPIType() { 42 | return MPI_LONG_LONG; 43 | }; 44 | 45 | template <> 46 | DataType TensorFlowDataType() { 47 | return DT_FLOAT; 48 | }; 49 | template <> 50 | DataType TensorFlowDataType() { 51 | return DT_INT32; 52 | }; 53 | template <> 54 | DataType TensorFlowDataType() { 55 | return DT_INT64; 56 | }; 57 | 58 | // Generate all necessary specializations for RingAllreduce. 59 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 60 | Tensor*, Tensor*); 61 | template Status RingAllreduce(OpKernelContext*, 62 | const Tensor*, Tensor*, 63 | Tensor*); 64 | template Status RingAllreduce(OpKernelContext*, const Tensor*, 65 | Tensor*, Tensor*); 66 | 67 | // Generate all necessary specializations for RingAllgather. 68 | template Status RingAllgather(OpKernelContext*, const Tensor*, 69 | const std::vector&, 70 | Tensor*); 71 | template Status RingAllgather(OpKernelContext*, 72 | const Tensor*, 73 | const std::vector&, 74 | Tensor*); 75 | template Status RingAllgather(OpKernelContext*, const Tensor*, 76 | const std::vector&, 77 | Tensor*); 78 | 79 | // Synchronously copy data on the GPU, using a different stream than the default 80 | // and than TensorFlow to avoid synchronizing on operations unrelated to the 81 | // allreduce. 82 | template <> 83 | void CopyTensorData(void* dst, void* src, size_t size) { 84 | auto stream = CudaStreamForMPI(); 85 | cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); 86 | cudaStreamSynchronize(stream); 87 | }; 88 | 89 | // Elementwise accumulation kernel for GPU. 90 | template 91 | __global__ void elemwise_accum(T* out, const T* in, const size_t N) { 92 | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N; 93 | i += blockDim.x * gridDim.x) { 94 | out[i] += in[i]; 95 | } 96 | } 97 | 98 | // Synchronously accumulate tensors on the GPU, using a different stream than 99 | // the default and than TensorFlow to avoid synchronizing on operations 100 | // unrelated to the allreduce. 101 | #define GENERATE_ACCUMULATE(type) \ 102 | template <> \ 103 | void AccumulateTensorData(type * dst, type * src, \ 104 | size_t size) { \ 105 | auto stream = CudaStreamForMPI(); \ 106 | elemwise_accum<<<32, 256, 0, stream>>>(dst, src, size); \ 107 | cudaStreamSynchronize(stream); \ 108 | }; 109 | GENERATE_ACCUMULATE(int); 110 | GENERATE_ACCUMULATE(long long); 111 | GENERATE_ACCUMULATE(float); 112 | #undef GENERATE_ACCUMULATE 113 | 114 | } // namespace mpi 115 | } // namespace contrib 116 | } // namespace tensorflow 117 | #endif // GOOGLE_CUDA 118 | 119 | #endif // TENSORFLOW_USE_MPI 120 | -------------------------------------------------------------------------------- /tensorflow_networking/mpi_collectives/python/ops/mpi_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Inter-process communication using MPI.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops 24 | from tensorflow.contrib.util import loader 25 | from tensorflow.python.framework import ops 26 | from tensorflow.python.platform import resource_loader 27 | 28 | _mpi_ops_so = loader.load_op_library( 29 | resource_loader.get_path_to_datafile('_mpi_ops.so')) 30 | 31 | 32 | def size(name=None): 33 | """An op which returns the number of MPI processes. 34 | 35 | This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the 36 | size of the global communicator. 37 | 38 | Returns: 39 | An integer scalar containing the number of MPI processes. 40 | """ 41 | return gen_mpi_ops.mpi_size(name=name) 42 | 43 | 44 | ops.NotDifferentiable('MPISize') 45 | 46 | 47 | def rank(name=None): 48 | """An op which returns the MPI rank of the calling process. 49 | 50 | This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the 51 | rank of the current process in the global communicator. 52 | 53 | Returns: 54 | An integer scalar with the MPI rank of the calling process. 55 | """ 56 | return gen_mpi_ops.mpi_rank(name=name) 57 | 58 | 59 | ops.NotDifferentiable('MPIRank') 60 | 61 | 62 | def init(name=None): 63 | """An op which initializes MPI on the device on which it is run. 64 | 65 | All future MPI ops must be run on the same device that the `init` op was run 66 | on. 67 | """ 68 | return gen_mpi_ops.mpi_init(name=name) 69 | 70 | 71 | ops.NotDifferentiable('MPIInit') 72 | 73 | 74 | def local_rank(name=None): 75 | """An op which returns the local MPI rank of the calling process, within the 76 | node that it is running on. For example, if there are seven processes running 77 | on a node, their local ranks will be zero through six, inclusive. 78 | 79 | This is equivalent to running `MPI_Comm_rank(...)` on a new communicator 80 | which only includes processes on the same node. 81 | 82 | Returns: 83 | An integer scalar with the local MPI rank of the calling process. 84 | """ 85 | return gen_mpi_ops.mpi_local_rank(name=name) 86 | 87 | 88 | ops.NotDifferentiable('MPILocalRank') 89 | 90 | 91 | def _allreduce(tensor, name=None): 92 | """An op which sums an input tensor over all the MPI processes. 93 | 94 | The reduction operation is keyed by the name of the op. The tensor type and 95 | shape must be the same on all MPI processes for a given name. The reduction 96 | will not start until all processes are ready to send and receive the tensor. 97 | 98 | Returns: 99 | A tensor of the same shape and type as `tensor`, summed across all 100 | processes. 101 | """ 102 | return gen_mpi_ops.mpi_allreduce(tensor, name=name) 103 | 104 | 105 | ops.NotDifferentiable('MPIAllreduce') 106 | 107 | 108 | def allgather(tensor, name=None): 109 | """An op which concatenates the input tensor with the same input tensor on 110 | all other MPI processes. 111 | 112 | The concatenation is done on the first dimension, so the input tensors on the 113 | different processes must have the same rank and shape, except for the first 114 | dimension, which is allowed to be different. 115 | 116 | Returns: 117 | A tensor of the same type as `tensor`, concatenated on dimension zero 118 | across all processes. The shape is identical to the input shape, except for 119 | the first dimension, which may be greater and is the sum of all first 120 | dimensions of the tensors in different MPI processes. 121 | """ 122 | # Specify that first allgather is to collect the tensor gather sizes, 123 | # indicated by passing in a scalar (0-D tensor) of value 0 124 | sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') 125 | my_size = tf.slice( 126 | tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') 127 | if name is None: 128 | name = 'allgather' 129 | sizing_name = '{}_sizing'.format(name) 130 | sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name) 131 | return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name) 132 | 133 | 134 | ops.NotDifferentiable('MPIAllgather') 135 | -------------------------------------------------------------------------------- /tensorflow_networking/verbs/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Verbs RDMA communication interfaces and implementations for TensorFlow. 3 | 4 | package(default_visibility = [ 5 | "//tensorflow_networking:__subpackages__", 6 | ]) 7 | 8 | licenses(["notice"]) # Apache 2.0 9 | 10 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_library") 11 | 12 | exports_files(["LICENSE"]) 13 | 14 | filegroup( 15 | name = "c_srcs", 16 | data = glob([ 17 | "**/*.cc", 18 | "**/*.h", 19 | ]), 20 | ) 21 | 22 | # For platform specific build config 23 | load( 24 | "@org_tensorflow//tensorflow/core:platform/default/build_config.bzl", 25 | "tf_proto_library_cc", 26 | ) 27 | 28 | tf_proto_library_cc( 29 | name = "verbs_service_proto", 30 | srcs = ["verbs_service.proto"], 31 | has_services = 1, 32 | cc_api_version = 2, 33 | visibility = [ 34 | "//tensorflow_networking:__subpackages__", 35 | ], 36 | ) 37 | 38 | cc_library( 39 | name = "verbs_util", 40 | srcs = ["verbs_util.cc"], 41 | hdrs = ["verbs_util.h"], 42 | deps = [ 43 | "@org_tensorflow//tensorflow/core:framework", 44 | "@org_tensorflow//tensorflow/core:lib", 45 | ], 46 | ) 47 | 48 | cc_library( 49 | name = "grpc_verbs_service", 50 | srcs = ["grpc_verbs_service.cc"], 51 | hdrs = ["grpc_verbs_service.h"], 52 | deps = [ 53 | ":grpc_verbs_service_impl", 54 | ":rdma_mgr", 55 | ":verbs_service_proto_cc", 56 | "@org_tensorflow//tensorflow:grpc++", 57 | "@org_tensorflow//tensorflow/core:lib", 58 | "@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr", 59 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:async_service_interface", 60 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_call", 61 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_util", 62 | ], 63 | alwayslink = 1, 64 | ) 65 | 66 | cc_library( 67 | name = "grpc_verbs_service_impl", 68 | srcs = ["grpc_verbs_service_impl.cc"], 69 | hdrs = ["grpc_verbs_service_impl.h"], 70 | deps = [ 71 | ":verbs_service_proto_cc", 72 | "@org_tensorflow//tensorflow:grpc++", 73 | ], 74 | ) 75 | 76 | cc_library( 77 | name = "grpc_verbs_client", 78 | srcs = ["grpc_verbs_client.cc"], 79 | hdrs = ["grpc_verbs_client.h"], 80 | deps = [ 81 | ":grpc_verbs_service_impl", 82 | ":verbs_service_proto_cc", 83 | "@org_tensorflow//tensorflow/core:lib", 84 | "@org_tensorflow//tensorflow/core/distributed_runtime:call_options", 85 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_util", 86 | ], 87 | alwayslink = 1, 88 | ) 89 | 90 | cc_library( 91 | name = "rdma_rendezvous_mgr", 92 | srcs = ["rdma_rendezvous_mgr.cc"], 93 | hdrs = ["rdma_rendezvous_mgr.h"], 94 | deps = [ 95 | ":rdma_mgr", 96 | ":verbs_util", 97 | "@org_tensorflow//tensorflow/core", 98 | "@org_tensorflow//tensorflow/core:lib", 99 | "@org_tensorflow//tensorflow/core/distributed_runtime:base_rendezvous_mgr", 100 | "@org_tensorflow//tensorflow/core/distributed_runtime:worker_env", 101 | ], 102 | ) 103 | 104 | tf_cuda_library( 105 | name = "rdma_mgr", 106 | srcs = ["rdma_mgr.cc"], 107 | hdrs = ["rdma_mgr.h"], 108 | deps = [ 109 | ":grpc_verbs_client", 110 | ":rdma", 111 | ":verbs_service_proto_cc", 112 | "@org_tensorflow//tensorflow/core", 113 | "@org_tensorflow//tensorflow/core:lib", 114 | "@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr", 115 | "@org_tensorflow//tensorflow/core/distributed_runtime:worker_env", 116 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_channel", 117 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", 118 | ], 119 | ) 120 | 121 | tf_cuda_library( 122 | name = "rdma", 123 | srcs = ["rdma.cc"], 124 | hdrs = ["rdma.h"], 125 | linkopts = ["-libverbs"], 126 | deps = [ 127 | ":grpc_verbs_client", 128 | ":verbs_service_proto_cc", 129 | ":verbs_util", 130 | "@org_tensorflow//tensorflow/core", 131 | "@org_tensorflow//tensorflow/core:framework", 132 | "@org_tensorflow//tensorflow/core:lib", 133 | "@org_tensorflow//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", 134 | "@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr", 135 | "@org_tensorflow//tensorflow/core/distributed_runtime:worker_env", 136 | ], 137 | ) 138 | 139 | cc_library( 140 | name = "verbs_server_lib", 141 | srcs = ["verbs_server_lib.cc"], 142 | hdrs = ["verbs_server_lib.h"], 143 | linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel 144 | deps = [ 145 | ":grpc_verbs_service", 146 | ":rdma_mgr", 147 | ":rdma_rendezvous_mgr", 148 | "@org_tensorflow//tensorflow/core:lib", 149 | "@org_tensorflow//tensorflow/core/distributed_runtime:server_lib", 150 | "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", 151 | ], 152 | alwayslink = 1, 153 | ) 154 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | # Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the 2 | # target CPU to build transient dependencies correctly. See 3 | # https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu 4 | build:android --crosstool_top=//external:android/crosstool 5 | build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain 6 | build:android_arm --config=android 7 | build:android_arm --cpu=armeabi-v7a 8 | build:android_arm --fat_apk_cpu=armeabi-v7a 9 | build:android_arm64 --config=android 10 | build:android_arm64 --cpu=arm64-v8a 11 | build:android_arm64 --fat_apk_cpu=arm64-v8a 12 | 13 | # Sets the default Apple platform to macOS. 14 | build --apple_platform_type=macos 15 | 16 | # Config to use a mostly-static build and disable modular op registration 17 | # support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python). 18 | # By default, TensorFlow will build with a dependence on 19 | # //tensorflow:libtensorflow_framework.so. 20 | build:monolithic --define framework_shared_object=false 21 | 22 | # For projects which use TensorFlow as part of a Bazel build process, putting 23 | # nothing in a bazelrc will default to a monolithic build. The following line 24 | # opts in to modular op registration support by default. 25 | build --define framework_shared_object=true 26 | 27 | # Please note that MKL on MacOS or windows is still not supported. 28 | # If you would like to use a local MKL instead of downloading, please set the 29 | # environment variable "TF_MKL_ROOT" every time before build. 30 | build:mkl --define=build_with_mkl=true --define=enable_mkl=true 31 | build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 32 | build:mkl -c opt 33 | 34 | # This config option is used to enable MKL-DNN open source library only, 35 | # without depending on MKL binary version. 36 | build:mkl_open_source_only --define=build_with_mkl_dnn_only=true 37 | build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true 38 | build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=0 39 | 40 | build:download_clang --crosstool_top=@local_config_download_clang//:toolchain 41 | build:download_clang --define=using_clang=true 42 | # Instruct clang to use LLD for linking. 43 | # This only works with GPU builds currently, since Bazel sets -B/usr/bin in 44 | # auto-generated CPU crosstool, forcing /usr/bin/ld.lld to be preferred over 45 | # the downloaded one. 46 | build:download_clang_use_lld --linkopt='-fuse-ld=lld' 47 | 48 | build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain 49 | build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true 50 | 51 | build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain 52 | build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true 53 | 54 | build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain 55 | build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true 56 | 57 | build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain 58 | build:sycl --define=using_sycl=true --define=using_trisycl=false 59 | 60 | build:sycl_nodouble --crosstool_top=@local_config_sycl//crosstool:toolchain 61 | build:sycl_nodouble --define=using_sycl=true --cxxopt -DTENSORFLOW_SYCL_NO_DOUBLE 62 | 63 | build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain 64 | build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address 65 | 66 | build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain 67 | build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true 68 | 69 | # Options extracted from configure script 70 | build:gdr --define=with_gdr_support=true 71 | build:ngraph --define=with_ngraph_support=true 72 | build:verbs --define=with_verbs_support=true 73 | build:numa --define=with_numa_support=true 74 | 75 | # Options to disable default on features 76 | build:noaws --define=no_aws_support=true 77 | build:nogcp --define=no_gcp_support=true 78 | build:nohdfs --define=no_hdfs_support=true 79 | build:nokafka --define=no_kafka_support=true 80 | build:noignite --define=no_ignite_support=true 81 | build:nonccl --define=no_nccl_support=true 82 | 83 | build --define=use_fast_cpp_protos=true 84 | build --define=allow_oversize_protos=true 85 | 86 | build --spawn_strategy=standalone 87 | build --strategy=Genrule=standalone 88 | build -c opt 89 | 90 | # Other build flags. 91 | build --define=grpc_no_ares=true 92 | 93 | # Modular TF build options 94 | build:dynamic_kernels --define=dynamic_loaded_kernels=true 95 | build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS 96 | 97 | # Build TF with C++ 17 features. 98 | build:c++17 --cxxopt=-std=c++1z 99 | build:c++17 --cxxopt=-stdlib=libc++ 100 | build:c++1z --cxxopt=-std=c++1z 101 | build:c++1z --cxxopt=-stdlib=libc++ 102 | 103 | # Default paths for TF_SYSTEM_LIBS 104 | build --define=PREFIX=/usr 105 | build --define=LIBDIR=$(PREFIX)/lib 106 | build --define=INCLUDEDIR=$(PREFIX)/include 107 | 108 | # Default options should come above this line 109 | 110 | # Options from ./configure 111 | try-import %workspace%/.tf_networking_configure.bazelrc 112 | 113 | # Put user-specific options in .bazelrc.user 114 | try-import %workspace%/.bazelrc.user 115 | -------------------------------------------------------------------------------- /third_party/ci_build/devtoolset/build_devtoolset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # Builds a cross-compiler targeting manylinux 2014 (glibc 2.17 / libstdc++ 4.8). 18 | 19 | VERSION="$1" 20 | TARGET="$2" 21 | 22 | case "${VERSION}" in 23 | devtoolset-7) 24 | LIBSTDCXX_VERSION="6.0.24" 25 | ;; 26 | devtoolset-8) 27 | LIBSTDCXX_VERSION="6.0.25" 28 | ;; 29 | *) 30 | echo "Usage: $0 {devtoolset-7|devtoolset-8} " 31 | exit 1 32 | ;; 33 | esac 34 | 35 | mkdir -p "${TARGET}" 36 | # Download binary glibc 2.17 release. 37 | curl -O "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6_2.17-0ubuntu5.1_amd64.deb" && \ 38 | unar "libc6_2.17-0ubuntu5.1_amd64.deb" && \ 39 | tar -C "${TARGET}" -xf "libc6_2.17-0ubuntu5.1_amd64/data.tar.gz" && \ 40 | rm -rf "libc6_2.17-0ubuntu5.1_amd64.deb" "libc6_2.17-0ubuntu5.1_amd64" 41 | curl -O "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6-dev_2.17-0ubuntu5.1_amd64.deb" && \ 42 | unar "libc6-dev_2.17-0ubuntu5.1_amd64.deb" && \ 43 | tar -C "${TARGET}" -xf "libc6-dev_2.17-0ubuntu5.1_amd64/data.tar.gz" && \ 44 | rm -rf "libc6-dev_2.17-0ubuntu5.1_amd64.deb" "libc6-dev_2.17-0ubuntu5.1_amd64" 45 | 46 | # Put the current kernel headers from ubuntu in place. 47 | ln -s "/usr/include/linux" "${TARGET}/usr/include/linux" 48 | ln -s "/usr/include/asm-generic" "${TARGET}/usr/include/asm-generic" 49 | ln -s "/usr/include/x86_64-linux-gnu/asm" "${TARGET}/usr/include/asm" 50 | 51 | # Symlinks in the binary distribution are set up for installation in /usr, we 52 | # need to fix up all the links to stay within ${TARGET}. 53 | /fixlinks.sh "${TARGET}" 54 | 55 | # Patch to allow non-glibc 2.17 compatible builds to work. 56 | sed -i '54i#define TCP_USER_TIMEOUT 18' "${TARGET}/usr/include/netinet/tcp.h" 57 | 58 | # Download binary libstdc++ 4.8 release we are going to link against. 59 | # We only need the shared library, as we're going to develop against the 60 | # libstdc++ provided by devtoolset. 61 | curl -O "http://old-releases.ubuntu.com/ubuntu/pool/main/g/gcc-4.8/libstdc++6_4.8.1-10ubuntu9_amd64.deb" && \ 62 | unar "libstdc++6_4.8.1-10ubuntu9_amd64.deb" && \ 63 | tar -C "${TARGET}" -xf "libstdc++6_4.8.1-10ubuntu9_amd64/data.tar.gz" "./usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.18" && \ 64 | rm -rf "libstdc++6_4.8.1-10ubuntu9_amd64.deb" "libstdc++6_4.8.1-10ubuntu9_amd64" 65 | 66 | mkdir -p "${TARGET}-src" 67 | cd "${TARGET}-src" 68 | 69 | # Build a devtoolset cross-compiler based on our glibc 2.17 sysroot setup. 70 | case "${VERSION}" in 71 | devtoolset-7) 72 | curl -O "http://vault.centos.org/centos/7/sclo/Source/rh/devtoolset-7/devtoolset-7-gcc-7.3.1-5.16.el7.src.rpm" 73 | rpm2cpio "devtoolset-7-gcc-7.3.1-5.16.el7.src.rpm" | cpio -idm 74 | tar -xf "gcc-7.3.1-20180303.tar.bz2" --strip 1 75 | ;; 76 | devtoolset-8) 77 | curl -O "http://vault.centos.org/centos/7/sclo/Source/rh/devtoolset-8/devtoolset-8-gcc-8.3.1-3.el7.src.rpm" 78 | rpm2cpio "devtoolset-8-gcc-8.3.1-3.el7.src.rpm" | cpio -idm 79 | tar -xf "gcc-8.3.1-20190311.tar.xz" --strip 1 80 | ;; 81 | esac 82 | 83 | # Apply the devtoolset patches to gcc. 84 | /rpm-patch.sh "gcc.spec" 85 | 86 | sed -i 's/ftp:\/\/gcc.gnu.org/https:\/\/mirror.math.princeton.edu/g' ./contrib/download_prerequisites 87 | ./contrib/download_prerequisites 88 | 89 | mkdir -p "${TARGET}-build" 90 | cd "${TARGET}-build" 91 | 92 | "${TARGET}-src/configure" \ 93 | --prefix="${TARGET}/usr" \ 94 | --with-sysroot="${TARGET}" \ 95 | --disable-bootstrap \ 96 | --disable-libmpx \ 97 | --disable-libsanitizer \ 98 | --disable-libunwind-exceptions \ 99 | --disable-libunwind-exceptions \ 100 | --disable-lto \ 101 | --disable-multilib \ 102 | --enable-__cxa_atexit \ 103 | --enable-gnu-indirect-function \ 104 | --enable-gnu-unique-object \ 105 | --enable-initfini-array \ 106 | --enable-languages="c,c++" \ 107 | --enable-linker-build-id \ 108 | --enable-plugin \ 109 | --enable-shared \ 110 | --enable-threads=posix \ 111 | --with-default-libstdcxx-abi="gcc4-compatible" \ 112 | --with-gcc-major-version-only \ 113 | --with-linker-hash-style="gnu" \ 114 | --with-tune="generic" \ 115 | && \ 116 | make -j && \ 117 | make install 118 | 119 | # Create the devtoolset libstdc++ linkerscript that links dynamically against 120 | # the system libstdc++ 4.8 and provides all other symbols statically. 121 | mv "${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}" \ 122 | "${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}.backup" 123 | echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.18 -lstdc++_nonshared44 )" \ 124 | > "${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}" 125 | cp "./x86_64-pc-linux-gnu/libstdc++-v3/src/.libs/libstdc++_nonshared44.a" \ 126 | "${TARGET}/usr/lib" 127 | 128 | # Clean up 129 | rm -rf "${TARGET}-build" 130 | rm -rf "${TARGET}-src" 131 | --------------------------------------------------------------------------------