├── __init__.py ├── .bazelrc ├── pygloo ├── include │ ├── rendezvous.h │ ├── transport.h │ └── collective.h ├── src │ ├── barrier.cc │ ├── recv.cc │ ├── send.cc │ ├── gather.cc │ ├── broadcast.cc │ ├── scatter.cc │ ├── reduce.cc │ ├── allreduce.cc │ ├── reduce_scatter.cc │ ├── allgather.cc │ ├── transport.cc │ └── rendezvous.cc ├── BUILD └── main.cc ├── format.sh ├── .github └── workflows │ └── ubuntu_basic.yml ├── tests ├── test_reduce.py ├── test_gather.py ├── test_reduce_scatter.py ├── test_allreduce.py ├── test_barrier.py ├── test_send_recv.py ├── test_broadcast.py ├── test_allgather.py ├── test_scatter.py ├── test_custom_store.py └── test_redis.py ├── .gitignore ├── README.md ├── WORKSPACE ├── setup.py └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | __import__('pkg_resources').declare_namespace(__name__) 2 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | # Debug build: 2 | build:debug -c dbg 3 | build:debug --copt="-g" 4 | build:debug --strip="never" 5 | -------------------------------------------------------------------------------- /pygloo/include/rendezvous.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | namespace rendezvous { 6 | 7 | void def_rendezvous_module(pybind11::module &m); 8 | } // namespace rendezvous 9 | } // namespace pygloo 10 | -------------------------------------------------------------------------------- /pygloo/src/barrier.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | 6 | void barrier(const std::shared_ptr &context, uint32_t tag) { 7 | gloo::BarrierOptions opts_(context); 8 | 9 | opts_.setTag(tag); 10 | 11 | gloo::barrier(opts_); 12 | } 13 | } // namespace pygloo 14 | -------------------------------------------------------------------------------- /format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # COPIED from Ray repo. 3 | # YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. 4 | # You are encouraged to run this locally before pushing changes for review. 5 | 6 | echo "$(date)" "clang-format...." 7 | git ls-files -- '*.cc' '*.h' '*.proto' "${GIT_LS_EXCLUDES[@]}" | xargs -P 5 clang-format-8 -i 8 | 9 | if ! git diff --quiet &>/dev/null; then 10 | echo 'Reformatted changed files. Please review and stage the changes.' 11 | echo 'Files updated:' 12 | echo 13 | 14 | git --no-pager diff --name-only 15 | 16 | exit 1 17 | fi 18 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu_basic.yml: -------------------------------------------------------------------------------- 1 | name: ubuntu-basic 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | run-unit-tests: 11 | timeout-minutes: 60 12 | runs-on: ubuntu-latest 13 | container: docker.io/library/ubuntu:latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Install bazel 19 | run: | 20 | apt-get update 21 | apt-get install -yq wget gcc g++ python3.7 zlib1g-dev zip libuv1.dev 22 | apt-get install -yq pip 23 | wget "https://github.com/bazelbuild/bazel/releases/download/5.1.0/bazel_5.1.0-linux-x86_64.deb" -O bazel_5.1.0-linux-x86_64.deb 24 | dpkg -i bazel_5.1.0-linux-x86_64.deb 25 | 26 | - name: Install dependencies 27 | run: | 28 | python3 -m pip install virtualenv 29 | python3 -m virtualenv -p python3 py3 30 | . py3/bin/activate 31 | which python 32 | pip install pytest torch 33 | pip install ray==1.11.0 34 | 35 | - name: Build and test 36 | run: | 37 | . py3/bin/activate 38 | python3 setup.py install 39 | cd tests && python3 -m pytest * 40 | -------------------------------------------------------------------------------- /tests/test_reduce.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_reduce(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 33 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 34 | sendptr = sendbuf.ctypes.data 35 | recvptr = recvbuf.ctypes.data 36 | 37 | # sendbuf = torch.Tensor([[1,2,3],[1,2,3]]).float() 38 | # recvbuf = torch.zeros_like(sendbuf) 39 | # sendptr = sendbuf.data_ptr() 40 | # recvptr = recvbuf.data_ptr() 41 | 42 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 43 | datatype = pygloo.glooDataType_t.glooFloat32 44 | op = pygloo.ReduceOp.SUM 45 | root = 0 46 | 47 | pygloo.reduce(context, sendptr, recvptr, data_size, datatype, op, root) 48 | 49 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 50 | 51 | 52 | if __name__ == "__main__": 53 | ray.init(num_cpus=6) 54 | world_size = 2 55 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 56 | 57 | fns = [test_reduce.remote(i, world_size, fileStore_path) for i in range(world_size)] 58 | ray.get(fns) 59 | -------------------------------------------------------------------------------- /pygloo/src/recv.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | 6 | template 7 | void recv(const std::shared_ptr &context, intptr_t recvbuf, 8 | size_t size, int peer, uint32_t tag) { 9 | if (context->rank == peer) 10 | throw std::runtime_error( 11 | "peer equals to current rank. Please specify other peer values."); 12 | 13 | auto outputBuffer = context->createUnboundBuffer( 14 | reinterpret_cast(recvbuf), size * sizeof(T)); 15 | 16 | constexpr uint8_t kSendRecvSlotPrefix = 0x09; 17 | gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag); 18 | 19 | outputBuffer->recv(peer, slot); 20 | outputBuffer->waitRecv(context->getTimeout()); 21 | } 22 | 23 | void recv_wrapper(const std::shared_ptr &context, 24 | intptr_t recvbuf, size_t size, glooDataType_t datatype, 25 | int peer, uint32_t tag) { 26 | switch (datatype) { 27 | case glooDataType_t::glooInt8: 28 | recv(context, recvbuf, size, peer, tag); 29 | break; 30 | case glooDataType_t::glooUint8: 31 | recv(context, recvbuf, size, peer, tag); 32 | break; 33 | case glooDataType_t::glooInt32: 34 | recv(context, recvbuf, size, peer, tag); 35 | break; 36 | case glooDataType_t::glooUint32: 37 | recv(context, recvbuf, size, peer, tag); 38 | break; 39 | case glooDataType_t::glooInt64: 40 | recv(context, recvbuf, size, peer, tag); 41 | break; 42 | case glooDataType_t::glooUint64: 43 | recv(context, recvbuf, size, peer, tag); 44 | break; 45 | case glooDataType_t::glooFloat16: 46 | recv(context, recvbuf, size, peer, tag); 47 | break; 48 | case glooDataType_t::glooFloat32: 49 | recv(context, recvbuf, size, peer, tag); 50 | break; 51 | case glooDataType_t::glooFloat64: 52 | recv(context, recvbuf, size, peer, tag); 53 | break; 54 | default: 55 | throw std::runtime_error("Unhandled dataType"); 56 | } 57 | } 58 | } // namespace pygloo 59 | -------------------------------------------------------------------------------- /pygloo/src/send.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | namespace pygloo { 5 | 6 | template 7 | void send(const std::shared_ptr &context, intptr_t sendbuf, 8 | size_t size, int peer, uint32_t tag) { 9 | if (context->rank == peer) 10 | throw std::runtime_error( 11 | "peer equals to current rank. Please specify other peer values."); 12 | 13 | auto inputBuffer = context->createUnboundBuffer( 14 | reinterpret_cast(sendbuf), size * sizeof(T)); 15 | 16 | constexpr uint8_t kSendRecvSlotPrefix = 0x09; 17 | gloo::Slot slot = gloo::Slot::build(kSendRecvSlotPrefix, tag); 18 | 19 | inputBuffer->send(peer, slot); 20 | inputBuffer->waitSend(context->getTimeout()); 21 | } 22 | 23 | void send_wrapper(const std::shared_ptr &context, 24 | intptr_t sendbuf, size_t size, glooDataType_t datatype, 25 | int peer, uint32_t tag) { 26 | switch (datatype) { 27 | case glooDataType_t::glooInt8: 28 | send(context, sendbuf, size, peer, tag); 29 | break; 30 | case glooDataType_t::glooUint8: 31 | send(context, sendbuf, size, peer, tag); 32 | break; 33 | case glooDataType_t::glooInt32: 34 | send(context, sendbuf, size, peer, tag); 35 | break; 36 | case glooDataType_t::glooUint32: 37 | send(context, sendbuf, size, peer, tag); 38 | break; 39 | case glooDataType_t::glooInt64: 40 | send(context, sendbuf, size, peer, tag); 41 | break; 42 | case glooDataType_t::glooUint64: 43 | send(context, sendbuf, size, peer, tag); 44 | break; 45 | case glooDataType_t::glooFloat16: 46 | send(context, sendbuf, size, peer, tag); 47 | break; 48 | case glooDataType_t::glooFloat32: 49 | send(context, sendbuf, size, peer, tag); 50 | break; 51 | case glooDataType_t::glooFloat64: 52 | send(context, sendbuf, size, peer, tag); 53 | break; 54 | default: 55 | throw std::runtime_error("Unhandled dataType"); 56 | } 57 | } 58 | } // namespace pygloo 59 | -------------------------------------------------------------------------------- /pygloo/src/gather.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | 6 | template 7 | void gather(const std::shared_ptr &context, intptr_t sendbuf, 8 | intptr_t recvbuf, size_t size, int root, uint32_t tag) { 9 | // Configure GatherOptions struct 10 | gloo::GatherOptions opts_(context); 11 | 12 | T *input_ptr = reinterpret_cast(sendbuf); 13 | opts_.setInput(input_ptr, size); 14 | 15 | if (root == context->rank) { 16 | T *output_ptr = reinterpret_cast(recvbuf); 17 | opts_.setOutput(output_ptr, context->size * size); 18 | } 19 | opts_.setRoot(root); 20 | opts_.setTag(tag); 21 | 22 | gloo::gather(opts_); 23 | } 24 | 25 | void gather_wrapper(const std::shared_ptr &context, 26 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 27 | glooDataType_t datatype, int root, uint32_t tag) { 28 | switch (datatype) { 29 | case glooDataType_t::glooInt8: 30 | gather(context, sendbuf, recvbuf, size, root, tag); 31 | break; 32 | case glooDataType_t::glooUint8: 33 | gather(context, sendbuf, recvbuf, size, root, tag); 34 | break; 35 | case glooDataType_t::glooInt32: 36 | gather(context, sendbuf, recvbuf, size, root, tag); 37 | break; 38 | case glooDataType_t::glooUint32: 39 | gather(context, sendbuf, recvbuf, size, root, tag); 40 | break; 41 | case glooDataType_t::glooInt64: 42 | gather(context, sendbuf, recvbuf, size, root, tag); 43 | break; 44 | case glooDataType_t::glooUint64: 45 | gather(context, sendbuf, recvbuf, size, root, tag); 46 | break; 47 | case glooDataType_t::glooFloat16: 48 | gather(context, sendbuf, recvbuf, size, root, tag); 49 | break; 50 | case glooDataType_t::glooFloat32: 51 | gather(context, sendbuf, recvbuf, size, root, tag); 52 | break; 53 | case glooDataType_t::glooFloat64: 54 | gather(context, sendbuf, recvbuf, size, root, tag); 55 | break; 56 | default: 57 | throw std::runtime_error("Unhandled dataType"); 58 | } 59 | } 60 | } // namespace pygloo 61 | -------------------------------------------------------------------------------- /pygloo/src/broadcast.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace pygloo { 6 | 7 | template 8 | void broadcast(const std::shared_ptr &context, intptr_t sendbuf, 9 | intptr_t recvbuf, size_t size, int root, uint32_t tag) { 10 | 11 | // Configure BroadcastOptions struct and call broadcast function 12 | gloo::BroadcastOptions opts_(context); 13 | 14 | if (context->rank == root) { 15 | T *input_ptr = reinterpret_cast(sendbuf); 16 | opts_.setInput(input_ptr, size); 17 | } 18 | T *output_ptr = reinterpret_cast(recvbuf); 19 | opts_.setOutput(output_ptr, size); 20 | 21 | opts_.setRoot(root); 22 | opts_.setTag(tag); 23 | 24 | gloo::broadcast(opts_); 25 | } 26 | 27 | void broadcast_wrapper(const std::shared_ptr &context, 28 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 29 | glooDataType_t datatype, int root, uint32_t tag) { 30 | switch (datatype) { 31 | case glooDataType_t::glooInt8: 32 | broadcast(context, sendbuf, recvbuf, size, root, tag); 33 | break; 34 | case glooDataType_t::glooUint8: 35 | broadcast(context, sendbuf, recvbuf, size, root, tag); 36 | break; 37 | case glooDataType_t::glooInt32: 38 | broadcast(context, sendbuf, recvbuf, size, root, tag); 39 | break; 40 | case glooDataType_t::glooUint32: 41 | broadcast(context, sendbuf, recvbuf, size, root, tag); 42 | break; 43 | case glooDataType_t::glooInt64: 44 | broadcast(context, sendbuf, recvbuf, size, root, tag); 45 | break; 46 | case glooDataType_t::glooUint64: 47 | broadcast(context, sendbuf, recvbuf, size, root, tag); 48 | break; 49 | case glooDataType_t::glooFloat16: 50 | broadcast(context, sendbuf, recvbuf, size, root, tag); 51 | break; 52 | case glooDataType_t::glooFloat32: 53 | broadcast(context, sendbuf, recvbuf, size, root, tag); 54 | break; 55 | case glooDataType_t::glooFloat64: 56 | broadcast(context, sendbuf, recvbuf, size, root, tag); 57 | break; 58 | default: 59 | throw std::runtime_error("Unhandled dataType"); 60 | } 61 | } 62 | } // namespace pygloo 63 | -------------------------------------------------------------------------------- /tests/test_gather.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_gather(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = np.array([rank, rank+1], dtype=np.float32) 33 | sendptr = sendbuf.ctypes.data 34 | 35 | recvbuf = np.zeros((1, world_size*2), dtype=np.float32) 36 | recvptr = recvbuf.ctypes.data 37 | 38 | # sendbuf = torch.Tensor([i+1 for i in range(sum([j+1 for j in range(world_size)]))]).float() 39 | # sendptr = sendbuf.data_ptr() 40 | # recvbuf = torch.zeros(rank+1).float() 41 | # recvptr = recvbuf.data_ptr() 42 | 43 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 44 | datatype = pygloo.glooDataType_t.glooFloat32 45 | 46 | pygloo.gather(context, sendptr, recvptr, data_size, datatype, root = 0) 47 | 48 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 49 | 50 | ## example output 51 | # (pid=23172) rank 2 sends [2. 3.], receives [[0. 0. 0. 0. 0. 0.]] 52 | # (pid=23171) rank 1 sends [1. 2.], receives [[0. 0. 0. 0. 0. 0.]] 53 | # (pid=23173) rank 0 sends [0. 1.], receives [[0. 1. 1. 2. 2. 3.]] 54 | 55 | if __name__ == "__main__": 56 | ray.init(num_cpus=6) 57 | world_size = 3 58 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 59 | 60 | fns = [test_gather.remote(i, world_size, fileStore_path) for i in range(world_size)] 61 | ray.get(fns) 62 | -------------------------------------------------------------------------------- /pygloo/src/scatter.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | 6 | template 7 | void scatter(const std::shared_ptr &context, 8 | std::vector sendbuf, intptr_t recvbuf, size_t size, 9 | int root, uint32_t tag) { 10 | 11 | std::vector input_ptr; 12 | for (size_t i = 0; i < sendbuf.size(); ++i) 13 | input_ptr.emplace_back(reinterpret_cast(sendbuf[i])); 14 | 15 | T *output_ptr = reinterpret_cast(recvbuf); 16 | 17 | // Configure ScatterOptions struct 18 | gloo::ScatterOptions opts_(context); 19 | opts_.setInputs(input_ptr, size); 20 | opts_.setOutput(output_ptr, size); 21 | opts_.setTag(tag); 22 | opts_.setRoot(root); 23 | 24 | gloo::scatter(opts_); 25 | } 26 | 27 | void scatter_wrapper(const std::shared_ptr &context, 28 | std::vector sendbuf, intptr_t recvbuf, 29 | size_t size, glooDataType_t datatype, int root, 30 | uint32_t tag) { 31 | switch (datatype) { 32 | case glooDataType_t::glooInt8: 33 | scatter(context, sendbuf, recvbuf, size, root, tag); 34 | break; 35 | case glooDataType_t::glooUint8: 36 | scatter(context, sendbuf, recvbuf, size, root, tag); 37 | break; 38 | case glooDataType_t::glooInt32: 39 | scatter(context, sendbuf, recvbuf, size, root, tag); 40 | break; 41 | case glooDataType_t::glooUint32: 42 | scatter(context, sendbuf, recvbuf, size, root, tag); 43 | break; 44 | case glooDataType_t::glooInt64: 45 | scatter(context, sendbuf, recvbuf, size, root, tag); 46 | break; 47 | case glooDataType_t::glooUint64: 48 | scatter(context, sendbuf, recvbuf, size, root, tag); 49 | break; 50 | case glooDataType_t::glooFloat16: 51 | scatter(context, sendbuf, recvbuf, size, root, tag); 52 | break; 53 | case glooDataType_t::glooFloat32: 54 | scatter(context, sendbuf, recvbuf, size, root, tag); 55 | break; 56 | case glooDataType_t::glooFloat64: 57 | scatter(context, sendbuf, recvbuf, size, root, tag); 58 | break; 59 | default: 60 | throw std::runtime_error("Unhandled dataType"); 61 | } 62 | } 63 | } // namespace pygloo 64 | -------------------------------------------------------------------------------- /tests/test_reduce_scatter.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_reduce_scatter(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = np.array([i+1 for i in range(sum([j+1 for j in range(world_size)]))], dtype=np.float32) 33 | sendptr = sendbuf.ctypes.data 34 | 35 | recvbuf = np.zeros((rank+1,), dtype=np.float32) 36 | recvptr = recvbuf.ctypes.data 37 | recvElems = [i+1 for i in range(world_size)] 38 | 39 | 40 | # sendbuf = torch.Tensor([i+1 for i in range(sum([j+1 for j in range(world_size)]))]).float() 41 | # sendptr = sendbuf.data_ptr() 42 | # recvbuf = torch.zeros(rank+1).float() 43 | # recvptr = recvbuf.data_ptr() 44 | 45 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 46 | datatype = pygloo.glooDataType_t.glooFloat32 47 | op = pygloo.ReduceOp.SUM 48 | 49 | pygloo.reduce_scatter(context, sendptr, recvptr, data_size, recvElems, datatype, op) 50 | 51 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 52 | 53 | ## example output 54 | # (pid=22653) rank 2 sends [1. 2. 3. 4. 5. 6.], receives [12. 15. 18.] 55 | # (pid=22658) rank 0 sends [1. 2. 3. 4. 5. 6.], receives [3.] 56 | # (pid=22656) rank 1 sends [1. 2. 3. 4. 5. 6.], receives [6. 9.] 57 | 58 | if __name__ == "__main__": 59 | ray.init(num_cpus=6) 60 | world_size = 3 61 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 62 | 63 | fns = [test_reduce_scatter.remote(i, world_size, fileStore_path) for i in range(world_size)] 64 | ray.get(fns) 65 | -------------------------------------------------------------------------------- /tests/test_allreduce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ray 4 | import time 5 | import shutil 6 | import torch 7 | import pygloo 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_allreduce(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 33 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 34 | sendptr = sendbuf.ctypes.data 35 | recvptr = recvbuf.ctypes.data 36 | 37 | # sendbuf = torch.Tensor([[1,2,3],[1,2,3]]).float() 38 | # recvbuf = torch.zeros_like(sendbuf) 39 | # sendptr = sendbuf.data_ptr() 40 | # recvptr = recvbuf.data_ptr() 41 | 42 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 43 | datatype = pygloo.glooDataType_t.glooFloat32 44 | op = pygloo.ReduceOp.SUM 45 | algorithm = pygloo.allreduceAlgorithm.RING 46 | 47 | pygloo.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) 48 | 49 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 50 | ## example output 51 | # (pid=30445) rank 0 sends [[1. 2. 3.] 52 | # (pid=30445) [1. 2. 3.]], 53 | # (pid=30445) receives [[2. 4. 6.] 54 | # (pid=30445) [2. 4. 6.]] 55 | # (pid=30446) rank 1 sends [[1. 2. 3.] 56 | # (pid=30446) [1. 2. 3.]], 57 | # (pid=30446) receives [[2. 4. 6.] 58 | # (pid=30446) [2. 4. 6.]] 59 | 60 | if __name__ == "__main__": 61 | ray.init(num_cpus=6) 62 | world_size = 2 63 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 64 | 65 | fns = [test_allreduce.remote(i, world_size, fileStore_path) for i in range(world_size)] 66 | ray.get(fns) 67 | -------------------------------------------------------------------------------- /tests/test_barrier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ray 4 | import time 5 | import shutil 6 | import torch 7 | import pygloo 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_barrier(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 33 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 34 | sendptr = sendbuf.ctypes.data 35 | recvptr = recvbuf.ctypes.data 36 | 37 | # sendbuf = torch.Tensor([[1,2,3],[1,2,3]]).float() 38 | # recvbuf = torch.zeros_like(sendbuf) 39 | # sendptr = sendbuf.data_ptr() 40 | # recvptr = recvbuf.data_ptr() 41 | 42 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 43 | datatype = pygloo.glooDataType_t.glooFloat32 44 | op = pygloo.ReduceOp.SUM 45 | algorithm = pygloo.allreduceAlgorithm.RING 46 | 47 | pygloo.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) 48 | pygloo.barrier(context) 49 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 50 | ## example output 51 | # (pid=30445) rank 0 sends [[1. 2. 3.] 52 | # (pid=30445) [1. 2. 3.]], 53 | # (pid=30445) receives [[2. 4. 6.] 54 | # (pid=30445) [2. 4. 6.]] 55 | # (pid=30446) rank 1 sends [[1. 2. 3.] 56 | # (pid=30446) [1. 2. 3.]], 57 | # (pid=30446) receives [[2. 4. 6.] 58 | # (pid=30446) [2. 4. 6.]] 59 | 60 | if __name__ == "__main__": 61 | ray.init(num_cpus=6) 62 | world_size = 2 63 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 64 | 65 | fns = [test_barrier.remote(i, world_size, fileStore_path) for i in range(world_size)] 66 | ray.get(fns) 67 | -------------------------------------------------------------------------------- /tests/test_send_recv.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_send_recv(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | if rank == 0: 33 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 34 | sendptr = sendbuf.ctypes.data 35 | 36 | # sendbuf = torch.Tensor([[1,2,3],[1,2,3]]).float() 37 | # sendptr = sendbuf.data_ptr() 38 | 39 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 40 | datatype = pygloo.glooDataType_t.glooFloat32 41 | peer = 1 42 | pygloo.send(context, sendptr, data_size, datatype, peer) 43 | print(f"rank {rank} sends {sendbuf}") 44 | 45 | elif rank == 1: 46 | recvbuf = np.zeros((2,3), dtype=np.float32) 47 | recvptr = recvbuf.ctypes.data 48 | 49 | # recvbuf = torch.zeros(2,3).float() 50 | # recvptr = recvbuf.data_ptr() 51 | 52 | data_size = recvbuf.size if isinstance(recvbuf, np.ndarray) else recvbuf.numpy().size 53 | datatype = pygloo.glooDataType_t.glooFloat32 54 | peer = 0 55 | 56 | pygloo.recv(context, recvptr, data_size, datatype, peer) 57 | print(f"rank {rank} receives {recvbuf}") 58 | else: 59 | raise Exception("Only support 2 process to test send function and recv function") 60 | ## example output 61 | 62 | 63 | if __name__ == "__main__": 64 | ray.init(num_cpus=6) 65 | world_size = 2 66 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 67 | 68 | fns = [test_send_recv.remote(i, world_size, fileStore_path) for i in range(world_size)] 69 | ray.get(fns) 70 | -------------------------------------------------------------------------------- /pygloo/src/reduce.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | 6 | template 7 | void reduce(const std::shared_ptr &context, intptr_t sendbuf, 8 | intptr_t recvbuf, size_t size, ReduceOp reduceop, int root, 9 | uint32_t tag) { 10 | T *input_ptr = reinterpret_cast(sendbuf); 11 | 12 | T *output_ptr; 13 | if(context->rank == root) 14 | output_ptr = reinterpret_cast(recvbuf); 15 | else 16 | output_ptr = new T[size]; 17 | 18 | // Configure reduceOptions struct 19 | gloo::ReduceOptions opts_(context); 20 | opts_.setInput(input_ptr, size); 21 | opts_.setOutput(output_ptr, size); 22 | gloo::ReduceOptions::Func fn = toFunction(reduceop); 23 | opts_.setReduceFunction(fn); 24 | opts_.setRoot(root); 25 | opts_.setTag(tag); 26 | 27 | gloo::reduce(opts_); 28 | 29 | if(context->rank != root) 30 | delete output_ptr; 31 | } 32 | 33 | void reduce_wrapper(const std::shared_ptr &context, 34 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 35 | glooDataType_t datatype, ReduceOp reduceop, int root, 36 | uint32_t tag) { 37 | switch (datatype) { 38 | case glooDataType_t::glooInt8: 39 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 40 | break; 41 | case glooDataType_t::glooUint8: 42 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 43 | break; 44 | case glooDataType_t::glooInt32: 45 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 46 | break; 47 | case glooDataType_t::glooUint32: 48 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 49 | break; 50 | case glooDataType_t::glooInt64: 51 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 52 | break; 53 | case glooDataType_t::glooUint64: 54 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 55 | break; 56 | case glooDataType_t::glooFloat16: 57 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 58 | break; 59 | case glooDataType_t::glooFloat32: 60 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 61 | break; 62 | case glooDataType_t::glooFloat64: 63 | reduce(context, sendbuf, recvbuf, size, reduceop, root, tag); 64 | break; 65 | default: 66 | throw std::runtime_error("Unhandled dataType"); 67 | } 68 | } 69 | } // namespace pygloo 70 | -------------------------------------------------------------------------------- /tests/test_broadcast.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_broadcast(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | if rank == 0: 33 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 34 | sendptr = sendbuf.ctypes.data 35 | else: 36 | sendbuf = np.zeros((2,3), dtype=np.float32) 37 | sendptr = -1 38 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 39 | recvptr = recvbuf.ctypes.data 40 | 41 | # if rank == 0: 42 | # sendbuf = torch.Tensor([[1,2,3],[1,2,3]]).float() 43 | # sendptr = sendbuf.data_ptr() 44 | # else: 45 | # sendbuf = torch.zeros(2,3) 46 | # sendptr = 0 47 | # recvbuf = torch.zeros_like(sendbuf) 48 | # recvptr = recvbuf.data_ptr() 49 | 50 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 51 | datatype = pygloo.glooDataType_t.glooFloat32 52 | root = 0 53 | 54 | pygloo.broadcast(context, sendptr, recvptr, data_size, datatype, root) 55 | 56 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 57 | ## example output 58 | # (pid=36435) rank 1 sends [[0. 0. 0.] 59 | # (pid=36435) [0. 0. 0.]], receives [[1. 2. 3.] 60 | # (pid=36435) [1. 2. 3.]] 61 | # (pid=36432) rank 0 sends [[1. 2. 3.] 62 | # (pid=36432) [1. 2. 3.]], receives [[1. 2. 3.] 63 | # (pid=36432) [1. 2. 3.]] 64 | 65 | 66 | if __name__ == "__main__": 67 | ray.init(num_cpus=6) 68 | world_size = 2 69 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 70 | 71 | fns = [test_broadcast.remote(i, world_size, fileStore_path) for i in range(world_size)] 72 | ray.get(fns) 73 | -------------------------------------------------------------------------------- /pygloo/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_foreign_cc//tools/build_defs:cmake.bzl", "cmake_external") 2 | load("@rules_foreign_cc//tools/build_defs:make.bzl", "make") 3 | load("@pybind11_bazel//:build_defs.bzl", "pybind_library") 4 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") 5 | 6 | cmake_external( 7 | name = "libuv", 8 | # Values to be passed as -Dkey=value on the CMake command line; 9 | # here are serving to provide some CMake script configuration options 10 | env_vars = {"CFLAGS": "-fPIC"}, 11 | lib_source = "@libuv//:all", 12 | 13 | out_lib_dir = "lib", 14 | # We are selecting the resulting static library to be passed in C/C++ provider 15 | # as the result of the build; 16 | static_libraries = ["libuv_a.a"], 17 | ) 18 | 19 | make( 20 | name = "hiredis", 21 | 22 | out_include_dir = "include/hiredis", 23 | 24 | lib_source = "@hiredis//:all", 25 | # We are selecting the resulting static library to be passed in C/C++ provider 26 | # as the result of the build; 27 | static_libraries = ["libhiredis.a"], 28 | ) 29 | 30 | cmake_external( 31 | name = "gloo", 32 | # Values to be passed as -Dkey=value on the CMake command line; 33 | # here are serving to provide some CMake script configuration options 34 | cache_entries = { 35 | "libuv_LIBDIR": "$EXT_BUILD_DEPS/libuv/lib", 36 | "libuv_INCLUDE_DIRS": "$EXT_BUILD_DEPS/libuv/include", 37 | 38 | "HIREDIS_NESTED_INCLUDE": "off", # 'on' use hiredis/hiredis.h, 'off' use hiredis.h. 39 | "HIREDIS_ROOT_DIR":"$EXT_BUILD_DEPS/hiredis", 40 | "HIREDIS_INCLUDE_DIR": "$EXT_BUILD_DEPS/hiredis/include/hiredis", 41 | "HIREDIS_LIB_DIR": "$EXT_BUILD_DEPS/hiredis/lib", 42 | 43 | "USE_REDIS": "on", 44 | "USE_IBVERBS": "off", 45 | "USE_NCCL": "off", 46 | "USE_RCCL": "off", 47 | "USE_LIBUV": "on", 48 | "USE_TCP_OPENSSL": "off", 49 | }, 50 | 51 | deps = [ 52 | ":libuv", 53 | ":hiredis", 54 | ], 55 | 56 | lib_source = "@gloo//:all", 57 | 58 | # We are selecting the resulting static library to be passed in C/C++ provider 59 | # as the result of the build; 60 | static_libraries = ["libgloo.a"], 61 | ) 62 | 63 | pybind_library( 64 | name = "pygloo-deps", 65 | srcs = glob(["src/*.cc"]), 66 | hdrs = glob(["include/*.h"]), 67 | strip_include_prefix = "include", 68 | visibility = ["//visibility:public"], 69 | deps = [":gloo"] 70 | ) 71 | 72 | pybind_extension( 73 | name = "pygloo", 74 | srcs = ["main.cc"], 75 | deps = [":gloo", ":pygloo-deps"], 76 | ) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pycharm related 132 | .idea/ 133 | 134 | # vscode 135 | .vscode/ 136 | 137 | # bazel 138 | bazel-* 139 | 140 | -------------------------------------------------------------------------------- /tests/test_allgather.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_allgather(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 33 | recvbuf = np.zeros([world_size] + list(sendbuf.shape), dtype=np.float32) 34 | sendptr = sendbuf.ctypes.data 35 | recvptr = recvbuf.ctypes.data 36 | 37 | # sendbuf = torch.Tensor([[1,2,3],[1,2,3]]).float() 38 | # recvbuf = torch.zeros([world_size] + list(sendbuf.shape)).float() 39 | # sendptr = sendbuf.data_ptr() 40 | # recvptr = recvbuf.data_ptr() 41 | 42 | assert sendbuf.size() * world_size == recvbuf.size() 43 | 44 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 45 | datatype = pygloo.glooDataType_t.glooFloat32 46 | 47 | pygloo.allgather(context, sendptr, recvptr, data_size, datatype) 48 | 49 | print(f"rank {rank} sends {sendbuf},\nreceives {recvbuf}") 50 | 51 | ## example output 52 | # (pid=29044) rank 0 sends [[1. 2. 3.] 53 | # (pid=29044) [1. 2. 3.]], 54 | # (pid=29044) receives [[[1. 2. 3.] 55 | # (pid=29044) [1. 2. 3.]] 56 | # (pid=29044) [[1. 2. 3.] 57 | # (pid=29044) [1. 2. 3.]]] 58 | # (pid=29046) rank 1 sends [[1. 2. 3.] 59 | # (pid=29046) [1. 2. 3.]], 60 | # (pid=29046) receives [[[1. 2. 3.] 61 | # (pid=29046) [1. 2. 3.]] 62 | # (pid=29046) [[1. 2. 3.] 63 | # (pid=29046) [1. 2. 3.]]] 64 | 65 | if __name__ == "__main__": 66 | ray.init(num_cpus=6) 67 | world_size = 2 68 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 69 | 70 | fns = [test_allgather.remote(i, world_size, fileStore_path) for i in range(world_size)] 71 | ray.get(fns) 72 | -------------------------------------------------------------------------------- /tests/test_scatter.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | 9 | @ray.remote(num_cpus=1) 10 | def test_scatter(rank, world_size, fileStore_path): 11 | ''' 12 | rank # Rank of this process within list of participating processes 13 | world_size # Number of participating processes 14 | ''' 15 | if rank==0: 16 | if os.path.exists(fileStore_path): 17 | shutil.rmtree(fileStore_path) 18 | os.makedirs(fileStore_path) 19 | else: time.sleep(0.5) 20 | 21 | context = pygloo.rendezvous.Context(rank, world_size) 22 | 23 | attr = pygloo.transport.tcp.attr("localhost") 24 | # Perform rendezvous for TCP pairs 25 | dev = pygloo.transport.tcp.CreateDevice(attr) 26 | 27 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 28 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 29 | 30 | context.connectFullMesh(store, dev) 31 | 32 | sendbuf = [np.array([[1,2,3],[1,2,3]], dtype=np.float32)]*world_size 33 | recvbuf = np.zeros((2, 3), dtype=np.float32) 34 | sendptr = [] 35 | for i in sendbuf: 36 | sendptr.append(i.ctypes.data) 37 | recvptr = recvbuf.ctypes.data 38 | 39 | # sendbuf = [torch.Tensor([[1,2,3],[1,2,3]]).float()]*world_size 40 | # recvbuf = torch.zeros_like(sendbuf) 41 | # sendptr = [] 42 | # for i in sendbuf: 43 | # sendptr.append(i.data_ptr()) 44 | # recvptr = recvbuf.data_ptr() 45 | 46 | data_size = sendbuf[0].size if isinstance(sendbuf[0], np.ndarray) else sendbuf[0].numpy().size 47 | datatype = pygloo.glooDataType_t.glooFloat32 48 | root = 0 49 | 50 | pygloo.scatter(context, sendptr, recvptr, data_size, datatype, root) 51 | 52 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 53 | ## example output, root is 0. 54 | # (pid=18951) rank 1 sends [array([[1., 2., 3.], 55 | # (pid=18951) [1., 2., 3.]], dtype=float32), array([[1., 2., 3.], 56 | # (pid=18951) [1., 2., 3.]], dtype=float32)], receives [[1. 2. 3.] 57 | # (pid=18951) [1. 2. 3.]] 58 | # (pid=18952) rank 0 sends [array([[1., 2., 3.], 59 | # (pid=18952) [1., 2., 3.]], dtype=float32), array([[1., 2., 3.], 60 | # (pid=18952) [1., 2., 3.]], dtype=float32)], receives [[1. 2. 3.] 61 | # (pid=18952) [1. 2. 3.]] 62 | 63 | if __name__ == "__main__": 64 | ray.init(num_cpus=6) 65 | world_size = 2 66 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 67 | 68 | fns = [test_scatter.remote(i, world_size, fileStore_path) for i in range(world_size)] 69 | ray.get(fns) 70 | -------------------------------------------------------------------------------- /pygloo/src/allreduce.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace pygloo { 6 | 7 | template 8 | void allreduce(const std::shared_ptr &context, intptr_t sendbuf, 9 | intptr_t recvbuf, size_t size, ReduceOp reduceop, 10 | gloo::AllreduceOptions::Algorithm algorithm, uint32_t tag) { 11 | std::vector input_ptr{reinterpret_cast(sendbuf)}; 12 | std::vector output_ptr{reinterpret_cast(recvbuf)}; 13 | 14 | // Configure AllreduceOptions struct and call allreduce function 15 | gloo::AllreduceOptions opts_(context); 16 | opts_.setInputs(input_ptr, size); 17 | opts_.setOutputs(output_ptr, size); 18 | opts_.setAlgorithm(algorithm); 19 | gloo::ReduceOptions::Func fn = toFunction(reduceop); 20 | opts_.setReduceFunction(fn); 21 | opts_.setTag(tag); 22 | 23 | gloo::allreduce(opts_); 24 | } 25 | 26 | void allreduce_wrapper(const std::shared_ptr &context, 27 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 28 | glooDataType_t datatype, ReduceOp reduceop, 29 | gloo::AllreduceOptions::Algorithm algorithm, 30 | uint32_t tag) { 31 | switch (datatype) { 32 | case glooDataType_t::glooInt8: 33 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 34 | tag); 35 | break; 36 | case glooDataType_t::glooUint8: 37 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 38 | tag); 39 | break; 40 | case glooDataType_t::glooInt32: 41 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 42 | tag); 43 | break; 44 | case glooDataType_t::glooUint32: 45 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 46 | tag); 47 | break; 48 | case glooDataType_t::glooInt64: 49 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 50 | tag); 51 | break; 52 | case glooDataType_t::glooUint64: 53 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 54 | tag); 55 | break; 56 | case glooDataType_t::glooFloat16: 57 | allreduce(context, sendbuf, recvbuf, size, reduceop, 58 | algorithm, tag); 59 | break; 60 | case glooDataType_t::glooFloat32: 61 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 62 | tag); 63 | break; 64 | case glooDataType_t::glooFloat64: 65 | allreduce(context, sendbuf, recvbuf, size, reduceop, algorithm, 66 | tag); 67 | break; 68 | default: 69 | throw std::runtime_error("Unhandled dataType"); 70 | } 71 | } 72 | } // namespace pygloo 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pygloo 2 | 3 | Pygloo provides Python bindings for [gloo](https://github.com/facebookincubator/gloo). 4 | It is implemented using [pybind11](https://github.com/pybind/pybind11). 5 | 6 | It is currenlty used in [Ray for collective communication](https://github.com/ray-project/ray/tree/master/python/ray/util/collective) between CPUs. 7 | 8 | 9 | ## Requirements 10 | ```python 11 | Python >= 3.6 12 | ``` 13 | 14 | ## Installation 15 | ### Install From Wheels 16 | We provide prepackaged Python wheels (`manylinux2014_x86_64`,`manylinux_2_24_x86_64`). To install from wheels: 17 | ```python 18 | pip install pygloo 19 | ``` 20 | 21 | ### Building from source 22 | One can build pygloo from source if none of released wheels fit with the development environment. 23 | 24 | Pygloo uses [Bazel](https://github.com/bazelbuild/bazel) to automatically manange dependencies and compilation. 25 | To compile from source, install Bazel>=2.0.0 following the [Bazel installation guide](https://docs.bazel.build/versions/master/install.html). 26 | After installing Bazel, build and install pygloo following this command: 27 | ```python 28 | python setup.py install 29 | ``` 30 | 31 | ## Testing 32 | Pygloo uses [Ray](https://github.com/ray-project/ray) to create multiple, distributed processes for collective communication tests. See `tests` directory. 33 | 34 | ## Example 35 | An example for allreduce. 36 | ```python 37 | import os 38 | import ray 39 | import pygloo 40 | import numpy as np 41 | 42 | @ray.remote(num_cpus=1) 43 | def test_allreduce(rank, world_size, fileStore_path): 44 | ''' 45 | rank # Rank of this process within list of participating processes 46 | world_size # Number of participating processes 47 | fileStore_path # The path to create filestore 48 | ''' 49 | context = pygloo.rendezvous.Context(rank, world_size) 50 | # Prepare device and store for rendezvous 51 | attr = pygloo.transport.tcp.attr("localhost") 52 | dev = pygloo.transport.tcp.CreateDevice(attr) 53 | fileStore = pygloo.rendezvous.FileStore(fileStore_path) 54 | store = pygloo.rendezvous.PrefixStore(str(world_size), fileStore) 55 | 56 | context.connectFullMesh(store, dev) 57 | 58 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 59 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 60 | sendptr = sendbuf.ctypes.data 61 | recvptr = recvbuf.ctypes.data 62 | 63 | pygloo.allreduce(context, sendptr, recvptr, 64 | sendbuf.size, pygloo.glooDataType_t.glooFloat32, 65 | pygloo.ReduceOp.SUM, pygloo.allreduceAlgorithm.RING) 66 | 67 | if __name__ == "__main__": 68 | ray.init(num_cpus=6) 69 | world_size = 2 70 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 71 | os.makedirs(fileStore_path) 72 | ray.get([test_allreduce.remote(rank, world_size, fileStore_path) for rank in range(world_size)]) 73 | ``` 74 | 75 | 76 | ## License 77 | Gloo is licensed under the Apache License, Version 2.0. -------------------------------------------------------------------------------- /pygloo/include/transport.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #if GLOO_HAVE_TRANSPORT_TCP 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #endif 17 | 18 | #if GLOO_HAVE_TRANSPORT_UV 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #endif 27 | 28 | #if !GLOO_HAVE_TRANSPORT_UV 29 | #if !GLOO_HAVE_TRANSPORT_UV 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #endif 36 | #endif 37 | 38 | namespace pygloo { 39 | namespace transport { 40 | class PyDevice : public gloo::transport::Device { 41 | public: 42 | using gloo::transport::Device::Device; 43 | 44 | std::string str() const override { 45 | PYBIND11_OVERRIDE_PURE( 46 | std::string, /* Return type */ 47 | gloo::transport::Device, /* Parent class */ 48 | str, /* Name of function in C++ (must match Python name) */ 49 | /* Argument(s) */ 50 | ); 51 | } 52 | 53 | const std::string &getPCIBusID() const override { 54 | PYBIND11_OVERRIDE_PURE( 55 | const std::string &, /* Return type */ 56 | gloo::transport::Device, /* Parent class */ 57 | getPCIBusID, /* Name of function in C++ (must match Python name) */ 58 | /* Argument(s) */ 59 | ); 60 | } 61 | 62 | int getInterfaceSpeed() const override { 63 | PYBIND11_OVERRIDE(int, /* Return type */ 64 | gloo::transport::Device, /* Parent class */ 65 | getInterfaceSpeed, /* Name of function in C++ (must match 66 | Python name) */ 67 | /* Argument(s) */ 68 | ); 69 | } 70 | 71 | bool hasGPUDirect() const override { 72 | PYBIND11_OVERRIDE( 73 | bool, /* Return type */ 74 | gloo::transport::Device, /* Parent class */ 75 | hasGPUDirect, /* Name of function in C++ (must match Python name) */ 76 | /* Argument(s) */ 77 | ); 78 | } 79 | 80 | std::shared_ptr createContext(int rank, 81 | int size) override { 82 | PYBIND11_OVERRIDE_PURE( 83 | std::shared_ptr, /* Return type */ 84 | gloo::transport::Device, /* Parent class */ 85 | createContext, /* Name of function in C++ (must match Python name) */ 86 | rank, size /* Argument(s) */ 87 | ); 88 | } 89 | }; 90 | 91 | void def_transport_module(pybind11::module &m); 92 | void def_transport_tcp_module(pybind11::module &m); 93 | void def_transport_uv_module(pybind11::module &m); 94 | } // namespace transport 95 | } // namespace pygloo 96 | -------------------------------------------------------------------------------- /tests/test_custom_store.py: -------------------------------------------------------------------------------- 1 | import pygloo 2 | import numpy as np 3 | import os 4 | import ray 5 | import time 6 | import shutil 7 | import torch 8 | import pytest 9 | 10 | import ray.experimental.internal_kv as internal_kv 11 | from ray._private.gcs_utils import GcsClient 12 | 13 | 14 | class MyMockCustomStore: 15 | def __init__(self): 16 | gcs_address = ray.worker._global_node.gcs_address 17 | self._gcs_client = GcsClient(address=gcs_address, nums_reconnect_retry=0) 18 | internal_kv._initialize_internal_kv(self._gcs_client) 19 | 20 | def set(self, key: str, data: bytes) -> bool: 21 | ret = internal_kv._internal_kv_put(key, data) 22 | return ret 23 | 24 | def get(self, key: str) -> bytes: 25 | ret = internal_kv._internal_kv_get(key) 26 | return ret 27 | 28 | def wait(self, keys: list): 29 | while(True): 30 | all_exist = True 31 | for key in keys: 32 | result = internal_kv._internal_kv_exists(key) 33 | if not result: 34 | all_exist = False 35 | break 36 | if all_exist: 37 | return True 38 | time.sleep(1) 39 | 40 | def del_keys(self, keys: list): 41 | for key in keys: 42 | ok = internal_kv._internal_kv_del(key) 43 | if not ok: 44 | return False 45 | return True 46 | 47 | @ray.remote(num_cpus=1) 48 | class Sender: 49 | def __init__(self): 50 | rank = 0 51 | world_size = 2 52 | self._context = pygloo.rendezvous.Context(rank, world_size) 53 | attr = pygloo.transport.tcp.attr("localhost") 54 | dev = pygloo.transport.tcp.CreateDevice(attr) 55 | real_store = MyMockCustomStore() 56 | custom_store = pygloo.rendezvous.CustomStore(real_store) 57 | self._context.connectFullMesh(custom_store, dev) 58 | 59 | def do_send(self): 60 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 61 | sendptr = sendbuf.ctypes.data 62 | pygloo.send(self._context, sendptr, sendbuf.size, pygloo.glooDataType_t.glooFloat32, 1) 63 | return True 64 | 65 | 66 | @ray.remote(num_cpus=1) 67 | class Recver: 68 | def __init__(self): 69 | rank = 1 70 | world_size = 2 71 | self._context = pygloo.rendezvous.Context(rank, world_size) 72 | attr = pygloo.transport.tcp.attr("localhost") 73 | dev = pygloo.transport.tcp.CreateDevice(attr) 74 | real_store = MyMockCustomStore() 75 | custom_store = pygloo.rendezvous.CustomStore(real_store) 76 | self._context.connectFullMesh(custom_store, dev) 77 | 78 | def do_recv(self): 79 | recvbuf = np.zeros((2, 3), dtype=np.float32) 80 | recvptr = recvbuf.ctypes.data 81 | 82 | data_size = recvbuf.size if isinstance(recvbuf, np.ndarray) else recvbuf.numpy().size 83 | datatype = pygloo.glooDataType_t.glooFloat32 84 | peer = 0 85 | 86 | pygloo.recv(self._context, recvptr, data_size, datatype, peer) 87 | return recvbuf 88 | 89 | 90 | def test_basic(): 91 | ray.init(num_cpus=6) 92 | 93 | sender = Sender.remote() 94 | recver = Recver.remote() 95 | fn1 = sender.do_send.remote() 96 | fn2 = recver.do_recv.remote() 97 | 98 | a, b = ray.get([fn1, fn2]) 99 | assert a 100 | expected = [[1, 2, 3], [1, 2, 3]] 101 | assert len(b) == 2 102 | assert len(b[0]) == 3 103 | assert len(b[1]) == 3 104 | 105 | 106 | if __name__ == "__main__": 107 | sys.exit(pytest.main(["-v", __file__])) 108 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "pygloo") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | # Group the sources of the library so that CMake rule have access to it 6 | all_content = """filegroup(name = "all", srcs = glob(["**"]), visibility = ["//visibility:public"])""" 7 | 8 | # Rule repository 9 | http_archive( 10 | name = "rules_foreign_cc", 11 | strip_prefix = "rules_foreign_cc-87df6b25f6c009883da87f07ea680d38780a4d6f", 12 | url = "https://github.com/bazelbuild/rules_foreign_cc/archive/87df6b25f6c009883da87f07ea680d38780a4d6f.zip", 13 | sha256 = "a45511a054598dd9b87d4d5765a18df4e5777736026087cf96ffc30704e6c918", 14 | ) 15 | 16 | load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") 17 | 18 | # Call this function from the WORKSPACE file to initialize rules_foreign_cc 19 | # dependencies and let neccesary code generation happen 20 | # (Code generation is needed to support different variants of the C++ Starlark API.). 21 | # 22 | # Args: 23 | # native_tools_toolchains: pass the toolchains for toolchain types 24 | # '@rules_foreign_cc//tools/build_defs:make_toolchain', 25 | # '@rules_foreign_cc//tools/build_defs:cmake_toolchain' and 26 | # '@rules_foreign_cc//tools/build_defs:ninja_toolchain' with the needed platform constraints. 27 | # If you do not pass anything, registered default toolchains will be selected (see below). 28 | # 29 | # register_default_tools: if True, the make, cmake and ninja toolchains, calling corresponding 30 | # preinstalled binaries by name (make, cmake, ninja) will be registered after 31 | # 'native_tools_toolchains' without any platform constraints. 32 | # The default is True. 33 | rules_foreign_cc_dependencies() 34 | 35 | 36 | http_archive( 37 | name = "rules_foreign_cc", 38 | strip_prefix = "opencensus-proto-0.3.0/src", 39 | urls = ["https://github.com/census-instrumentation/opencensus-proto/archive/v0.3.0.tar.gz"], 40 | sha256 = "b7e13f0b4259e80c3070b583c2f39e53153085a6918718b1c710caf7037572b0", 41 | ) 42 | 43 | http_archive( 44 | name = "pybind11_bazel", 45 | strip_prefix = "pybind11_bazel-f4f1bd4fa4b368b79dd6f003f8ef8c5a91fad36b", 46 | urls = ["https://github.com/Ezra-H/pybind11_bazel/archive/f4f1bd4fa4b368b79dd6f003f8ef8c5a91fad36b.zip"], 47 | sha256 = "6ea811e7a7348f7c9d5b59887aa0c65e42222e199049a1ee55db147d2e9ca4a7", 48 | ) 49 | 50 | # We still require the pybind library. 51 | http_archive( 52 | name = "pybind11", 53 | build_file = "@pybind11_bazel//:pybind11.BUILD", 54 | strip_prefix = "pybind11-2.6.1", 55 | urls = ["https://github.com/pybind/pybind11/archive/v2.6.1.tar.gz"], 56 | sha256 = "cdbe326d357f18b83d10322ba202d69f11b2f49e2d87ade0dc2be0c5c34f8e2a", 57 | ) 58 | 59 | http_archive( 60 | name = "libuv", 61 | build_file_content = all_content, 62 | strip_prefix = "libuv-1.40.0", 63 | urls = ["https://github.com/libuv/libuv/archive/v1.40.0.tar.gz"], 64 | sha256 = "70fe1c9ba4f2c509e8166c0ca2351000237da573bb6c82092339207a9715ba6b", 65 | ) 66 | 67 | http_archive( 68 | name = "hiredis", 69 | build_file_content = all_content, 70 | strip_prefix = "hiredis-1.0.0", 71 | urls = ["https://github.com/redis/hiredis/archive/v1.0.0.tar.gz"], 72 | sha256 = "2a0b5fe5119ec973a0c1966bfc4bd7ed39dbce1cb6d749064af9121fe971936f", 73 | ) 74 | 75 | # gloo source code repository 76 | http_archive( 77 | name = "gloo", 78 | build_file_content = all_content, 79 | strip_prefix = "gloo-add3f38c6a2715e9387f4966b4fc3d92bb786adb", 80 | urls = ["https://github.com/Ezra-H/gloo/archive/add3f38c6a2715e9387f4966b4fc3d92bb786adb.tar.gz"], 81 | sha256 = "a146136bb6efdac0e3ede952d09aec44b771a87ebc713bd815c3a90a7428c908", 82 | ) 83 | 84 | load("@pybind11_bazel//:python_configure.bzl", "python_configure_pybind") 85 | python_configure_pybind(name = "local_config_python") 86 | 87 | 88 | -------------------------------------------------------------------------------- /tests/test_redis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ray 3 | import torch 4 | import pygloo 5 | 6 | @ray.remote(num_cpus=1) 7 | def test_redis(rank, world_size, redis_ip_address, redis_port, redis_password): 8 | ''' 9 | rank # Rank of this process within list of participating processes 10 | world_size # Number of participating processes 11 | ''' 12 | context = pygloo.rendezvous.Context(rank, world_size); 13 | 14 | attr = pygloo.transport.tcp.attr("localhost") 15 | # Perform rendezvous for TCP pairs 16 | dev = pygloo.transport.tcp.CreateDevice(attr) 17 | 18 | redisStore = pygloo.rendezvous.RedisStore(redis_ip_address, redis_port) 19 | 20 | redisStore.authorize(redis_password) 21 | store = pygloo.rendezvous.PrefixStore("default", redisStore) 22 | 23 | context.connectFullMesh(store, dev) 24 | 25 | print("Using RedisStore rendezvous, connect successful!!") 26 | 27 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 28 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 29 | sendptr = sendbuf.ctypes.data 30 | recvptr = recvbuf.ctypes.data 31 | 32 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 33 | datatype = pygloo.glooDataType_t.glooFloat32 34 | op = pygloo.ReduceOp.SUM 35 | algorithm = pygloo.allreduceAlgorithm.RING 36 | 37 | pygloo.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) 38 | 39 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 40 | 41 | 42 | @ray.remote(num_cpus=1) 43 | def test_multiGroup(rank, world_size, redis_ip_address, redis_port, redis_password): 44 | ''' 45 | test the multiGroup without prefixStore 46 | rank # Rank of this process within list of participating processes 47 | world_size # Number of participating processes 48 | ''' 49 | groups = [f"multiGroup{i}" for i in range(3)] 50 | contexts = {} 51 | for group_name in groups: 52 | context = pygloo.rendezvous.Context(rank, world_size); 53 | 54 | attr = pygloo.transport.tcp.attr("localhost") 55 | # Perform rendezvous for TCP pairs 56 | dev = pygloo.transport.tcp.CreateDevice(attr) 57 | 58 | redisStore = pygloo.rendezvous.RedisStore(redis_ip_address, redis_port) 59 | 60 | redisStore.authorize(redis_password) 61 | 62 | context.connectFullMesh(redisStore, dev) 63 | if rank == 0: 64 | keys = [] 65 | keys += [f"rank_{i}" for i in range(world_size)] 66 | keys += [f"{i}" for i in range(world_size)] 67 | redisStore.delKeys(keys) 68 | contexts[group_name] = context 69 | print("Using RedisStore rendezvous, connect successful!!") 70 | 71 | for group_name in groups: 72 | context = contexts[group_name] 73 | sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32) 74 | recvbuf = np.zeros_like(sendbuf, dtype=np.float32) 75 | sendptr = sendbuf.ctypes.data 76 | recvptr = recvbuf.ctypes.data 77 | 78 | data_size = sendbuf.size if isinstance(sendbuf, np.ndarray) else sendbuf.numpy().size 79 | datatype = pygloo.glooDataType_t.glooFloat32 80 | op = pygloo.ReduceOp.SUM 81 | algorithm = pygloo.allreduceAlgorithm.RING 82 | 83 | pygloo.allreduce(context, sendptr, recvptr, data_size, datatype, op, algorithm) 84 | 85 | print(f"rank {rank} sends {sendbuf}, receives {recvbuf}") 86 | 87 | if __name__ == "__main__": 88 | ray.init(num_cpus=6) 89 | world_size = 2 90 | fileStore_path = f"{ray.worker._global_node.get_session_dir_path()}" + "/collective/gloo/rendezvous" 91 | redis_address = ray.worker._global_node.redis_address 92 | redis_ip, redis_port = redis_address.split(":") 93 | redis_password = ray.worker._global_node.redis_password 94 | print(f"redis_address is {redis_ip}, the port is {redis_port}, redis_password is {redis_password}") 95 | ray.get([test_redis.remote(i, world_size, redis_ip, int(redis_port), redis_password) for i in range(world_size)]) 96 | 97 | ray.get([test_multiGroup.remote(i, world_size, redis_ip, int(redis_port), redis_password) for i in range(world_size)]) 98 | -------------------------------------------------------------------------------- /pygloo/src/reduce_scatter.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace pygloo { 7 | 8 | template 9 | const gloo::ReductionFunction *getReductionFunction(ReduceOp reduceop) { 10 | switch (reduceop) { 11 | case ReduceOp::SUM: 12 | return gloo::ReductionFunction::sum; 13 | break; 14 | case ReduceOp::PRODUCT: 15 | return gloo::ReductionFunction::product; 16 | break; 17 | case ReduceOp::MIN: 18 | return gloo::ReductionFunction::min; 19 | break; 20 | case ReduceOp::MAX: 21 | return gloo::ReductionFunction::max; 22 | break; 23 | case ReduceOp::BAND: 24 | throw std::runtime_error( 25 | "Cannot use ReduceOp.BAND with non-integral dtype"); 26 | break; 27 | case ReduceOp::BOR: 28 | throw std::runtime_error("Cannot use ReduceOp.BOR with non-integral dtype"); 29 | break; 30 | case ReduceOp::BXOR: 31 | throw std::runtime_error( 32 | "Cannot use ReduceOp.BXOR with non-integral dtype"); 33 | break; 34 | case ReduceOp::UNUSED: 35 | break; 36 | } 37 | throw std::runtime_error("Unhandled ReduceOp"); 38 | } 39 | 40 | template 41 | void reduce_scatter(const std::shared_ptr &context, 42 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 43 | std::vector recvElems, ReduceOp reduceop) { 44 | T *input_ptr = reinterpret_cast(sendbuf); 45 | 46 | std::vector inputbuf(size); 47 | 48 | memcpy(inputbuf.data(), input_ptr, size * sizeof(T)); 49 | 50 | std::vector dataPtrs{inputbuf.data()}; 51 | 52 | const gloo::ReductionFunction *fn = getReductionFunction(reduceop); 53 | 54 | gloo::ReduceScatterHalvingDoubling algorithm(context, dataPtrs, size, 55 | recvElems, fn); 56 | algorithm.run(); 57 | 58 | memcpy(reinterpret_cast(recvbuf), inputbuf.data(), 59 | recvElems[context->rank] * sizeof(T)); 60 | } 61 | 62 | void reduce_scatter_wrapper(const std::shared_ptr &context, 63 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 64 | std::vector recvElems, glooDataType_t datatype, 65 | ReduceOp reduceop) { 66 | switch (datatype) { 67 | case glooDataType_t::glooInt8: 68 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 69 | reduceop); 70 | break; 71 | case glooDataType_t::glooUint8: 72 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 73 | reduceop); 74 | break; 75 | case glooDataType_t::glooInt32: 76 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 77 | reduceop); 78 | break; 79 | case glooDataType_t::glooUint32: 80 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 81 | reduceop); 82 | break; 83 | case glooDataType_t::glooInt64: 84 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 85 | reduceop); 86 | break; 87 | case glooDataType_t::glooUint64: 88 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 89 | reduceop); 90 | break; 91 | case glooDataType_t::glooFloat16: 92 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 93 | reduceop); 94 | break; 95 | case glooDataType_t::glooFloat32: 96 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 97 | reduceop); 98 | break; 99 | case glooDataType_t::glooFloat64: 100 | reduce_scatter(context, sendbuf, recvbuf, size, recvElems, 101 | reduceop); 102 | break; 103 | default: 104 | throw std::runtime_error("Unhandled dataType"); 105 | } 106 | } 107 | } // namespace pygloo 108 | -------------------------------------------------------------------------------- /pygloo/src/allgather.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace pygloo { 8 | 9 | template 10 | void allgather(const std::shared_ptr &context, intptr_t sendbuf, 11 | intptr_t recvbuf, size_t size, uint32_t tag) { 12 | T *input_ptr = reinterpret_cast(sendbuf); 13 | T *output_ptr = reinterpret_cast(recvbuf); 14 | 15 | // Configure AllgatherOptions struct and call allgather function 16 | gloo::AllgatherOptions opts_(context); 17 | opts_.setInput(input_ptr, size); 18 | opts_.setOutput(output_ptr, size * context->size); 19 | opts_.setTag(tag); 20 | 21 | gloo::allgather(opts_); 22 | } 23 | 24 | void allgather_wrapper(const std::shared_ptr &context, 25 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 26 | glooDataType_t datatype, uint32_t tag) { 27 | switch (datatype) { 28 | case glooDataType_t::glooInt8: 29 | allgather(context, sendbuf, recvbuf, size, tag); 30 | break; 31 | case glooDataType_t::glooUint8: 32 | allgather(context, sendbuf, recvbuf, size, tag); 33 | break; 34 | case glooDataType_t::glooInt32: 35 | allgather(context, sendbuf, recvbuf, size, tag); 36 | break; 37 | case glooDataType_t::glooUint32: 38 | allgather(context, sendbuf, recvbuf, size, tag); 39 | break; 40 | case glooDataType_t::glooInt64: 41 | allgather(context, sendbuf, recvbuf, size, tag); 42 | break; 43 | case glooDataType_t::glooUint64: 44 | allgather(context, sendbuf, recvbuf, size, tag); 45 | break; 46 | case glooDataType_t::glooFloat16: 47 | allgather(context, sendbuf, recvbuf, size, tag); 48 | break; 49 | case glooDataType_t::glooFloat32: 50 | allgather(context, sendbuf, recvbuf, size, tag); 51 | break; 52 | case glooDataType_t::glooFloat64: 53 | allgather(context, sendbuf, recvbuf, size, tag); 54 | break; 55 | default: 56 | throw std::runtime_error("Unhandled dataType"); 57 | } 58 | } 59 | 60 | template 61 | void allgatherv(const std::shared_ptr &context, intptr_t sendbuf, 62 | intptr_t recvbuf, size_t size, uint32_t tag) { 63 | T *input_ptr = reinterpret_cast(sendbuf); 64 | T *output_ptr = reinterpret_cast(recvbuf); 65 | 66 | // Configure AllgatherOptions struct and call allgather function 67 | gloo::AllgatherOptions opts_(context); 68 | opts_.setInput(input_ptr, size); 69 | opts_.setOutput(output_ptr, size * context->size); 70 | opts_.setTag(tag); 71 | 72 | gloo::allgather(opts_); 73 | } 74 | 75 | void allgatherv_wrapper(const std::shared_ptr &context, 76 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 77 | glooDataType_t datatype, uint32_t tag) { 78 | switch (datatype) { 79 | case glooDataType_t::glooInt8: 80 | allgather(context, sendbuf, recvbuf, size, tag); 81 | break; 82 | case glooDataType_t::glooUint8: 83 | allgather(context, sendbuf, recvbuf, size, tag); 84 | break; 85 | case glooDataType_t::glooInt32: 86 | allgather(context, sendbuf, recvbuf, size, tag); 87 | break; 88 | case glooDataType_t::glooUint32: 89 | allgather(context, sendbuf, recvbuf, size, tag); 90 | break; 91 | case glooDataType_t::glooInt64: 92 | allgather(context, sendbuf, recvbuf, size, tag); 93 | break; 94 | case glooDataType_t::glooUint64: 95 | allgather(context, sendbuf, recvbuf, size, tag); 96 | break; 97 | case glooDataType_t::glooFloat16: 98 | allgather(context, sendbuf, recvbuf, size, tag); 99 | break; 100 | case glooDataType_t::glooFloat32: 101 | allgather(context, sendbuf, recvbuf, size, tag); 102 | break; 103 | case glooDataType_t::glooFloat64: 104 | allgather(context, sendbuf, recvbuf, size, tag); 105 | break; 106 | default: 107 | throw std::runtime_error("Unhandled dataType"); 108 | } 109 | } 110 | } // namespace pygloo 111 | -------------------------------------------------------------------------------- /pygloo/include/collective.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace pygloo { 12 | 13 | enum class ReduceOp : std::uint8_t { 14 | SUM = 0, 15 | PRODUCT, 16 | MIN, 17 | MAX, 18 | BAND, // Bitwise AND 19 | BOR, // Bitwise OR 20 | BXOR, // Bitwise XOR 21 | UNUSED, 22 | }; 23 | 24 | typedef void (*ReduceFunc)(void *, const void *, const void *, size_t); 25 | 26 | template ReduceFunc toFunction(const ReduceOp &r) { 27 | switch (r) { 28 | case ReduceOp::SUM: 29 | return ReduceFunc(&gloo::sum); 30 | case ReduceOp::PRODUCT: 31 | return ReduceFunc(&gloo::product); 32 | case ReduceOp::MIN: 33 | return ReduceFunc(&gloo::min); 34 | case ReduceOp::MAX: 35 | return ReduceFunc(&gloo::max); 36 | case ReduceOp::BAND: 37 | throw std::runtime_error( 38 | "Cannot use ReduceOp.BAND with non-integral dtype"); 39 | break; 40 | case ReduceOp::BOR: 41 | throw std::runtime_error("Cannot use ReduceOp.BOR with non-integral dtype"); 42 | break; 43 | case ReduceOp::BXOR: 44 | throw std::runtime_error( 45 | "Cannot use ReduceOp.BXOR with non-integral dtype"); 46 | break; 47 | case ReduceOp::UNUSED: 48 | break; 49 | } 50 | 51 | throw std::runtime_error("Unhandled ReduceOp"); 52 | } 53 | 54 | enum class glooDataType_t : std::uint8_t { 55 | glooInt8 = 0, 56 | glooUint8, 57 | glooInt32, 58 | glooUint32, 59 | glooInt64, 60 | glooUint64, 61 | glooFloat16, 62 | glooFloat32, 63 | glooFloat64, 64 | }; 65 | 66 | void allreduce_wrapper(const std::shared_ptr &context, 67 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 68 | glooDataType_t datatype, 69 | ReduceOp reduceop = ReduceOp::SUM, 70 | gloo::AllreduceOptions::Algorithm algorithm = 71 | gloo::AllreduceOptions::Algorithm::RING, 72 | uint32_t tag = 0); 73 | 74 | void allgather_wrapper(const std::shared_ptr &context, 75 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 76 | glooDataType_t datatype, uint32_t tag = 0); 77 | 78 | void allgatherv_wrapper(const std::shared_ptr &context, 79 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 80 | glooDataType_t datatype, uint32_t tag = 0); 81 | 82 | void reduce_wrapper(const std::shared_ptr &context, 83 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 84 | glooDataType_t datatype, 85 | ReduceOp reduceop = pygloo::ReduceOp::SUM, int root = 0, 86 | uint32_t tag = 0); 87 | 88 | void scatter_wrapper(const std::shared_ptr &context, 89 | std::vector sendbuf, intptr_t recvbuf, 90 | size_t size, glooDataType_t datatype, int root = 0, 91 | uint32_t tag = 0); 92 | 93 | void gather_wrapper(const std::shared_ptr &context, 94 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 95 | glooDataType_t datatype, int root = 0, uint32_t tag = 0); 96 | 97 | void send_wrapper(const std::shared_ptr &context, 98 | intptr_t sendbuf, size_t size, glooDataType_t datatype, 99 | int peer, uint32_t tag = 0); 100 | 101 | void recv_wrapper(const std::shared_ptr &context, 102 | intptr_t recvbuf, size_t size, glooDataType_t datatype, 103 | int peer, uint32_t tag = 0); 104 | 105 | void broadcast_wrapper(const std::shared_ptr &context, 106 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 107 | glooDataType_t datatype, int root = 0, uint32_t tag = 0); 108 | 109 | void reduce_scatter_wrapper(const std::shared_ptr &context, 110 | intptr_t sendbuf, intptr_t recvbuf, size_t size, 111 | std::vector recvElems, glooDataType_t datatype, 112 | ReduceOp reduceop = pygloo::ReduceOp::SUM); 113 | 114 | void barrier(const std::shared_ptr &context, uint32_t tag = 0); 115 | } // namespace pygloo 116 | -------------------------------------------------------------------------------- /pygloo/src/transport.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace pygloo { 5 | namespace transport { 6 | 7 | #if GLOO_HAVE_TRANSPORT_TCP 8 | template 9 | using overload_cast_ = pybind11::detail::overload_cast_impl; 10 | 11 | void def_transport_tcp_module(pybind11::module &m) { 12 | pybind11::module tcp = m.def_submodule("tcp", "This is a tcp module"); 13 | 14 | tcp.def("CreateDevice", &gloo::transport::tcp::CreateDevice); 15 | 16 | pybind11::class_(tcp, "attr") 17 | .def(pybind11::init<>()) 18 | .def(pybind11::init()) 19 | .def_readwrite("hostname", &gloo::transport::tcp::attr::hostname) 20 | .def_readwrite("iface", &gloo::transport::tcp::attr::iface) 21 | .def_readwrite("ai_family", &gloo::transport::tcp::attr::ai_family) 22 | .def_readwrite("hostname", &gloo::transport::tcp::attr::hostname) 23 | .def_readwrite("ai_socktype", &gloo::transport::tcp::attr::ai_socktype) 24 | .def_readwrite("ai_protocol", &gloo::transport::tcp::attr::ai_protocol) 25 | .def_readwrite("ai_addr", &gloo::transport::tcp::attr::ai_addr) 26 | .def_readwrite("ai_addrlen", &gloo::transport::tcp::attr::ai_addrlen); 27 | 28 | pybind11::class_>(tcp, 30 | "Context") 31 | .def(pybind11::init, int, 32 | int>()) 33 | // .def("createPair", &gloo::transport::tcp::Context::createPair) 34 | .def("createUnboundBuffer", 35 | &gloo::transport::tcp::Context::createUnboundBuffer); 36 | 37 | pybind11::class_, 39 | gloo::transport::Device>(tcp, "Device") 40 | .def(pybind11::init()); 41 | } 42 | #else 43 | void def_transport_tcp_module(pybind11::module &m) { 44 | pybind11::module tcp = m.def_submodule("tcp", "This is a tcp module"); 45 | } 46 | #endif 47 | 48 | #if GLOO_HAVE_TRANSPORT_UV 49 | void def_transport_uv_module(pybind11::module &m) { 50 | pybind11::module uv = m.def_submodule("uv", "This is a uv module"); 51 | 52 | uv.def("CreateDevice", &gloo::transport::uv::CreateDevice, "CreateDevice"); 53 | 54 | pybind11::class_(uv, "attr") 55 | .def(pybind11::init<>()) 56 | .def(pybind11::init()) 57 | .def_readwrite("hostname", &gloo::transport::uv::attr::hostname) 58 | .def_readwrite("iface", &gloo::transport::uv::attr::iface) 59 | .def_readwrite("ai_family", &gloo::transport::uv::attr::ai_family) 60 | .def_readwrite("ai_socktype", &gloo::transport::uv::attr::ai_socktype) 61 | .def_readwrite("ai_protocol", &gloo::transport::uv::attr::ai_protocol) 62 | .def_readwrite("ai_addr", &gloo::transport::uv::attr::ai_addr) 63 | .def_readwrite("ai_addrlen", &gloo::transport::uv::attr::ai_addrlen); 64 | 65 | pybind11::class_>(uv, "Context") 67 | .def(pybind11::init, int, 68 | int>()) 69 | .def("createUnboundBuffer", 70 | &gloo::transport::uv::Context::createUnboundBuffer); 71 | 72 | pybind11::class_, 74 | gloo::transport::Device>(uv, "Device") 75 | .def(pybind11::init()); 76 | } 77 | #else 78 | void def_transport_uv_module(pybind11::module &m) { 79 | pybind11::module uv = m.def_submodule("uv", "This is a uv module"); 80 | } 81 | #endif 82 | 83 | void def_transport_module(pybind11::module &m) { 84 | pybind11::module transport = 85 | m.def_submodule("transport", "This is a transport module"); 86 | 87 | pybind11::class_, 89 | pygloo::transport::PyDevice>(transport, "Device", 90 | pybind11::module_local()) 91 | .def("str", &gloo::transport::Device::str) 92 | .def("getPCIBusID", &gloo::transport::Device::getPCIBusID) 93 | .def("getInterfaceSpeed", &gloo::transport::Device::getInterfaceSpeed) 94 | .def("hasGPUDirect", &gloo::transport::Device::hasGPUDirect) 95 | .def("createContext", &gloo::transport::Device::createContext); 96 | 97 | def_transport_uv_module(transport); 98 | def_transport_tcp_module(transport); 99 | } 100 | } // namespace transport 101 | } // namespace pygloo 102 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import errno 3 | import glob 4 | import io 5 | import logging 6 | import os 7 | import re 8 | import shutil 9 | import subprocess 10 | import sys 11 | import tarfile 12 | import tempfile 13 | import zipfile 14 | import time 15 | 16 | from itertools import chain 17 | from itertools import takewhile 18 | 19 | import urllib.error 20 | import urllib.parse 21 | import urllib.request 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | SUPPORTED_PYTHONS = [(3, 6), (3, 7), (3, 8)] 26 | SUPPORTED_BAZEL = (3, 2, 0) 27 | 28 | ROOT_DIR = os.path.dirname(__file__) 29 | 30 | install_requires = [] 31 | 32 | 33 | # Calls Bazel in PATH, falling back to the standard user installatation path 34 | # (~/.bazel/bin/bazel) if it isn't found. 35 | def bazel_invoke(invoker, cmdline, *args, **kwargs): 36 | home = os.path.expanduser("~") 37 | first_candidate = os.getenv("BAZEL_PATH", "bazel") 38 | candidates = [first_candidate] 39 | if sys.platform == "win32": 40 | mingw_dir = os.getenv("MINGW_DIR") 41 | if mingw_dir: 42 | candidates.append(mingw_dir + "/bin/bazel.exe") 43 | else: 44 | candidates.append(os.path.join(home, ".bazel", "bin", "bazel")) 45 | result = None 46 | for i, cmd in enumerate(candidates): 47 | try: 48 | result = invoker([cmd] + cmdline, *args, **kwargs) 49 | break 50 | except IOError: 51 | if i >= len(candidates) - 1: 52 | raise 53 | return result 54 | 55 | 56 | def move_file(target_dir, filename): 57 | source = filename 58 | destination = os.path.join(target_dir, filename.split('/')[-1]) 59 | # Create the target directory if it doesn't already exist. 60 | os.makedirs(os.path.dirname(destination), exist_ok=True) 61 | if not os.path.exists(destination): 62 | print("Copying {} to {}.".format(source, destination)) 63 | if sys.platform == "win32": 64 | # Does not preserve file mode (needed to avoid read-only bit) 65 | shutil.copyfile(source, destination, follow_symlinks=True) 66 | else: 67 | # Preserves file mode (needed to copy executable bit) 68 | shutil.copy(source, destination, follow_symlinks=True) 69 | 70 | 71 | def build(): 72 | # no support windows 73 | if tuple(sys.version_info[:2]) not in SUPPORTED_PYTHONS: 74 | msg = ("Detected Python version {}, which is not supported. " 75 | "Only Python {} are supported.").format( 76 | ".".join(map(str, sys.version_info[:2])), 77 | ", ".join(".".join(map(str, v)) for v in SUPPORTED_PYTHONS)) 78 | raise RuntimeError(msg) 79 | 80 | bazel_env = dict(os.environ, PYTHON3_BIN_PATH=sys.executable) 81 | 82 | version_info = bazel_invoke(subprocess.check_output, ["--version"]) 83 | bazel_version_str = version_info.rstrip().decode("utf-8").split(" ", 1)[1] 84 | bazel_version_split = bazel_version_str.split(".") 85 | bazel_version_digits = [ 86 | "".join(takewhile(str.isdigit, s)) for s in bazel_version_split 87 | ] 88 | bazel_version = tuple(map(int, bazel_version_digits)) 89 | if bazel_version < SUPPORTED_BAZEL: 90 | logger.warning("Expected Bazel version {} but found {}".format( 91 | ".".join(map(str, SUPPORTED_BAZEL)), bazel_version_str)) 92 | 93 | bazel_targets = ["//pygloo:all"] 94 | return bazel_invoke( 95 | subprocess.check_call, 96 | ["build", "--verbose_failures", "--"] + bazel_targets, 97 | env=bazel_env) 98 | 99 | 100 | def pip_run(build_ext): 101 | build() 102 | 103 | files_to_include = ["./bazel-bin/pygloo/pygloo.so"] 104 | 105 | for filename in files_to_include: 106 | move_file(build_ext.build_lib, filename) 107 | 108 | 109 | 110 | if __name__ == "__main__": 111 | import setuptools 112 | import setuptools.command.build_ext 113 | from setuptools.command.install import install 114 | class build_ext(setuptools.command.build_ext.build_ext): 115 | def run(self): 116 | return pip_run(self) 117 | 118 | class BinaryDistribution(setuptools.Distribution): 119 | def has_ext_modules(self): 120 | return True 121 | 122 | class InstallPlatlib(install): 123 | def finalize_options(self): 124 | install.finalize_options(self) 125 | if self.distribution.has_ext_modules(): 126 | self.install_lib = self.install_platlib 127 | 128 | with open(os.path.join(ROOT_DIR, "README.md"), encoding="utf-8") as f: 129 | long_description = f.read() 130 | 131 | setuptools.setup( 132 | name="pygloo", 133 | version="0.2.0", 134 | author="Ray Team", 135 | author_email="ray-dev@googlegroups.com", 136 | description=("A python binding for gloo"), 137 | long_description=long_description, 138 | long_description_content_type="text/markdown", 139 | url="https://github.com/ray-project/pygloo", 140 | classifiers=[ 141 | 'Programming Language :: Python :: 3', 142 | 'Topic :: Scientific/Engineering :: Artificial Intelligence' 143 | ], 144 | keywords=("collective communication"), 145 | packages=setuptools.find_packages(), 146 | cmdclass={"build_ext": build_ext, 147 | "install": InstallPlatlib}, 148 | # The BinaryDistribution argument triggers build_ext. 149 | distclass=BinaryDistribution, 150 | install_requires=install_requires, 151 | setup_requires=["wheel"], 152 | include_package_data=True, 153 | zip_safe=False, 154 | license="Apache 2.0" 155 | ) 156 | -------------------------------------------------------------------------------- /pygloo/main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | // #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace pygloo { 13 | bool transport_tcp_available() { return GLOO_HAVE_TRANSPORT_TCP; } 14 | 15 | bool transport_uv_available() { return GLOO_HAVE_TRANSPORT_UV; } 16 | } // namespace pygloo 17 | 18 | PYBIND11_MODULE(pygloo, m) { 19 | m.doc() = "binding gloo from c to python"; // optional module docstring 20 | 21 | m.def("transport_tcp_available", &pygloo::transport_tcp_available, 22 | "transport_tcp_available"); 23 | 24 | m.def("transport_uv_available", &pygloo::transport_uv_available, 25 | "transport_uv_available"); 26 | 27 | pybind11::enum_(m, "ReduceOp", pybind11::arithmetic()) 28 | .value("SUM", pygloo::ReduceOp::SUM) 29 | .value("PRODUCT", pygloo::ReduceOp::PRODUCT) 30 | .value("MIN", pygloo::ReduceOp::MIN) 31 | .value("MAX", pygloo::ReduceOp::MAX) 32 | .value("BAND", pygloo::ReduceOp::BAND) 33 | .value("BOR", pygloo::ReduceOp::BOR) 34 | .value("BXOR", pygloo::ReduceOp::BXOR) 35 | .value("UNUSED", pygloo::ReduceOp::UNUSED) 36 | .export_values(); 37 | 38 | pybind11::enum_( 39 | m, "allreduceAlgorithm", pybind11::arithmetic()) 40 | .value("SUM", gloo::detail::AllreduceOptionsImpl::Algorithm::UNSPECIFIED) 41 | .value("RING", gloo::detail::AllreduceOptionsImpl::Algorithm::RING) 42 | .value("BCUBE", gloo::detail::AllreduceOptionsImpl::Algorithm::BCUBE) 43 | .export_values(); 44 | 45 | pybind11::enum_(m, "glooDataType_t", 46 | pybind11::arithmetic()) 47 | .value("glooInt8", pygloo::glooDataType_t::glooInt8) 48 | .value("glooUint8", pygloo::glooDataType_t::glooUint8) 49 | .value("glooInt32", pygloo::glooDataType_t::glooInt32) 50 | .value("glooUint32", pygloo::glooDataType_t::glooUint32) 51 | .value("glooInt64", pygloo::glooDataType_t::glooInt64) 52 | .value("glooUint64", pygloo::glooDataType_t::glooUint64) 53 | .value("glooFloat16", pygloo::glooDataType_t::glooFloat16) 54 | .value("glooFloat32", pygloo::glooDataType_t::glooFloat32) 55 | .value("glooFloat64", pygloo::glooDataType_t::glooFloat64) 56 | .export_values(); 57 | 58 | m.def("allreduce", &pygloo::allreduce_wrapper, 59 | pybind11::arg("context") = nullptr, pybind11::arg("sendbuf") = nullptr, 60 | pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr, 61 | pybind11::arg("datatype") = nullptr, 62 | pybind11::arg("reduceop") = pygloo::ReduceOp::SUM, 63 | pybind11::arg("algorithm") = gloo::AllreduceOptions::Algorithm::RING, 64 | pybind11::arg("tag") = 0); 65 | 66 | m.def("allgather", &pygloo::allgather_wrapper, 67 | pybind11::arg("context") = nullptr, pybind11::arg("sendbuf") = nullptr, 68 | pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr, 69 | pybind11::arg("datatype") = nullptr, pybind11::arg("tag") = 0); 70 | m.def("allgatherv", &pygloo::allgatherv_wrapper, 71 | pybind11::arg("context") = nullptr, pybind11::arg("sendbuf") = nullptr, 72 | pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr, 73 | pybind11::arg("datatype") = nullptr, pybind11::arg("tag") = 0); 74 | 75 | m.def("reduce", &pygloo::reduce_wrapper, pybind11::arg("context") = nullptr, 76 | pybind11::arg("sendbuf") = nullptr, pybind11::arg("recvbuf") = nullptr, 77 | pybind11::arg("size") = nullptr, pybind11::arg("datatype") = nullptr, 78 | pybind11::arg("reduceop") = pygloo::ReduceOp::SUM, 79 | pybind11::arg("root") = 0, pybind11::arg("tag") = 0); 80 | 81 | m.def("scatter", &pygloo::scatter_wrapper, pybind11::arg("context") = nullptr, 82 | pybind11::arg("sendbuf") = nullptr, pybind11::arg("recvbuf") = nullptr, 83 | pybind11::arg("size") = nullptr, pybind11::arg("datatype") = nullptr, 84 | pybind11::arg("root") = 0, pybind11::arg("tag") = 0); 85 | 86 | m.def("gather", &pygloo::gather_wrapper, pybind11::arg("context") = nullptr, 87 | pybind11::arg("sendbuf") = nullptr, pybind11::arg("recvbuf") = nullptr, 88 | pybind11::arg("size") = nullptr, pybind11::arg("datatype") = nullptr, 89 | pybind11::arg("root") = 0, pybind11::arg("tag") = 0); 90 | 91 | m.def("send", &pygloo::send_wrapper, pybind11::arg("context") = nullptr, 92 | pybind11::arg("sendbuf") = nullptr, pybind11::arg("size") = nullptr, 93 | pybind11::arg("datatype") = nullptr, pybind11::arg("peer") = nullptr, 94 | pybind11::arg("tag") = 0); 95 | m.def("recv", &pygloo::recv_wrapper, pybind11::arg("context") = nullptr, 96 | pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr, 97 | pybind11::arg("datatype") = nullptr, pybind11::arg("peer") = nullptr, 98 | pybind11::arg("tag") = 0); 99 | 100 | m.def("broadcast", &pygloo::broadcast_wrapper, 101 | pybind11::arg("context") = nullptr, pybind11::arg("sendbuf") = nullptr, 102 | pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr, 103 | pybind11::arg("datatype") = nullptr, pybind11::arg("root") = 0, 104 | pybind11::arg("tag") = 0); 105 | 106 | m.def("reduce_scatter", &pygloo::reduce_scatter_wrapper, 107 | pybind11::arg("context") = nullptr, pybind11::arg("sendbuf") = nullptr, 108 | pybind11::arg("recvbuf") = nullptr, pybind11::arg("size") = nullptr, 109 | pybind11::arg("recvElems") = nullptr, 110 | pybind11::arg("datatype") = nullptr, 111 | pybind11::arg("reduceop") = pygloo::ReduceOp::SUM); 112 | 113 | m.def("barrier", &pygloo::barrier, pybind11::arg("context") = nullptr, 114 | pybind11::arg("tag") = 0); 115 | 116 | pybind11::class_>(m, "Context") 117 | .def(pybind11::init(), pybind11::arg("rank") = nullptr, 118 | pybind11::arg("size") = nullptr, pybind11::arg("base") = 2) 119 | .def("getDevice", &gloo::Context::getDevice) 120 | .def_readonly("rank", &gloo::Context::rank) 121 | .def_readonly("size", &gloo::Context::size) 122 | .def_readwrite("base", &gloo::Context::base) 123 | // .def("getPair", &gloo::Context::getPair) 124 | .def("createUnboundBuffer", &gloo::Context::createUnboundBuffer) 125 | .def("nextSlot", &gloo::Context::nextSlot) 126 | .def("closeConnections", &gloo::Context::closeConnections) 127 | .def("setTimeout", &gloo::Context::setTimeout) 128 | .def("getTimeout", &gloo::Context::getTimeout); 129 | 130 | pygloo::transport::def_transport_module(m); 131 | pygloo::rendezvous::def_rendezvous_module(m); 132 | } 133 | -------------------------------------------------------------------------------- /pygloo/src/rendezvous.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #if GLOO_USE_REDIS 17 | #include 18 | #include 19 | #endif 20 | 21 | using namespace gloo; 22 | 23 | namespace pygloo { 24 | namespace rendezvous { 25 | 26 | void def_rendezvous_module(pybind11::module &m) { 27 | pybind11::module rendezvous = 28 | m.def_submodule("rendezvous", "This is a rendezvous module"); 29 | 30 | pybind11::class_>(rendezvous, 32 | "Context") 33 | .def(pybind11::init(), pybind11::arg("rank") = nullptr, 34 | pybind11::arg("size") = nullptr, pybind11::arg("base") = 2) 35 | .def("connectFullMesh", &gloo::rendezvous::Context::connectFullMesh); 36 | 37 | pybind11::class_>(rendezvous, 39 | "Store") 40 | .def("set", &gloo::rendezvous::Store::set) 41 | .def("get", &gloo::rendezvous::Store::get); 42 | 43 | pybind11::class_>(rendezvous, 45 | "FileStore") 46 | .def(pybind11::init()) 47 | .def("set", &gloo::rendezvous::FileStore::set) 48 | .def("get", &gloo::rendezvous::FileStore::get); 49 | 50 | pybind11::class_>(rendezvous, 52 | "HashStore") 53 | .def(pybind11::init([]() { return new gloo::rendezvous::HashStore(); })) 54 | .def("set", &gloo::rendezvous::HashStore::set) 55 | .def("get", &gloo::rendezvous::HashStore::get); 56 | 57 | pybind11::class_>( 59 | rendezvous, "PrefixStore") 60 | .def(pybind11::init()) 61 | .def("set", &gloo::rendezvous::PrefixStore::set) 62 | .def("get", &gloo::rendezvous::PrefixStore::get); 63 | 64 | #if GLOO_USE_REDIS 65 | class RedisStoreWithAuth : public gloo::rendezvous::RedisStore { 66 | public: 67 | RedisStoreWithAuth(const std::string &host, int port) 68 | : gloo::rendezvous::RedisStore(host, port){}; 69 | using gloo::rendezvous::RedisStore::check; 70 | using gloo::rendezvous::RedisStore::get; 71 | using gloo::rendezvous::RedisStore::redis_; 72 | using gloo::rendezvous::RedisStore::set; 73 | using gloo::rendezvous::RedisStore::wait; 74 | 75 | void authorize(std::string redis_password) { 76 | void *ptr = 77 | (redisReply *)redisCommand(redis_, "auth %b", redis_password.c_str(), 78 | (size_t)redis_password.size()); 79 | 80 | if (ptr == nullptr) { 81 | GLOO_THROW_IO_EXCEPTION(redis_->errstr); 82 | } 83 | redisReply *reply = static_cast(ptr); 84 | if (reply->type == REDIS_REPLY_ERROR) { 85 | GLOO_THROW_IO_EXCEPTION("Error: ", reply->str); 86 | } 87 | freeReplyObject(reply); 88 | } 89 | 90 | void delKey(const std::string &key) { 91 | void* ptr = redisCommand(redis_, "del %b", key.c_str(), (size_t)key.size()); 92 | 93 | if (ptr == nullptr) { 94 | GLOO_THROW_IO_EXCEPTION(redis_->errstr); 95 | } 96 | redisReply *reply = static_cast(ptr); 97 | if (reply->type == REDIS_REPLY_ERROR) { 98 | GLOO_THROW_IO_EXCEPTION("Error: ", reply->str); 99 | } 100 | freeReplyObject(reply); 101 | } 102 | 103 | void delKeys(const std::vector &keys) { 104 | bool result = check(keys); 105 | if(!result) 106 | GLOO_THROW_IO_EXCEPTION("Error: keys not exist"); 107 | 108 | std::vector args; 109 | args.push_back("del"); 110 | for (const auto& key : keys) { 111 | args.push_back(key); 112 | } 113 | 114 | std::vector argv; 115 | std::vector argvlen; 116 | for (const auto& arg : args) { 117 | argv.push_back(arg.c_str()); 118 | argvlen.push_back(arg.length()); 119 | } 120 | 121 | auto argc = argv.size(); 122 | void* ptr = redisCommandArgv(redis_, argc, argv.data(), argvlen.data()); 123 | 124 | if (ptr == nullptr) { 125 | GLOO_THROW_IO_EXCEPTION(redis_->errstr); 126 | } 127 | redisReply *reply = static_cast(ptr); 128 | if (reply->type == REDIS_REPLY_ERROR) { 129 | GLOO_THROW_IO_EXCEPTION("Error: ", reply->str); 130 | } 131 | freeReplyObject(reply); 132 | } 133 | }; 134 | 135 | pybind11::class_>(rendezvous, 137 | "_RedisStore") 138 | .def(pybind11::init()) 139 | .def("set", &gloo::rendezvous::RedisStore::set) 140 | .def("get", &gloo::rendezvous::RedisStore::get); 141 | 142 | pybind11::class_>(rendezvous, 145 | "RedisStore") 146 | .def(pybind11::init()) 147 | .def("set", &RedisStoreWithAuth::set) 148 | .def("get", &RedisStoreWithAuth::get) 149 | .def("authorize", &RedisStoreWithAuth::authorize) 150 | .def("delKey", &RedisStoreWithAuth::delKey) 151 | .def("delKeys", &RedisStoreWithAuth::delKeys); 152 | #endif 153 | 154 | 155 | class CustomStore: public gloo::rendezvous::Store { 156 | public: 157 | explicit CustomStore(const pybind11::object &real_store_py_object) 158 | :real_store_py_object_(real_store_py_object) { 159 | } 160 | 161 | virtual ~CustomStore() {} 162 | 163 | virtual void set(const std::string& key, const std::vector& data) override { 164 | pybind11::str py_key(key.data(), key.size()); 165 | pybind11::bytes py_data(data.data(), data.size()); 166 | auto set_func = real_store_py_object_.attr("set"); 167 | set_func(py_key, py_data); 168 | } 169 | 170 | virtual std::vector get(const std::string& key) override { 171 | /// Wait until key being ready. 172 | wait({key}); 173 | 174 | pybind11::str py_key(key.data(), key.size()); 175 | auto get_func = real_store_py_object_.attr("get"); 176 | pybind11::bytes data = get_func(py_key); 177 | std::string ret_str = data; 178 | std::vector ret(ret_str.data(), ret_str.data() + ret_str.size()); 179 | return ret; 180 | } 181 | 182 | virtual void wait(const std::vector& keys) override { 183 | wait(keys, Store::kDefaultTimeout); 184 | } 185 | 186 | virtual void wait(const std::vector& keys, const std::chrono::milliseconds& timeout) override { 187 | // We now ignore the timeout_ms. 188 | 189 | pybind11::list py_keys = pybind11::cast(keys); 190 | auto wait_func = real_store_py_object_.attr("wait"); 191 | wait_func(py_keys); 192 | } 193 | 194 | void delKeys(const std::vector &keys) { 195 | pybind11::list py_keys = pybind11::cast(keys); 196 | auto del_keys_func = real_store_py_object_.attr("del_keys"); 197 | del_keys_func(py_keys); 198 | } 199 | 200 | protected: 201 | const pybind11::object real_store_py_object_; 202 | }; 203 | 204 | 205 | pybind11::class_>(rendezvous, "CustomStore") 207 | .def(pybind11::init()) 208 | .def("set", &CustomStore::set) 209 | .def("get", &CustomStore::get) 210 | .def("delKeys", &CustomStore::delKeys); 211 | 212 | } 213 | } // namespace rendezvous 214 | } // namespace pygloo 215 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------