├── fed ├── tests │ ├── __init__.py │ ├── multi-jobs │ │ ├── __init__.py │ │ ├── test_multi_proxy_actor.py │ │ └── test_ignore_other_job_msg.py │ ├── client_mode_tests │ │ ├── __init__.py │ │ └── test_basic_client_mode.py │ ├── serializations_tests │ │ ├── __init__.py │ │ └── test_unpickle_with_whitelist.py │ ├── without_ray_tests │ │ ├── __init__.py │ │ ├── test_utils.py │ │ └── test_tree_utils.py │ ├── test_internal_kv.py │ ├── test_options.py │ ├── test_api.py │ ├── test_pass_fed_objects_in_containers_in_normal_tasks.py │ ├── test_listening_address.py │ ├── test_pass_fed_objects_in_containers_in_actor.py │ ├── test_async_startup_2_clusters.py │ ├── test_repeat_init.py │ ├── test_retry_policy.py │ ├── test_ping_others.py │ ├── test_enable_tls_across_parties.py │ ├── test_cache_fed_objects.py │ ├── test_basic_pass_fed_objects.py │ ├── simple_example.py │ ├── test_utils.py │ ├── test_exit_on_failure_sending.py │ ├── test_reset_context.py │ ├── test_setup_proxy_actor.py │ ├── test_fed_get.py │ ├── test_transport_proxy_tls.py │ ├── test_grpc_options_on_proxies.py │ ├── test_transport_proxy.py │ └── test_cross_silo_error.py ├── grpc │ ├── fed.proto │ ├── __init__.py │ ├── pb3 │ │ ├── __init__.py │ │ ├── fed_pb2_grpc.py │ │ └── fed_pb2.py │ └── pb4 │ │ ├── __init__.py │ │ ├── fed_pb2.py │ │ └── fed_pb2_grpc.py ├── _private │ ├── __init__.py │ ├── constants.py │ ├── serialization_utils.py │ ├── message_queue.py │ ├── global_context.py │ ├── fed_call_holder.py │ ├── fed_actor.py │ └── compatible_utils.py ├── proxy │ ├── grpc │ │ ├── __init__.py │ │ └── grpc_options.py │ ├── __init__.py │ └── base_proxy.py ├── __init__.py ├── exceptions.py ├── fed_object.py ├── config.py ├── tree_util.py └── cleanup.py ├── dev-requirements.txt ├── docs ├── source │ ├── advanced_topic │ │ ├── principle.md │ │ ├── architecture.md │ │ └── index.rst │ ├── getting_started │ │ ├── quick_start.md │ │ ├── installation.md │ │ └── index.rst │ ├── tutorials │ │ ├── index.rst │ │ ├── split_learning_demo.ipynb │ │ ├── federated_learning_demo.ipynb │ │ └── transition_from_ray_to_rayfed.ipynb │ ├── api.rst │ ├── conf.py │ └── index.rst ├── images │ ├── morse-logo.png │ └── secretflow-logo.png ├── enhancements │ ├── images │ │ ├── dead_lock.png │ │ ├── local_error.png │ │ ├── threading_model.png │ │ └── cross_silo_error_flow.png │ └── 2023-09-01-cross-silo-error.md ├── doc-requirements.txt ├── Makefile └── make.bat ├── requirements.txt ├── .isort.cfg ├── test.sh ├── license_header.txt ├── .flake8 ├── CHANGELOG.md ├── .github └── workflows │ ├── license-checker.yml │ ├── lint.yml │ ├── test_on_ray1.13.0.yml │ ├── building-wheels.yml │ ├── unit_tests_for_protobuf_matrix.yml │ ├── pypi-nightly.yml │ └── unit_tests_on_ray_matrix.yml ├── .readthedocs.yaml ├── benchmarks └── many_tiny_tasks_benchmark.py ├── .gitignore ├── setup.py ├── tool └── generate_tls_certs.py └── README.md /fed/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fed/tests/multi-jobs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | cryptography 2 | numpy -------------------------------------------------------------------------------- /fed/tests/client_mode_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fed/tests/serializations_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fed/tests/without_ray_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/advanced_topic/principle.md: -------------------------------------------------------------------------------- 1 | # The Principle of RayFed 2 | 3 | TBD -------------------------------------------------------------------------------- /docs/source/advanced_topic/architecture.md: -------------------------------------------------------------------------------- 1 | # The Architecture of RayFed 2 | 3 | TBD -------------------------------------------------------------------------------- /docs/images/morse-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/rayfed/HEAD/docs/images/morse-logo.png -------------------------------------------------------------------------------- /docs/images/secretflow-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/rayfed/HEAD/docs/images/secretflow-logo.png -------------------------------------------------------------------------------- /docs/enhancements/images/dead_lock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/rayfed/HEAD/docs/enhancements/images/dead_lock.png -------------------------------------------------------------------------------- /docs/enhancements/images/local_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/rayfed/HEAD/docs/enhancements/images/local_error.png -------------------------------------------------------------------------------- /docs/enhancements/images/threading_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/rayfed/HEAD/docs/enhancements/images/threading_model.png -------------------------------------------------------------------------------- /docs/enhancements/images/cross_silo_error_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/rayfed/HEAD/docs/enhancements/images/cross_silo_error_flow.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ## setup.py read_requirements 2 | ray>=1.13.0 3 | cloudpickle 4 | pickle5==0.0.11; python_version < '3.8' 5 | protobuf 6 | grpcio>=1.42.0 7 | -------------------------------------------------------------------------------- /docs/source/getting_started/quick_start.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | 3 | TBD 4 | 5 | ## More Examples 6 | 7 | Check out [tutorials](../tutorials/index.rst) to get more examples. -------------------------------------------------------------------------------- /docs/source/getting_started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Install it from pypi. 4 | 5 | ```shell 6 | pip install -U rayfed 7 | ``` 8 | 9 | Install the nightly released version from pypi. 10 | 11 | ```shell 12 | pip install -U rayfed-nightly 13 | -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. _tutorials: 2 | 3 | Tutorials 4 | =============== 5 | 6 | We hope you enjoy these tutorials from RayFed developers. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | federated_learning_demo 12 | split_learning_demo 13 | transition_from_ray_to_rayfed -------------------------------------------------------------------------------- /docs/source/tutorials/split_learning_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Split Learning on RayFed" 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "nbformat": 4, 17 | "nbformat_minor": 2 18 | } 19 | -------------------------------------------------------------------------------- /docs/doc-requirements.txt: -------------------------------------------------------------------------------- 1 | # The following dependencies are required in RayFed 2 | cloudpickle 3 | pickle5==0.0.11; python_version < '3.8' 4 | protobuf>=3.9.2,<3.20 5 | rayfed-nightly 6 | myst-parser==0.18.1 7 | nbsphinx==0.8.9 8 | sphinx==5.3.0 9 | 10 | # The following dependencies are required for doc-building only. 11 | jinja2<3.1.0 12 | pydata_sphinx_theme 13 | -------------------------------------------------------------------------------- /docs/source/tutorials/federated_learning_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Federated Learning on RayFed" 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "nbformat": 4, 17 | "nbformat_minor": 2 18 | } 19 | -------------------------------------------------------------------------------- /docs/source/tutorials/transition_from_ray_to_rayfed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Seamless Federated, Transitioning from Ray to RayFed." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "nbformat": 4, 17 | "nbformat_minor": 2 18 | } 19 | -------------------------------------------------------------------------------- /docs/source/getting_started/index.rst: -------------------------------------------------------------------------------- 1 | .. _getting_started: 2 | 3 | 4 | Getting Started 5 | ================= 6 | 7 | Please follow the `installation `_ to get RayFed ready. 8 | 9 | Then, we encourage you to check `quick_start `_ and :ref:`tutorials` to play with RayFed. 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | 14 | installation 15 | quick_start -------------------------------------------------------------------------------- /docs/source/advanced_topic/index.rst: -------------------------------------------------------------------------------- 1 | .. _advanced_topic: 2 | 3 | 4 | Advanced Topic 5 | ================= 6 | 7 | The architecture of RayFed is illustrated in the `architecture `_ . 8 | For a deeper understanding of the principles behind RayFed, refer to the `principle `_ . 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | 14 | architecture 15 | principle -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | # This is to make isort compatible with Black. See 3 | # https://black.readthedocs.io/en/stable/the_black_code_style.html#how-black-wraps-lines. 4 | line_length=88 5 | profile=black 6 | multi_line_output=3 7 | include_trailing_comma=True 8 | use_parentheses=True 9 | float_to_top=True 10 | filter_files=True 11 | 12 | known_local_folder=fed 13 | sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | =========== 3 | 4 | fed.api module 5 | -------------- 6 | 7 | .. automodule:: fed 8 | :members: init, remote, get, shutdown, kill 9 | 10 | fed.config module 11 | ----------------- 12 | 13 | .. automodule:: fed.config 14 | :members: ClusterConfig, CrossSiloMessageConfig, GrpcCrossSiloMessageConfig 15 | 16 | .. Module contents 17 | .. --------------- 18 | 19 | .. .. automodule:: fed 20 | .. :members: 21 | -------------------------------------------------------------------------------- /fed/grpc/fed.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option cc_generic_services = true; 4 | 5 | service GrpcService { 6 | rpc SendData (SendDataRequest) returns (SendDataResponse) {} 7 | } 8 | 9 | message SendDataRequest { 10 | bytes data = 1; 11 | string upstream_seq_id = 2; 12 | string downstream_seq_id = 3; 13 | string job_name = 4; 14 | }; 15 | 16 | message SendDataResponse { 17 | int32 code = 1; 18 | string result = 2; 19 | }; 20 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | # All tests should be ran with TLS enabled for Ray cluster. 7 | python tool/generate_tls_certs.py 8 | export RAY_USE_TLS=1 9 | export RAY_TLS_SERVER_CERT="/tmp/rayfed/test-certs/server.crt" 10 | export RAY_TLS_SERVER_KEY="/tmp/rayfed/test-certs/server.key" 11 | export RAY_TLS_CA_CERT="/tmp/rayfed/test-certs/server.crt" 12 | 13 | directory="fed/tests" 14 | 15 | find "$directory" -type f -name "test_*.py" -exec pytest -vs {} \; 16 | 17 | echo "All tests finished." 18 | -------------------------------------------------------------------------------- /fed/grpc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /license_header.txt: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fed/_private/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fed/grpc/pb3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fed/grpc/pb4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /fed/proxy/grpc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | fed_pb2_grpc.py 4 | fed_pb2.py 5 | build 6 | py3 7 | max-line-length = 88 8 | inline-quotes = " 9 | ignore = 10 | C408 11 | C417 12 | E121 13 | E123 14 | E126 15 | E203 16 | E226 17 | E24 18 | E704 19 | W503 20 | W504 21 | W605 22 | I 23 | N 24 | B001 25 | B002 26 | B003 27 | B004 28 | B005 29 | B007 30 | B008 31 | B009 32 | B010 33 | B011 34 | B012 35 | B013 36 | B014 37 | B015 38 | B016 39 | B017 40 | avoid-escape = no 41 | per-file-ignores = 42 | fed/_private/fed_call_holder.py:E402 43 | fed/proxy/barriers.py:F401 44 | tests/test_transport_proxy.py:F401 -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## Types of changes 8 | `Added ` for new features. 9 | `Changed` for changes in existing functionality. 10 | `Deprecated` for soon-to-be removed features. 11 | `Removed` for now removed features. 12 | `Fixed` for any bug fixes. 13 | `Security` in case of vulnerabilities. 14 | 15 | ## [0.1.0] - 2024-01-09 16 | ### Added 17 | - TBD 18 | 19 | ### Changed 20 | - TBD 21 | 22 | ### Fixed 23 | - TBD 24 | -------------------------------------------------------------------------------- /.github/workflows/license-checker.yml: -------------------------------------------------------------------------------- 1 | name: license-checker 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | timeout-minutes: 20 # Lint should be done in 10 minutes. 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out code 15 | uses: actions/checkout@v3 16 | - name: Install license-header-checker 17 | run: curl -s https://raw.githubusercontent.com/lluissm/license-header-checker/master/install.sh | bash 18 | - name: Run license check 19 | run: ./bin/license-header-checker -v -r -i docs ./license_header.txt . py && [[ -z `git status -s` ]] 20 | -------------------------------------------------------------------------------- /fed/proxy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | __all__ = [ 17 | "barriers", 18 | ] 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/doc-requirements.txt 30 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | 3 | # -- Project information 4 | 5 | project = 'RayFed' 6 | copyright = '2022, The RayFed Team' 7 | author = 'The RayFed Authors' 8 | 9 | release = '0.1' 10 | version = '0.1.0' 11 | 12 | # -- General configuration 13 | 14 | extensions = [ 15 | 'sphinx.ext.duration', 16 | 'sphinx.ext.doctest', 17 | 'sphinx.ext.autodoc', 18 | 'sphinx.ext.autosummary', 19 | 'sphinx.ext.intersphinx', 20 | 'sphinx.ext.viewcode', 21 | 'sphinx.ext.extlinks', 22 | 'myst_parser', 23 | 'nbsphinx', 24 | ] 25 | 26 | intersphinx_mapping = { 27 | 'python': ('https://docs.python.org/3/', None), 28 | 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), 29 | } 30 | intersphinx_disabled_domains = ['std'] 31 | 32 | templates_path = ['_templates'] 33 | 34 | # -- Options for HTML output 35 | 36 | html_theme = 'pydata_sphinx_theme' 37 | 38 | # -- Options for EPUB output 39 | epub_show_urls = 'footnote' 40 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | timeout-minutes: 10 # Lint should be done in 10 minutes. 12 | runs-on: ubuntu-22.04 13 | container: docker.io/library/ubuntu:22.04 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Install bazel 19 | run: | 20 | apt-get update 21 | apt-get install -yq wget gcc g++ python3.10 zlib1g-dev zip libuv1.dev 22 | apt-get install -yq pip 23 | 24 | - name: Install dependencies 25 | run: | 26 | python3 -m pip install virtualenv 27 | python3 -m virtualenv -p python3 py3 28 | . py3/bin/activate 29 | which python 30 | pip install ray==2.0.0 31 | pip install black==23.1 32 | 33 | - name: Lint 34 | run: | 35 | . py3/bin/activate 36 | black -S --check --diff . --exclude='fed/grpc|py3' 37 | -------------------------------------------------------------------------------- /fed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from fed.api import get, init, kill, remote, shutdown 16 | from fed.exceptions import FedRemoteError 17 | from fed.fed_object import FedObject 18 | from fed.proxy.barriers import recv, send 19 | 20 | __all__ = [ 21 | "get", 22 | "init", 23 | "kill", 24 | "remote", 25 | "shutdown", 26 | "recv", 27 | "send", 28 | "FedObject", 29 | "FedRemoteError", 30 | ] 31 | -------------------------------------------------------------------------------- /fed/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class FedRemoteError(Exception): 17 | def __init__(self, src_party: str, cause: Exception) -> None: 18 | self._src_party = src_party 19 | self._cause = cause 20 | 21 | def __str__(self): 22 | error_msg = f'FedRemoteError occurred at {self._src_party}' 23 | if self._cause is not None: 24 | error_msg += f" caused by {str(self._cause)}" 25 | return error_msg 26 | -------------------------------------------------------------------------------- /.github/workflows/test_on_ray1.13.0.yml: -------------------------------------------------------------------------------- 1 | name: test on ray1.13.0 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | run-unit-tests: 11 | timeout-minutes: 60 12 | runs-on: ubuntu-22.04 13 | container: docker.io/library/ubuntu:22.04 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Install bazel 19 | run: | 20 | apt-get update 21 | apt-get install -yq wget gcc g++ python3.10 zlib1g-dev zip libuv1.dev 22 | apt-get install -yq pip 23 | 24 | - name: Install dependencies 25 | run: | 26 | python3 -m pip install virtualenv 27 | python3 -m virtualenv -p python3 py3 28 | . py3/bin/activate 29 | which python 30 | # Revert setuptools for compatibility with ray dashboard. 31 | pip install setuptools==69.5.1 32 | pip install pytest torch cloudpickle cryptography 33 | pip install ray==1.13.0 34 | 35 | - name: Build and test 36 | run: | 37 | . py3/bin/activate 38 | pip install -e . -v 39 | sh test.sh 40 | -------------------------------------------------------------------------------- /.github/workflows/building-wheels.yml: -------------------------------------------------------------------------------- 1 | name: build-wheels 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build-wheels: 11 | timeout-minutes: 60 12 | runs-on: ubuntu-22.04 13 | container: docker.io/library/ubuntu:22.04 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Install bazel 19 | run: | 20 | apt-get update 21 | apt-get install -yq wget gcc g++ python3.10 zlib1g-dev zip libuv1.dev 22 | apt-get install -yq pip 23 | 24 | - name: Install dependencies 25 | run: | 26 | python3 -m pip install virtualenv 27 | python3 -m virtualenv -p python3 py3 28 | . py3/bin/activate 29 | which python 30 | # Revert setuptools for compatibility with ray dashboard. 31 | pip install setuptools==69.5.1 32 | pip install pytest torch cloudpickle cryptography 33 | pip install ray==2.0.0 34 | 35 | - name: Install and export wheels 36 | run: | 37 | . py3/bin/activate 38 | python3 setup.py sdist bdist_wheel -d dist 39 | 40 | - name: Archive rayfed-wheel 41 | uses: actions/upload-artifact@v1 42 | with: 43 | name: rayfed_python39_wheel_on_ubuntu 44 | path: dist/ 45 | -------------------------------------------------------------------------------- /.github/workflows/unit_tests_for_protobuf_matrix.yml: -------------------------------------------------------------------------------- 1 | name: test for protobuf 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | run-unit-tests-on-for-protobuf: 11 | strategy: 12 | matrix: 13 | protobuf_ver: ["3.19", "3.20", "4.23"] 14 | 15 | timeout-minutes: 60 16 | runs-on: ubuntu-22.04 17 | container: docker.io/library/ubuntu:22.04 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | 22 | - name: Install basic dependencies 23 | run: | 24 | apt-get update 25 | apt-get install -yq wget gcc g++ python3.10 zlib1g-dev zip libuv1.dev 26 | apt-get install -yq pip 27 | 28 | - name: Install python dependencies 29 | run: | 30 | python3 -m pip install virtualenv 31 | python3 -m virtualenv -p python3 py3 32 | . py3/bin/activate 33 | which python 34 | # Revert setuptools for compatibility with ray dashboard. 35 | pip install setuptools==69.5.1 36 | pip install pytest torch cloudpickle cryptography numpy 37 | pip install protobuf==${{ matrix.protobuf_ver }} 38 | pip install ray==2.4.0 39 | 40 | - name: Build and test 41 | run: | 42 | . py3/bin/activate 43 | pip install -e . -v 44 | sh test.sh -------------------------------------------------------------------------------- /.github/workflows/pypi-nightly.yml: -------------------------------------------------------------------------------- 1 | name: RayFed PyPi Nightly 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | # can manually trigger the workflow 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build-and-publish: 11 | # do not run in forks 12 | if: ${{ github.repository_owner == 'ray-project' }} 13 | name: build wheel and upload 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: 3.10 22 | 23 | - name: days since the commit date 24 | run: | 25 | : 26 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA) 27 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 )) 28 | if [ $days -eq 0 ]; then 29 | echo COMMIT_TODAY=true >> $GITHUB_ENV 30 | fi 31 | 32 | - name: Build wheel 33 | if: env.COMMIT_TODAY == 'true' 34 | env: 35 | RAYFED_BUILD_MODE: nightly 36 | run: | 37 | pip install pytest torch cloudpickle cryptography wheel 38 | pip install ray==2.0.0 39 | python3 setup.py bdist_wheel 40 | 41 | - name: Upload 42 | if: env.COMMIT_TODAY == 'true' 43 | uses: pypa/gh-action-pypi-publish@release/v1 44 | with: 45 | password: ${{ secrets.PYPI_API_TOKEN }} 46 | -------------------------------------------------------------------------------- /.github/workflows/unit_tests_on_ray_matrix.yml: -------------------------------------------------------------------------------- 1 | name: test on many rays 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | run-unit-tests-on-many-rays: 11 | strategy: 12 | matrix: 13 | os: [ubuntu-latest] 14 | ray_version: [2.31.0, 2.32.0, 2.33.0, 2.34.0, 2.35.0] 15 | 16 | timeout-minutes: 60 17 | runs-on: ubuntu-22.04 18 | container: docker.io/library/ubuntu:22.04 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Install basic dependencies 24 | run: | 25 | apt-get update 26 | apt-get install -yq wget gcc g++ python3.10 zlib1g-dev zip libuv1.dev 27 | apt-get install -yq pip 28 | 29 | - name: Install python dependencies 30 | run: | 31 | python3 -m pip install virtualenv 32 | python3 -m virtualenv -p python3 py3 33 | . py3/bin/activate 34 | which python 35 | # Revert setuptools for compatibility with ray dashboard. 36 | pip install setuptools==69.5.1 37 | pip install pytest 38 | pip install -r dev-requirements.txt 39 | pip install ray==${{ matrix.ray_version }} 40 | grep -ivE "ray" requirements.txt > temp_requirement.txt 41 | pip install -r temp_requirement.txt 42 | 43 | - name: Build and test 44 | run: | 45 | . py3/bin/activate 46 | pip install -e . -v 47 | sh test.sh 48 | -------------------------------------------------------------------------------- /fed/_private/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | KEY_OF_CLUSTER_CONFIG = "CLUSTER_CONFIG" 17 | 18 | KEY_OF_JOB_CONFIG = "JOB_CONFIG" 19 | 20 | KEY_OF_GRPC_METADATA = "GRPC_METADATA" 21 | 22 | KEY_OF_CLUSTER_ADDRESSES = "CLUSTER_ADDRESSES" 23 | 24 | KEY_OF_CURRENT_PARTY_NAME = "CURRENT_PARTY_NAME" 25 | 26 | KEY_OF_TLS_CONFIG = "TLS_CONFIG" 27 | 28 | KEY_OF_CROSS_SILO_COMM_CONFIG_DICT = "CROSS_SILO_COMM_CONFIG_DICT" 29 | 30 | RAYFED_LOG_FMT = "%(asctime)s.%(msecs)03d %(levelname)s %(filename)s:%(lineno)s [%(party)s] -- [%(jobname)s] %(message)s" # noqa 31 | 32 | RAYFED_DATE_FMT = "%Y-%m-%d %H:%M:%S" 33 | 34 | RAY_VERSION_2_0_0_STR = "2.0.0" 35 | 36 | RAYFED_DEFAULT_JOB_NAME = "Anonymous_job" 37 | 38 | RAYFED_JOB_KV_DATA_KEY_FMT = "RAYFED#{}#{}" 39 | 40 | RAYFED_DEFAULT_SENDER_PROXY_ACTOR_NAME = "SenderProxyActor" 41 | 42 | RAYFED_DEFAULT_RECEIVER_PROXY_ACTOR_NAME = "ReceiverProxyActor" 43 | 44 | RAYFED_DEFAULT_SENDER_RECEIVER_PROXY_ACTOR_NAME = "SenderReceiverProxyActor" 45 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to RayFed documentation! 2 | =================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :titlesonly: 7 | :hidden: 8 | 9 | getting_started/index 10 | tutorials/index 11 | api 12 | advanced_topic/index 13 | 14 | 15 | **RayFed** is a multiple parties joint, distributed execution engine based on Ray, 16 | to help build your own federated learning frameworks **in minutes**. 17 | 18 | RayFed has the following highlight features: 19 | 20 | 1. Ray Native Programming Pattern 21 | 2. Multiple Controller Execution Mode 22 | 3. Very Restricted and Clear Data Perimeters 23 | 4. Very Large Scale Federated Computing and Training 24 | 25 | Check out the :ref:`getting_started` section for further information, including 26 | how to install the project. 27 | 28 | .. note:: 29 | 30 | This project is under active development. 31 | 32 | Why RayFed 33 | ============== 34 | TBD 35 | 36 | Getting Started 37 | ================= 38 | 39 | Please check :ref:`getting_started` for installation and a quick start guide. 40 | 41 | 42 | Tutorials 43 | ============== 44 | - :doc:`tutorials/federated_learning_demo` 45 | - :doc:`tutorials/split_learning_demo` 46 | - :doc:`tutorials/transition_from_ray_to_rayfed` 47 | 48 | 49 | API 50 | ============== 51 | Check out `api `_ to write your own RayFed programs. 52 | 53 | 54 | Architecture and principle of RayFed 55 | ====================================== 56 | 57 | To gain a comprehensive understanding of RayFed's architecture and principles, 58 | we highly recommend reading :ref:`advanced_topic` 59 | 60 | -------------------------------------------------------------------------------- /fed/tests/without_ray_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | import fed 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "input_address, is_valid_address", 22 | [ 23 | ("192.168.0.1:8080", True), 24 | ("sa127032as:80", True), 25 | ("https://www.example.com", True), 26 | ("http://www.example.com", True), 27 | ("local", True), 28 | ("localhost", True), 29 | (None, False), 30 | ("invalid_string", False), 31 | ("http", False), 32 | ("example.com", False), 33 | ], 34 | ) 35 | def test_validate_address(input_address, is_valid_address): 36 | if is_valid_address: 37 | fed.utils.validate_address(input_address) 38 | else: 39 | try: 40 | fed.utils.validate_address(input_address) 41 | assert False 42 | except Exception as e: 43 | assert isinstance(e, ValueError) 44 | 45 | 46 | if __name__ == "__main__": 47 | import sys 48 | 49 | sys.exit(pytest.main(["-sv", __file__])) 50 | -------------------------------------------------------------------------------- /fed/tests/test_internal_kv.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import time 3 | 4 | import pytest 5 | import ray 6 | import ray.experimental.internal_kv as ray_internal_kv 7 | 8 | import fed 9 | import fed._private.compatible_utils as compatible_utils 10 | 11 | 12 | def run(party): 13 | compatible_utils.init_ray("local") 14 | addresses = { 15 | 'alice': '127.0.0.1:12012', 16 | 'bob': '127.0.0.1:12011', 17 | } 18 | assert compatible_utils.kv is None 19 | fed.init(addresses=addresses, party=party, job_name="test_job_name") 20 | assert compatible_utils.kv 21 | assert not compatible_utils.kv.put("test_key", b"test_val") 22 | assert compatible_utils.kv.get("test_key") == b"test_val" 23 | 24 | # Test that a prefix key name is added under the hood. 25 | assert ray_internal_kv._internal_kv_get(b"test_key") is None 26 | assert ( 27 | ray_internal_kv._internal_kv_get(b"RAYFED#test_job_name#test_key") 28 | == b"test_val" 29 | ) 30 | 31 | time.sleep(5) 32 | fed.shutdown() 33 | 34 | assert compatible_utils.kv is None 35 | with pytest.raises(ValueError): 36 | # Make sure the kv actor is non-exist no matter whether it's in client mode 37 | ray.get_actor("_INTERNAL_KV_ACTOR") 38 | ray.shutdown() 39 | 40 | 41 | def test_kv_init(): 42 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 43 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 44 | p_alice.start() 45 | p_bob.start() 46 | p_alice.join() 47 | p_bob.join() 48 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 49 | 50 | 51 | if __name__ == "__main__": 52 | import sys 53 | 54 | sys.exit(pytest.main(["-sv", __file__])) 55 | -------------------------------------------------------------------------------- /fed/tests/multi-jobs/test_multi_proxy_actor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import multiprocessing 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | from fed.proxy.barriers import receiver_proxy_actor_name, sender_proxy_actor_name 22 | from fed.proxy.grpc.grpc_proxy import GrpcSenderProxy 23 | 24 | 25 | def run(): 26 | job_name = 'job_test' 27 | ray.init(address='local', include_dashboard=False) 28 | fed.init( 29 | addresses={ 30 | 'alice': '127.0.0.1:11012', 31 | }, 32 | party='alice', 33 | job_name=job_name, 34 | sender_proxy_cls=GrpcSenderProxy, 35 | config={ 36 | 'cross_silo_comm': { 37 | 'exit_on_sending_failure': True, 38 | # Create unique proxy for current job 39 | 'use_global_proxy': False, 40 | } 41 | }, 42 | ) 43 | 44 | assert ray.get_actor(sender_proxy_actor_name()) 45 | assert ray.get_actor(receiver_proxy_actor_name()) 46 | 47 | fed.shutdown() 48 | ray.shutdown() 49 | 50 | 51 | def test_multi_proxy_actor(): 52 | p_alice = multiprocessing.Process(target=run) 53 | p_alice.start() 54 | p_alice.join() 55 | assert p_alice.exitcode == 0 56 | 57 | 58 | if __name__ == "__main__": 59 | import sys 60 | 61 | sys.exit(pytest.main(["-sv", __file__])) 62 | -------------------------------------------------------------------------------- /fed/tests/test_options.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | class Foo: 26 | def run(self): 27 | return 2, 3 28 | 29 | 30 | @fed.remote 31 | def bar(x): 32 | return x / 2, x * 2 33 | 34 | 35 | def run(party): 36 | compatible_utils.init_ray(address='local') 37 | addresses = { 38 | 'alice': '127.0.0.1:11012', 39 | 'bob': '127.0.0.1:11011', 40 | } 41 | fed.init(addresses=addresses, party=party) 42 | 43 | foo = Foo.party("alice").remote() 44 | a, b = fed.get(foo.run.options(num_returns=2).remote()) 45 | c, d = fed.get(bar.party("bob").options(num_returns=2).remote(2)) 46 | 47 | assert a == 2 and b == 3 48 | assert c == 1 and d == 4 49 | 50 | fed.shutdown() 51 | ray.shutdown() 52 | 53 | 54 | def test_fed_get_in_2_parties(): 55 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 56 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 57 | p_alice.start() 58 | p_bob.start() 59 | p_alice.join() 60 | p_bob.join() 61 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 62 | 63 | 64 | if __name__ == "__main__": 65 | import sys 66 | 67 | sys.exit(pytest.main(["-sv", __file__])) 68 | -------------------------------------------------------------------------------- /benchmarks/many_tiny_tasks_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import time 17 | 18 | import ray 19 | 20 | import fed 21 | 22 | 23 | @fed.remote 24 | class MyActor: 25 | def run(self): 26 | return None 27 | 28 | 29 | @fed.remote 30 | class Aggregator: 31 | def aggr(self, val1, val2): 32 | return None 33 | 34 | 35 | def main(party): 36 | ray.init(address='local', include_dashboard=False) 37 | 38 | addresses = { 39 | 'alice': '127.0.0.1:11010', 40 | 'bob': '127.0.0.1:11011', 41 | } 42 | fed.init(addresses=addresses, party=party) 43 | 44 | actor_alice = MyActor.party("alice").remote() 45 | actor_bob = MyActor.party("bob").remote() 46 | aggregator = Aggregator.party("alice").remote() 47 | 48 | start = time.time() 49 | num_calls = 10000 50 | for i in range(num_calls): 51 | val_alice = actor_alice.run.remote() 52 | val_bob = actor_bob.run.remote() 53 | sum_val_obj = aggregator.aggr.remote(val_alice, val_bob) 54 | fed.get(sum_val_obj) 55 | if i % 100 == 0: 56 | print(f"Running {i}th call") 57 | print(f"num calls: {num_calls}") 58 | print("total time (ms) = ", (time.time() - start) * 1000) 59 | print("per task overhead (ms) =", (time.time() - start) * 1000 / num_calls) 60 | 61 | fed.shutdown() 62 | ray.shutdown() 63 | 64 | 65 | if __name__ == "__main__": 66 | assert len(sys.argv) == 2, 'Please run this script with party.' 67 | main(sys.argv[1]) 68 | -------------------------------------------------------------------------------- /fed/tests/test_api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | import fed.config as fed_config 23 | 24 | 25 | def run(): 26 | compatible_utils.init_ray(address='local') 27 | addresses = { 28 | 'alice': '127.0.0.1:11012', 29 | } 30 | fed.init(addresses=addresses, party="alice") 31 | config = fed_config.get_cluster_config() 32 | assert config.cluster_addresses == addresses 33 | assert config.current_party == "alice" 34 | fed.shutdown() 35 | ray.shutdown() 36 | 37 | 38 | def test_fed_apis(): 39 | p_alice = multiprocessing.Process(target=run) 40 | p_alice.start() 41 | p_alice.join() 42 | assert p_alice.exitcode == 0 43 | 44 | 45 | def _run(): 46 | compatible_utils.init_ray(address='local') 47 | addresses = { 48 | 'alice': '127.0.0.1:11012', 49 | } 50 | fed.init(addresses=addresses, party="alice") 51 | 52 | @fed.remote 53 | class MyActor: 54 | pass 55 | 56 | with pytest.raises(ValueError): 57 | MyActor.remote() 58 | 59 | fed.shutdown() 60 | ray.shutdown() 61 | 62 | 63 | def test_miss_party_name_on_actor(): 64 | p_alice = multiprocessing.Process(target=_run) 65 | p_alice.start() 66 | p_alice.join() 67 | assert p_alice.exitcode == 0 68 | 69 | 70 | if __name__ == "__main__": 71 | import sys 72 | 73 | sys.exit(pytest.main(["-sv", __file__])) 74 | -------------------------------------------------------------------------------- /fed/tests/test_pass_fed_objects_in_containers_in_normal_tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | def foo(i: int): 26 | return f"foo-{i}" 27 | 28 | 29 | @fed.remote 30 | def bar(li): 31 | assert li[0] == "hello" 32 | li1 = li[1] 33 | li2 = li[2] 34 | assert fed.get(li1[0]) == "foo-0" 35 | assert li2[0] == "world" 36 | assert fed.get(li2[1][0]) == "foo-1" 37 | return True 38 | 39 | 40 | def run(party): 41 | compatible_utils.init_ray(address='local') 42 | addresses = { 43 | 'alice': '127.0.0.1:11012', 44 | 'bob': '127.0.0.1:11011', 45 | } 46 | fed.init(addresses=addresses, party=party) 47 | o1 = foo.party("alice").remote(0) 48 | o2 = foo.party("bob").remote(1) 49 | li = ["hello", [o1], ["world", [o2]]] 50 | o3 = bar.party("bob").remote(li) 51 | 52 | result = fed.get(o3) 53 | assert result 54 | fed.shutdown() 55 | ray.shutdown() 56 | 57 | 58 | def test_pass_fed_objects_in_list(): 59 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 60 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 61 | p_alice.start() 62 | p_bob.start() 63 | p_alice.join() 64 | p_bob.join() 65 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 66 | 67 | 68 | if __name__ == "__main__": 69 | import sys 70 | 71 | sys.exit(pytest.main(["-sv", __file__])) 72 | -------------------------------------------------------------------------------- /fed/tests/test_listening_address.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | def _run(party): 25 | import socket 26 | 27 | compatible_utils.init_ray(address='local') 28 | occupied_port = 11020 29 | # NOTE(NKcqx): Firstly try to bind IPv6 because the grpc server will do so. 30 | # Otherwise this UT will fail because socket bind $occupied_port 31 | # on IPv4 address while grpc server listened on the Ipv6 address. 32 | s_ipv6 = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) 33 | s_ipv6.bind(("::1", occupied_port)) 34 | s_ipv4 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 35 | s_ipv4.bind(("127.0.0.1", occupied_port)) 36 | import time 37 | 38 | time.sleep(5) 39 | 40 | addresses = {'alice': f'127.0.0.1:{occupied_port}'} 41 | 42 | # Starting grpc server on an used port will cause AssertionError 43 | with pytest.raises(AssertionError): 44 | fed.init( 45 | addresses=addresses, 46 | party=party, 47 | ) 48 | 49 | s_ipv6.close() 50 | s_ipv4.close() 51 | fed.shutdown() 52 | ray.shutdown() 53 | 54 | 55 | def test_listen_used_address(): 56 | p_alice = multiprocessing.Process(target=_run, args=('alice',)) 57 | p_alice.start() 58 | p_alice.join() 59 | assert p_alice.exitcode == 0 60 | 61 | 62 | if __name__ == "__main__": 63 | import sys 64 | 65 | sys.exit(pytest.main(["-sv", __file__])) 66 | -------------------------------------------------------------------------------- /fed/tests/test_pass_fed_objects_in_containers_in_actor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | class My: 26 | def foo(self, i: int): 27 | return f"foo-{i}" 28 | 29 | def bar(self, li): 30 | assert li[0] == "hello" 31 | li1 = li[1] 32 | li2 = li[2] 33 | assert fed.get(li1[0]) == "foo-0" 34 | assert li2[0] == "world" 35 | 36 | assert fed.get(li2[1][0]) == "foo-1" 37 | return True 38 | 39 | 40 | addresses = { 41 | 'alice': '127.0.0.1:11012', 42 | 'bob': '127.0.0.1:11011', 43 | } 44 | 45 | 46 | def run(party): 47 | compatible_utils.init_ray(address='local') 48 | fed.init(addresses=addresses, party=party) 49 | my1 = My.party("alice").remote() 50 | my2 = My.party("bob").remote() 51 | o1 = my1.foo.remote(0) 52 | o2 = my2.foo.remote(1) 53 | li = ["hello", [o1], ["world", [o2]]] 54 | o3 = my2.bar.remote(li) 55 | 56 | result = fed.get(o3) 57 | assert result 58 | fed.shutdown() 59 | ray.shutdown() 60 | 61 | 62 | def test_pass_fed_objects_in_list(): 63 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 64 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 65 | p_alice.start() 66 | p_bob.start() 67 | p_alice.join() 68 | p_bob.join() 69 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 70 | 71 | 72 | if __name__ == "__main__": 73 | import sys 74 | 75 | sys.exit(pytest.main(["-sv", __file__])) 76 | -------------------------------------------------------------------------------- /fed/tests/test_async_startup_2_clusters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | class My: 26 | def __init__(self) -> None: 27 | self._val = 0 28 | 29 | def incr(self, delta): 30 | self._val += delta 31 | return self._val 32 | 33 | 34 | @fed.remote 35 | def add(x, y): 36 | return x + y 37 | 38 | 39 | def _run(party: str): 40 | if party == "alice": 41 | import time 42 | 43 | time.sleep(10) 44 | 45 | compatible_utils.init_ray(address='local') 46 | addresses = { 47 | 'alice': '127.0.0.1:11012', 48 | 'bob': '127.0.0.1:11011', 49 | } 50 | fed.init(addresses=addresses, party=party) 51 | 52 | my1 = My.party("alice").remote() 53 | my2 = My.party("bob").remote() 54 | x = my1.incr.remote(10) 55 | y = my2.incr.remote(20) 56 | o = add.party("alice").remote(x, y) 57 | assert 30 == fed.get(o) 58 | fed.shutdown() 59 | ray.shutdown() 60 | 61 | 62 | # This case is used to test that we start 2 clusters not at the same time. 63 | def test_async_startup_2_clusters(): 64 | p_alice = multiprocessing.Process(target=_run, args=('alice',)) 65 | p_bob = multiprocessing.Process(target=_run, args=('bob',)) 66 | p_alice.start() 67 | p_bob.start() 68 | p_alice.join() 69 | p_bob.join() 70 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 71 | 72 | 73 | if __name__ == "__main__": 74 | import sys 75 | 76 | sys.exit(pytest.main(["-sv", __file__])) 77 | -------------------------------------------------------------------------------- /fed/tests/test_repeat_init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import multiprocessing 17 | 18 | import pytest 19 | import ray 20 | 21 | import fed 22 | import fed._private.compatible_utils as compatible_utils 23 | 24 | 25 | @fed.remote 26 | class My: 27 | def foo(self, i: int): 28 | return f"foo-{i}" 29 | 30 | def bar(self, li): 31 | assert li[0] == "hello" 32 | li1 = li[1] 33 | li2 = li[2] 34 | assert fed.get(li1[0]) == "foo-0" 35 | assert li2[0] == "world" 36 | 37 | assert fed.get(li2[1][0]) == "foo-1" 38 | return True 39 | 40 | 41 | addresses = { 42 | 'alice': '127.0.0.1:11012', 43 | 'bob': '127.0.0.1:11011', 44 | } 45 | 46 | 47 | def run(party): 48 | def _run(): 49 | compatible_utils.init_ray(address='local') 50 | fed.init(addresses=addresses, party=party) 51 | 52 | my1 = My.party("alice").remote() 53 | my2 = My.party("bob").remote() 54 | o1 = my1.foo.remote(0) 55 | o2 = my2.foo.remote(1) 56 | li = ["hello", [o1], ["world", [o2]]] 57 | o3 = my2.bar.remote(li) 58 | 59 | result = fed.get(o3) 60 | assert result 61 | 62 | fed.shutdown() 63 | ray.shutdown() 64 | 65 | _run() 66 | _run() 67 | 68 | 69 | def test_pass_fed_objects_in_list(): 70 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 71 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 72 | p_alice.start() 73 | p_bob.start() 74 | p_alice.join() 75 | p_bob.join() 76 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 77 | 78 | 79 | if __name__ == "__main__": 80 | import sys 81 | 82 | sys.exit(pytest.main(["-sv", __file__])) 83 | -------------------------------------------------------------------------------- /fed/grpc/pb4/fed_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # -*- coding: utf-8 -*- 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # source: fed.proto 18 | """Generated protocol buffer code.""" 19 | from google.protobuf import descriptor as _descriptor 20 | from google.protobuf import descriptor_pool as _descriptor_pool 21 | from google.protobuf import symbol_database as _symbol_database 22 | from google.protobuf.internal import builder as _builder 23 | # @@protoc_insertion_point(imports) 24 | 25 | _sym_db = _symbol_database.Default() 26 | 27 | 28 | 29 | 30 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"0\n\x10SendDataResponse\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0e\n\x06result\x18\x02 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3') 31 | 32 | _globals = globals() 33 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 34 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fed_pb2', _globals) 35 | if _descriptor._USE_C_DESCRIPTORS == False: 36 | 37 | DESCRIPTOR._options = None 38 | DESCRIPTOR._serialized_options = b'\200\001\001' 39 | _globals['_SENDDATAREQUEST']._serialized_start=13 40 | _globals['_SENDDATAREQUEST']._serialized_end=114 41 | _globals['_SENDDATARESPONSE']._serialized_start=116 42 | _globals['_SENDDATARESPONSE']._serialized_end=164 43 | _globals['_GRPCSERVICE']._serialized_start=166 44 | _globals['_GRPCSERVICE']._serialized_end=230 45 | # @@protoc_insertion_point(module_scope) 46 | -------------------------------------------------------------------------------- /fed/tests/test_retry_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import multiprocessing 17 | from unittest import TestCase 18 | 19 | import pytest 20 | import ray 21 | 22 | import fed 23 | import fed._private.compatible_utils as compatible_utils 24 | from fed import config 25 | 26 | 27 | @fed.remote 28 | def f(): 29 | return 100 30 | 31 | 32 | @fed.remote 33 | class My: 34 | def __init__(self, value) -> None: 35 | self._value = value 36 | 37 | def get_value(self): 38 | return self._value 39 | 40 | 41 | def run(): 42 | compatible_utils.init_ray(address='local') 43 | addresses = { 44 | 'alice': '127.0.0.1:11012', 45 | 'bob': '127.0.0.1:11011', 46 | } 47 | retry_policy = { 48 | "maxAttempts": 4, 49 | "initialBackoff": "5s", 50 | "maxBackoff": "5s", 51 | "backoffMultiplier": 1, 52 | "retryableStatusCodes": ["UNAVAILABLE"], 53 | } 54 | test_job_name = 'test_retry_policy' 55 | fed.init( 56 | addresses=addresses, 57 | party='alice', 58 | config={ 59 | 'cross_silo_comm': { 60 | 'grpc_retry_policy': retry_policy, 61 | } 62 | }, 63 | ) 64 | 65 | job_config = config.get_job_config(test_job_name) 66 | cross_silo_comm_config = job_config.cross_silo_comm_config_dict 67 | TestCase().assertDictEqual( 68 | cross_silo_comm_config['grpc_retry_policy'], retry_policy 69 | ) 70 | 71 | fed.shutdown() 72 | ray.shutdown() 73 | 74 | 75 | def test_retry_policy(): 76 | p_alice = multiprocessing.Process(target=run) 77 | p_alice.start() 78 | p_alice.join() 79 | assert p_alice.exitcode == 0 80 | 81 | 82 | if __name__ == "__main__": 83 | import sys 84 | 85 | sys.exit(pytest.main(["-sv", __file__])) 86 | -------------------------------------------------------------------------------- /fed/tests/test_ping_others.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | import time 17 | 18 | import pytest 19 | import ray 20 | 21 | import fed 22 | import fed._private.compatible_utils as compatible_utils 23 | from fed.proxy.barriers import ping_others 24 | 25 | addresses = { 26 | 'alice': '127.0.0.1:11012', 27 | 'bob': '127.0.0.1:11011', 28 | } 29 | 30 | 31 | def test_ping_non_started_party(): 32 | def run(party): 33 | compatible_utils.init_ray(address='local') 34 | fed.init(addresses=addresses, party=party) 35 | if party == 'alice': 36 | with pytest.raises(RuntimeError): 37 | ping_others(addresses, party, 5) 38 | 39 | fed.shutdown() 40 | ray.shutdown() 41 | 42 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 43 | p_alice.start() 44 | p_alice.join() 45 | 46 | 47 | def test_ping_started_party(): 48 | def run(party): 49 | compatible_utils.init_ray(address='local') 50 | fed.init(addresses=addresses, party=party) 51 | if party == 'alice': 52 | ping_success = ping_others(addresses, party, 5) 53 | assert ping_success is True 54 | else: 55 | # Wait for alice to ping, otherwise, bob may 56 | # exit before alice when started first. 57 | time.sleep(10) 58 | fed.shutdown() 59 | ray.shutdown() 60 | 61 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 62 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 63 | p_alice.start() 64 | p_bob.start() 65 | p_alice.join() 66 | p_bob.join() 67 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 68 | 69 | 70 | if __name__ == "__main__": 71 | import sys 72 | 73 | sys.exit(pytest.main(["-sv", __file__])) 74 | -------------------------------------------------------------------------------- /fed/tests/test_enable_tls_across_parties.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | import os 17 | 18 | import pytest 19 | import ray 20 | 21 | import fed 22 | import fed._private.compatible_utils as compatible_utils 23 | 24 | 25 | @fed.remote 26 | class My: 27 | def __init__(self) -> None: 28 | self._val = 0 29 | 30 | def incr(self, delta): 31 | self._val += delta 32 | return self._val 33 | 34 | 35 | @fed.remote 36 | def add(x, y): 37 | return x + y 38 | 39 | 40 | def _run(party: str): 41 | compatible_utils.init_ray(address='local') 42 | cert_dir = os.path.join( 43 | os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" 44 | ) 45 | cert_config = { 46 | "ca_cert": os.path.join(cert_dir, "server.crt"), 47 | "cert": os.path.join(cert_dir, "server.crt"), 48 | "key": os.path.join(cert_dir, "server.key"), 49 | } 50 | 51 | addresses = { 52 | 'alice': '127.0.0.1:11012', 53 | 'bob': '127.0.0.1:11011', 54 | } 55 | fed.init(addresses=addresses, party=party, tls_config=cert_config) 56 | 57 | my1 = My.party("alice").remote() 58 | my2 = My.party("bob").remote() 59 | x = my1.incr.remote(10) 60 | y = my2.incr.remote(20) 61 | o = add.party("alice").remote(x, y) 62 | assert fed.get(o) == 30 63 | fed.shutdown() 64 | ray.shutdown() 65 | 66 | 67 | def test_enable_tls_across_parties(): 68 | p_alice = multiprocessing.Process(target=_run, args=('alice',)) 69 | p_bob = multiprocessing.Process(target=_run, args=('bob',)) 70 | p_alice.start() 71 | p_bob.start() 72 | p_alice.join() 73 | p_bob.join() 74 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 75 | 76 | 77 | if __name__ == "__main__": 78 | import sys 79 | 80 | sys.exit(pytest.main(["-sv", __file__])) 81 | -------------------------------------------------------------------------------- /fed/tests/test_cache_fed_objects.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | from fed.proxy.barriers import receiver_proxy_actor_name, sender_proxy_actor_name 23 | 24 | 25 | @fed.remote 26 | def f(): 27 | return "hello" 28 | 29 | 30 | @fed.remote 31 | def g(x, index): 32 | return x + str(index) 33 | 34 | 35 | def run(party): 36 | compatible_utils.init_ray(address='local') 37 | addresses = { 38 | 'alice': '127.0.0.1:11012', 39 | 'bob': '127.0.0.1:11011', 40 | } 41 | fed.init(addresses=addresses, party=party) 42 | 43 | o = f.party("alice").remote() 44 | o1 = g.party("bob").remote(o, 1) 45 | o2 = g.party("bob").remote(o, 2) 46 | 47 | a, b, c = fed.get([o, o1, o2]) 48 | assert a == "hello" 49 | assert b == "hello1" 50 | assert c == "hello2" 51 | 52 | if party == "bob": 53 | proxy_actor = ray.get_actor(receiver_proxy_actor_name()) 54 | stats = ray.get(proxy_actor._get_stats.remote()) 55 | assert stats["receive_op_count"] == 1 56 | if party == "alice": 57 | proxy_actor = ray.get_actor(sender_proxy_actor_name()) 58 | stats = ray.get(proxy_actor._get_stats.remote()) 59 | assert stats["send_op_count"] == 1 60 | fed.shutdown() 61 | ray.shutdown() 62 | 63 | 64 | def test_cache_fed_object_if_sent(): 65 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 66 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 67 | p_alice.start() 68 | p_bob.start() 69 | p_alice.join() 70 | p_bob.join() 71 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 72 | 73 | 74 | if __name__ == "__main__": 75 | import sys 76 | 77 | sys.exit(pytest.main(["-sv", __file__])) 78 | -------------------------------------------------------------------------------- /docs/enhancements/2023-09-01-cross-silo-error.md: -------------------------------------------------------------------------------- 1 | ## Summary 2 | ### Changelog 3 | 4 | #### [2023-09-01] 5 | - Fisrt Version 6 | 7 | #### [2024-01-03] 8 | - Update the threading model: main thread is signaled and exit only when exit_on_sending_failure==True. 9 | 10 | ### General Motivation 11 | Before this proposal, when the execution of a DAG encounters an error in 'alice', below is what will happen: 12 | ![image](./images/local_error.png) 13 | 14 | In alice, both main thread and data sending thread will raise the error, and the process will exit. 15 | In bob, since it needs the input from 'alice', it waits for 'alice' forever no matter whether 'alice' exists or not. 16 | 17 | Therefore, we need a mechanism to inform the other participant(s) when the DAG execution raises error. 18 | 19 | ## Design and Architecture 20 | The below graph shows what will happen now after this proposal: 21 | ![image](./images/cross_silo_error_flow.png) 22 | 23 | In alice, when the data-sending thread finds a RayTaskError indicating an execution failure, it will wrap it as a `FedRemoteError` object and replace the original data object in place to send to bob. 24 | In bob, the main thread will poll data from receiver actor, where it finds out the data is in the type of `FedRemoteError` and re-raises it, and gets an exception just as what happens in "alice". 25 | 26 | The threading model in this proposal is shown below: 27 | 28 | ![image](./images/threading_model.png) 29 | 30 | ### The explanation of the `_atomic_shutdown_flag` 31 | When the failure happens, both main thread and data thread get the error and trigger the shutdown, which will execute the "failure handler" twice. The typical method to ensure the `failure_handler` is executed only once is to set up a flag to check whether it has been executed or not, and wrap it with `threading.lock` because it's a critical section. 32 | 33 | However, this will cause the dead lock as shown in below's graph. 34 | The data thread triggers the shutdown stage by sending `SIGINT` signal that is implemented by causing `KeyboardInterrupt` error (step 8). In order to handle the exception, OS will hold the context of the current process, including the acquired `threading.lock` in step 6, and change the context to error handler, i.e. the signal handler in step 9. Since the lock has not yet released, acquiring the same lock will cause the dead lock (step 10). 35 | ![image](./images/dead_lock.png) 36 | 37 | The solution is to check the lock before sending the signal. That lock is the `_atomic_shutdown_flag`. 38 | -------------------------------------------------------------------------------- /fed/tests/test_basic_pass_fed_objects.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | def f(): 26 | return 100 27 | 28 | 29 | @fed.remote 30 | class My: 31 | def __init__(self, value) -> None: 32 | self._value = value 33 | 34 | def get_value(self): 35 | return self._value 36 | 37 | 38 | def run(party, is_inner_party): 39 | compatible_utils.init_ray(address='local') 40 | addresses = { 41 | 'alice': '127.0.0.1:11012', 42 | 'bob': '127.0.0.1:11011', 43 | } 44 | fed.init(addresses=addresses, party=party) 45 | 46 | o = f.party("alice").remote() 47 | actor_location = "alice" if is_inner_party else "bob" 48 | my = My.party(actor_location).remote(o) 49 | val = my.get_value.remote() 50 | result = fed.get(val) 51 | assert result == 100 52 | assert fed.get(o) == 100 53 | import time 54 | 55 | time.sleep(5) 56 | fed.shutdown() 57 | ray.shutdown() 58 | 59 | 60 | def test_pass_fed_objects_for_actor_creation_inner_party(): 61 | p_alice = multiprocessing.Process(target=run, args=('alice', True)) 62 | p_bob = multiprocessing.Process(target=run, args=('bob', True)) 63 | p_alice.start() 64 | p_bob.start() 65 | p_alice.join() 66 | p_bob.join() 67 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 68 | 69 | 70 | def test_pass_fed_objects_for_actor_creation_across_party(): 71 | p_alice = multiprocessing.Process(target=run, args=('alice', False)) 72 | p_bob = multiprocessing.Process(target=run, args=('bob', False)) 73 | p_alice.start() 74 | p_bob.start() 75 | p_alice.join() 76 | p_bob.join() 77 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 78 | 79 | 80 | if __name__ == "__main__": 81 | import sys 82 | 83 | sys.exit(pytest.main(["-sv", __file__])) 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mac temp files 132 | .DS_Store 133 | 134 | # Vscode 135 | .vscode/ 136 | 137 | # lincense checker binary 138 | bin/ 139 | -------------------------------------------------------------------------------- /fed/tests/simple_example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import ray 18 | 19 | import fed 20 | 21 | 22 | @fed.remote 23 | class MyActor: 24 | def __init__(self, party, data): 25 | self.__data = data 26 | self._party = party 27 | 28 | def f(self): 29 | print(f"=====THIS IS F IN PARTY {self._party}") 30 | return f"f({self._party}, ip is {ray.util.get_node_ip_address()})" 31 | 32 | def g(self, obj): 33 | print(f"=====THIS IS G IN PARTY {self._party}") 34 | return obj + "g" 35 | 36 | def h(self, obj): 37 | print(f"=====THIS IS H IN PARTY {self._party}, obj is {obj}") 38 | return obj + "h" 39 | 40 | 41 | @fed.remote 42 | def agg_fn(obj1, obj2): 43 | print(f"=====THIS IS AGG_FN, obj1={obj1}, obj2={obj2}") 44 | return f"agg-{obj1}-{obj2}" 45 | 46 | 47 | addresses = { 48 | 'alice': '127.0.0.1:11012', 49 | 'bob': '127.0.0.1:11011', 50 | } 51 | 52 | 53 | def run(party): 54 | ray.init(address='local', include_dashboard=False) 55 | fed.init(addresses=addresses, party=party) 56 | print(f"Running the script in party {party}") 57 | 58 | ds1, ds2 = [123, 789] 59 | actor_alice = MyActor.party("alice").remote(party, ds1) 60 | actor_bob = MyActor.party("bob").remote(party, ds2) 61 | 62 | obj_alice_f = actor_alice.f.remote() 63 | obj_bob_f = actor_bob.f.remote() 64 | 65 | obj_alice_g = actor_alice.g.remote(obj_alice_f) 66 | obj_bob_h = actor_bob.h.remote(obj_bob_f) 67 | 68 | obj = agg_fn.party("bob").remote(obj_alice_g, obj_bob_h) 69 | result = fed.get(obj) 70 | print(f"The result in party {party} is :{result}") 71 | fed.shutdown() 72 | ray.shutdown() 73 | 74 | 75 | def main(): 76 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 77 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 78 | p_alice.start() 79 | p_bob.start() 80 | p_alice.join() 81 | p_bob.join() 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /fed/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | 17 | import pytest 18 | 19 | import fed.utils as fed_utils 20 | 21 | 22 | def start_ray_cluster( 23 | ray_port, 24 | client_server_port, 25 | dashboard_port, 26 | ): 27 | command = [ 28 | 'ray', 29 | 'start', 30 | '--head', 31 | f'--port={ray_port}', 32 | f'--ray-client-server-port={client_server_port}', 33 | f'--include-dashboard=false', 34 | f'--dashboard-port={dashboard_port}', 35 | ] 36 | command_str = ' '.join(command) 37 | try: 38 | _ = fed_utils.start_command(command_str) 39 | except RuntimeError as e: 40 | # As we should treat the following warning messages is ok to use. 41 | # E RuntimeError: Failed to start command [ray start --head --port=41012 42 | # --ray-client-server-port=21012 --dashboard-port=9112], the error is: 43 | # E 2023-09-13 13:04:11,520 WARNING services.py:1882 -- WARNING: The 44 | # object store is using /tmp instead of /dev/shm because /dev/shm has only 45 | # 67108864 bytes available. This will harm performance! You may be able to 46 | # free up space by deleting files in /dev/shm. If you are inside a Docker 47 | # container, you can increase /dev/shm size by passing '--shm-size=1.97gb' to 48 | # 'docker run' (or add it to the run_options list in a Ray cluster config). 49 | # Make sure to set this to more than 0% of available RAM. 50 | assert 'Overwriting previous Ray address' in str( 51 | e 52 | ) or 'WARNING: The object store is using /tmp instead of /dev/shm' in str(e) 53 | 54 | 55 | @pytest.fixture 56 | def ray_client_mode_setup(): 57 | # Start 2 Ray clusters. 58 | start_ray_cluster(ray_port=41012, client_server_port=21012, dashboard_port=9112) 59 | time.sleep(1) 60 | start_ray_cluster(ray_port=41011, client_server_port=21011, dashboard_port=9111) 61 | 62 | yield 63 | fed_utils.start_command('ray stop --force') 64 | -------------------------------------------------------------------------------- /fed/tests/test_exit_on_failure_sending.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | import sys 17 | 18 | import pytest 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | def f(): 26 | raise Exception('By design.') 27 | 28 | 29 | @fed.remote 30 | class My: 31 | def __init__(self, value) -> None: 32 | self._value = value 33 | 34 | def get_value(self): 35 | return self._value 36 | 37 | 38 | def run(party: str, q: multiprocessing.Queue): 39 | compatible_utils.init_ray(address='local') 40 | addresses = { 41 | 'alice': '127.0.0.1:21321', 42 | 'bob': '127.0.0.1:21322', 43 | } 44 | retry_policy = { 45 | "maxAttempts": 2, 46 | "initialBackoff": "1s", 47 | "maxBackoff": "1s", 48 | "backoffMultiplier": 1, 49 | "retryableStatusCodes": ["UNAVAILABLE"], 50 | } 51 | 52 | def failure_handler(error): 53 | q.put('failure handler') 54 | 55 | fed.init( 56 | addresses=addresses, 57 | party=party, 58 | logging_level='debug', 59 | config={ 60 | 'cross_silo_comm': { 61 | 'grpc_retry_policy': retry_policy, 62 | 'exit_on_sending_failure': True, 63 | 'timeout_ms': 20 * 1000, 64 | }, 65 | }, 66 | sending_failure_handler=failure_handler, 67 | ) 68 | o = f.party("alice").remote() 69 | My.party("bob").remote(o) 70 | 71 | import time 72 | 73 | # Wait a long time. 74 | # If the test takes effect, the main loop here will be broken. 75 | time.sleep(86400) 76 | 77 | 78 | def test_exit_when_failure_on_sending(): 79 | q = multiprocessing.Queue() 80 | p_alice = multiprocessing.Process(target=run, args=('alice', q)) 81 | p_alice.start() 82 | p_alice.join() 83 | assert p_alice.exitcode == 1 84 | assert q.get() == 'failure handler' 85 | 86 | 87 | if __name__ == "__main__": 88 | sys.exit(pytest.main(["-sv", __file__])) 89 | -------------------------------------------------------------------------------- /fed/tests/test_reset_context.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import pytest 4 | import ray 5 | 6 | import fed 7 | import fed._private.compatible_utils as compatible_utils 8 | 9 | addresses = { 10 | 'alice': '127.0.0.1:11012', 11 | 'bob': '127.0.0.1:11011', 12 | } 13 | 14 | 15 | @fed.remote 16 | class A: 17 | def __init__(self, init_val=0) -> None: 18 | self.value = init_val 19 | 20 | def get(self): 21 | return self.value 22 | 23 | 24 | def run(party): 25 | compatible_utils.init_ray(address='local') 26 | fed.init(addresses=addresses, party=party) 27 | 28 | actor = A.party('alice').remote(10) 29 | alice_fed_obj = actor.get.remote() 30 | alice_first_fed_obj_id = alice_fed_obj.get_fed_task_id() 31 | assert fed.get(alice_fed_obj) == 10 32 | 33 | actor = A.party('bob').remote(12) 34 | bob_fed_obj = actor.get.remote() 35 | bob_first_fed_obj_id = bob_fed_obj.get_fed_task_id() 36 | assert fed.get(bob_fed_obj) == 12 37 | 38 | assert compatible_utils.kv.put("key", "val") is False 39 | assert compatible_utils.kv.get("key") == b"val" 40 | fed.shutdown() 41 | ray.shutdown() 42 | with pytest.raises(AttributeError): 43 | # `internal_kv` should be reset, putting to which should raise 44 | # `AttributeError` 45 | compatible_utils.kv.put("key2", "val2") 46 | 47 | compatible_utils.init_ray(address='local') 48 | fed.init(addresses=addresses, party=party) 49 | 50 | actor = A.party('alice').remote(10) 51 | alice_fed_obj = actor.get.remote() 52 | alice_second_fed_obj_id = alice_fed_obj.get_fed_task_id() 53 | assert fed.get(alice_fed_obj) == 10 54 | assert alice_first_fed_obj_id == alice_second_fed_obj_id 55 | 56 | actor = A.party('bob').remote(12) 57 | bob_fed_obj = actor.get.remote() 58 | bob_second_fed_obj_id = bob_fed_obj.get_fed_task_id() 59 | assert fed.get(bob_fed_obj) == 12 60 | assert bob_first_fed_obj_id == bob_second_fed_obj_id 61 | 62 | assert compatible_utils.kv.get("key") is None 63 | assert compatible_utils.kv.put("key", "val") is False 64 | assert compatible_utils.kv.get("key") == b"val" 65 | 66 | fed.shutdown() 67 | ray.shutdown() 68 | 69 | 70 | def test_reset_context(): 71 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 72 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 73 | p_alice.start() 74 | 75 | import time 76 | 77 | time.sleep(5) 78 | p_bob.start() 79 | p_alice.join() 80 | p_bob.join() 81 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 82 | 83 | 84 | if __name__ == "__main__": 85 | import sys 86 | 87 | sys.exit(pytest.main(["-sv", __file__])) 88 | -------------------------------------------------------------------------------- /fed/tests/multi-jobs/test_ignore_other_job_msg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import cloudpickle 17 | import grpc 18 | import pytest 19 | import ray 20 | import multiprocessing 21 | 22 | import fed._private.compatible_utils as compatible_utils 23 | import fed.utils as fed_utils 24 | from fed.proxy.barriers import ReceiverProxyActor 25 | from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy 26 | 27 | if compatible_utils._compare_version_strings( 28 | fed_utils.get_package_version('protobuf'), '4.0.0' 29 | ): 30 | from fed.grpc.pb4 import fed_pb2 as fed_pb2 31 | from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc 32 | else: 33 | from fed.grpc.pb3 import fed_pb2 as fed_pb2 34 | from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc 35 | 36 | 37 | def run(): 38 | # GIVEN 39 | ray.init(address='local', include_dashboard=False) 40 | address = '127.0.0.1:15111' 41 | receiver_proxy_actor = ReceiverProxyActor.remote( 42 | listening_address=address, 43 | party='alice', 44 | job_name='job1', 45 | logging_level='info', 46 | proxy_cls=GrpcReceiverProxy, 47 | ) 48 | receiver_proxy_actor.start.remote() 49 | server_state = ray.get(receiver_proxy_actor.is_ready.remote(), timeout=60) 50 | assert server_state[0], server_state[1] 51 | 52 | # WHEN 53 | channel = grpc.insecure_channel(address) 54 | stub = fed_pb2_grpc.GrpcServiceStub(channel) 55 | 56 | data = cloudpickle.dumps('data') 57 | request = fed_pb2.SendDataRequest( 58 | data=data, 59 | upstream_seq_id=str(1), 60 | downstream_seq_id=str(2), 61 | job_name='job2', 62 | ) 63 | response = stub.SendData(request) 64 | 65 | # THEN 66 | assert response.code == 417 67 | assert "JobName mis-match" in response.result 68 | 69 | ray.shutdown() 70 | 71 | 72 | def test_ignore_other_job_msg(): 73 | p_alice = multiprocessing.Process(target=run) 74 | p_alice.start() 75 | p_alice.join() 76 | assert p_alice.exitcode == 0 77 | 78 | 79 | if __name__ == "__main__": 80 | import sys 81 | 82 | sys.exit(pytest.main(["-sv", __file__])) 83 | -------------------------------------------------------------------------------- /fed/tests/serializations_tests/test_unpickle_with_whitelist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import numpy 18 | import pytest 19 | import ray 20 | 21 | import fed 22 | import fed._private.compatible_utils as compatible_utils 23 | 24 | 25 | @fed.remote 26 | def generate_wrong_type(): 27 | class WrongType: 28 | pass 29 | 30 | return WrongType() 31 | 32 | 33 | @fed.remote 34 | def generate_allowed_type(): 35 | return numpy.array([1, 2, 3, 4, 5]) 36 | 37 | 38 | @fed.remote 39 | def pass_arg(d): 40 | return True 41 | 42 | 43 | def run(party): 44 | compatible_utils.init_ray(address='local', include_dashboard=False) 45 | addresses = { 46 | 'alice': '127.0.0.1:11355', 47 | 'bob': '127.0.0.1:11356', 48 | } 49 | allowed_list = { 50 | "numpy._core.numeric": ["*"], 51 | "numpy": ["dtype"], 52 | } 53 | fed.init( 54 | addresses=addresses, 55 | party=party, 56 | config={"cross_silo_comm": {'serializing_allowed_list': allowed_list}}, 57 | ) 58 | 59 | # Test passing an allowed type. 60 | o1 = generate_allowed_type.party("alice").remote() 61 | o2 = pass_arg.party("bob").remote(o1) 62 | res = fed.get(o2) 63 | assert res 64 | 65 | # Test passing an unallowed type. 66 | o3 = generate_wrong_type.party("alice").remote() 67 | o4 = pass_arg.party("bob").remote(o3) 68 | if party == "bob": 69 | try: 70 | fed.get(o4) 71 | assert False, "This code path shouldn't be reached." 72 | except Exception as e: 73 | assert "_pickle.UnpicklingError" in str(e) 74 | else: 75 | import time 76 | 77 | time.sleep(10) 78 | fed.shutdown() 79 | ray.shutdown() 80 | 81 | 82 | def test_restricted_loads(): 83 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 84 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 85 | p_alice.start() 86 | p_bob.start() 87 | p_alice.join() 88 | p_bob.join() 89 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 90 | 91 | 92 | if __name__ == "__main__": 93 | import sys 94 | 95 | sys.exit(pytest.main(["-sv", __file__])) 96 | -------------------------------------------------------------------------------- /fed/proxy/grpc/grpc_options.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | 17 | _GRPC_SERVICE = "GrpcService" 18 | 19 | _DEFAULT_GRPC_RETRY_POLICY = { 20 | "maxAttempts": 5, 21 | "initialBackoff": "5s", 22 | "maxBackoff": "30s", 23 | "backoffMultiplier": 2, 24 | "retryableStatusCodes": ["UNAVAILABLE"], 25 | } 26 | 27 | 28 | _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH = 500 * 1024 * 1024 29 | _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH = 500 * 1024 * 1024 30 | 31 | _DEFAULT_GRPC_CHANNEL_OPTIONS = { 32 | 'grpc.enable_retries': 1, 33 | 'grpc.so_reuseport': 0, 34 | 'grpc.max_send_message_length': _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH, 35 | 'grpc.max_receive_message_length': _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH, 36 | 'grpc.service_config': json.dumps( 37 | { 38 | 'methodConfig': [ 39 | { 40 | 'name': [{'service': _GRPC_SERVICE}], 41 | 'retryPolicy': _DEFAULT_GRPC_RETRY_POLICY, 42 | } 43 | ] 44 | } 45 | ), 46 | } 47 | 48 | 49 | def get_grpc_options( 50 | retry_policy=None, max_send_message_length=None, max_receive_message_length=None 51 | ): 52 | if not retry_policy: 53 | retry_policy = _DEFAULT_GRPC_RETRY_POLICY 54 | if not max_send_message_length: 55 | max_send_message_length = _DEFAULT_GRPC_MAX_SEND_MESSAGE_LENGTH 56 | if not max_receive_message_length: 57 | max_receive_message_length = _DEFAULT_GRPC_MAX_RECEIVE_MESSAGE_LENGTH 58 | 59 | return [ 60 | ( 61 | 'grpc.max_send_message_length', 62 | max_send_message_length, 63 | ), 64 | ( 65 | 'grpc.max_receive_message_length', 66 | max_receive_message_length, 67 | ), 68 | ('grpc.enable_retries', 1), 69 | ( 70 | 'grpc.service_config', 71 | json.dumps( 72 | { 73 | 'methodConfig': [ 74 | { 75 | 'name': [{'service': _GRPC_SERVICE}], 76 | 'retryPolicy': retry_policy, 77 | } 78 | ] 79 | } 80 | ), 81 | ), 82 | ('grpc.so_reuseport', 0), 83 | ] 84 | -------------------------------------------------------------------------------- /fed/_private/serialization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import io 16 | 17 | import cloudpickle 18 | 19 | import fed.config as fed_config 20 | 21 | _pickle_whitelist = None 22 | 23 | 24 | def _restricted_loads( 25 | serialized_data, 26 | *, 27 | fix_imports=True, 28 | encoding="ASCII", 29 | errors="strict", 30 | buffers=None, 31 | ): 32 | from sys import version_info 33 | 34 | assert version_info.major == 3 35 | 36 | if version_info.minor >= 8: 37 | import pickle as pickle 38 | else: 39 | import pickle5 as pickle 40 | 41 | class RestrictedUnpickler(pickle.Unpickler): 42 | def find_class(self, module, name): 43 | if _pickle_whitelist is None or ( 44 | module in _pickle_whitelist 45 | and ( 46 | _pickle_whitelist[module] is None 47 | or name in _pickle_whitelist[module] 48 | ) 49 | ): 50 | return super().find_class(module, name) 51 | 52 | if module == "fed._private": # TODO(qwang): Not sure if it works. 53 | return super().find_class(module, name) 54 | 55 | # Forbid everything else. 56 | raise pickle.UnpicklingError("global '%s.%s' is forbidden" % (module, name)) 57 | 58 | if isinstance(serialized_data, str): 59 | raise TypeError("Can't load pickle from unicode string") 60 | file = io.BytesIO(serialized_data) 61 | return RestrictedUnpickler( 62 | file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors 63 | ).load() 64 | 65 | 66 | def _apply_loads_function_with_whitelist(): 67 | global _pickle_whitelist 68 | 69 | cross_silo_comm_config = fed_config.CrossSiloMessageConfig.from_dict( 70 | fed_config.get_job_config().cross_silo_comm_config_dict 71 | ) 72 | _pickle_whitelist = cross_silo_comm_config.serializing_allowed_list 73 | if _pickle_whitelist is None: 74 | return 75 | 76 | if "*" in _pickle_whitelist: 77 | _pickle_whitelist = None 78 | return 79 | 80 | for module, attr_list in _pickle_whitelist.items(): 81 | if attr_list is not None and "*" in attr_list: 82 | _pickle_whitelist[module] = None 83 | cloudpickle.loads = _restricted_loads 84 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from datetime import datetime 17 | 18 | import setuptools 19 | from setuptools import find_packages, setup 20 | 21 | # Package and version. 22 | BASE_VERSION = "0.1.0" 23 | build_mode = os.getenv("RAYFED_BUILD_MODE", "") 24 | package_name = os.getenv("RAYFED_PACKAGE_NAME", "rayfed") 25 | 26 | if build_mode == "nightly": 27 | VERSION = BASE_VERSION + datetime.today().strftime("b%Y%m%d.dev0") 28 | package_name = "rayfed-nightly" 29 | else: 30 | VERSION = BASE_VERSION + ".dev0" 31 | 32 | this_directory = os.path.abspath(os.path.dirname(__file__)) 33 | with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: 34 | long_description = f.read() 35 | 36 | plat_name = "any" 37 | 38 | 39 | def read_requirements(): 40 | requirements = [] 41 | with open('requirements.txt') as file: 42 | requirements = file.read().splitlines() 43 | print("Requirements: ", requirements) 44 | return requirements 45 | 46 | 47 | # [ref](https://github.com/perwin/pyimfit/blob/master/setup.py) 48 | # Modified cleanup command to remove build subdirectory 49 | # Based on: https://stackoverflow.com/questions/1710839/custom-distutils-commands 50 | class CleanCommand(setuptools.Command): 51 | description = "custom clean command that forcefully removes dist/build directories" 52 | user_options = [] 53 | 54 | def initialize_options(self): 55 | self._cwd = None 56 | 57 | def finalize_options(self): 58 | self._cwd = os.getcwd() 59 | 60 | def run(self): 61 | assert os.getcwd() == self._cwd, 'Must be in package root: %s' % self._cwd 62 | os.system('rm -rf ./build ./dist') 63 | 64 | 65 | setup( 66 | name=package_name, 67 | version=VERSION, 68 | license='Apache 2.0', 69 | description=( 70 | 'A multiple parties joint, distributed execution engine based on Ray,' 71 | 'to help build your own federated learning frameworks in minutes.' 72 | ), 73 | long_description=long_description, 74 | long_description_content_type='text/markdown', 75 | author='RayFed Team', 76 | author_email='rayfed-dev@googlegroups.com', 77 | url='https://github.com/ray-project/rayfed', 78 | packages=find_packages(exclude=('examples', 'tests', 'tests.*')), 79 | install_requires=read_requirements(), 80 | extras_require={'dev': ['pylint']}, 81 | options={'bdist_wheel': {'plat_name': plat_name}}, 82 | ) 83 | -------------------------------------------------------------------------------- /fed/tests/test_setup_proxy_actor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import multiprocessing 17 | 18 | import pytest 19 | import ray 20 | 21 | import fed 22 | import fed._private.compatible_utils as compatible_utils 23 | from fed.proxy.barriers import receiver_proxy_actor_name, sender_proxy_actor_name 24 | 25 | 26 | def run(party): 27 | compatible_utils.init_ray(address='local', resources={"127.0.0.1": 2}) 28 | addresses = { 29 | 'alice': '127.0.0.1:11010', 30 | 'bob': '127.0.0.1:11011', 31 | } 32 | fed.init( 33 | addresses=addresses, 34 | party=party, 35 | ) 36 | 37 | assert ray.get_actor(sender_proxy_actor_name()) is not None 38 | assert ray.get_actor(receiver_proxy_actor_name()) is not None 39 | 40 | fed.shutdown() 41 | ray.shutdown() 42 | 43 | 44 | def run_failure(party): 45 | compatible_utils.init_ray(address='local', resources={"127.0.0.1": 1}) 46 | addresses = { 47 | 'alice': '127.0.0.1:11010', 48 | 'bob': '127.0.0.1:11011', 49 | } 50 | sender_proxy_resources = {"127.0.0.2": 1} # Insufficient resource 51 | receiver_proxy_resources = {"127.0.0.2": 1} # Insufficient resource 52 | with pytest.raises(ray.exceptions.GetTimeoutError): 53 | fed.init( 54 | addresses=addresses, 55 | party=party, 56 | config={ 57 | 'cross_silo_comm': { 58 | 'send_resource_label': sender_proxy_resources, 59 | 'recv_resource_label': receiver_proxy_resources, 60 | 'timeout_in_ms': 10 * 1000, 61 | } 62 | }, 63 | ) 64 | 65 | fed.shutdown() 66 | ray.shutdown() 67 | 68 | 69 | def test_setup_proxy_success(): 70 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 71 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 72 | p_alice.start() 73 | p_bob.start() 74 | p_alice.join() 75 | p_bob.join() 76 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 77 | 78 | 79 | def test_setup_proxy_failed(): 80 | p_alice = multiprocessing.Process(target=run_failure, args=('alice',)) 81 | p_bob = multiprocessing.Process(target=run_failure, args=('bob',)) 82 | p_alice.start() 83 | p_bob.start() 84 | p_alice.join() 85 | p_bob.join() 86 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 87 | 88 | 89 | if __name__ == "__main__": 90 | import sys 91 | 92 | sys.exit(pytest.main(["-sv", __file__])) 93 | -------------------------------------------------------------------------------- /fed/tests/without_ray_tests/test_tree_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, List, Tuple, Union 16 | 17 | import pytest 18 | 19 | import fed.tree_util as tree_utils 20 | 21 | 22 | def test_flatten_none(): 23 | li, tree_def = tree_utils.tree_flatten(None) 24 | assert isinstance(li, list) 25 | assert len(li) == 1 26 | res = tree_utils.tree_unflatten(li, tree_def) 27 | assert res is None 28 | 29 | 30 | def test_flatten_single_primivite_elements(): 31 | def _assert_flatten_single_element(target: Any): 32 | li, tree_def = tree_utils.tree_flatten(target) 33 | assert isinstance(li, list) 34 | assert len(li) == 1 35 | res = tree_utils.tree_unflatten(li, tree_def) 36 | assert res == target 37 | 38 | _assert_flatten_single_element(1) 39 | _assert_flatten_single_element(0.5) 40 | _assert_flatten_single_element("hello") 41 | _assert_flatten_single_element(b"world") 42 | 43 | 44 | def test_flatten_single_simple_containers(): 45 | def _assert_flatten_single_simple_container(target: Union[List, Tuple, Dict]): 46 | container_len = len(target) 47 | li, tree_def = tree_utils.tree_flatten(target) 48 | assert isinstance(li, list) 49 | assert len(li) == container_len 50 | res = tree_utils.tree_unflatten(li, tree_def) 51 | assert res == target 52 | 53 | _assert_flatten_single_simple_container([1, 2, 3]) 54 | _assert_flatten_single_simple_container((1, 2, 3)) 55 | _assert_flatten_single_simple_container({"a": 1, "b": 2, "c": 3}) 56 | 57 | 58 | def test_flatten_complext_nested_container(): 59 | o = [1, 2, (3, 4), [5, {"b", 6}, 7], 8] 60 | flattened, tree_def = tree_utils.tree_flatten(o) 61 | assert len(flattened) == 8 62 | res = tree_utils.tree_unflatten(flattened, tree_def) 63 | assert o == res 64 | 65 | 66 | def test_flatten_and_replace_element(): 67 | o = [1, 2, (3, 4), [5, {"b": 6}, 7], 8] 68 | flattened, tree_def = tree_utils.tree_flatten(o) 69 | flattened[0] = "hello" 70 | flattened[5] = b"world" 71 | assert len(flattened) == 8 72 | res = tree_utils.tree_unflatten(flattened, tree_def) 73 | assert o != res 74 | assert len(res) == 5 75 | print(res) 76 | 77 | assert res[0] == "hello" 78 | assert res[1] == 2 79 | assert res[2] == (3, 4) 80 | assert res[3] == [5, {"b": b"world"}, 7] 81 | assert res[4] == 8 82 | 83 | 84 | if __name__ == "__main__": 85 | import sys 86 | 87 | sys.exit(pytest.main(["-sv", __file__])) 88 | -------------------------------------------------------------------------------- /fed/fed_object.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ray import ObjectRef 16 | 17 | 18 | class FedObjectSendingContext: 19 | """The class that's used for holding the all contexts about sending side.""" 20 | 21 | def __init__(self) -> None: 22 | # This field holds the target(downstream) parties that this fed object 23 | # is sending or sent to. 24 | # The key is the party name and the value is a boolean indicating whether 25 | # this object is sending or sent to the party. 26 | self._is_sending_or_sent = {} 27 | 28 | def mark_is_sending_to_party(self, target_party: str): 29 | self._is_sending_or_sent[target_party] = True 30 | 31 | def was_sending_or_sent_to_party(self, target_party: str): 32 | return target_party in self._is_sending_or_sent 33 | 34 | 35 | class FedObjectReceivingContext: 36 | """The class that's used for holding the all contexts about receiving side.""" 37 | 38 | pass 39 | 40 | 41 | class FedObject: 42 | """The class that represents for a fed object handle for the result 43 | of the return value from a fed task. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | node_party: str, 49 | fed_task_id: int, 50 | ray_object_ref: ObjectRef, 51 | idx_in_task: int = 0, 52 | ) -> None: 53 | # The party name to exeute the task which produce this fed object. 54 | self._node_party = node_party 55 | self._ray_object_ref = ray_object_ref 56 | self._fed_task_id = fed_task_id 57 | self._idx_in_task = idx_in_task 58 | self._sending_context = FedObjectSendingContext() 59 | self._receiving_context = FedObjectReceivingContext() 60 | 61 | def get_ray_object_ref(self): 62 | return self._ray_object_ref 63 | 64 | def get_fed_task_id(self): 65 | return f'{self._fed_task_id}#{self._idx_in_task}' 66 | 67 | def get_party(self): 68 | return self._node_party 69 | 70 | def _mark_is_sending_to_party(self, target_party: str): 71 | """Mark this fed object is sending to the target party.""" 72 | self._sending_context.mark_is_sending_to_party(target_party) 73 | 74 | def _was_sending_or_sent_to_party(self, target_party: str): 75 | """Query whether this fed object was sending or sent to the target party.""" 76 | return self._sending_context.was_sending_or_sent_to_party(target_party) 77 | 78 | def _cache_ray_object_ref(self, ray_object_ref): 79 | """Cache the ray object reference for this fed object.""" 80 | self._ray_object_ref = ray_object_ref 81 | -------------------------------------------------------------------------------- /fed/grpc/pb3/fed_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 16 | """Client and server classes corresponding to protobuf-defined services.""" 17 | import grpc 18 | 19 | import fed.grpc.pb3.fed_pb2 as fed__pb2 20 | 21 | 22 | class GrpcServiceStub(object): 23 | """Missing associated documentation comment in .proto file.""" 24 | 25 | def __init__(self, channel): 26 | """Constructor. 27 | 28 | Args: 29 | channel: A grpc.Channel. 30 | """ 31 | self.SendData = channel.unary_unary( 32 | '/GrpcService/SendData', 33 | request_serializer=fed__pb2.SendDataRequest.SerializeToString, 34 | response_deserializer=fed__pb2.SendDataResponse.FromString, 35 | ) 36 | 37 | 38 | class GrpcServiceServicer(object): 39 | """Missing associated documentation comment in .proto file.""" 40 | 41 | def SendData(self, request, context): 42 | """Missing associated documentation comment in .proto file.""" 43 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 44 | context.set_details('Method not implemented!') 45 | raise NotImplementedError('Method not implemented!') 46 | 47 | 48 | def add_GrpcServiceServicer_to_server(servicer, server): 49 | rpc_method_handlers = { 50 | 'SendData': grpc.unary_unary_rpc_method_handler( 51 | servicer.SendData, 52 | request_deserializer=fed__pb2.SendDataRequest.FromString, 53 | response_serializer=fed__pb2.SendDataResponse.SerializeToString, 54 | ), 55 | } 56 | generic_handler = grpc.method_handlers_generic_handler( 57 | 'GrpcService', rpc_method_handlers) 58 | server.add_generic_rpc_handlers((generic_handler,)) 59 | 60 | 61 | # This class is part of an EXPERIMENTAL API. 62 | class GrpcService(object): 63 | """Missing associated documentation comment in .proto file.""" 64 | 65 | @staticmethod 66 | def SendData(request, 67 | target, 68 | options=(), 69 | channel_credentials=None, 70 | call_credentials=None, 71 | insecure=False, 72 | compression=None, 73 | wait_for_ready=None, 74 | timeout=None, 75 | metadata=None): 76 | return grpc.experimental.unary_unary(request, target, '/GrpcService/SendData', 77 | fed__pb2.SendDataRequest.SerializeToString, 78 | fed__pb2.SendDataResponse.FromString, 79 | options, channel_credentials, 80 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 81 | -------------------------------------------------------------------------------- /fed/grpc/pb4/fed_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 16 | """Client and server classes corresponding to protobuf-defined services.""" 17 | import grpc 18 | 19 | import fed.grpc.pb4.fed_pb2 as fed__pb2 20 | 21 | 22 | class GrpcServiceStub(object): 23 | """Missing associated documentation comment in .proto file.""" 24 | 25 | def __init__(self, channel): 26 | """Constructor. 27 | 28 | Args: 29 | channel: A grpc.Channel. 30 | """ 31 | self.SendData = channel.unary_unary( 32 | '/GrpcService/SendData', 33 | request_serializer=fed__pb2.SendDataRequest.SerializeToString, 34 | response_deserializer=fed__pb2.SendDataResponse.FromString, 35 | ) 36 | 37 | 38 | class GrpcServiceServicer(object): 39 | """Missing associated documentation comment in .proto file.""" 40 | 41 | def SendData(self, request, context): 42 | """Missing associated documentation comment in .proto file.""" 43 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 44 | context.set_details('Method not implemented!') 45 | raise NotImplementedError('Method not implemented!') 46 | 47 | 48 | def add_GrpcServiceServicer_to_server(servicer, server): 49 | rpc_method_handlers = { 50 | 'SendData': grpc.unary_unary_rpc_method_handler( 51 | servicer.SendData, 52 | request_deserializer=fed__pb2.SendDataRequest.FromString, 53 | response_serializer=fed__pb2.SendDataResponse.SerializeToString, 54 | ), 55 | } 56 | generic_handler = grpc.method_handlers_generic_handler( 57 | 'GrpcService', rpc_method_handlers) 58 | server.add_generic_rpc_handlers((generic_handler,)) 59 | 60 | 61 | # This class is part of an EXPERIMENTAL API. 62 | class GrpcService(object): 63 | """Missing associated documentation comment in .proto file.""" 64 | 65 | @staticmethod 66 | def SendData(request, 67 | target, 68 | options=(), 69 | channel_credentials=None, 70 | call_credentials=None, 71 | insecure=False, 72 | compression=None, 73 | wait_for_ready=None, 74 | timeout=None, 75 | metadata=None): 76 | return grpc.experimental.unary_unary(request, target, '/GrpcService/SendData', 77 | fed__pb2.SendDataRequest.SerializeToString, 78 | fed__pb2.SendDataResponse.FromString, 79 | options, channel_credentials, 80 | insecure, call_credentials, compression, wait_for_ready, timeout, metadata) 81 | -------------------------------------------------------------------------------- /fed/tests/test_fed_get.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | 23 | 24 | @fed.remote 25 | class MyModel: 26 | def __init__(self, party, step_length): 27 | self._trained_steps = 0 28 | self._step_length = step_length 29 | self._weights = 0 30 | self._party = party 31 | 32 | def train(self): 33 | self._trained_steps += 1 34 | self._weights += self._step_length 35 | return self._weights 36 | 37 | def get_weights(self): 38 | return self._weights 39 | 40 | def set_weights(self, new_weights): 41 | self._weights = new_weights 42 | return new_weights 43 | 44 | 45 | @fed.remote 46 | def mean(x, y): 47 | return (x + y) / 2 48 | 49 | 50 | def run(party): 51 | import time 52 | 53 | if party == 'alice': 54 | time.sleep(1.4) 55 | 56 | # address = 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' # noqa 57 | # compatible_utils.init_ray(address=address) 58 | compatible_utils.init_ray(address='local') 59 | 60 | addresses = { 61 | 'alice': '127.0.0.1:31012', 62 | 'bob': '127.0.0.1:31011', 63 | } 64 | fed.init(addresses=addresses, party=party) 65 | 66 | epochs = 3 67 | alice_model = MyModel.party("alice").remote("alice", 2) 68 | bob_model = MyModel.party("bob").remote("bob", 4) 69 | 70 | all_mean_weights = [] 71 | for epoch in range(epochs): 72 | w1 = alice_model.train.remote() 73 | w2 = bob_model.train.remote() 74 | new_weights = mean.party("alice").remote(w1, w2) 75 | result = fed.get(new_weights) 76 | alice_model.set_weights.remote(new_weights) 77 | bob_model.set_weights.remote(new_weights) 78 | all_mean_weights.append(result) 79 | assert all_mean_weights == [3, 6, 9] 80 | latest_weights = fed.get( 81 | [alice_model.get_weights.remote(), bob_model.get_weights.remote()] 82 | ) 83 | assert latest_weights == [9, 9] 84 | fed.shutdown() 85 | ray.shutdown() 86 | 87 | 88 | def test_fed_get_in_2_parties(): 89 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 90 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 91 | p_alice.start() 92 | p_bob.start() 93 | p_alice.join() 94 | p_bob.join() 95 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 96 | 97 | 98 | if __name__ == "__main__": 99 | import sys 100 | 101 | sys.exit(pytest.main(["-sv", __file__])) 102 | -------------------------------------------------------------------------------- /fed/proxy/base_proxy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | from typing import Dict 17 | 18 | from fed.config import CrossSiloMessageConfig 19 | 20 | 21 | class SenderProxy(abc.ABC): 22 | def __init__( 23 | self, 24 | addresses: Dict, 25 | party: str, 26 | job_name: str, 27 | tls_config: Dict, 28 | proxy_config: CrossSiloMessageConfig = None, 29 | ) -> None: 30 | self._addresses = addresses 31 | self._party = party 32 | self._tls_config = tls_config 33 | self._proxy_config = proxy_config 34 | self._job_name = job_name 35 | 36 | @abc.abstractmethod 37 | async def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): 38 | pass 39 | 40 | async def is_ready(self): 41 | return True 42 | 43 | async def get_proxy_config(self, dest_party=None): 44 | return self._proxy_config 45 | 46 | 47 | class ReceiverProxy(abc.ABC): 48 | def __init__( 49 | self, 50 | listen_addr: str, 51 | party: str, 52 | job_name: str, 53 | tls_config: Dict, 54 | proxy_config: CrossSiloMessageConfig = None, 55 | ) -> None: 56 | self._listen_addr = listen_addr 57 | self._party = party 58 | self._tls_config = tls_config 59 | self._proxy_config = proxy_config 60 | self._job_name = job_name 61 | 62 | @abc.abstractmethod 63 | def start(self): 64 | pass 65 | 66 | @abc.abstractmethod 67 | async def get_data(self, src_party, upstream_seq_id, curr_seq_id): 68 | pass 69 | 70 | async def is_ready(self): 71 | return True 72 | 73 | async def get_proxy_config(self): 74 | return self._proxy_config 75 | 76 | 77 | class SenderReceiverProxy(abc.ABC): 78 | def __init__( 79 | self, 80 | addresses: Dict, 81 | self_party: str, 82 | tls_config: Dict, 83 | proxy_config: CrossSiloMessageConfig = None, 84 | ) -> None: 85 | self._addresses = addresses 86 | self._party = self_party 87 | self._tls_config = tls_config 88 | self._proxy_config = proxy_config 89 | 90 | @abc.abstractmethod 91 | def start(self): 92 | pass 93 | 94 | def is_ready(self): 95 | return True 96 | 97 | @abc.abstractmethod 98 | def get_data(self, src_party, upstream_seq_id, curr_seq_id): 99 | pass 100 | 101 | @abc.abstractmethod 102 | def send(self, dest_party, data, upstream_seq_id, downstream_seq_id): 103 | pass 104 | 105 | def get_proxy_config(self): 106 | return self._proxy_config 107 | -------------------------------------------------------------------------------- /fed/tests/client_mode_tests/test_basic_client_mode.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | from fed._private.compatible_utils import _compare_version_strings 23 | from fed.tests.test_utils import ray_client_mode_setup # noqa 24 | 25 | pytestmark = pytest.mark.skipif( 26 | _compare_version_strings( 27 | ray.__version__, 28 | '2.4.0', 29 | ), 30 | reason='Skip client mode when ray > 2.4.0', 31 | ) 32 | 33 | 34 | @fed.remote 35 | class MyModel: 36 | def __init__(self, party, step_length): 37 | self._trained_steps = 0 38 | self._step_length = step_length 39 | self._weights = 0 40 | self._party = party 41 | 42 | def train(self): 43 | self._trained_steps += 1 44 | self._weights += self._step_length 45 | return self._weights 46 | 47 | def get_weights(self): 48 | return self._weights 49 | 50 | def set_weights(self, new_weights): 51 | self._weights = new_weights 52 | return new_weights 53 | 54 | 55 | @fed.remote 56 | def mean(x, y): 57 | return (x + y) / 2 58 | 59 | 60 | def run(party): 61 | import time 62 | 63 | if party == 'alice': 64 | time.sleep(1.4) 65 | 66 | address = ( 67 | 'ray://127.0.0.1:21012' if party == 'alice' else 'ray://127.0.0.1:21011' 68 | ) # noqa 69 | compatible_utils.init_ray(address=address) 70 | 71 | addresses = { 72 | 'alice': '127.0.0.1:31012', 73 | 'bob': '127.0.0.1:31011', 74 | } 75 | fed.init(addresses=addresses, party=party) 76 | 77 | epochs = 3 78 | alice_model = MyModel.party("alice").remote("alice", 2) 79 | bob_model = MyModel.party("bob").remote("bob", 4) 80 | 81 | all_mean_weights = [] 82 | for epoch in range(epochs): 83 | w1 = alice_model.train.remote() 84 | w2 = bob_model.train.remote() 85 | new_weights = mean.party("alice").remote(w1, w2) 86 | result = fed.get(new_weights) 87 | alice_model.set_weights.remote(new_weights) 88 | bob_model.set_weights.remote(new_weights) 89 | all_mean_weights.append(result) 90 | assert all_mean_weights == [3, 6, 9] 91 | latest_weights = fed.get( 92 | [alice_model.get_weights.remote(), bob_model.get_weights.remote()] 93 | ) 94 | assert latest_weights == [9, 9] 95 | fed.shutdown() 96 | ray.shutdown() 97 | 98 | 99 | def test_fed_get_in_2_parties(ray_client_mode_setup): # noqa 100 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 101 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 102 | p_alice.start() 103 | p_bob.start() 104 | p_alice.join() 105 | p_bob.join() 106 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 107 | 108 | 109 | if __name__ == "__main__": 110 | import sys 111 | 112 | sys.exit(pytest.main(["-sv", __file__])) 113 | -------------------------------------------------------------------------------- /fed/tests/test_transport_proxy_tls.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import cloudpickle 18 | import pytest 19 | import ray 20 | 21 | import fed._private.compatible_utils as compatible_utils 22 | from fed._private import constants, global_context 23 | from fed.proxy.barriers import ( 24 | _start_receiver_proxy, 25 | _start_sender_proxy, 26 | receiver_proxy_actor_name, 27 | send, 28 | ) 29 | from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy, GrpcSenderProxy 30 | 31 | 32 | def test_n_to_1_transport(): 33 | """This case is used to test that we have N send_op barriers, 34 | sending data to the target receiver proxy, and there also have 35 | N receivers to `get_data` from receiver proxy at that time. 36 | """ 37 | compatible_utils.init_ray(address="local") 38 | test_job_name = "test_n_to_1_transport" 39 | cert_dir = os.path.join( 40 | os.path.dirname(os.path.abspath(__file__)), "/tmp/rayfed/test-certs/" 41 | ) 42 | tls_config = { 43 | "ca_cert": os.path.join(cert_dir, "server.crt"), 44 | "cert": os.path.join(cert_dir, "server.crt"), 45 | "key": os.path.join(cert_dir, "server.key"), 46 | } 47 | party = "test_party" 48 | cluster_config = { 49 | constants.KEY_OF_CLUSTER_ADDRESSES: "", 50 | constants.KEY_OF_CURRENT_PARTY_NAME: "", 51 | constants.KEY_OF_TLS_CONFIG: tls_config, 52 | } 53 | global_context.init_global_context(party, test_job_name, False, False) 54 | global_context.get_global_context().get_cleanup_manager().start() 55 | compatible_utils._init_internal_kv(test_job_name) 56 | compatible_utils.kv.put( 57 | constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) 58 | ) 59 | 60 | NUM_DATA = 10 61 | SERVER_ADDRESS = "127.0.0.1:65422" 62 | addresses = {"test_party": SERVER_ADDRESS} 63 | _start_receiver_proxy( 64 | addresses, 65 | party, 66 | logging_level="info", 67 | tls_config=tls_config, 68 | proxy_cls=GrpcReceiverProxy, 69 | proxy_config={}, 70 | ) 71 | _start_sender_proxy( 72 | addresses, 73 | party, 74 | logging_level="info", 75 | tls_config=tls_config, 76 | proxy_cls=GrpcSenderProxy, 77 | proxy_config={}, 78 | ) 79 | 80 | sent_objs = [] 81 | get_objs = [] 82 | receiver_proxy_actor = ray.get_actor(receiver_proxy_actor_name()) 83 | for i in range(NUM_DATA): 84 | sent_obj = send( 85 | party, 86 | f"data-{i}", 87 | i, 88 | i + 1, 89 | ) 90 | sent_objs.append(sent_obj) 91 | get_obj = receiver_proxy_actor.get_data.remote(party, i, i + 1) 92 | get_objs.append(get_obj) 93 | for result in ray.get(sent_objs): 94 | assert result 95 | 96 | for i in range(NUM_DATA): 97 | assert f"data-{i}" in ray.get(get_objs) 98 | 99 | global_context.get_global_context().get_cleanup_manager().stop() 100 | global_context.clear_global_context() 101 | compatible_utils._clear_internal_kv() 102 | ray.shutdown() 103 | 104 | 105 | if __name__ == "__main__": 106 | import sys 107 | 108 | sys.exit(pytest.main(["-sv", __file__])) 109 | -------------------------------------------------------------------------------- /fed/_private/message_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import threading 17 | import time 18 | from collections import deque 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | # NOTE(NKcqx): The symbol to let the polling thread inside message queue to stop. 23 | # Because in python, the recommended way to stop a sub-thread is to set a flag 24 | # that checked by the sub-thread itself.(see https://stackoverflow.com/a/325528). 25 | STOP_SYMBOL = False 26 | 27 | 28 | class MessageQueueManager: 29 | def __init__(self, msg_handler, failure_handler=None, thread_name=""): 30 | assert callable(msg_handler), "msg_handler must be a callable function" 31 | # `deque()` is thread safe on `popleft` and `append` operations. 32 | # See https://docs.python.org/3/library/collections.html#deque-objects 33 | self._queue = deque() 34 | self._msg_handler = msg_handler 35 | self._failure_handler = failure_handler 36 | self._thread = None 37 | # Assign a name to the thread to better distinguish it from all threads. 38 | self._thread_name = thread_name 39 | 40 | def start(self): 41 | def _loop(): 42 | while True: 43 | try: 44 | message = self._queue.popleft() 45 | except IndexError: 46 | time.sleep(0.1) 47 | continue 48 | 49 | if message == STOP_SYMBOL: 50 | break 51 | res = self._msg_handler(message) 52 | if not res: 53 | break 54 | 55 | if self._thread is None or not self._thread.is_alive(): 56 | logger.debug( 57 | f"Starting new thread[{self._thread_name}] for message polling." 58 | ) 59 | self._queue = deque() 60 | self._thread = threading.Thread(target=_loop, name=self._thread_name) 61 | self._thread.start() 62 | 63 | def append(self, message): 64 | self._queue.append(message) 65 | 66 | def appendleft(self, message): 67 | self._queue.appendleft(message) 68 | 69 | def _notify_to_exit(self, immediately=False): 70 | logger.info(f"Notify message polling thread[{self._thread_name}] to exit.") 71 | if immediately: 72 | self.appendleft(STOP_SYMBOL) 73 | else: 74 | self.append(STOP_SYMBOL) 75 | 76 | def stop(self, wait_for_sending=True): 77 | """ 78 | Stop the message queue. 79 | 80 | Args: 81 | wait_for_sending (bool): A flag indicating whether joining the thread to wait for 82 | the loop stop. If True, do not join. Defaults to True. 83 | """ 84 | if threading.current_thread() == self._thread: 85 | logger.error( 86 | f"Can't stop the message queue in the message " 87 | f"polling thread[{self._thread_name}]. Ignore it as this" 88 | f"could bring unknown time sequence problems." 89 | ) 90 | raise RuntimeError("Thread can't kill itself") 91 | 92 | # TODO(NKcqx): Force kill sub-thread by calling `._stop()` will 93 | # encounter AssertionError because sub-thread's lock is not released. 94 | # Therefore, currently, not support forcelly kill thread 95 | if self.is_started(): 96 | logger.debug(f"Killing thread[{self._thread_name}].") 97 | self._notify_to_exit(immediately=not wait_for_sending) 98 | if wait_for_sending: 99 | self._thread.join() 100 | logger.info( 101 | f"The message polling thread[{self._thread_name}] was exited." 102 | ) 103 | 104 | def is_started(self): 105 | return self._thread is not None and self._thread.is_alive() 106 | -------------------------------------------------------------------------------- /fed/_private/global_context.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import threading 16 | from typing import Callable 17 | 18 | from fed.cleanup import CleanupManager 19 | from fed.exceptions import FedRemoteError 20 | 21 | 22 | class GlobalContext: 23 | def __init__( 24 | self, 25 | job_name: str, 26 | current_party: str, 27 | sending_failure_handler: Callable[[Exception], None], 28 | exit_on_sending_failure=False, 29 | continue_waiting_for_data_sending_on_error=False, 30 | ) -> None: 31 | self._job_name = job_name 32 | self._seq_count = 0 33 | self._sending_failure_handler = sending_failure_handler 34 | self._exit_on_sending_failure = exit_on_sending_failure 35 | self._atomic_shutdown_flag_lock = threading.Lock() 36 | self._atomic_shutdown_flag = True 37 | self._cleanup_manager = CleanupManager( 38 | current_party, self.acquire_shutdown_flag 39 | ) 40 | self._last_received_error: FedRemoteError = None 41 | self._continue_waiting_for_data_sending_on_error = ( 42 | continue_waiting_for_data_sending_on_error 43 | ) 44 | 45 | def next_seq_id(self) -> int: 46 | self._seq_count += 1 47 | return self._seq_count 48 | 49 | def get_cleanup_manager(self) -> CleanupManager: 50 | return self._cleanup_manager 51 | 52 | def get_job_name(self) -> str: 53 | return self._job_name 54 | 55 | def get_sending_failure_handler(self) -> Callable[[], None]: 56 | return self._sending_failure_handler 57 | 58 | def get_exit_on_sending_failure(self) -> bool: 59 | return self._exit_on_sending_failure 60 | 61 | def get_last_recevied_error(self) -> FedRemoteError: 62 | return self._last_received_error 63 | 64 | def set_last_recevied_error(self, err): 65 | self._last_received_error = err 66 | 67 | def get_continue_waiting_for_data_sending_on_error(self) -> bool: 68 | return self._continue_waiting_for_data_sending_on_error 69 | 70 | def acquire_shutdown_flag(self) -> bool: 71 | """ 72 | Acquiring a lock and set the flag to make sure 73 | `fed.shutdown()` can be called only once. 74 | 75 | The unintended shutdown, i.e. `fed.shutdown(intended=False)`, needs to 76 | be executed only once. However, `fed.shutdown` may get called duing 77 | error handling, where acquiring lock inside `fed.shutdown` may cause 78 | dead lock, see `CleanupManager._signal_exit` for more details. 79 | 80 | Returns: 81 | bool: True if successfully get the permission to unintended shutdown. 82 | """ 83 | with self._atomic_shutdown_flag_lock: 84 | if self._atomic_shutdown_flag: 85 | self._atomic_shutdown_flag = False 86 | return True 87 | return False 88 | 89 | 90 | _global_context = None 91 | 92 | 93 | def init_global_context( 94 | current_party: str, 95 | job_name: str, 96 | exit_on_sending_failure: bool, 97 | continue_waiting_for_data_sending_on_error: bool, 98 | sending_failure_handler: Callable[[Exception], None] = None, 99 | ) -> None: 100 | global _global_context 101 | if _global_context is None: 102 | _global_context = GlobalContext( 103 | job_name, 104 | current_party, 105 | exit_on_sending_failure=exit_on_sending_failure, 106 | continue_waiting_for_data_sending_on_error=continue_waiting_for_data_sending_on_error, 107 | sending_failure_handler=sending_failure_handler, 108 | ) 109 | 110 | 111 | def get_global_context(): 112 | global _global_context 113 | return _global_context 114 | 115 | 116 | def clear_global_context(wait_for_sending=False): 117 | global _global_context 118 | if _global_context is not None: 119 | _global_context.get_cleanup_manager().stop(wait_for_sending=wait_for_sending) 120 | _global_context = None 121 | -------------------------------------------------------------------------------- /fed/_private/fed_call_holder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import fed.config as fed_config 18 | from fed._private.global_context import get_global_context 19 | from fed.fed_object import FedObject 20 | from fed.proxy.barriers import send 21 | from fed.tree_util import tree_flatten 22 | from fed.utils import resolve_dependencies 23 | 24 | # Set config in the very beginning to avoid being overwritten by other packages. 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class FedCallHolder: 32 | """ 33 | `FedCallHolder` represents a call node holder when submitting tasks. 34 | For example, 35 | 36 | f.party("ALICE").remote() 37 | ~~~~~~~~~~~~~~~~ 38 | ^ 39 | | 40 | it's a holder. 41 | 42 | """ 43 | 44 | def __init__( 45 | self, 46 | node_party, 47 | submit_ray_task_func, 48 | options={}, 49 | ) -> None: 50 | # Note(NKcqx): FedCallHolder will only be created in driver process, where 51 | # the GlobalContext must has been initialized. 52 | job_name = get_global_context().get_job_name() 53 | self._party = fed_config.get_cluster_config(job_name).current_party 54 | self._node_party = node_party 55 | self._options = options 56 | self._submit_ray_task_func = submit_ray_task_func 57 | 58 | def options(self, **options): 59 | self._options = options 60 | return self 61 | 62 | def internal_remote(self, *args, **kwargs): 63 | if not self._node_party: 64 | raise ValueError("You should specify a party name on the fed actor.") 65 | 66 | # Generate a new fed task id for this call. 67 | fed_task_id = get_global_context().next_seq_id() 68 | if self._party == self._node_party: 69 | resolved_args, resolved_kwargs = resolve_dependencies( 70 | self._party, fed_task_id, *args, **kwargs 71 | ) 72 | # TODO(qwang): Handle kwargs. 73 | ray_obj_ref = self._submit_ray_task_func(resolved_args, resolved_kwargs) 74 | if isinstance(ray_obj_ref, list): 75 | return [ 76 | FedObject(self._node_party, fed_task_id, ref, i) 77 | for i, ref in enumerate(ray_obj_ref) 78 | ] 79 | else: 80 | return FedObject(self._node_party, fed_task_id, ray_obj_ref) 81 | else: 82 | flattened_args, _ = tree_flatten((args, kwargs)) 83 | for arg in flattened_args: 84 | # TODO(qwang): We still need to cosider kwargs and a deeply object_ref 85 | # in this party. 86 | if isinstance(arg, FedObject) and arg.get_party() == self._party: 87 | if arg._was_sending_or_sent_to_party(self._node_party): 88 | # This object was sending or sent to the target party, so no 89 | # need to do it again. 90 | continue 91 | else: 92 | arg._mark_is_sending_to_party(self._node_party) 93 | send( 94 | dest_party=self._node_party, 95 | data=arg.get_ray_object_ref(), 96 | upstream_seq_id=arg.get_fed_task_id(), 97 | downstream_seq_id=fed_task_id, 98 | ) 99 | if ( 100 | self._options 101 | and 'num_returns' in self._options 102 | and self._options['num_returns'] > 1 103 | ): 104 | num_returns = self._options['num_returns'] 105 | return [ 106 | FedObject(self._node_party, fed_task_id, None, i) 107 | for i in range(num_returns) 108 | ] 109 | else: 110 | return FedObject(self._node_party, fed_task_id, None) 111 | -------------------------------------------------------------------------------- /tool/generate_tls_certs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import datetime 16 | import errno 17 | import os 18 | import socket 19 | 20 | 21 | def try_make_directory_shared(directory_path): 22 | try: 23 | os.chmod(directory_path, 0o0777) 24 | except OSError as e: 25 | # Silently suppress the PermissionError that is thrown by the chmod. 26 | # This is done because the user attempting to change the permissions 27 | # on a directory may not own it. The chmod is attempted whether the 28 | # directory is new or not to avoid race conditions. 29 | # ray-project/ray/#3591 30 | if e.errno in [errno.EACCES, errno.EPERM]: 31 | pass 32 | else: 33 | raise 34 | 35 | 36 | def try_to_create_directory(directory_path): 37 | """Attempt to create a directory that is globally readable/writable. 38 | 39 | Args: 40 | directory_path: The path of the directory to create. 41 | """ 42 | directory_path = os.path.expanduser(directory_path) 43 | os.makedirs(directory_path, exist_ok=True) 44 | # Change the log directory permissions so others can use it. This is 45 | # important when multiple people are using the same machine. 46 | try_make_directory_shared(directory_path) 47 | 48 | 49 | def generate_self_signed_tls_certs(): 50 | """Create self-signed key/cert pair for testing. 51 | This method requires the library ``cryptography`` be installed. 52 | """ 53 | try: 54 | from cryptography import x509 55 | from cryptography.hazmat.backends import default_backend 56 | from cryptography.hazmat.primitives import hashes, serialization 57 | from cryptography.hazmat.primitives.asymmetric import rsa 58 | from cryptography.x509.oid import NameOID 59 | except ImportError: 60 | raise ImportError( 61 | "Using `Security.temporary` requires `cryptography`, please " 62 | "install it using either pip or conda" 63 | ) 64 | key = rsa.generate_private_key( 65 | public_exponent=65537, key_size=2048, backend=default_backend() 66 | ) 67 | key_contents = key.private_bytes( 68 | encoding=serialization.Encoding.PEM, 69 | format=serialization.PrivateFormat.PKCS8, 70 | encryption_algorithm=serialization.NoEncryption(), 71 | ).decode() 72 | 73 | ray_interal = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "ray-internal")]) 74 | # This is the same logic used by the GCS server to acquire a 75 | # private/interal IP address to listen on. If we just use localhost + 76 | # 127.0.0.1 then we won't be able to connect to the GCS and will get 77 | # an error like "No match found for server name: 192.168.X.Y" 78 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 79 | s.connect(("8.8.8.8", 80)) 80 | private_ip_address = s.getsockname()[0] 81 | s.close() 82 | altnames = x509.SubjectAlternativeName( 83 | [ 84 | x509.DNSName( 85 | socket.gethostbyname(socket.gethostname()) 86 | ), # Probably 127.0.0.1 87 | x509.DNSName("127.0.0.1"), 88 | x509.DNSName(private_ip_address), # 192.168.*.* 89 | x509.DNSName("localhost"), 90 | ] 91 | ) 92 | now = datetime.datetime.utcnow() 93 | cert = ( 94 | x509.CertificateBuilder() 95 | .subject_name(ray_interal) 96 | .issuer_name(ray_interal) 97 | .add_extension(altnames, critical=False) 98 | .public_key(key.public_key()) 99 | .serial_number(x509.random_serial_number()) 100 | .not_valid_before(now) 101 | .not_valid_after(now + datetime.timedelta(days=365)) 102 | .sign(key, hashes.SHA256(), default_backend()) 103 | ) 104 | 105 | cert_contents = cert.public_bytes(serialization.Encoding.PEM).decode() 106 | 107 | return cert_contents, key_contents 108 | 109 | 110 | def dump_to_files(cert_contents, key_contents): 111 | temp_dir = "/tmp/rayfed/test-certs" 112 | try_to_create_directory(temp_dir) 113 | cert_filepath = os.path.join(temp_dir, "server.crt") 114 | key_filepath = os.path.join(temp_dir, "server.key") 115 | with open(cert_filepath, "w") as fh: 116 | fh.write(cert_contents) 117 | with open(key_filepath, "w") as fh: 118 | fh.write(key_contents) 119 | 120 | return key_filepath, cert_filepath, temp_dir 121 | 122 | 123 | def main(): 124 | cert_contents, key_centents = generate_self_signed_tls_certs() 125 | dump_to_files(cert_contents, key_centents) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /fed/_private/fed_actor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | import ray 18 | from ray.util.client.common import ClientActorHandle 19 | 20 | from fed._private.fed_call_holder import FedCallHolder 21 | from fed.fed_object import FedObject 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class FedActorHandle: 27 | def __init__( 28 | self, 29 | fed_class_task_id, 30 | addresses, 31 | cls, 32 | party, 33 | node_party, 34 | options, 35 | ) -> None: 36 | self._fed_class_task_id = fed_class_task_id 37 | self._addresses = addresses 38 | self._body = cls 39 | self._party = party 40 | self._node_party = node_party 41 | self._options = options 42 | self._ray_actor_handle = None 43 | 44 | def __getattr__(self, method_name: str): 45 | # User trying to call .bind() without a bind class method 46 | if method_name == "remote" and "remote" not in dir(self._body): 47 | raise AttributeError(f".remote() cannot be used again on {type(self)} ") 48 | # Raise an error if the method is invalid. 49 | getattr(self._body, method_name) 50 | 51 | if self._party == self._node_party: 52 | ray_actor_handle = self._ray_actor_handle 53 | try: 54 | ray_wrappered_method = ray_actor_handle.__getattribute__(method_name) 55 | except AttributeError: 56 | # The code path in Ray client mode. 57 | assert isinstance(ray_actor_handle, ClientActorHandle) 58 | ray_wrappered_method = ray_actor_handle.__getattr__(method_name) 59 | 60 | return FedActorMethod( 61 | self._addresses, 62 | self._party, 63 | self._node_party, 64 | self, 65 | method_name, 66 | ray_wrappered_method, 67 | ).options(**self._options) 68 | else: 69 | return FedActorMethod( 70 | self._addresses, 71 | self._party, 72 | self._node_party, 73 | self, 74 | method_name, 75 | None, 76 | ).options(**self._options) 77 | 78 | def _execute_impl(self, cls_args, cls_kwargs): 79 | """Executor of ClassNode by ray.remote() 80 | 81 | Args and kwargs are to match base class signature, but not in the 82 | implementation. All args and kwargs should be resolved and replaced 83 | with value in bound_args and bound_kwargs via bottom-up recursion when 84 | current node is executed. 85 | """ 86 | if self._node_party == self._party: 87 | self._ray_actor_handle = ( 88 | ray.remote(self._body) 89 | .options(**self._options) 90 | .remote(*cls_args, **cls_kwargs) 91 | ) 92 | 93 | def _execute_remote_method( 94 | self, 95 | method_name, 96 | options, 97 | _ray_wrappered_method, 98 | args, 99 | kwargs, 100 | ): 101 | num_returns = 1 102 | if options and 'num_returns' in options: 103 | num_returns = options['num_returns'] 104 | logger.debug(f"Actor method call: {method_name}, num_returns: {num_returns}") 105 | 106 | return _ray_wrappered_method.options( 107 | name='', 108 | num_returns=num_returns, 109 | ).remote( 110 | *args, 111 | **kwargs, 112 | ) 113 | 114 | 115 | class FedActorMethod: 116 | def __init__( 117 | self, 118 | addresses, 119 | party, 120 | node_party, 121 | fed_actor_handle, 122 | method_name, 123 | ray_wrappered_method, 124 | ) -> None: 125 | self._addresses = addresses 126 | self._party = party # Current party 127 | self._node_party = node_party 128 | self._fed_actor_handle = fed_actor_handle 129 | self._method_name = method_name 130 | self._options = {} 131 | self._ray_wrappered_method = ray_wrappered_method 132 | self._fed_call_holder = FedCallHolder(node_party, self._execute_impl) 133 | 134 | def remote(self, *args, **kwargs) -> FedObject: 135 | return self._fed_call_holder.internal_remote(*args, **kwargs) 136 | 137 | def options(self, **options): 138 | self._options = options 139 | self._fed_call_holder.options(**options) 140 | return self 141 | 142 | def _execute_impl(self, args, kwargs): 143 | return self._fed_actor_handle._execute_remote_method( 144 | self._method_name, self._options, self._ray_wrappered_method, args, kwargs 145 | ) 146 | -------------------------------------------------------------------------------- /fed/tests/test_grpc_options_on_proxies.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | 17 | import pytest 18 | import ray 19 | 20 | import fed 21 | import fed._private.compatible_utils as compatible_utils 22 | from fed.proxy.barriers import receiver_proxy_actor_name, sender_proxy_actor_name 23 | 24 | 25 | @fed.remote 26 | def dummpy(): 27 | return 2 28 | 29 | 30 | def run(party): 31 | compatible_utils.init_ray(address='local') 32 | addresses = { 33 | 'alice': '127.0.0.1:11019', 34 | 'bob': '127.0.0.1:11018', 35 | } 36 | fed.init( 37 | addresses=addresses, 38 | party=party, 39 | config={ 40 | "cross_silo_comm": { 41 | "grpc_channel_options": [('grpc.max_send_message_length', 100)], 42 | }, 43 | }, 44 | ) 45 | 46 | def _assert_on_proxy(proxy_actor): 47 | config = ray.get(proxy_actor._get_proxy_config.remote()) 48 | options = config['grpc_options'] 49 | assert ("grpc.max_send_message_length", 100) in options 50 | assert ('grpc.so_reuseport', 0) in options 51 | 52 | sender_proxy = ray.get_actor(sender_proxy_actor_name()) 53 | receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) 54 | _assert_on_proxy(sender_proxy) 55 | _assert_on_proxy(receiver_proxy) 56 | 57 | a = dummpy.party('alice').remote() 58 | b = dummpy.party('bob').remote() 59 | fed.get([a, b]) 60 | 61 | fed.shutdown() 62 | ray.shutdown() 63 | 64 | 65 | def test_grpc_max_size_by_channel_options(): 66 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 67 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 68 | p_alice.start() 69 | p_bob.start() 70 | p_alice.join() 71 | p_bob.join() 72 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 73 | 74 | 75 | def run2(party): 76 | compatible_utils.init_ray(address='local') 77 | addresses = { 78 | 'alice': '127.0.0.1:11019', 79 | 'bob': '127.0.0.1:11018', 80 | } 81 | fed.init( 82 | addresses=addresses, 83 | party=party, 84 | config={ 85 | "cross_silo_comm": { 86 | "messages_max_size_in_bytes": 100, 87 | }, 88 | }, 89 | ) 90 | 91 | def _assert_on_proxy(proxy_actor): 92 | config = ray.get(proxy_actor._get_proxy_config.remote()) 93 | options = config['grpc_options'] 94 | assert ("grpc.max_send_message_length", 100) in options 95 | assert ("grpc.max_receive_message_length", 100) in options 96 | assert ('grpc.so_reuseport', 0) in options 97 | 98 | sender_proxy = ray.get_actor(sender_proxy_actor_name()) 99 | receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) 100 | _assert_on_proxy(sender_proxy) 101 | _assert_on_proxy(receiver_proxy) 102 | 103 | a = dummpy.party('alice').remote() 104 | b = dummpy.party('bob').remote() 105 | fed.get([a, b]) 106 | 107 | fed.shutdown() 108 | ray.shutdown() 109 | 110 | 111 | def test_grpc_max_size_by_common_config(): 112 | p_alice = multiprocessing.Process(target=run2, args=('alice',)) 113 | p_bob = multiprocessing.Process(target=run2, args=('bob',)) 114 | p_alice.start() 115 | p_bob.start() 116 | p_alice.join() 117 | p_bob.join() 118 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 119 | 120 | 121 | def run3(party): 122 | compatible_utils.init_ray(address='local') 123 | addresses = { 124 | 'alice': '127.0.0.1:11019', 125 | 'bob': '127.0.0.1:11018', 126 | } 127 | fed.init( 128 | addresses=addresses, 129 | party=party, 130 | config={ 131 | "cross_silo_comm": { 132 | "messages_max_size_in_bytes": 100, 133 | "grpc_channel_options": [ 134 | ('grpc.max_send_message_length', 200), 135 | ], 136 | }, 137 | }, 138 | ) 139 | 140 | def _assert_on_proxy(proxy_actor): 141 | config = ray.get(proxy_actor._get_proxy_config.remote()) 142 | options = config['grpc_options'] 143 | assert ("grpc.max_send_message_length", 200) in options 144 | assert ("grpc.max_receive_message_length", 100) in options 145 | assert ('grpc.so_reuseport', 0) in options 146 | 147 | sender_proxy = ray.get_actor(sender_proxy_actor_name()) 148 | receiver_proxy = ray.get_actor(receiver_proxy_actor_name()) 149 | _assert_on_proxy(sender_proxy) 150 | _assert_on_proxy(receiver_proxy) 151 | 152 | a = dummpy.party('alice').remote() 153 | b = dummpy.party('bob').remote() 154 | fed.get([a, b]) 155 | 156 | fed.shutdown() 157 | ray.shutdown() 158 | 159 | 160 | def test_grpc_max_size_by_both_config(): 161 | p_alice = multiprocessing.Process(target=run3, args=('alice',)) 162 | p_bob = multiprocessing.Process(target=run3, args=('bob',)) 163 | p_alice.start() 164 | p_bob.start() 165 | p_alice.join() 166 | p_bob.join() 167 | assert p_alice.exitcode == 0 and p_bob.exitcode == 0 168 | 169 | 170 | if __name__ == "__main__": 171 | import sys 172 | 173 | sys.exit(pytest.main(["-sv", __file__])) 174 | -------------------------------------------------------------------------------- /fed/_private/compatible_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import abc 16 | 17 | import ray 18 | import ray.experimental.internal_kv as ray_internal_kv 19 | 20 | import fed._private.constants as fed_constants 21 | from fed._private import constants 22 | 23 | 24 | def _compare_version_strings(version1, version2): 25 | """ 26 | This utility function compares two version strings and returns 27 | True if version1 is greater, and False if they're equal, and 28 | False if version2 is greater. 29 | """ 30 | v1_list = version1.split('.') 31 | v2_list = version2.split('.') 32 | len1 = len(v1_list) 33 | len2 = len(v2_list) 34 | 35 | for i in range(min(len1, len2)): 36 | if v1_list[i] == v2_list[i]: 37 | continue 38 | else: 39 | break 40 | 41 | return int(v1_list[i]) > int(v2_list[i]) 42 | 43 | 44 | def _ray_version_less_than_2_0_0(): 45 | """Whther the current ray version is less 2.0.0.""" 46 | return _compare_version_strings( 47 | fed_constants.RAY_VERSION_2_0_0_STR, ray.__version__ 48 | ) 49 | 50 | 51 | def init_ray(address: str = None, **kwargs): 52 | """A compatible API to init Ray.""" 53 | if address == 'local' and _ray_version_less_than_2_0_0(): 54 | # Ignore the `local` when ray < 2.0.0 55 | ray.init(**kwargs) 56 | else: 57 | ray.init(address=address, **kwargs) 58 | 59 | 60 | def _get_gcs_address_from_ray_worker(): 61 | """A compatible API to get the gcs address from the ray worker module.""" 62 | try: 63 | return ray._private.worker._global_node.gcs_address 64 | except AttributeError: 65 | return ray.worker._global_node.gcs_address 66 | 67 | 68 | def wrap_kv_key(job_name, key: str): 69 | """Add an prefix to the key to avoid conflict with other jobs.""" 70 | assert isinstance( 71 | key, str 72 | ), f"The key of KV data must be `str` type, got {type(key)}." 73 | 74 | return constants.RAYFED_JOB_KV_DATA_KEY_FMT.format(job_name, key) 75 | 76 | 77 | class AbstractInternalKv(abc.ABC): 78 | """An abstract class that represents for bridging Ray internal kv in 79 | both Ray client mode and non Ray client mode. 80 | """ 81 | 82 | def __init__(self) -> None: 83 | pass 84 | 85 | @abc.abstractmethod 86 | def initialize(self): 87 | pass 88 | 89 | @abc.abstractmethod 90 | def put(self, k, v): 91 | pass 92 | 93 | @abc.abstractmethod 94 | def get(self, k): 95 | pass 96 | 97 | @abc.abstractmethod 98 | def delete(self, k): 99 | pass 100 | 101 | @abc.abstractmethod 102 | def reset(self): 103 | pass 104 | 105 | 106 | class InternalKv(AbstractInternalKv): 107 | """The internal kv class for non Ray client mode.""" 108 | 109 | def __init__(self, job_name: str) -> None: 110 | super().__init__() 111 | self._job_name = job_name 112 | 113 | def initialize(self): 114 | try: 115 | from ray._private.gcs_utils import GcsClient 116 | except ImportError: 117 | # The GcsClient was moved to `ray._raylet` module in `ray-2.5.0`. 118 | assert _compare_version_strings(ray.__version__, "2.4.0") 119 | from ray._raylet import GcsClient 120 | 121 | gcs_client = GcsClient( 122 | address=_get_gcs_address_from_ray_worker(), nums_reconnect_retry=10 123 | ) 124 | return ray_internal_kv._initialize_internal_kv(gcs_client) 125 | 126 | def put(self, k, v): 127 | return ray_internal_kv._internal_kv_put(wrap_kv_key(self._job_name, k), v) 128 | 129 | def get(self, k): 130 | return ray_internal_kv._internal_kv_get(wrap_kv_key(self._job_name, k)) 131 | 132 | def delete(self, k): 133 | return ray_internal_kv._internal_kv_del(wrap_kv_key(self._job_name, k)) 134 | 135 | def reset(self): 136 | return ray_internal_kv._internal_kv_reset() 137 | 138 | def _ping(self): 139 | return "pong" 140 | 141 | 142 | class ClientModeInternalKv(AbstractInternalKv): 143 | """The internal kv class for Ray client mode.""" 144 | 145 | def __init__(self) -> None: 146 | super().__init__() 147 | self._client_api = ray.util.client.ray 148 | 149 | def initialize(self): 150 | # Note(NKcqx): internval_kv is always initiated after `ray.init`, 151 | # calling this is equal to directly return "True" 152 | return self._client_api._internal_kv_initialized() 153 | 154 | def put(self, k, v, overwrite=True): 155 | return self._client_api._internal_kv_put(k, v, overwrite) 156 | 157 | def get(self, k): 158 | return self._client_api._internal_kv_get(k) 159 | 160 | def delete(self, k): 161 | return self._client_api._internal_kv_del(k) 162 | 163 | def reset(self): 164 | # Note(NKcqx): No `gcs_client` is instantiated for kv, and the 'initialized' 165 | # flag is also reset after `ray.shutdown`, so calling `reset` will do nothing 166 | pass 167 | 168 | 169 | def _init_internal_kv(job_name): 170 | """An internal API that initialize the internal kv object.""" 171 | global kv 172 | if kv is None: 173 | from ray._private.client_mode_hook import is_client_mode_enabled 174 | 175 | kv = ClientModeInternalKv() if is_client_mode_enabled else InternalKv(job_name) 176 | kv.initialize() 177 | 178 | 179 | def _clear_internal_kv(): 180 | global kv 181 | if kv is not None: 182 | kv.delete(constants.KEY_OF_CLUSTER_CONFIG) 183 | kv.delete(constants.KEY_OF_JOB_CONFIG) 184 | kv.reset() 185 | kv = None 186 | 187 | 188 | kv = None 189 | -------------------------------------------------------------------------------- /fed/grpc/pb3/fed_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # -*- coding: utf-8 -*- 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # source: fed.proto 18 | """Generated protocol buffer code.""" 19 | from google.protobuf import descriptor as _descriptor 20 | from google.protobuf import message as _message 21 | from google.protobuf import reflection as _reflection 22 | from google.protobuf import symbol_database as _symbol_database 23 | # @@protoc_insertion_point(imports) 24 | 25 | _sym_db = _symbol_database.Default() 26 | 27 | 28 | 29 | 30 | DESCRIPTOR = _descriptor.FileDescriptor( 31 | name='fed.proto', 32 | package='', 33 | syntax='proto3', 34 | serialized_options=b'\200\001\001', 35 | create_key=_descriptor._internal_create_key, 36 | serialized_pb=b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"0\n\x10SendDataResponse\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0e\n\x06result\x18\x02 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3' 37 | ) 38 | 39 | 40 | 41 | 42 | _SENDDATAREQUEST = _descriptor.Descriptor( 43 | name='SendDataRequest', 44 | full_name='SendDataRequest', 45 | filename=None, 46 | file=DESCRIPTOR, 47 | containing_type=None, 48 | create_key=_descriptor._internal_create_key, 49 | fields=[ 50 | _descriptor.FieldDescriptor( 51 | name='data', full_name='SendDataRequest.data', index=0, 52 | number=1, type=12, cpp_type=9, label=1, 53 | has_default_value=False, default_value=b"", 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 57 | _descriptor.FieldDescriptor( 58 | name='upstream_seq_id', full_name='SendDataRequest.upstream_seq_id', index=1, 59 | number=2, type=9, cpp_type=9, label=1, 60 | has_default_value=False, default_value=b"".decode('utf-8'), 61 | message_type=None, enum_type=None, containing_type=None, 62 | is_extension=False, extension_scope=None, 63 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 64 | _descriptor.FieldDescriptor( 65 | name='downstream_seq_id', full_name='SendDataRequest.downstream_seq_id', index=2, 66 | number=3, type=9, cpp_type=9, label=1, 67 | has_default_value=False, default_value=b"".decode('utf-8'), 68 | message_type=None, enum_type=None, containing_type=None, 69 | is_extension=False, extension_scope=None, 70 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 71 | _descriptor.FieldDescriptor( 72 | name='job_name', full_name='SendDataRequest.job_name', index=3, 73 | number=4, type=9, cpp_type=9, label=1, 74 | has_default_value=False, default_value=b"".decode('utf-8'), 75 | message_type=None, enum_type=None, containing_type=None, 76 | is_extension=False, extension_scope=None, 77 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 78 | ], 79 | extensions=[ 80 | ], 81 | nested_types=[], 82 | enum_types=[ 83 | ], 84 | serialized_options=None, 85 | is_extendable=False, 86 | syntax='proto3', 87 | extension_ranges=[], 88 | oneofs=[ 89 | ], 90 | serialized_start=13, 91 | serialized_end=114, 92 | ) 93 | 94 | 95 | _SENDDATARESPONSE = _descriptor.Descriptor( 96 | name='SendDataResponse', 97 | full_name='SendDataResponse', 98 | filename=None, 99 | file=DESCRIPTOR, 100 | containing_type=None, 101 | create_key=_descriptor._internal_create_key, 102 | fields=[ 103 | _descriptor.FieldDescriptor( 104 | name='code', full_name='SendDataResponse.code', index=0, 105 | number=1, type=5, cpp_type=1, label=1, 106 | has_default_value=False, default_value=0, 107 | message_type=None, enum_type=None, containing_type=None, 108 | is_extension=False, extension_scope=None, 109 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 110 | _descriptor.FieldDescriptor( 111 | name='result', full_name='SendDataResponse.result', index=1, 112 | number=2, type=9, cpp_type=9, label=1, 113 | has_default_value=False, default_value=b"".decode('utf-8'), 114 | message_type=None, enum_type=None, containing_type=None, 115 | is_extension=False, extension_scope=None, 116 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 117 | ], 118 | extensions=[ 119 | ], 120 | nested_types=[], 121 | enum_types=[ 122 | ], 123 | serialized_options=None, 124 | is_extendable=False, 125 | syntax='proto3', 126 | extension_ranges=[], 127 | oneofs=[ 128 | ], 129 | serialized_start=116, 130 | serialized_end=164, 131 | ) 132 | 133 | DESCRIPTOR.message_types_by_name['SendDataRequest'] = _SENDDATAREQUEST 134 | DESCRIPTOR.message_types_by_name['SendDataResponse'] = _SENDDATARESPONSE 135 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 136 | 137 | SendDataRequest = _reflection.GeneratedProtocolMessageType('SendDataRequest', (_message.Message,), { 138 | 'DESCRIPTOR' : _SENDDATAREQUEST, 139 | '__module__' : 'fed_pb2' 140 | # @@protoc_insertion_point(class_scope:SendDataRequest) 141 | }) 142 | _sym_db.RegisterMessage(SendDataRequest) 143 | 144 | SendDataResponse = _reflection.GeneratedProtocolMessageType('SendDataResponse', (_message.Message,), { 145 | 'DESCRIPTOR' : _SENDDATARESPONSE, 146 | '__module__' : 'fed_pb2' 147 | # @@protoc_insertion_point(class_scope:SendDataResponse) 148 | }) 149 | _sym_db.RegisterMessage(SendDataResponse) 150 | 151 | 152 | DESCRIPTOR._options = None 153 | 154 | _GRPCSERVICE = _descriptor.ServiceDescriptor( 155 | name='GrpcService', 156 | full_name='GrpcService', 157 | file=DESCRIPTOR, 158 | index=0, 159 | serialized_options=None, 160 | create_key=_descriptor._internal_create_key, 161 | serialized_start=166, 162 | serialized_end=230, 163 | methods=[ 164 | _descriptor.MethodDescriptor( 165 | name='SendData', 166 | full_name='GrpcService.SendData', 167 | index=0, 168 | containing_service=None, 169 | input_type=_SENDDATAREQUEST, 170 | output_type=_SENDDATARESPONSE, 171 | serialized_options=None, 172 | create_key=_descriptor._internal_create_key, 173 | ), 174 | ]) 175 | _sym_db.RegisterServiceDescriptor(_GRPCSERVICE) 176 | 177 | DESCRIPTOR.services_by_name['GrpcService'] = _GRPCSERVICE 178 | 179 | # @@protoc_insertion_point(module_scope) 180 | -------------------------------------------------------------------------------- /fed/config.py: -------------------------------------------------------------------------------- 1 | """This module should be cached locally due to all configurations 2 | are mutable. 3 | """ 4 | 5 | import json 6 | from dataclasses import dataclass, fields 7 | from typing import Dict, List, Optional 8 | 9 | import cloudpickle 10 | 11 | import fed._private.compatible_utils as compatible_utils 12 | import fed._private.constants as fed_constants 13 | 14 | 15 | class ClusterConfig: 16 | """A local cache of cluster configuration items.""" 17 | 18 | def __init__(self, raw_bytes: bytes) -> None: 19 | self._data = cloudpickle.loads(raw_bytes) 20 | 21 | @property 22 | def cluster_addresses(self): 23 | return self._data[fed_constants.KEY_OF_CLUSTER_ADDRESSES] 24 | 25 | @property 26 | def current_party(self): 27 | return self._data[fed_constants.KEY_OF_CURRENT_PARTY_NAME] 28 | 29 | @property 30 | def tls_config(self): 31 | return self._data[fed_constants.KEY_OF_TLS_CONFIG] 32 | 33 | 34 | class JobConfig: 35 | def __init__(self, raw_bytes: bytes) -> None: 36 | if raw_bytes is None: 37 | self._data = {} 38 | else: 39 | self._data = cloudpickle.loads(raw_bytes) 40 | 41 | @property 42 | def cross_silo_comm_config_dict(self) -> Dict: 43 | return self._data.get(fed_constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT, {}) 44 | 45 | 46 | # A module level cache for the cluster configurations. 47 | _cluster_config = None 48 | 49 | _job_config = None 50 | 51 | 52 | def get_cluster_config(job_name: str = None) -> ClusterConfig: 53 | """This function is not thread safe to use.""" 54 | global _cluster_config 55 | if _cluster_config is None: 56 | assert ( 57 | job_name is not None 58 | ), "Initializing internal kv need to provide job_name." 59 | compatible_utils._init_internal_kv(job_name) 60 | raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_CLUSTER_CONFIG) 61 | _cluster_config = ClusterConfig(raw_dict) 62 | return _cluster_config 63 | 64 | 65 | def get_job_config(job_name: str = None) -> JobConfig: 66 | """This config still acts like cluster config for now""" 67 | global _job_config 68 | if _job_config is None: 69 | assert ( 70 | job_name is not None 71 | ), "Initializing internal kv need to provide job_name." 72 | compatible_utils._init_internal_kv(job_name) 73 | raw_dict = compatible_utils.kv.get(fed_constants.KEY_OF_JOB_CONFIG) 74 | _job_config = JobConfig(raw_dict) 75 | return _job_config 76 | 77 | 78 | @dataclass 79 | class CrossSiloMessageConfig: 80 | """A class to store parameters used for Proxy Actor. 81 | 82 | Attributes: 83 | proxy_max_restarts: 84 | The max restart times for the send proxy. 85 | serializing_allowed_list: 86 | The package or class list allowed for 87 | serializing(deserializating) cross silos. It's used for avoiding pickle 88 | deserializing execution attack when crossing silos. 89 | send_resource_label: 90 | Customized resource label, the SenderProxyActor 91 | will be scheduled based on the declared resource label. For example, 92 | when setting to `{"my_label": 1}`, then the sender proxy actor will be 93 | started only on nodes with `{"resource": {"my_label": $NUM}}` where 94 | $NUM >= 1. 95 | recv_resource_label: 96 | Customized resource label, the ReceiverProxyActor 97 | will be scheduled based on the declared resource label. For example, 98 | when setting to `{"my_label": 1}`, then the receiver proxy actor will be 99 | started only on nodes with `{"resource": {"my_label": $NUM}}` where 100 | $NUM >= 1. 101 | exit_on_sending_failure: 102 | whether exit when failure on cross-silo sending. If True, a SIGINT will be 103 | signaled to self if failed to sending cross-silo data and exit then. 104 | continue_waiting_for_data_sending_on_error: 105 | Whether to continue waiting for data sending if an error occurs, including 106 | data-sending errors and receiving errors from the peer. If True, wait until 107 | all data has been sent. 108 | messages_max_size_in_bytes: 109 | The maximum length in bytes of cross-silo messages. If None, the default 110 | value of 500 MB is specified. 111 | timeout_in_ms: 112 | The timeout in mili-seconds of a cross-silo RPC call. It's 60000 by 113 | default. 114 | http_header: 115 | The HTTP header, e.g. metadata in grpc, sent with the RPC request. 116 | This won't override basic tcp headers, such as `user-agent`, but concat 117 | them together. 118 | max_concurrency: 119 | the max_concurrency of the sender/receiver proxy actor. 120 | use_global_proxy: 121 | Whether using the global proxy actor or create new proxy actor for current 122 | job. 123 | """ 124 | 125 | proxy_max_restarts: int = None 126 | timeout_in_ms: int = 60000 127 | messages_max_size_in_bytes: int = None 128 | exit_on_sending_failure: Optional[bool] = False 129 | continue_waiting_for_data_sending_on_error: Optional[bool] = False 130 | serializing_allowed_list: Optional[Dict[str, str]] = None 131 | send_resource_label: Optional[Dict[str, str]] = None 132 | recv_resource_label: Optional[Dict[str, str]] = None 133 | http_header: Optional[Dict[str, str]] = None 134 | max_concurrency: Optional[int] = None 135 | expose_error_trace: Optional[bool] = False 136 | use_global_proxy: Optional[bool] = True 137 | 138 | def __json__(self): 139 | return json.dumps(self.__dict__) 140 | 141 | @classmethod 142 | def from_json(cls, json_str): 143 | data = json.loads(json_str) 144 | return cls(**data) 145 | 146 | @classmethod 147 | def from_dict(cls, data: Dict) -> 'CrossSiloMessageConfig': 148 | """Initialize CrossSiloMessageConfig from a dictionary. 149 | 150 | Args: 151 | data (Dict): Dictionary with keys as member variable names. 152 | 153 | Returns: 154 | CrossSiloMessageConfig: An instance of CrossSiloMessageConfig. 155 | """ 156 | # Get the attributes of the class 157 | data = data or {} 158 | attrs = [field.name for field in fields(cls)] 159 | # Filter the dictionary to only include keys that are attributes of the class 160 | filtered_data = {key: value for key, value in data.items() if key in attrs} 161 | return cls(**filtered_data) 162 | 163 | 164 | @dataclass 165 | class GrpcCrossSiloMessageConfig(CrossSiloMessageConfig): 166 | """A class to store parameters used for GRPC communication 167 | 168 | Attributes: 169 | grpc_retry_policy: 170 | a dict descibes the retry policy for cross silo rpc call. If None, the 171 | following default retry policy will be used. More details please refer to 172 | `retry-policy `_. # noqa 173 | 174 | .. code:: python 175 | 176 | { 177 | "maxAttempts": 4, 178 | "initialBackoff": "0.1s", 179 | "maxBackoff": "1s", 180 | "backoffMultiplier": 2, 181 | "retryableStatusCodes": [ 182 | "UNAVAILABLE" 183 | ] 184 | } 185 | grpc_channel_options: A list of tuples to store GRPC channel options, e.g. 186 | .. code:: python 187 | 188 | [ 189 | ('grpc.enable_retries', 1), 190 | ('grpc.max_send_message_length', 50 * 1024 * 1024) 191 | ] 192 | """ 193 | 194 | grpc_channel_options: List = None 195 | grpc_retry_policy: Dict[str, str] = None 196 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RayFed 2 | ![docs building](https://readthedocs.org/projects/rayfed/badge/?version=latest) ![test on many rays](https://github.com/ray-project/rayfed/actions/workflows/unit_tests_on_ray_matrix.yml/badge.svg) ![test on ray 1.13.0](https://github.com/ray-project/rayfed/actions/workflows/test_on_ray1.13.0.yml/badge.svg) 3 | 4 | A multiple parties joint, distributed execution engine based on Ray, to help build your own federated learning frameworks in minutes. 5 | 6 | ## Overview 7 | **Note: This project is now in actively developing.** 8 | 9 | RayFed is a distributed computing framework for cross-parties federated learning. 10 | Built in the Ray ecosystem, RayFed provides a Ray native programming pattern for federated learning so that users can build a distributed program easily. 11 | 12 | It provides users the role of "party", thus users can write code belonging to the specific party explicitly imposing more clear data perimeters. These codes will be restricted to execute within the party. 13 | 14 | As for the code execution, RayFed introduces the multi-controller architecture: 15 | The code view in each party is exactly the same, but the execution differs based on the declared party of code and the current party of executor. 16 | 17 | 18 | 19 | ## Features 20 | - **Ray Native Programming Pattern** 21 | 22 | Let you write your federated and distributed computing applications like a single-machine program. 23 | 24 | - **Multiple Controller Execution Mode** 25 | 26 | The RayFed job can be run in the single-controller mode for developing and debugging and the multiple-controller mode for production without code change. 27 | 28 | - **Very Restricted and Clear Data Perimeters** 29 | 30 | Because of the PUSH-BASED data transferring mechanism and multiple controller execution mode, the data transmission authority is held by the data owner rather than the data demander. 31 | 32 | - **Very Large Scale Federated Computing and Training** 33 | 34 | Powered by the scalabilities and the distributed abilities from Ray, large scale federated computing and training jobs are naturally supported. 35 | 36 | 37 | ## Supported Ray Versions 38 | Due to Ray's aggressive release strategy, Rayfed only supports the last 5 Ray versions. 39 | | RayFed Versions | ray-1.13.0 | ray-2.31.0 | ray-2.32.0 | ray-2.33.0 | ray-2.34.0 | ray-2.35.0 | 40 | |:---------------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:| 41 | | 0.1.0 |✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 42 | | 0.2.0 |not released|not released|not released|not released|not released|not released| 43 | 44 | 45 | ## Installation 46 | Install it from pypi. 47 | 48 | ```shell 49 | pip install -U rayfed 50 | ``` 51 | 52 | Install the nightly released version from pypi. 53 | 54 | ```shell 55 | pip install -U rayfed-nightly 56 | ``` 57 | ## Quick Start 58 | 59 | This example shows how to aggregate values across two participators. 60 | 61 | ### Step 1: Write an Actor that Generates Value 62 | The `MyActor` increment its value by `num`. 63 | This actor will be executed within the explicitly declared party. 64 | 65 | ```python 66 | import sys 67 | import ray 68 | import fed 69 | 70 | @fed.remote 71 | class MyActor: 72 | def __init__(self, value): 73 | self.value = value 74 | 75 | def inc(self, num): 76 | self.value = self.value + num 77 | return self.value 78 | ``` 79 | ### Step 2: Define Aggregation Function 80 | The below function collects and aggragates values from two parties separately, and will also be executed within the declared party. 81 | 82 | ```python 83 | @fed.remote 84 | def aggregate(val1, val2): 85 | return val1 + val2 86 | ``` 87 | 88 | ### Step 3: Create the actor and call methods in a specific party 89 | 90 | The creation code is similar with `Ray`, however, the difference is that in `RayFed` the actor must be explicitly created within a party: 91 | 92 | ```python 93 | actor_alice = MyActor.party("alice").remote(1) 94 | actor_bob = MyActor.party("bob").remote(1) 95 | 96 | val_alice = actor_alice.inc.remote(1) 97 | val_bob = actor_bob.inc.remote(2) 98 | 99 | sum_val_obj = aggregate.party("bob").remote(val_alice, val_bob) 100 | ``` 101 | The above codes: 102 | 1. Create two `MyActor`s separately in each party, i.e. 'alice' and 'bob'; 103 | 2. Increment by '1' in alice and '2' in 'bob'; 104 | 3. Execute the aggregation function in party 'bob'. 105 | 106 | ### Step 4: Declare Cross-party Cluster & Init 107 | ```python 108 | def main(party): 109 | ray.init(address='local', include_dashboard=False) 110 | 111 | addresses = { 112 | 'alice': '127.0.0.1:11012', 113 | 'bob': '127.0.0.1:11011', 114 | } 115 | fed.init(addresses=addresses, party=party) 116 | ``` 117 | This first declares a two-party cluster, whose addresses corresponding to '127.0.0.1:11012' in 'alice' and '127.0.0.1:11011' in 'bob'. 118 | And then, the `fed.init` create a cluster in the specified party. 119 | Note that `fed.init` should be called twice, passing in the different party each time. 120 | 121 | When executing codes in step 1~3, the 'alice' cluster will only execute functions whose "party" are also declared as 'alice'. 122 | 123 | ### Put it together ! 124 | Save below codes as `demo.py`: 125 | ```python 126 | import sys 127 | import ray 128 | import fed 129 | 130 | 131 | @fed.remote 132 | class MyActor: 133 | def __init__(self, value): 134 | self.value = value 135 | 136 | def inc(self, num): 137 | self.value = self.value + num 138 | return self.value 139 | 140 | 141 | @fed.remote 142 | def aggregate(val1, val2): 143 | return val1 + val2 144 | 145 | 146 | def main(party): 147 | ray.init(address='local', include_dashboard=False) 148 | 149 | addresses = { 150 | 'alice': '127.0.0.1:11012', 151 | 'bob': '127.0.0.1:11011', 152 | } 153 | fed.init(addresses=addresses, party=party) 154 | 155 | actor_alice = MyActor.party("alice").remote(1) 156 | actor_bob = MyActor.party("bob").remote(1) 157 | 158 | val_alice = actor_alice.inc.remote(1) 159 | val_bob = actor_bob.inc.remote(2) 160 | 161 | sum_val_obj = aggregate.party("bob").remote(val_alice, val_bob) 162 | result = fed.get(sum_val_obj) 163 | print(f"The result in party {party} is {result}") 164 | 165 | fed.shutdown() 166 | ray.shutdown() 167 | 168 | 169 | if __name__ == "__main__": 170 | assert len(sys.argv) == 2, 'Please run this script with party.' 171 | main(sys.argv[1]) 172 | 173 | ``` 174 | 175 | ### Run The Code. 176 | 177 | Open a terminal and run the code as `alice`. It's recommended to run the code with Ray TLS enabled (please refer to [Ray TLS](https://docs.ray.io/en/latest/ray-core/configure.html#tls-authentication)) 178 | ```shell 179 | RAY_USE_TLS=1 \ 180 | RAY_TLS_SERVER_CERT='/path/to/the/server/cert/file' \ 181 | RAY_TLS_SERVER_KEY='/path/to/the/server/key/file' \ 182 | RAY_TLS_CA_CERT='/path/to/the/ca/cert/file' \ 183 | python test.py alice 184 | ``` 185 | 186 | In the mean time, open another terminal and run the code as `bob`. 187 | ```shell 188 | RAY_USE_TLS=1 \ 189 | RAY_TLS_SERVER_CERT='/path/to/the/server/cert/file' \ 190 | RAY_TLS_SERVER_KEY='/path/to/the/server/key/file' \ 191 | RAY_TLS_CA_CERT='/path/to/the/ca/cert/file' \ 192 | python test.py bob 193 | ``` 194 | 195 | Then you will get `The result in party alice is 5` on the first terminal screen and `The result in party bob is 5` on the second terminal screen. 196 | 197 | Figure shows the execution under the hood: 198 |
199 | Figure 200 |
201 | ## Running untrusted codes 202 | As a general rule: Always execute untrusted codes inside a sandbox (e.g., [nsjail](https://github.com/google/nsjail)). 203 | 204 | ## Who use us 205 | 206 | Ant Chain Morse 207 | 208 | 209 | SecretFlow 210 | 211 | -------------------------------------------------------------------------------- /fed/tests/test_transport_proxy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import asyncio 16 | 17 | import cloudpickle 18 | import grpc 19 | import pytest 20 | import ray 21 | 22 | import fed._private.compatible_utils as compatible_utils 23 | import fed.utils as fed_utils 24 | from fed._private import constants, global_context 25 | from fed.proxy.barriers import ( 26 | _start_receiver_proxy, 27 | _start_sender_proxy, 28 | receiver_proxy_actor_name, 29 | send, 30 | ) 31 | from fed.proxy.grpc.grpc_proxy import GrpcReceiverProxy, GrpcSenderProxy 32 | 33 | if compatible_utils._compare_version_strings( 34 | fed_utils.get_package_version("protobuf"), "4.0.0" 35 | ): 36 | from fed.grpc.pb4 import fed_pb2 as fed_pb2 37 | from fed.grpc.pb4 import fed_pb2_grpc as fed_pb2_grpc 38 | else: 39 | from fed.grpc.pb3 import fed_pb2 as fed_pb2 40 | from fed.grpc.pb3 import fed_pb2_grpc as fed_pb2_grpc 41 | 42 | 43 | def test_n_to_1_transport(): 44 | """This case is used to test that we have N send_op barriers, 45 | sending data to the target receiver proxy, and there also have 46 | N receivers to `get_data` from receiver proxy at that time. 47 | """ 48 | compatible_utils.init_ray(address="local") 49 | test_job_name = "test_n_to_1_transport" 50 | party = "test_party" 51 | global_context.init_global_context(party, test_job_name, False, False) 52 | global_context.get_global_context().get_cleanup_manager().start() 53 | cluster_config = { 54 | constants.KEY_OF_CLUSTER_ADDRESSES: "", 55 | constants.KEY_OF_CURRENT_PARTY_NAME: "", 56 | constants.KEY_OF_TLS_CONFIG: "", 57 | } 58 | compatible_utils._init_internal_kv(test_job_name) 59 | compatible_utils.kv.put( 60 | constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) 61 | ) 62 | 63 | NUM_DATA = 10 64 | SERVER_ADDRESS = "127.0.0.1:12344" 65 | 66 | addresses = {"test_party": SERVER_ADDRESS} 67 | _start_receiver_proxy( 68 | addresses, 69 | party, 70 | logging_level="info", 71 | proxy_cls=GrpcReceiverProxy, 72 | proxy_config={}, 73 | ) 74 | _start_sender_proxy( 75 | addresses, 76 | party, 77 | logging_level="info", 78 | proxy_cls=GrpcSenderProxy, 79 | proxy_config={}, 80 | ) 81 | 82 | sent_objs = [] 83 | get_objs = [] 84 | receiver_proxy_actor = ray.get_actor(receiver_proxy_actor_name()) 85 | for i in range(NUM_DATA): 86 | sent_obj = send(party, f"data-{i}", i, i + 1) 87 | sent_objs.append(sent_obj) 88 | get_obj = receiver_proxy_actor.get_data.remote(party, i, i + 1) 89 | get_objs.append(get_obj) 90 | for result in ray.get(sent_objs): 91 | assert result 92 | 93 | for i in range(NUM_DATA): 94 | assert f"data-{i}" in ray.get(get_objs) 95 | 96 | global_context.get_global_context().get_cleanup_manager().stop() 97 | global_context.clear_global_context() 98 | compatible_utils._clear_internal_kv() 99 | ray.shutdown() 100 | 101 | 102 | class TestSendDataService(fed_pb2_grpc.GrpcServiceServicer): 103 | def __init__( 104 | self, all_events, all_data, party, lock, expected_metadata, expected_jobname 105 | ): 106 | self.expected_metadata = expected_metadata or {} 107 | self._expected_jobname = expected_jobname or "" 108 | 109 | async def SendData(self, request, context): 110 | job_name = request.job_name 111 | assert self._expected_jobname == job_name 112 | metadata = dict(context.invocation_metadata()) 113 | for k, v in self.expected_metadata.items(): 114 | assert ( 115 | k in metadata 116 | ), f"The expected key {k} is not in the metadata keys: {metadata.keys()}." 117 | assert v == metadata[k] 118 | event = asyncio.Event() 119 | event.set() 120 | return fed_pb2.SendDataResponse(code=200, result="OK") 121 | 122 | 123 | async def _test_run_grpc_server( 124 | port, 125 | event, 126 | all_data, 127 | party, 128 | lock, 129 | grpc_options=None, 130 | expected_metadata=None, 131 | expected_jobname=None, 132 | ): 133 | server = grpc.aio.server(options=grpc_options) 134 | fed_pb2_grpc.add_GrpcServiceServicer_to_server( 135 | TestSendDataService( 136 | event, all_data, party, lock, expected_metadata, expected_jobname 137 | ), 138 | server, 139 | ) 140 | server.add_insecure_port(f"[::]:{port}") 141 | await server.start() 142 | await server.wait_for_termination() 143 | 144 | 145 | @ray.remote 146 | class TestReceiverProxyActor: 147 | def __init__( 148 | self, 149 | listen_addr: str, 150 | party: str, 151 | expected_metadata: dict, 152 | expected_jobname: str, 153 | ): 154 | self._listen_addr = listen_addr 155 | self._party = party 156 | self._expected_metadata = expected_metadata 157 | self._expected_jobname = expected_jobname 158 | 159 | async def run_grpc_server(self): 160 | return await _test_run_grpc_server( 161 | self._listen_addr[self._listen_addr.index(":") + 1 :], 162 | None, 163 | None, 164 | self._party, 165 | None, 166 | expected_metadata=self._expected_metadata, 167 | expected_jobname=self._expected_jobname, 168 | ) 169 | 170 | async def is_ready(self): 171 | return True 172 | 173 | 174 | def _test_start_receiver_proxy( 175 | addresses: str, 176 | party: str, 177 | expected_metadata: dict, 178 | expected_jobname: str, 179 | ): 180 | # Create RecevrProxyActor 181 | # Not that this is now a threaded actor. 182 | address = addresses[party] 183 | receiver_proxy_actor = TestReceiverProxyActor.options( 184 | name=receiver_proxy_actor_name(), max_concurrency=1000 185 | ).remote( 186 | listen_addr=address, 187 | party=party, 188 | expected_metadata=expected_metadata, 189 | expected_jobname=expected_jobname, 190 | ) 191 | receiver_proxy_actor.run_grpc_server.remote() 192 | assert ray.get(receiver_proxy_actor.is_ready.remote()) 193 | 194 | 195 | def test_send_grpc_with_meta(): 196 | compatible_utils.init_ray(address="local") 197 | cluster_config = { 198 | constants.KEY_OF_CLUSTER_ADDRESSES: "", 199 | constants.KEY_OF_CURRENT_PARTY_NAME: "", 200 | constants.KEY_OF_TLS_CONFIG: "", 201 | } 202 | metadata = {"key": "value"} 203 | config = {"http_header": metadata} 204 | job_config = { 205 | constants.KEY_OF_CROSS_SILO_COMM_CONFIG_DICT: config, 206 | } 207 | test_job_name = "test_send_grpc_with_meta" 208 | party_name = "test_party" 209 | global_context.init_global_context(party_name, test_job_name, False, False) 210 | compatible_utils._init_internal_kv(test_job_name) 211 | compatible_utils.kv.put( 212 | constants.KEY_OF_CLUSTER_CONFIG, cloudpickle.dumps(cluster_config) 213 | ) 214 | compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG, cloudpickle.dumps(job_config)) 215 | global_context.get_global_context().get_cleanup_manager().start() 216 | 217 | SERVER_ADDRESS = "127.0.0.1:12344" 218 | 219 | addresses = {party_name: SERVER_ADDRESS} 220 | _test_start_receiver_proxy( 221 | addresses, 222 | party_name, 223 | expected_metadata=metadata, 224 | expected_jobname=test_job_name, 225 | ) 226 | _start_sender_proxy( 227 | addresses, 228 | party_name, 229 | logging_level="info", 230 | proxy_cls=GrpcSenderProxy, 231 | proxy_config=config, 232 | ) 233 | sent_objs = [] 234 | sent_obj = send(party_name, "data", 0, 1) 235 | sent_objs.append(sent_obj) 236 | for result in ray.get(sent_objs): 237 | assert result 238 | 239 | global_context.get_global_context().get_cleanup_manager().stop() 240 | global_context.clear_global_context() 241 | ray.shutdown() 242 | 243 | 244 | if __name__ == "__main__": 245 | import sys 246 | 247 | sys.exit(pytest.main(["-sv", __file__])) 248 | -------------------------------------------------------------------------------- /fed/tree_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Most codes are copied from https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/_pytree.py # noqa 16 | 17 | from collections import OrderedDict, namedtuple 18 | from dataclasses import dataclass 19 | from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Type, TypeVar, cast 20 | 21 | T = TypeVar('T') 22 | S = TypeVar('S') 23 | U = TypeVar('U') 24 | R = TypeVar('R') 25 | 26 | """ 27 | Contains utility functions for working with nested python data structures. 28 | 29 | A *pytree* is Python nested data structure. It is a tree in the sense that 30 | nodes are Python collections (e.g., list, tuple, dict) and the leaves are 31 | Python values. Furthermore, a pytree should not contain reference cycles. 32 | 33 | pytrees are useful for working with nested collections of Tensors. For example, 34 | one can use `tree_map` to map a function over all Tensors inside some nested 35 | collection of Tensors and `tree_unflatten` to get a flat list of all Tensors 36 | inside some nested collection. pytrees are helpful for implementing nested 37 | collection support for PyTorch APIs. 38 | 39 | This pytree implementation is not very performant due to Python overhead 40 | To improve the performance we can move parts of the implementation to C++. 41 | """ 42 | 43 | # A NodeDef holds two callables: 44 | # - flatten_fn should take the collection and return a flat list of values. 45 | # It can also return some context that is used in reconstructing the 46 | # collection. 47 | # - unflatten_fn should take a flat list of values and some context 48 | # (returned by flatten_fn). It returns the collection by reconstructing 49 | # it from the list and the context. 50 | Context = Any 51 | PyTree = Any 52 | FlattenFunc = Callable[[PyTree], Tuple[List, Context]] 53 | UnflattenFunc = Callable[[List, Context], PyTree] 54 | 55 | 56 | class NodeDef(NamedTuple): 57 | flatten_fn: FlattenFunc 58 | unflatten_fn: UnflattenFunc 59 | 60 | 61 | SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} 62 | 63 | 64 | def _register_pytree_node( 65 | typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc 66 | ) -> None: 67 | SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) 68 | 69 | 70 | def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: 71 | return list(d.values()), list(d.keys()) 72 | 73 | 74 | def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: 75 | return dict(zip(context, values)) 76 | 77 | 78 | def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: 79 | return d, None 80 | 81 | 82 | def _list_unflatten(values: List[Any], context: Context) -> List[Any]: 83 | return list(values) 84 | 85 | 86 | def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: 87 | return list(d), None 88 | 89 | 90 | def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]: 91 | return tuple(values) 92 | 93 | 94 | def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: 95 | return list(d), type(d) 96 | 97 | 98 | def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple: 99 | return cast(NamedTuple, context(*values)) 100 | 101 | 102 | def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]: 103 | return list(d.values()), list(d.keys()) 104 | 105 | 106 | def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]': 107 | return OrderedDict((key, value) for key, value in zip(context, values)) 108 | 109 | 110 | _register_pytree_node(dict, _dict_flatten, _dict_unflatten) 111 | _register_pytree_node(list, _list_flatten, _list_unflatten) 112 | _register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten) 113 | _register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten) 114 | _register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) 115 | 116 | 117 | # h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple # noqa 118 | def _is_namedtuple_instance(pytree: Any) -> bool: 119 | typ = type(pytree) 120 | bases = typ.__bases__ 121 | if len(bases) != 1 or bases[0] != tuple: 122 | return False 123 | fields = getattr(typ, '_fields', None) 124 | if not isinstance(fields, tuple): 125 | return False 126 | return all(isinstance(entry, str) for entry in fields) 127 | 128 | 129 | def _get_node_type(pytree: Any) -> Any: 130 | if _is_namedtuple_instance(pytree): 131 | return namedtuple 132 | return type(pytree) 133 | 134 | 135 | # A leaf is defined as anything that is not a Node. 136 | def _is_leaf(pytree: PyTree) -> bool: 137 | return _get_node_type(pytree) not in SUPPORTED_NODES 138 | 139 | 140 | # A TreeSpec represents the structure of a pytree. It holds: 141 | # "type": the type of root Node of the pytree 142 | # context: some context that is useful in unflattening the pytree 143 | # children_specs: specs for each child of the root Node 144 | # num_leaves: the number of leaves 145 | @dataclass 146 | class TreeSpec: 147 | type: Any 148 | context: Context 149 | children_specs: List['TreeSpec'] 150 | 151 | def __post_init__(self) -> None: 152 | self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) 153 | 154 | def __repr__(self, indent: int = 0) -> str: 155 | repr_prefix: str = f'TreeSpec({self.type.__name__}, {self.context}, [' 156 | children_specs_str: str = '' 157 | if len(self.children_specs): 158 | indent += len(repr_prefix) 159 | children_specs_str += self.children_specs[0].__repr__(indent) 160 | children_specs_str += ',' if len(self.children_specs) > 1 else '' 161 | children_specs_str += ','.join( 162 | [ 163 | '\n' + ' ' * indent + child.__repr__(indent) 164 | for child in self.children_specs[1:] 165 | ] 166 | ) 167 | repr_suffix: str = f'{children_specs_str}])' 168 | return repr_prefix + repr_suffix 169 | 170 | 171 | class LeafSpec(TreeSpec): 172 | def __init__(self) -> None: 173 | super().__init__(None, None, []) 174 | self.num_leaves = 1 175 | 176 | def __repr__(self, indent: int = 0) -> str: 177 | return '*' 178 | 179 | 180 | def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: 181 | """Flattens a pytree into a list of values and a TreeSpec that can be used 182 | to reconstruct the pytree. 183 | """ 184 | if _is_leaf(pytree): 185 | return [pytree], LeafSpec() 186 | 187 | node_type = _get_node_type(pytree) 188 | flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 189 | child_pytrees, context = flatten_fn(pytree) 190 | 191 | # Recursively flatten the children 192 | result: List[Any] = [] 193 | children_specs: List['TreeSpec'] = [] 194 | for child in child_pytrees: 195 | flat, child_spec = tree_flatten(child) 196 | result += flat 197 | children_specs.append(child_spec) 198 | 199 | return result, TreeSpec(node_type, context, children_specs) 200 | 201 | 202 | def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: 203 | """Given a list of values and a TreeSpec, builds a pytree. 204 | This is the inverse operation of `tree_flatten`. 205 | """ 206 | if not isinstance(spec, TreeSpec): 207 | raise ValueError( 208 | f'tree_unflatten(values, spec): Expected `spec` to be instance of ' 209 | f'TreeSpec but got item of type {type(spec)}.' 210 | ) 211 | if len(values) != spec.num_leaves: 212 | raise ValueError( 213 | f'tree_unflatten(values, spec): `values` has length {len(values)} ' 214 | f'but the spec refers to a pytree that holds {spec.num_leaves} ' 215 | f'items ({spec}).' 216 | ) 217 | if isinstance(spec, LeafSpec): 218 | return values[0] 219 | 220 | unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn 221 | 222 | # Recursively unflatten the children 223 | start = 0 224 | end = 0 225 | child_pytrees = [] 226 | for child_spec in spec.children_specs: 227 | end += child_spec.num_leaves 228 | child_pytrees.append(tree_unflatten(values[start:end], child_spec)) 229 | start = end 230 | 231 | return unflatten_fn(child_pytrees, spec.context) 232 | -------------------------------------------------------------------------------- /fed/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | import signal 18 | import threading 19 | 20 | import ray 21 | from ray.exceptions import RayError 22 | 23 | from fed._private.message_queue import MessageQueueManager 24 | from fed.exceptions import FedRemoteError 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class CleanupManager: 30 | """ 31 | This class is used to manage the related works when the fed driver exiting. 32 | It monitors whether the main thread is broken and it needs wait until all sending 33 | objects get repsonsed. 34 | 35 | The main logic path is: 36 | A. If `fed.shutdown()` is invoked in the main thread and every thing works well, 37 | the `stop()` will be invoked as well and the checking thread will be 38 | notified to exit gracefully. 39 | 40 | B. If the main thread are broken before sending the stop flag to the sending 41 | thread, the monitor thread will detect that and then notifys the checking 42 | thread. 43 | """ 44 | 45 | def __init__(self, current_party, acquire_shutdown_flag) -> None: 46 | self._sending_data_q = MessageQueueManager( 47 | lambda msg: self._process_data_sending_task_return(msg), 48 | thread_name="DataSendingQueueThread", 49 | ) 50 | 51 | self._sending_error_q = MessageQueueManager( 52 | lambda msg: self._process_error_sending_task_return(msg), 53 | thread_name="ErrorSendingQueueThread", 54 | ) 55 | 56 | self._monitor_thread = None 57 | 58 | self._current_party = current_party 59 | self._acquire_shutdown_flag = acquire_shutdown_flag 60 | self._last_sending_error = None 61 | 62 | def start(self, exit_on_sending_failure=False, expose_error_trace=False): 63 | self._exit_on_sending_failure = exit_on_sending_failure 64 | self._expose_error_trace = expose_error_trace 65 | 66 | self._sending_data_q.start() 67 | logger.debug("Start check sending thread.") 68 | self._sending_error_q.start() 69 | logger.debug("Start check error sending thread.") 70 | 71 | def stop(self, wait_for_sending=False): 72 | # NOTE(NKcqx): MUST firstly stop the data queue, because it 73 | # may still throw errors during the termination which need to 74 | # be sent to the error queue. 75 | self._sending_data_q.stop(wait_for_sending=wait_for_sending) 76 | self._sending_error_q.stop(wait_for_sending=wait_for_sending) 77 | 78 | def push_to_sending( 79 | self, 80 | obj_ref: ray.ObjectRef, 81 | dest_party: str = None, 82 | upstream_seq_id: int = -1, 83 | downstream_seq_id: int = -1, 84 | is_error: bool = False, 85 | ): 86 | """ 87 | Push the sending remote task's return value, i.e. `obj_ref` to 88 | the corresponding message queue. 89 | 90 | Args: 91 | obj_ref: The return value of the send remote task. 92 | dest_party: Destination 93 | upstream_seq_id: (Optional) This is unneccessary when sending 94 | normal data message because it was already sent to the target 95 | party. However, if the computation is corrupted, an error object 96 | will be created and sent with the same seq_id to replace the 97 | original data object. This argument is used to send the error. 98 | downstream_seq_id: (Optional) Same as `upstream_seq_id`. 99 | is_error: (Optional) Whether the obj_ref represent an error object or not. 100 | Default to False. If True, the obj_ref will be sent to the error 101 | queue instead. 102 | """ 103 | msg_pack = (obj_ref, dest_party, upstream_seq_id, downstream_seq_id) 104 | if is_error: 105 | self._sending_error_q.append(msg_pack) 106 | else: 107 | self._sending_data_q.append(msg_pack) 108 | 109 | def get_last_sending_error(self): 110 | return self._last_sending_error 111 | 112 | def _signal_exit(self): 113 | """ 114 | Exit the current process immediately. The signal will be captured 115 | in main thread where the `stop` will be called. 116 | """ 117 | # NOTE(NKcqx): The signal is implemented by the error mechanism, 118 | # a `KeyboardInterrupt` will be raised after sending the signal, 119 | # and OS will hold 120 | # the process's original context and change to the error handler context. 121 | # States that the original context hold, including `threading.lock`, 122 | # will not be released, acquiring the same lock in signal handler 123 | # will cause dead lock. In order to ensure executing `shutdown` exactly 124 | # once and avoid dead lock, the lock must be checked before sending 125 | # signals. 126 | if self._acquire_shutdown_flag(): 127 | logger.warn("Signal SIGINT to exit.") 128 | os.kill(os.getpid(), signal.SIGINT) 129 | 130 | def _process_data_sending_task_return(self, message): 131 | """ 132 | This is the message handler function used in message queue for 133 | processing each element. The element is putted from `barriers.send` 134 | and is a quadruple of . 135 | 136 | The `obj_ref` is the task return value of `sender_proxy.send.remote`. 137 | It `obj_ref` needs `ray.get` to trigger the execution of the corresponding 138 | task, which is the main functionality of this handler function. 139 | 140 | If any exception occurs during `ray.get`, it indicates that the data cannot 141 | be sent to other party normally. In order to notify the other party the current 142 | situation and prevent it from hanging, a RemoteError object will be constructed 143 | to replace the origin data object, and try to send it again. 144 | 145 | Return: 146 | bool: True, means the processing is success. The message queue will keep 147 | polling. 148 | False, make the message queue stop polling. 149 | """ 150 | obj_ref, dest_party, upstream_seq_id, downstream_seq_id = message 151 | try: 152 | res = ray.get(obj_ref) 153 | except Exception as e: 154 | logger.warn( 155 | f"Failed to send {obj_ref} to {dest_party}, error: {e}," 156 | f"upstream_seq_id: {upstream_seq_id}, " 157 | f"downstream_seq_id: {downstream_seq_id}." 158 | ) 159 | self._last_sending_error = e 160 | if isinstance(e, RayError): 161 | logger.info(f"Sending error {e.cause} to {dest_party}.") 162 | from fed.proxy.barriers import send 163 | 164 | # TODO(NKcqx): Cascade broadcast to all parties 165 | error_trace = e.cause if self._expose_error_trace else None 166 | send( 167 | dest_party, 168 | FedRemoteError(self._current_party, error_trace), 169 | upstream_seq_id, 170 | downstream_seq_id, 171 | True, 172 | ) 173 | 174 | res = False 175 | 176 | if not res and self._exit_on_sending_failure: 177 | # NOTE(NKcqx): Send signal to main thread so that it can 178 | # do some cleaning, e.g. kill the error sending thread. 179 | self._signal_exit() 180 | # Return False to exit the loop in sub-thread. Note that 181 | # the above signal will also make the main thread to kill 182 | # the sub-thread eventually by pushing a stop flag. 183 | return False 184 | return True 185 | 186 | def _process_error_sending_task_return(self, error_msg): 187 | error_ref, dest_party, upstream_seq_id, downstream_seq_id = error_msg 188 | try: 189 | res = ray.get(error_ref) 190 | logger.debug(f"Sending error got response: {res}.") 191 | except Exception: 192 | res = False 193 | 194 | if not res: 195 | logger.warning( 196 | f"Failed to send error {error_ref} to {dest_party}, " 197 | f"upstream_seq_id: {upstream_seq_id} " 198 | f"downstream_seq_id: {downstream_seq_id}. " 199 | "In this case, other parties won't sense " 200 | "this error and may cause unknown behaviour." 201 | ) 202 | # Return True so that remaining error objects can be sent 203 | return True 204 | -------------------------------------------------------------------------------- /fed/tests/test_cross_silo_error.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The RayFed Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | import sys 17 | 18 | import pytest 19 | import ray 20 | 21 | import fed 22 | import fed._private.compatible_utils as compatible_utils 23 | from fed._private.global_context import get_global_context 24 | from fed.exceptions import FedRemoteError 25 | 26 | 27 | class MyError(Exception): 28 | def __init__(self, message): 29 | super().__init__(message) 30 | 31 | 32 | @fed.remote 33 | def error_func(): 34 | raise MyError("Test normal task Error") 35 | 36 | 37 | @fed.remote 38 | def normal_func(a): 39 | return a 40 | 41 | 42 | @fed.remote 43 | class My: 44 | def __init__(self) -> None: 45 | pass 46 | 47 | def error_func(self): 48 | raise MyError("Test actor task Error") 49 | 50 | 51 | def run(party): 52 | compatible_utils.init_ray(address='local') 53 | addresses = { 54 | 'alice': '127.0.0.1:11012', 55 | 'bob': '127.0.0.1:11011', 56 | } 57 | 58 | fed.init( 59 | addresses=addresses, 60 | party=party, 61 | logging_level='debug', 62 | config={ 63 | 'cross_silo_comm': { 64 | 'timeout_ms': 20 * 1000, 65 | 'expose_error_trace': True, 66 | }, 67 | }, 68 | ) 69 | 70 | # Both party should catch the error 71 | o = error_func.party("alice").remote() 72 | with pytest.raises(Exception) as e: 73 | fed.get(o) 74 | if party == 'bob': 75 | assert isinstance(e.value.cause, FedRemoteError) 76 | assert 'RemoteError occurred at alice' in str(e.value.cause) 77 | assert "normal task Error" in str(e.value.cause) 78 | else: 79 | assert isinstance(e.value.cause, MyError) 80 | assert "normal task Error" in str(e.value.cause) 81 | fed.shutdown() 82 | ray.shutdown() 83 | 84 | 85 | def test_cross_silo_normal_task_error(): 86 | p_alice = multiprocessing.Process(target=run, args=('alice',)) 87 | p_bob = multiprocessing.Process(target=run, args=('bob',)) 88 | p_alice.start() 89 | p_bob.start() 90 | p_alice.join() 91 | p_bob.join() 92 | assert p_alice.exitcode == 0 93 | assert p_bob.exitcode == 0 94 | 95 | 96 | def run2(party): 97 | compatible_utils.init_ray(address='local') 98 | addresses = { 99 | 'alice': '127.0.0.1:11012', 100 | 'bob': '127.0.0.1:11011', 101 | } 102 | fed.init( 103 | addresses=addresses, 104 | party=party, 105 | logging_level='debug', 106 | config={ 107 | 'cross_silo_comm': { 108 | 'timeout_ms': 20 * 1000, 109 | 'expose_error_trace': True, 110 | }, 111 | }, 112 | ) 113 | 114 | # Both party should catch the error 115 | my = My.party('alice').remote() 116 | o = my.error_func.remote() 117 | with pytest.raises(Exception) as e: 118 | fed.get(o) 119 | 120 | if party == 'bob': 121 | assert isinstance(e.value.cause, FedRemoteError) 122 | assert 'RemoteError occurred at alice' in str(e.value.cause) 123 | assert "actor task Error" in str(e.value.cause) 124 | else: 125 | assert isinstance(e.value.cause, MyError) 126 | assert "actor task Error" in str(e.value.cause) 127 | 128 | fed.shutdown() 129 | ray.shutdown() 130 | 131 | 132 | def test_cross_silo_actor_task_error(): 133 | p_alice = multiprocessing.Process(target=run2, args=('alice',)) 134 | p_bob = multiprocessing.Process(target=run2, args=('bob',)) 135 | p_alice.start() 136 | p_bob.start() 137 | p_alice.join() 138 | p_bob.join() 139 | assert p_alice.exitcode == 0 140 | assert p_bob.exitcode == 0 141 | 142 | 143 | def run3(party): 144 | compatible_utils.init_ray(address='local') 145 | addresses = { 146 | 'alice': '127.0.0.1:11012', 147 | 'bob': '127.0.0.1:11011', 148 | } 149 | 150 | fed.init( 151 | addresses=addresses, 152 | party=party, 153 | logging_level='debug', 154 | config={ 155 | 'cross_silo_comm': { 156 | 'timeout_ms': 20 * 1000, 157 | 'expose_error_trace': False, 158 | }, 159 | }, 160 | ) 161 | 162 | # Both party should catch the error 163 | o = error_func.party("alice").remote() 164 | with pytest.raises(Exception) as e: 165 | fed.get(o) 166 | if party == 'bob': 167 | assert isinstance(e.value.cause, FedRemoteError) 168 | assert 'RemoteError occurred at alice' in str(e.value.cause) 169 | assert 'caused by' not in str(e.value.cause) 170 | else: 171 | assert isinstance(e.value.cause, MyError) 172 | assert "normal task Error" in str(e.value.cause) 173 | fed.shutdown() 174 | ray.shutdown() 175 | 176 | 177 | def test_cross_silo_not_expose_error_trace(): 178 | p_alice = multiprocessing.Process(target=run3, args=('alice',)) 179 | p_bob = multiprocessing.Process(target=run3, args=('bob',)) 180 | p_alice.start() 181 | p_bob.start() 182 | p_alice.join() 183 | p_bob.join() 184 | assert p_alice.exitcode == 0 185 | assert p_bob.exitcode == 0 186 | 187 | 188 | @fed.remote 189 | def foo(e): 190 | print(e) 191 | 192 | 193 | def run4(party): 194 | compatible_utils.init_ray(address='local') 195 | addresses = { 196 | 'alice': '127.0.0.1:11012', 197 | 'bob': '127.0.0.1:11011', 198 | } 199 | 200 | fed.init( 201 | addresses=addresses, 202 | party=party, 203 | logging_level='debug', 204 | config={ 205 | 'cross_silo_comm': { 206 | 'timeout_ms': 20 * 1000, 207 | 'expose_error_trace': False, 208 | }, 209 | }, 210 | ) 211 | 212 | a = error_func.party("alice").remote() 213 | o = foo.party('bob').remote(a) 214 | if party == 'bob': 215 | # Wait a while to receive error from alice. 216 | import time 217 | 218 | time.sleep(10) 219 | # Alice will shutdown once exactly. 220 | fed.shutdown() 221 | ray.shutdown() 222 | 223 | 224 | def test_cross_silo_alice_send_error_and_shutdown_once(): 225 | p_alice = multiprocessing.Process(target=run4, args=('alice',)) 226 | p_bob = multiprocessing.Process(target=run4, args=('bob',)) 227 | p_alice.start() 228 | p_bob.start() 229 | p_alice.join() 230 | p_bob.join() 231 | assert p_alice.exitcode == 0 232 | assert p_bob.exitcode == 0 233 | 234 | 235 | def run5(party: str): 236 | compatible_utils.init_ray(address='local') 237 | addresses = { 238 | 'alice': '127.0.0.1:11012', 239 | 'bob': '127.0.0.1:11011', 240 | } 241 | 242 | fed.init( 243 | addresses=addresses, 244 | party=party, 245 | logging_level='debug', 246 | config={ 247 | 'cross_silo_comm': { 248 | 'timeout_ms': 20 * 1000, 249 | 'expose_error_trace': False, 250 | 'continue_waiting_for_data_sending_on_error': True, 251 | }, 252 | }, 253 | ) 254 | 255 | assert get_global_context().get_continue_waiting_for_data_sending_on_error() 256 | 257 | fed.shutdown() 258 | ray.shutdown() 259 | 260 | 261 | def test_continue_waiting_for_data_sending_on_error(): 262 | p_alice = multiprocessing.Process(target=run5, args=('alice',)) 263 | p_alice.start() 264 | p_alice.join() 265 | assert p_alice.exitcode == 0 266 | 267 | 268 | def run6(party: str): 269 | compatible_utils.init_ray(address='local') 270 | addresses = { 271 | 'alice': '127.0.0.1:11012', 272 | 'bob': '127.0.0.1:11011', 273 | } 274 | 275 | fed.init( 276 | addresses=addresses, 277 | party=party, 278 | logging_level='debug', 279 | config={ 280 | 'cross_silo_comm': { 281 | 'timeout_ms': 20 * 1000, 282 | 'expose_error_trace': False, 283 | 'exit_on_sending_failure': True, 284 | }, 285 | }, 286 | ) 287 | 288 | try: 289 | # Alice ran into an error and broadcast error to bob. And exit then. 290 | a = error_func.party('alice').remote() 291 | b = normal_func.party('bob').remote(a) 292 | 293 | # Bob got the error. 294 | fed.get(b) 295 | finally: 296 | fed.shutdown() 297 | ray.shutdown() 298 | 299 | 300 | def test_no_wait_for_data_sending_on_error(): 301 | p_alice = multiprocessing.Process(target=run6, args=('alice',)) 302 | p_bob = multiprocessing.Process(target=run6, args=('bob',)) 303 | p_alice.start() 304 | p_bob.start() 305 | p_alice.join() 306 | p_bob.join() 307 | assert p_alice.exitcode == 1 308 | assert p_bob.exitcode == 1 309 | 310 | 311 | if __name__ == "__main__": 312 | sys.exit(pytest.main(["-sv", __file__])) 313 | --------------------------------------------------------------------------------