├── fedray ├── algorithms │ ├── __init__.py │ └── fedopt │ │ ├── server.py │ │ ├── client.py │ │ └── federation.py ├── util │ ├── __init__.py │ ├── exceptions.py │ └── resources.py ├── .DS_Store ├── __init__.py ├── core │ ├── node │ │ ├── __init__.py │ │ ├── callback_node.py │ │ ├── virtual.py │ │ └── fedray_node.py │ ├── federation │ │ ├── __init__.py │ │ ├── hierarchical.py │ │ ├── decentralized.py │ │ ├── client_server.py │ │ └── base.py │ └── communication │ │ ├── message.py │ │ └── topology │ │ └── manager.py └── _private │ └── decorator.py ├── CHANGELOG.md ├── Makefile ├── pyproject.toml ├── docs ├── requirements.txt ├── _static │ └── images │ │ └── fedray_logo_name_color.png ├── source │ ├── core-api │ │ ├── util.rst │ │ ├── index.rst │ │ ├── node.rst │ │ ├── communication.rst │ │ └── federation.rst │ └── tutorials │ │ ├── index.rst │ │ └── message-passing.rst ├── README.md ├── Makefile ├── make.bat ├── index.rst └── conf.py ├── CONTRIBUTING.md ├── .readthedocs.yaml ├── README.rst ├── .pre-commit-config.yaml ├── LICENSE ├── setup.py ├── examples └── messaging.py └── .gitignore /fedray/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedray/algorithms/fedopt/server.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fedray/util/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from . import resources 3 | -------------------------------------------------------------------------------- /fedray/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclypse-org/federact/HEAD/fedray/.DS_Store -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.0.2 (2023-02-09) 2 | 3 | ### Refactor 4 | 5 | - Apply pre-commit hooks 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := check 2 | 3 | check: 4 | pre-commit run -a 5 | 6 | changelog: 7 | cz bump --changelog 8 | -------------------------------------------------------------------------------- /fedray/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from ._private.decorator import remote 3 | 4 | __version__ = "0.0.2" 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.commitizen] 2 | name = "cz_conventional_commits" 3 | version = "0.0.2" 4 | tag_format = "$version" 5 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | ray[tune] 2 | numpy 3 | networkx 4 | sphinx 5 | sphinx-copybutton 6 | sphinx-book-theme 7 | Jinja2<3.1 8 | -------------------------------------------------------------------------------- /docs/_static/images/fedray_logo_name_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eclypse-org/federact/HEAD/docs/_static/images/fedray_logo_name_color.png -------------------------------------------------------------------------------- /fedray/util/exceptions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | class EndProcessException(Exception): 3 | """Exception to end a federated session.""" 4 | 5 | pass 6 | -------------------------------------------------------------------------------- /fedray/core/node/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .callback_node import CallbackFedRayNode 3 | from .fedray_node import FedRayNode 4 | from .virtual import VirtualNode 5 | -------------------------------------------------------------------------------- /docs/source/core-api/util.rst: -------------------------------------------------------------------------------- 1 | ``fedray.core.util`` API 2 | ========================= 3 | 4 | get_resources_split 5 | ------------------- 6 | .. autofunction:: fedray.util.resources.get_resources_split 7 | -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | The following tutorials are available: 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | Message Passing Interface (MPI) 10 | -------------------------------------------------------------------------------- /docs/source/core-api/index.rst: -------------------------------------------------------------------------------- 1 | FedRay Core API 2 | =============== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | node 9 | federation 10 | communication 11 | util -------------------------------------------------------------------------------- /fedray/core/federation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .base import Federation 3 | from .client_server import ClientServerFederation 4 | from .decentralized import DecentralizedFederation 5 | from .hierarchical import HierarchicalFederation 6 | -------------------------------------------------------------------------------- /docs/source/core-api/node.rst: -------------------------------------------------------------------------------- 1 | ``fedray.core.node`` API 2 | ======================== 3 | 4 | 5 | .. currentmodule:: fedray.core.node 6 | 7 | VirtualNode 8 | ----------- 9 | .. autoclass:: VirtualNode 10 | :members: 11 | 12 | 13 | FedRayNode 14 | ---------- 15 | .. autoclass:: FedRayNode 16 | :members: 17 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | Documentation is built with sphinx. 4 | 5 | Build the documentation (it must be executed in the `docs` folder): 6 | ``` 7 | sphinx-build . _build 8 | ``` 9 | 10 | ### Doc coverage 11 | it is possible to check the class coverage to find classes that are missing from the documentation using the command: 12 | ``` 13 | sphinx-build -b coverage . _build 14 | ``` 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## CONTRIBUTING.md 2 | 3 | ### Required tools 4 | - [pipx](https://pypa.github.io/pipx/#install-pipx) 5 | - [pre-commit](https://pre-commit.com/#install) 6 | - [commitzen](https://github.com/commitizen-tools/commitizen#installation) 7 | 8 | ### Commit style 9 | This project follows [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary). This is **enforced via commitzen**. 10 | -------------------------------------------------------------------------------- /docs/source/core-api/communication.rst: -------------------------------------------------------------------------------- 1 | ``fedray.core.communication`` API 2 | ================================= 3 | 4 | .. currentmodule:: fedray.core.communication.topology.manager 5 | 6 | TopologyManager 7 | --------------- 8 | .. autoclass:: TopologyManager 9 | :members: 10 | 11 | 12 | .. currentmodule:: fedray.core.communication.message 13 | 14 | Message 15 | ------- 16 | .. autoclass:: Message 17 | :members: -------------------------------------------------------------------------------- /docs/source/core-api/federation.rst: -------------------------------------------------------------------------------- 1 | ``fedray.core.federation`` API 2 | ============================== 3 | 4 | .. currentmodule:: fedray.core.federation 5 | 6 | Federation 7 | ---------- 8 | .. autoclass:: Federation 9 | :members: 10 | 11 | 12 | 13 | Client-Server Federation 14 | ------------------------ 15 | .. autoclass:: ClientServerFederation 16 | :members: 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /fedray/core/node/callback_node.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .fedray_node import FedRayNode 3 | 4 | 5 | class CallbackFedRayNode(FedRayNode): 6 | def run(self): 7 | while True: 8 | in_msg = self.receive() 9 | fn = getattr(self, in_msg.header) 10 | args = in_msg.body 11 | args["sender_id"] = in_msg.sender_id 12 | args["timestamp"] = in_msg.timestamp 13 | 14 | response = fn(**args) 15 | if response is not None: 16 | self.send(**response) 17 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.8" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | 19 | # Optionally declare the Python requirements required to build your docs 20 | python: 21 | install: 22 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /fedray/core/communication/message.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | from dataclasses import dataclass 4 | from dataclasses import field 5 | from typing import Dict 6 | 7 | 8 | @dataclass 9 | class Message: 10 | """ 11 | A message is a simple data structure that is used to communicate between 12 | nodes. It contains a header, a sender ID, a timestamp, and a body. The body 13 | is a dictionary that can contain any data that is needed to be communicated 14 | between nodes. 15 | """ 16 | 17 | header: str = None 18 | sender_id: str = None 19 | timestamp: datetime.datetime = field(default_factory=datetime.datetime.now) 20 | body: Dict = field(default_factory=dict) 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /fedray/core/federation/hierarchical.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict 3 | from typing import List 4 | from typing import Type 5 | from typing import Union 6 | 7 | import numpy as np 8 | from ray.util.placement_group import PlacementGroup 9 | 10 | from fedray.core.federation import Federation 11 | from fedray.core.node.fedray_node import FedRayNode 12 | 13 | 14 | class HierarchicalFederation(Federation): 15 | def __init__( 16 | self, 17 | level_templates: List[Type[FedRayNode]], 18 | nodes_per_level: Union[List[int], List[List[str]]], 19 | roles: List[str], 20 | topology: Union[str, np.ndarray], 21 | level_config: List[Dict], 22 | resources: Union[str, PlacementGroup] = "uniform", 23 | ) -> None: 24 | raise NotImplementedError 25 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: docs/_static/images/fedray_logo_name_color.png 2 | 3 | FedRay is a framework for Research in Federated Learning based on Ray. FedRay allows to 4 | easily implement your FL algorithms or use off-the-shelf algorithms, and distribute 5 | them seamlessly on Ray Clusters. 6 | 7 | FedRay is a research project, and it is still under development. 8 | 9 | Installation 10 | ------------ 11 | FedRay can be installed by cloning the repository and running the setup script. 12 | 13 | .. code-block:: console 14 | 15 | $ git clone https://github.com/vdecaro/fedray 16 | $ cd fedray 17 | $ pip install -e . 18 | 19 | Documentation 20 | ------------- 21 | The documentation is hosted on ReadTheDocs and can be found 22 | `here `_. 23 | 24 | License 25 | ------- 26 | FedRay is released under the MIT license. 27 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - hooks: 4 | - id: commitizen 5 | - id: commitizen-branch 6 | stages: 7 | - push 8 | repo: https://github.com/commitizen-tools/commitizen 9 | rev: v2.40.0 10 | - hooks: 11 | - id: check-added-large-files 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | - id: fix-encoding-pragma 15 | - id: detect-private-key 16 | repo: https://github.com/pre-commit/pre-commit-hooks 17 | rev: v2.3.0 18 | - hooks: 19 | - id: black 20 | repo: https://github.com/psf/black 21 | rev: 22.10.0 22 | - hooks: 23 | - additional_dependencies: 24 | - black==22.12.0 25 | id: blacken-docs 26 | repo: https://github.com/adamchainz/blacken-docs 27 | rev: v1.12.1 28 | - hooks: 29 | - id: reorder-python-imports 30 | repo: https://github.com/asottile/reorder_python_imports 31 | rev: v3.9.0 32 | - hooks: 33 | - id: pycln 34 | repo: https://github.com/hadialqattan/pycln 35 | rev: v2.1.3 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Valerio De Caro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 2 3 | :hidden: 4 | 5 | Getting Started 6 | Tutorials 7 | FedRay Core API 8 | 9 | .. image:: _static/images/fedray_logo_name_color.png 10 | 11 | =============================================== 12 | 13 | Welcome to the `FedRay `_ documentation! 14 | 15 | 16 | Why FedRay? 17 | =========== 18 | FedRay is **a framework for Research in Federated 19 | Learning based on Ray**. It allows to easily *implement your FL algorithms* or *use 20 | off-the-shelf algorithms*, and distribute them seamlessly on Ray Clusters. 21 | 22 | FedRay is a research project, and it is still under development. 23 | 24 | 25 | .. _getting-started: 26 | 27 | Getting Started 28 | =============== 29 | To get started, you can easily install FedRay by running the following commands: 30 | 31 | .. code-block:: bash 32 | 33 | pip install git+https://github.com/vdecaro/fedray 34 | 35 | To ensure that the installation was successful, you can execute this simple python 36 | script: 37 | 38 | .. code-block:: python 39 | 40 | import fedray 41 | print(fedray.__version__) 42 | -------------------------------------------------------------------------------- /fedray/_private/decorator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ray 3 | 4 | 5 | def remote(*args, **kwargs): 6 | """Defines a remote node of a federation. 7 | 8 | Overrides the default behavior of `ray.remote` by ensuring the value of 9 | max_concurrency is at least 2. This allows to avoid a node being stuck on a single 10 | active training/eval session. The default value of max_concurrency is set to 100. 11 | 12 | The other arguments are the same as `ray.remote`. 13 | """ 14 | empty_kwargs = len(kwargs) == 0 15 | _default_max_concurrency = 100 16 | if "max_concurrency" not in kwargs: 17 | kwargs["max_concurrency"] = _default_max_concurrency 18 | else: 19 | if kwargs["max_concurrency"] is None: 20 | kwargs["max_concurrency"] = _default_max_concurrency 21 | elif kwargs["max_concurrency"] < 2: 22 | raise ValueError( 23 | "max_concurrency must be greater than 1. Set max_concurrency to None " 24 | f"to use the default value of {_default_max_concurrency}. If you are " 25 | "using a custom max_concurrency value, make sure it is at least 2." 26 | ) 27 | if empty_kwargs: 28 | return ray.remote(**kwargs)(*args) 29 | else: 30 | return lambda fn_or_class: ray.remote(**kwargs)(fn_or_class) 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import codecs 3 | import os 4 | 5 | import setuptools 6 | 7 | with open("README.rst", "r") as fh: 8 | long_description = fh.read() 9 | 10 | 11 | def read(rel_path): 12 | here = os.path.abspath(os.path.dirname(__file__)) 13 | with codecs.open(os.path.join(here, rel_path), "r") as fp: 14 | return fp.read() 15 | 16 | 17 | def get_version(rel_path): 18 | for line in read(rel_path).splitlines(): 19 | if line.startswith("__version__"): 20 | delim = '"' if '"' in line else "'" 21 | return line.split(delim)[1] 22 | else: 23 | raise RuntimeError("Unable to find version string.") 24 | 25 | 26 | setuptools.setup( 27 | name="fedray", 28 | version=get_version("fedray/__init__.py"), 29 | author="Valerio De Caro", 30 | author_email="valerio.decaro@phd.unipi.it", 31 | description="FedRay: a Research Framework for Federated Learning based on Ray", 32 | long_description=long_description, 33 | long_description_content_type="text/markdown", 34 | url="https://github.com/vdecaro/fedray", 35 | packages=setuptools.find_packages(), 36 | classifiers=[ 37 | "Programming Language :: Python :: 3", 38 | "License :: OSI Approved :: MIT License", 39 | "Operating System :: OS Independent", 40 | ], 41 | python_requires=">=3.7,<3.11", 42 | install_requires=[ 43 | "ray[tune]", 44 | "numpy", 45 | "networkx", 46 | ], 47 | include_package_data=True, 48 | ) 49 | -------------------------------------------------------------------------------- /examples/messaging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | 4 | import ray 5 | 6 | import fedray 7 | from fedray.core.federation import ClientServerFederation 8 | from fedray.core.node import FedRayNode 9 | 10 | 11 | @fedray.remote 12 | class MessagingServer(FedRayNode): 13 | def train(self, out_msg: str): 14 | n_exchanges = 0 15 | while True: 16 | msg = self.receive() 17 | print( 18 | f"{self.id} received {msg.body['msg']} from {msg.sender_id} at", 19 | {msg.timestamp}, 20 | ) 21 | self.send("exchange", {"msg": out_msg()}, to=msg.sender_id) 22 | self.update_version(n_exchanges=n_exchanges) 23 | 24 | 25 | @fedray.remote 26 | class MessagingClient(FedRayNode): 27 | def train(self, out_msg: str) -> None: 28 | while True: 29 | self.send("exchange", {"msg": out_msg()}) 30 | msg = self.receive() 31 | print( 32 | f"{self.id} received {msg.body['msg']} from {msg.sender_id}", 33 | msg.timestamp, 34 | ) 35 | time.sleep(3) 36 | 37 | 38 | def main(): 39 | ray.init() 40 | federation = ClientServerFederation( 41 | server_template=MessagingServer, 42 | client_template=MessagingClient, 43 | n_clients_or_ids=4, 44 | roles=["train" for _ in range(4)], 45 | ) 46 | report = federation.train( 47 | server_args={"out_msg": lambda: "Hello from server!"}, 48 | client_args={"out_msg": lambda: "Hello from client!"}, 49 | ) 50 | for _ in range(4): 51 | version = federation.pull_version() 52 | print(version) 53 | time.sleep(3) 54 | federation.stop() 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /fedray/algorithms/fedopt/client.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict 3 | from typing import List 4 | 5 | import fedray 6 | 7 | 8 | class FedOptClient: 9 | def build(self, local_epochs: int) -> None: 10 | self.epochs = local_epochs 11 | 12 | def run(self): 13 | while True: 14 | msg = self.get_message(block=True) 15 | self.model.load_state_dict(msg.body["state"]) 16 | 17 | for e in range(self.epochs): 18 | self.train_epoch() 19 | 20 | self.send(type="model", body=self.client_opt.get_params()) 21 | 22 | def train_epoch(self): 23 | raise NotImplementedError 24 | 25 | 26 | class ServerOpt: 27 | def __init__(self) -> None: 28 | self.client_ids: Dict[str, bool] = {} 29 | 30 | def set_iteration(self, client_ids: List[str], **kwargs) -> None: 31 | raise NotImplementedError 32 | 33 | def update(self): 34 | raise NotImplementedError 35 | 36 | def aggregate(self): 37 | raise NotImplementedError 38 | 39 | 40 | class AveragingOpt(ServerOpt): 41 | def __init__(self) -> None: 42 | self.n_samples: int = 0 43 | self.state_dict: Dict = None 44 | 45 | def set_iteration(self, client_ids: List[str]) -> None: 46 | self.client_ids = {c_id: False for c_id in client_ids} 47 | self.n_samples = 0 48 | self.state_dict = None 49 | 50 | def update(self, client_id: str, client_dict: Dict): 51 | local_n_samples = client_dict.pop("n_samples") 52 | for k in client_dict["state"].keys(): 53 | client_dict["state"][k] = client_dict["state"][k] * local_n_samples 54 | self.n_samples += client_dict.pop("n_samples") 55 | 56 | if self.state_dict is None: 57 | self.state_dict = client_dict 58 | 59 | def aggregate(self): 60 | n_samples = sum(self.client_dicts[k]["n_samples"] for k in self.client_ids) 61 | model_dict = {} 62 | 63 | for k in self.client_ids: 64 | pass 65 | 66 | @property 67 | def ready(self): 68 | return all([c_id[1] for c_id in self.client_ids]) 69 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Configuration file for the Sphinx documentation builder. 3 | # 4 | # For the full list of built-in configuration values, see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | import os 7 | import sys 8 | from datetime import datetime 9 | 10 | sys.path.insert(0, os.path.abspath("..")) 11 | 12 | import fedray 13 | 14 | # -- Project information ----------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 16 | project = "FedRay" 17 | copyright = str(datetime.now().year) + ", Valerio De Caro" 18 | author = "Valerio De Caro" 19 | release = fedray.__version__ 20 | 21 | # -- General configuration --------------------------------------------------- 22 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 23 | 24 | extensions = [ 25 | "sphinx.ext.autodoc", 26 | "sphinx.ext.viewcode", 27 | "sphinx.ext.napoleon", 28 | "sphinx_copybutton", 29 | "sphinx.ext.autosummary", 30 | "sphinx.ext.coverage", 31 | "sphinx.ext.napoleon", 32 | "sphinx_book_theme", 33 | ] 34 | 35 | 36 | # templates_path = ["_templates"] 37 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 38 | 39 | coverage_show_missing_items = True 40 | # Automatically extract typehints when specified and place them in 41 | # descriptions of the relevant function/method. 42 | autodoc_typehints = "description" 43 | autodoc_member_order = "bysource" 44 | # Don't show class signature with the class' name. 45 | autodoc_class_signature = "separated" 46 | # -- Options for HTML output ------------------------------------------------- 47 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 48 | 49 | html_theme = "sphinx_book_theme" 50 | html_title = "Version " + release 51 | html_logo = "_static/images/fedray_logo_name_color.png" 52 | # html_sidebars = {"Federation": ["generated/fedray.core.federation.Federation.rst"]} 53 | html_theme_options = { 54 | "repository_url": "https://github.com/vdecaro/fedray", 55 | "use_repository_button": True, 56 | "collapse_navigation": False, 57 | } 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/ 132 | -------------------------------------------------------------------------------- /fedray/core/node/virtual.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict 3 | from typing import Type 4 | 5 | from ray.util.placement_group import PlacementGroup 6 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 7 | 8 | from .fedray_node import FedRayNode 9 | 10 | 11 | class VirtualNode(object): 12 | """ 13 | A VirtualNode is a wrapper around a FedRayNode that is used to represent a node 14 | within a federation. This allows a Federation to perform the lazy initialization 15 | and build of the node, which is deferred to the first call of the node within 16 | a session. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | template: Type[FedRayNode], 22 | id: str, 23 | federation_id: str, 24 | role: str, 25 | config: Dict, 26 | ) -> None: 27 | """Creates a new VirtualNode object. 28 | 29 | Args: 30 | template (Type[FedRayNode]): The template for the node. 31 | id (str): The ID of the node. 32 | federation_id (str): The ID of the federation. 33 | role (str): The role of the node. 34 | config (Dict): The configuration to be passed to the build method of the 35 | node. 36 | """ 37 | self.template = template 38 | self.fed_id = federation_id 39 | self.id = id 40 | self.role = role 41 | self.config = config 42 | self.handle: FedRayNode = None 43 | 44 | def build(self, bundle_idx: int, placement_group: PlacementGroup): 45 | """Builds the node. 46 | 47 | Args: 48 | bundle_idx (int): The index of the bundle within the placement group. 49 | placement_group (PlacementGroup): The placement group to be used for the 50 | node. 51 | """ 52 | resources = placement_group.bundle_specs[bundle_idx] 53 | num_cpus = resources["CPU"] 54 | num_gpus = resources["GPU"] if "GPU" in resources else 0 55 | self.handle = self.template.options( 56 | name="/".join([self.fed_id, self.id]), 57 | num_cpus=num_cpus, 58 | num_gpus=num_gpus, 59 | scheduling_strategy=PlacementGroupSchedulingStrategy( 60 | placement_group, placement_group_bundle_index=bundle_idx 61 | ), 62 | ).remote( 63 | node_id=self.id, role=self.role, federation_id=self.fed_id, **self.config 64 | ) 65 | 66 | @property 67 | def built(self): 68 | """Returns whether the node has been built.""" 69 | return self.handle is not None 70 | -------------------------------------------------------------------------------- /fedray/algorithms/fedopt/federation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict 3 | from typing import List 4 | 5 | import fedray 6 | 7 | 8 | class FedOptServer: 9 | def __init__(self, global_epochs: int) -> None: 10 | self.global_epochs = global_epochs 11 | 12 | def run(self): 13 | for _ in range(self.global_epochs): 14 | client_sample = self.sample_clients() 15 | 16 | self.send(msg_type="model", body=self.model.state_dict(), ids=client_sample) 17 | self.set_iteration(client_ids=client_sample) 18 | while not self.server_aggregator.ready: 19 | self.on_client_receive(self.get_message(block=True)) 20 | self.aggregate() 21 | 22 | self.aggregate() 23 | 24 | def sample_clients(self): 25 | raise NotImplementedError 26 | 27 | def on_client_receive(self, message): 28 | raise NotImplementedError 29 | 30 | def aggregate(self): 31 | raise NotImplementedError 32 | 33 | 34 | class FedOptClient: 35 | def build(self, local_epochs: int) -> None: 36 | self.epochs = local_epochs 37 | 38 | def run(self): 39 | while True: 40 | msg = self.get_message(block=True) 41 | self.model.load_state_dict(msg.body["state"]) 42 | 43 | for e in range(self.epochs): 44 | self.train_epoch() 45 | 46 | self.send(type="model", body=self.client_opt.get_params()) 47 | 48 | def train_epoch(self): 49 | raise NotImplementedError 50 | 51 | 52 | class ServerOpt: 53 | def __init__(self) -> None: 54 | self.client_ids: Dict[str, bool] = {} 55 | 56 | def set_iteration(self, client_ids: List[str], **kwargs) -> None: 57 | raise NotImplementedError 58 | 59 | def update(self): 60 | raise NotImplementedError 61 | 62 | def aggregate(self): 63 | raise NotImplementedError 64 | 65 | 66 | class AveragingOpt(ServerOpt): 67 | def __init__(self) -> None: 68 | self.n_samples: int = 0 69 | self.state_dict: Dict = None 70 | 71 | def set_iteration(self, client_ids: List[str]) -> None: 72 | self.client_ids = {c_id: False for c_id in client_ids} 73 | self.n_samples = 0 74 | self.state_dict = None 75 | 76 | def update(self, client_id: str, client_dict: Dict): 77 | local_n_samples = client_dict.pop("n_samples") 78 | for k in client_dict["state"].keys(): 79 | client_dict["state"][k] = client_dict["state"][k] * local_n_samples 80 | self.n_samples += client_dict.pop("n_samples") 81 | 82 | if self.state_dict is None: 83 | self.state_dict = client_dict 84 | 85 | def aggregate(self): 86 | n_samples = sum(self.client_dicts[k]["n_samples"] for k in self.client_ids) 87 | model_dict = {} 88 | 89 | for k in self.client_ids: 90 | pass 91 | 92 | @property 93 | def ready(self): 94 | return all([c_id[1] for c_id in self.client_ids]) 95 | -------------------------------------------------------------------------------- /docs/source/tutorials/message-passing.rst: -------------------------------------------------------------------------------- 1 | Tutorial: Simple Message-Passing Federation 2 | =========================================== 3 | 4 | In this tutorial we are going to show how to easily create a simple message-passing 5 | federation. The objective of this tutorial is to show how to use the ``FedRayNode``, 6 | instantiate a ``ClientServerFederation`` and implement a training interface as a 7 | simple message-passing federation. 8 | 9 | We will use the following components: 10 | 11 | * The ``FedRayNode`` abstract class for implementing the client and the server; 12 | * The ``ClientServerFederation`` class for instantiating and running the federation. 13 | 14 | 15 | First, we need to perform all the useful imports for this task: 16 | 17 | .. code-block:: python3 18 | 19 | import time 20 | import ray 21 | import fedray 22 | 23 | Step 1: Create the client 24 | ------------------------- 25 | 26 | The client is a ``FedRayNode`` that sends a message to the server and waits for a reply. 27 | 28 | .. code-block:: python3 29 | 30 | from fedray.core.node import FedRayNode 31 | 32 | @fedray.remote # This is a Ray actor 33 | class MessagingClient(FedRayNode): 34 | 35 | def train(self, out_msg: str) -> None: 36 | while True: 37 | # Send a message to the server (doesn't need to specify sender-id) 38 | self.send("exchange", {"msg": out_msg()}) 39 | 40 | # Wait for a reply 41 | msg = self.receive() 42 | 43 | # Print the reply 44 | print( 45 | f"{self.id} received {msg.body['msg']} from {msg.sender_id}", 46 | msg.timestamp, 47 | ) 48 | 49 | # Wait for 3 seconds 50 | time.sleep(3) 51 | 52 | 53 | Here, we used exploited the communication interface easily by using the ``send`` and 54 | ``receive`` methods. The ``send`` method takes the parameters header and body, while 55 | the ``receive`` method returns a ``Message`` object. The ``Message`` object has the 56 | following attributes: 57 | 58 | * ``sender_id``: the id of the sender; 59 | * ``body``: the body of the message; 60 | * ``timestamp``: the timestamp of the message. 61 | 62 | 63 | Step 2: Create the server 64 | ------------------------- 65 | 66 | To create the server, we can employ the same approach as for the client: 67 | 68 | .. code-block:: python3 69 | 70 | from fedray.core.node import FedRayNode 71 | 72 | @fedray.remote # This is a Ray actor 73 | class MessagingServer(FedRayNode): 74 | 75 | def train(self, out_msg: str) -> None: 76 | while True: 77 | # Wait for a message 78 | msg = self.receive() 79 | 80 | # Print the message 81 | print( 82 | f"{self.id} received {msg.body['msg']} from {msg.sender_id}", 83 | msg.timestamp, 84 | ) 85 | 86 | # Send a reply to the client that sent the message 87 | self.send("exchange", {"msg": out_msg()}, msg.sender_id) 88 | 89 | # Wait for 3 seconds 90 | time.sleep(3) 91 | 92 | 93 | Step 3: Create the federation 94 | ----------------------------- 95 | 96 | Now that we have created the client and the server, we can instantiate and run the 97 | federation: 98 | 99 | .. code-block:: python3 100 | 101 | # Create the federation 102 | federation = ClientServerFederation( 103 | server_template=MessagingServer, # Specifies the server template 104 | client=MessagingClient, # Specifies the client template 105 | n_clients_or_ids=3, # Number of clients 106 | roles=["train" for _ in range(3)], # Roles of the nodes 107 | ) 108 | 109 | # Run the federation 110 | federation.train( 111 | server_args = {"out_msg": lambda: "Hello from server"}, 112 | client_args = {"out_msg": lambda: "Hello from client"}, 113 | ) 114 | 115 | Note that the arguments of the training process are passed as keyword arguments to the 116 | ``train`` method. Thus, **the keys in the ``server_args`` and ``client_args`` dictionaries 117 | must match the names of the arguments of the ``train`` method of the ``MessagingServer`` 118 | and ``MessagingClient`` classes**. 119 | 120 | Remarks 121 | ------- 122 | Regardless of the complexity of the training process, the logic is always the same: 123 | 124 | * Implement the ``FedRayNode`` abstract class for both the client and the server; 125 | * Instantiate the ``ClientServerFederation`` class with the appropriate arguments; 126 | 127 | The ``ClientServerFederation`` class takes care of the rest, including the creation of 128 | the Ray actors, the communication between the nodes, the synchronization of the nodes, 129 | the termination of the federation, etc. 130 | 131 | **Final note: the values of all the arguments of the FedRayNode need to be serializable.** 132 | 133 | -------------------------------------------------------------------------------- /fedray/util/resources.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import random 4 | from typing import Literal 5 | 6 | import ray 7 | 8 | 9 | def get_resources_split( 10 | num_nodes: int, 11 | num_cpus: int = None, 12 | num_gpus: int = None, 13 | split_strategy: Literal["random", "uniform"] = "uniform", 14 | placement_strategy: Literal[ 15 | "STRICT_PACK", "PACK", "STRICT_SPREAD", "SPREAD" 16 | ] = "PACK", 17 | is_tune: bool = False, 18 | ): 19 | """ 20 | Provides the resources for a federation. The resources are provided as a 21 | PlacementGroup or a PlacementGroupFactory (if is_tune is True). This function 22 | allows to simulate the system heterogeneity across nodes. 23 | 24 | Given the number of nodes, cpus, gpus and split strategy, this function creates a 25 | set of bundles (i.e., dictionaries containing splits of the resources, one split 26 | per node), and then creates a PlacementGroup or a PlacementGroupFactory from the 27 | bundles. The placement strategy is used to determine how Ray will allocate the nodes 28 | in the cluster. 29 | 30 | Args: 31 | num_nodes (int): The number of nodes in the federation. 32 | num_cpus (int, optional): The number of CPUs available in the cluster. If None, 33 | it is set to num_nodes. Defaults to None. 34 | num_gpus (int, optional): The number of GPUs available in the cluster. If None, 35 | it is set to the number of GPUs available in the cluster. Defaults to None. 36 | split_strategy (Literal["random", "uniform"], optional): The strategy to split 37 | the resources. Defaults to "uniform". 38 | placement_strategy (Literal["STRICT_PACK", "PACK", "STRICT_SPREAD", "SPREAD"], optional): The 39 | strategy to place the nodes in the cluster. Defaults to "PACK". 40 | is_tune (bool, optional): Whether the resources are used within a Ray Tune 41 | experiment. 42 | 43 | Returns: 44 | Union[PlacementGroup, PlacementGroupFactory]: The resources for the federation. 45 | 46 | Raises: 47 | RuntimeError: If the number of CPUs is less than 2. 48 | ValueError: If the split strategy is not "random" or "uniform". 49 | """ 50 | SAFETY_EPSILON = 0.01 51 | available_resources = ray.available_resources() 52 | if available_resources["CPU"] < 2: 53 | raise RuntimeError( 54 | "At least 2 CPUs are required in the Ray cluster. Please increase the", 55 | "number of CPUs.", 56 | ) 57 | 58 | if not is_tune: 59 | resources = [{"CPU": 1}] 60 | else: 61 | resources = [{"CPU": 0.5}, {"CPU": 0.5}] 62 | 63 | if num_cpus is None: 64 | num_cpus = num_nodes 65 | if num_cpus > available_resources["CPU"]: 66 | num_cpus = available_resources["CPU"] 67 | logging.warn( 68 | "The available CPUs are less than the declared parameter num_cpus.", 69 | f"Parameter num_cpus set to {num_cpus}.", 70 | ) 71 | 72 | if num_gpus is not None: 73 | if "GPU" not in available_resources and num_gpus is not None: 74 | logging.warn( 75 | "GPUs not available in this Ray cluster. Parameter num_gpus set to None." 76 | ) 77 | num_gpus = 0 78 | elif num_gpus > available_resources["GPU"]: 79 | num_gpus = available_resources["GPU"] 80 | logging.warn( 81 | f"The available GPUs are less than the declared parameter num_gpus.", 82 | f"Parameter num_gpus set to {num_gpus}.", 83 | ) 84 | else: 85 | num_gpus = available_resources["GPU"] if "GPU" in available_resources else 0 86 | 87 | fix_size = lambda x: x if x < 1 else int(x) 88 | if split_strategy == "uniform": 89 | alloc_fn = ( 90 | lambda i, num: num_nodes // num + 1 91 | if i < num_nodes % num 92 | else num_nodes // num 93 | ) 94 | cpu_alloc = [alloc_fn(i, num_cpus) for i in range(num_cpus)] 95 | gpu_alloc = [alloc_fn(i, num_gpus) for i in range(int(num_gpus))] 96 | cpu_i, gpu_i, b_cpu_i, b_gpu_i = 0, 0, 0, 0 97 | for i in range(num_nodes): 98 | resources_i = {} 99 | cpu_alloc_i = (1 - SAFETY_EPSILON) / cpu_alloc[cpu_i] 100 | resources_i["CPU"] = fix_size(cpu_alloc_i) 101 | b_cpu_i = b_cpu_i + 1 102 | if b_cpu_i == cpu_alloc[cpu_i]: 103 | cpu_i = cpu_i + 1 104 | b_cpu_i = 0 105 | 106 | if num_gpus > 0: 107 | gpu_alloc_i = (1 - SAFETY_EPSILON) / gpu_alloc[gpu_i] 108 | resources_i["GPU"] = fix_size(gpu_alloc_i) 109 | b_gpu_i = b_gpu_i + 1 110 | if b_gpu_i == gpu_alloc[gpu_i]: 111 | gpu_i = gpu_i + 1 112 | b_gpu_i = 0 113 | 114 | resources.append(resources_i) 115 | 116 | elif split_strategy == "random": 117 | perc = [random.random() for _ in range(num_nodes)] 118 | total = sum(perc) 119 | perc = [s / total for s in perc] 120 | 121 | else: 122 | raise ValueError(f"Unknown split strategy: {split_strategy}.") 123 | 124 | if not is_tune: 125 | from ray.util.placement_group import placement_group 126 | 127 | return placement_group(bundles=resources, strategy=placement_strategy) 128 | else: 129 | from ray import tune 130 | 131 | return tune.PlacementGroupFactory( 132 | bundles=resources, strategy=placement_strategy 133 | ) 134 | -------------------------------------------------------------------------------- /fedray/core/communication/topology/manager.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict 3 | from typing import List 4 | from typing import Optional 5 | from typing import Union 6 | 7 | import networkx as nx 8 | import numpy as np 9 | import ray 10 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 11 | 12 | import fedray 13 | from ..message import Message 14 | 15 | 16 | TP_MANAGER_CPU_RESOURCES = 0.05 17 | 18 | 19 | @fedray.remote(num_cpus=TP_MANAGER_CPU_RESOURCES) 20 | class TopologyManager: 21 | """ 22 | The TopologyManager is responsible for managing the network topology of a 23 | federation. It is responsible for creating the network topology, and for 24 | forwarding messages to the appropriate nodes through the neighborhood relationship. 25 | 26 | The actual definition of the network happens within the `build_network` method, 27 | which is called by the Federation object. Given a list of node IDs, and a topology, 28 | the TopologyManager creates a networkx graph, and stores the node IDs and 29 | the graph. The graph is used to determine the neighborhood relationship between 30 | nodes, and the node IDs are used to retrieve the nodes from the federation's 31 | node registry. 32 | """ 33 | 34 | def __init__(self, federation_id: str) -> None: 35 | """ 36 | Creates a new TopologyManager. 37 | 38 | Args: 39 | federation_id (str): The ID of the federation to which this topology 40 | manager belongs. 41 | """ 42 | self._fed_id = federation_id 43 | 44 | self._node_ids: List[str] = [] 45 | self._nodes: Dict[str] = None 46 | self._topology = None 47 | self._graph: nx.Graph = None 48 | 49 | def forward(self, msg: Message, to: Optional[Union[str, List[str]]] = None): 50 | """ 51 | Forwards a message to the appropriate nodes. 52 | 53 | Args: 54 | msg (Message): The message to forward. 55 | to (Optional[Union[str, List[str]]], optional): The nodes to which the 56 | message should be forwarded. If None, the message is forwarded to 57 | all neighbors of the sender. Defaults to None. 58 | 59 | Raises: 60 | ValueError: If the `to` argument is not None, and any of the node IDs 61 | in the list are not neighbors of the sender. 62 | 63 | Returns: 64 | List[ObjectRef]: A list of ObjectRefs to the messages that were sent. 65 | """ 66 | 67 | if to is None: 68 | to = self.get_neighbors(msg.sender_id) 69 | else: 70 | neighbors = self.get_neighbors(msg.sender_id) 71 | for curr_id in to: 72 | if not all([curr_id in neighbors for curr_id in neighbors]): 73 | raise ValueError(f"{curr_id} is not a neighbor of {msg.sender_id}") 74 | msg_ref = ray.put(msg) 75 | return ray.get([self._nodes[neigh].enqueue.remote(msg_ref) for neigh in to]) 76 | 77 | def get_neighbors(self, node_id: str): 78 | """ 79 | Returns the neighbors of a node. This function is implicitly called by the 80 | `neighbors` property of a FedRayNode. 81 | 82 | Args: 83 | node_id (str): The ID of the node for which to retrieve the neighbors. 84 | 85 | Returns: 86 | List[str]: A list of node IDs. 87 | """ 88 | 89 | return [neigh for neigh in self._graph.neighbors(node_id)] 90 | 91 | def build_network(self, node_ids: List[str], topology: Union[str, np.ndarray]): 92 | """Builds the network topology. 93 | 94 | Args: 95 | node_ids (List[str]): A list of node IDs. 96 | topology (Union[str, np.ndarray]): The topology to use. If a string, it 97 | must be one of the following: "star". If a numpy array, it must be 98 | a square matrix of shape (N, N), where N is the number of nodes. 99 | The matrix must be symmetric, and the diagonal must be all zeros. 100 | The matrix must be binary, and the matrix must be symmetric. 101 | 102 | Raises: 103 | ValueError: If the number of nodes is less than 2. 104 | NotImplementedError: If the topology is a numpy array. 105 | """ 106 | 107 | if len(node_ids) < 2: 108 | raise ValueError("At least 2 nodes are required to setup the network.") 109 | self._node_ids = node_ids 110 | self._nodes = { 111 | node_id: ray.get_actor("/".join([self._fed_id, node_id])) 112 | for node_id in self._node_ids 113 | } 114 | 115 | self._topology = topology 116 | if isinstance(self._topology, str): 117 | if self._topology == "star": 118 | self._graph = nx.star_graph(self._node_ids) 119 | elif isinstance(self._topology, np.ndarray): 120 | raise NotImplementedError 121 | 122 | 123 | def _get_or_create_topology_manager( 124 | placement_group, federation_id: str, bundle_offset: int 125 | ) -> TopologyManager: 126 | """ 127 | Returns the TopologyManager for the given federation ID. If the TopologyManager 128 | does not exist, it is created. 129 | 130 | Args: 131 | placement_group (PlacementGroup): The placement group to use. 132 | federation_id (str): The ID of the federation. 133 | bundle_offset (int): The bundle offset. 134 | 135 | Returns: 136 | TopologyManager: The TopologyManager. 137 | """ 138 | 139 | return TopologyManager.options( 140 | name=federation_id + "/topology_manager", 141 | num_cpus=TP_MANAGER_CPU_RESOURCES, 142 | scheduling_strategy=PlacementGroupSchedulingStrategy( 143 | placement_group, placement_group_bundle_index=0 + bundle_offset 144 | ), 145 | ).remote(federation_id=federation_id) 146 | -------------------------------------------------------------------------------- /fedray/core/federation/decentralized.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import threading 3 | from typing import Dict 4 | from typing import List 5 | from typing import Literal 6 | from typing import Type 7 | from typing import Union 8 | 9 | import numpy as np 10 | import ray 11 | from ray.util.placement_group import PlacementGroup 12 | 13 | from fedray.core.communication.topology.manager import _get_or_create_topology_manager 14 | from fedray.core.federation import Federation 15 | from fedray.core.node import FedRayNode 16 | from fedray.core.node import VirtualNode 17 | 18 | 19 | class DecentralizedFederation(Federation): 20 | """ 21 | A DecentralizedFederation is a special type of Federation that implements a 22 | decentralized federated learning scheme. It consists of multiple nodes, each of 23 | which has a different role. In this scheme, the nodes are connected with a 24 | user-defined topology. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | node_template: Type[FedRayNode], 30 | n_nodes_or_ids: Union[int, List[str]], 31 | roles: List[str], 32 | topology: Union[str, np.ndarray], 33 | node_config: Union[Dict, List[Dict]], 34 | resources: Union[str, PlacementGroup] = "uniform", 35 | federation_id: str = "", 36 | is_tune: bool = True, 37 | bundle_offset: int = 0, 38 | ) -> None: 39 | """Creates a new DecentralizedFederation object. 40 | 41 | Args: 42 | node_template (Type[FedRayNode]): The template for the nodes. 43 | n_nodes_or_ids (Union[int, List[str]]): The number of nodes, or a list of 44 | node IDs. 45 | roles (List[str]): A list of roles for the nodes. The length of this list 46 | must be equal to the number of nodes. 47 | topology (Union[str, np.ndarray]): The topology to be used for the 48 | federation. This can either be a string, or a numpy array. If it is a 49 | string, it must be one of the following: "fully_connected", "ring", 50 | "star", "line", "mesh", "grid", "tree", "bipartite", "custom". If it is 51 | a numpy array, it must be a square binary matrix of shape (N, N), where 52 | N is the number of nodes. The matrix must be symmetric, and the diagonal 53 | must be all zeros. 54 | node_config (Union[Dict, List[Dict]]): The configuration for the nodes. 55 | This can either be a dictionary, or a list of dictionaries. If it is a 56 | dictionary, the same configuration will be used for all nodes. If it is 57 | a list of dictionaries, the length of the list must be equal to the 58 | number of nodes. 59 | resources (Union[str, PlacementGroup], optional): The resources to be used 60 | for the federation. This can either be a string, or a PlacementGroup. 61 | If it is a string, it must be one of the following: "uniform", "random". 62 | Defaults to "uniform". 63 | federation_id (str, optional): The ID of the federation. Defaults to "". 64 | is_tune (bool, optional): Whether the federation is used for a Ray Tune 65 | experiment. Defaults to False. 66 | bundle_offset (int, optional): The offset for the bundle IDs. Defaults to 0. 67 | 68 | Raises: 69 | ValueError: If the number of nodes does not match the number of roles. 70 | """ 71 | if isinstance(n_nodes_or_ids, int): 72 | node_ids = [f"node_{i}" for i in range(n_nodes_or_ids)] 73 | else: 74 | node_ids = n_nodes_or_ids 75 | 76 | nodes = [ 77 | VirtualNode( 78 | node_template, 79 | node_id, 80 | federation_id, 81 | role, 82 | node_config[i] if isinstance(node_config, list) else node_config, 83 | ) 84 | for i, (node_id, role) in enumerate(zip(node_ids, roles)) 85 | ] 86 | 87 | super(DecentralizedFederation, self).__init__( 88 | nodes, topology, resources, federation_id, is_tune, bundle_offset 89 | ) 90 | 91 | def train(self, train_args: Union[Dict, List[Dict]], blocking: bool = False): 92 | """Performs a training session in the federation. This method calls the train 93 | method of each node in the federation. 94 | 95 | Args: 96 | train_args (Union[Dict, List[Dict]]): The arguments for the train method. 97 | This can either be a dictionary, or a list of dictionaries. If it is a 98 | dictionary, the same arguments will be used for all nodes. If it is a 99 | list of dictionaries, the length of the list must be equal to the 100 | number of nodes. 101 | blocking (bool, optional): Whether the method should block until the 102 | training session is finished. Defaults to False. 103 | 104 | Raises: 105 | ValueError: If the number of nodes does not match the number of train_args. 106 | """ 107 | if self._tp_manager is None: 108 | self._tp_manager = _get_or_create_topology_manager( 109 | self._pg, self._fed_id, self._bundle_offset 110 | ) 111 | train_nodes = [] 112 | for i, node in enumerate(self._nodes, start=1 + self._bundle_offset): 113 | if "train" in node.role: 114 | if not node.built: 115 | node.build(i, self._pg) 116 | train_nodes.append(node) 117 | 118 | ray.get( 119 | self._tp_manager.build_network.remote( 120 | [node.id for node in train_nodes], self._topology 121 | ) 122 | ) 123 | ray.get([node.handle._setup_train.remote() for node in train_nodes]) 124 | train_args = [ 125 | (train_args[i] if isinstance(train_args, List) else train_args) 126 | for i, _ in enumerate(train_nodes) 127 | ] 128 | self._runtime = threading.Thread( 129 | target=ray.get, 130 | args=[[node.handle._train.remote(**train_args[i]) for node in train_nodes]], 131 | daemon=True, 132 | ) 133 | self._runtime.start() 134 | if blocking: 135 | self._runtime.join() 136 | 137 | def test( 138 | self, phase: Literal["train", "eval", "test"], aggregate: bool = True, **kwargs 139 | ) -> Union[List[float], float]: 140 | """ 141 | Performs a test session in the federation. 142 | 143 | Args: 144 | phase (Literal["train", "eval", "test"]): the role of the nodes on which 145 | the test should be performed. 146 | aggregate (bool, optional): Whether to aggregate the results weighted by the 147 | number of samples of the local datasets. If False, the results of each 148 | node are returned in a list. Defaults to True. 149 | **kwargs: The arguments to be passed to the test function of the nodes. 150 | 151 | Returns: 152 | Union[List[float], float]: The results of the test session. If aggregate is 153 | True, the results are averaged. 154 | """ 155 | test_nodes = [] 156 | for i, node in enumerate(self._nodes, start=1 + self._bundle_offset): 157 | if node.role is not None and phase in node.role: 158 | test_nodes.append(node) 159 | if node.handle is None: 160 | node.build(i, self._pg) 161 | remotes = [node.handle.test.remote(phase, **kwargs) for node in test_nodes] 162 | results = ray.get(remotes) 163 | if not aggregate: 164 | return results 165 | 166 | values, weights = zip(*results) 167 | return np.average(values, weights=weights, axis=0) 168 | -------------------------------------------------------------------------------- /fedray/core/federation/client_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import threading 3 | from typing import Dict 4 | from typing import List 5 | from typing import Literal 6 | from typing import Optional 7 | from typing import Type 8 | from typing import Union 9 | 10 | import numpy as np 11 | import ray 12 | from ray.util.placement_group import PlacementGroup 13 | 14 | from fedray.core.communication.topology.manager import _get_or_create_topology_manager 15 | from fedray.core.federation import Federation 16 | from fedray.core.node import FedRayNode 17 | from fedray.core.node import VirtualNode 18 | 19 | 20 | class ClientServerFederation(Federation): 21 | """ 22 | A ClientServerFederation is a special type of Federation that implements a federated 23 | client-server scheme. It consists of a server node, and multiple client nodes. In 24 | this scheme, the nodes are connected in a star topology, where the server node is 25 | the center of the star, and the client nodes are the leaves of the star. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | server_template: Type[FedRayNode], 31 | client_template: Type[FedRayNode], 32 | n_clients_or_ids: Union[int, List[str]], 33 | roles: List[str], 34 | server_config: Dict = {}, 35 | client_config: Dict = {}, 36 | server_id: str = "server", 37 | resources: Union[str, PlacementGroup] = "uniform", 38 | federation_id: str = "", 39 | is_tune: bool = False, 40 | bundle_offset: int = 0, 41 | ) -> None: 42 | """Creates a new ClientServerFederation object. 43 | 44 | Args: 45 | server_template (Type[FedRayNode]): The template for the server node. 46 | client_template (Type[FedRayNode]): The template for the client nodes. 47 | n_clients_or_ids (Union[int, List[str]]): The number of clients, or a list 48 | of client IDs. 49 | roles (List[str]): A list of roles for the client nodes. The length of this 50 | list must be equal to the number of clients. 51 | server_config (Dict, optional): The configuration to be passed to the build 52 | method of the server node. Defaults to {}. 53 | client_config (Dict, optional): The configuration to be passed to the build 54 | method of the client nodes. Defaults to {}. 55 | server_id (str, optional): The ID of the server node. Defaults to "server". 56 | resources (Union[str, PlacementGroup], optional): The resources to be used 57 | for the nodes. Defaults to "uniform". 58 | federation_id (str, optional): The ID of the federation. Defaults to "". 59 | is_tune (bool, optional): Whether the federation is used for a Ray Tune 60 | experiment. Defaults to False. 61 | bundle_offset (int, optional): The offset to be used for the bundle IDs. 62 | This is useful whenever we are allocating multiple federations in the 63 | same PlacementGroup. Defaults to 0. 64 | 65 | Raises: 66 | ValueError: If the number of clients does not match the number of roles. 67 | """ 68 | if isinstance(n_clients_or_ids, int): 69 | c_ids = [f"client_{i}" for i in range(n_clients_or_ids)] 70 | else: 71 | c_ids = n_clients_or_ids 72 | 73 | nodes = [ 74 | VirtualNode( 75 | server_template, server_id, federation_id, "train", server_config 76 | ) 77 | ] 78 | for c_id, role in zip(c_ids, roles): 79 | nodes.append( 80 | VirtualNode(client_template, c_id, federation_id, role, client_config) 81 | ) 82 | 83 | super(ClientServerFederation, self).__init__( 84 | nodes, "star", resources, federation_id, is_tune, bundle_offset 85 | ) 86 | 87 | def train( 88 | self, server_args: Dict, client_args: Dict, blocking: bool = False 89 | ) -> None: 90 | """ 91 | Performs a training session in the federation. Before calling the train method 92 | of the nodes, the method instantiates the training nodes in the federation by 93 | calling the .build 94 | 95 | 96 | Args: 97 | server_args (Dict): The arguments to be passed to the train function of the 98 | server node. 99 | client_args (Dict): The arguments to be passed to the train function of the 100 | client nodes. 101 | blocking (bool, optional): Whether to block the current thread until the 102 | training session is finished. Defaults to False. 103 | """ 104 | if self._tp_manager is None: 105 | self._tp_manager = _get_or_create_topology_manager( 106 | self._pg, self._fed_id, self._bundle_offset 107 | ) 108 | train_nodes = [] 109 | for i, node in enumerate(self._nodes, start=1 + self._bundle_offset): 110 | if "train" in node.role: 111 | if not node.built: 112 | node.build(i, self._pg) 113 | train_nodes.append(node) 114 | 115 | ray.get( 116 | self._tp_manager.build_network.remote( 117 | [node.id for node in train_nodes], self._topology 118 | ) 119 | ) 120 | ray.get([node.handle._setup_train.remote() for node in train_nodes]) 121 | 122 | server_args = [server_args] 123 | client_args = [ 124 | client_args[i] if isinstance(client_args, List) else client_args 125 | for i, _ in enumerate(train_nodes[1:]) 126 | ] 127 | train_args = server_args + client_args 128 | 129 | self._runtime_remotes = [ 130 | node.handle._train.remote(**train_args[i]) 131 | for i, node in enumerate(train_nodes) 132 | ] 133 | self._runtime = threading.Thread( 134 | target=ray.get, args=[self._runtime_remotes], daemon=True 135 | ) 136 | self._runtime.start() 137 | if blocking: 138 | self._runtime.join() 139 | 140 | def test( 141 | self, phase: Literal["train", "eval", "test"], aggregate: bool = True, **kwargs 142 | ) -> Union[List[float], float]: 143 | """ 144 | Performs a test session in the federation. 145 | 146 | Args: 147 | phase (Literal["train", "eval", "test"]): the role of the nodes on which 148 | the test should be performed. 149 | aggregate (bool, optional): Whether to aggregate the results weighted by the 150 | number of samples of the local datasets. If False, the results of each 151 | node are returned in a list. Defaults to True. 152 | **kwargs: The arguments to be passed to the test function of the nodes. 153 | 154 | Returns: 155 | Union[List[float], float]: The results of the test session. If aggregate is 156 | True, the results are averaged. 157 | """ 158 | test_nodes = [] 159 | for i, node in enumerate(self._nodes[1:], start=2 + self._bundle_offset): 160 | if phase in node.role: 161 | test_nodes.append(node) 162 | if node.handle is None: 163 | node.build(i, self._pg) 164 | remotes = [node.handle.test.remote(phase, **kwargs) for node in test_nodes] 165 | 166 | results = ray.get(remotes) 167 | if not aggregate: 168 | return results 169 | 170 | values, weights = zip(*results) 171 | return np.average(values, weights=weights, axis=0) 172 | 173 | def pull_version( 174 | self, 175 | node_ids: Union[str, List[str]] = "server", 176 | timeout: Optional[float] = None, 177 | ) -> Dict: 178 | """ 179 | Pulls the latest version of a model from in a federation. The default 180 | behavior is to pull the version from the server node. 181 | 182 | Args: 183 | node_ids (Union[str, List[str]], optional): The ID of the node(s) from which 184 | to pull the version. Defaults to "server". 185 | timeout (Optional[float], optional): The timeout for the pull operation. 186 | Defaults to None. 187 | 188 | Returns: 189 | Dict: The latest version of the model. 190 | """ 191 | return super().pull_version(node_ids, timeout) 192 | 193 | @property 194 | def server(self): 195 | """Returns the handle of the server node.""" 196 | return self._nodes[0].handle 197 | -------------------------------------------------------------------------------- /fedray/core/federation/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import threading 3 | from typing import Dict 4 | from typing import List 5 | from typing import Literal 6 | from typing import Optional 7 | from typing import Union 8 | 9 | import numpy as np 10 | import ray 11 | from ray.util.placement_group import PlacementGroup 12 | 13 | from fedray.core.communication.message import Message 14 | from fedray.core.communication.topology.manager import TopologyManager 15 | from fedray.core.node import VirtualNode 16 | from fedray.util.resources import get_resources_split 17 | 18 | 19 | class Federation(object): 20 | """ 21 | The Federation class is the main class that is used to create a federated 22 | learning system. It is responsible for creating the nodes, and for managing 23 | the network topology. It is also responsible for performing training and testing of 24 | the models within the federation, with the implemented algorithm. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | nodes: List[VirtualNode], 30 | topology: Union[str, np.ndarray], 31 | resources: Union[str, PlacementGroup] = "uniform", 32 | federation_id: str = "", 33 | is_tune: bool = False, 34 | bundle_offset: int = 0, 35 | ): 36 | """Creates a new Federation object. 37 | 38 | Args: 39 | nodes (List[VirtualNode]): A list of VirtualNode objects that are 40 | part of the federation. 41 | topology (Union[str, np.ndarray]): The topology of the network. Can 42 | be either a string, or a numpy array. If a string, it must be one 43 | of the following: "star". If a numpy array, it must be a square 44 | binary matrix of shape (N, N), where N is the number of nodes. The 45 | matrix must be symmetric, and the diagonal must be all zeros. 46 | resources (Union[str, PlacementGroup], optional): The resources to use. 47 | If a string, it must be one of the following: "uniform", "random". 48 | If a PlacementGroup, it must be a placement group that has been 49 | created with the same number of bundles as the number of nodes in 50 | the federation. If parameter `is_tune` is set to True, this argument 51 | is ignored. Defaults to "uniform". 52 | federation_id (str, optional): The ID of the federation. Defaults to "". 53 | is_tune (bool, optional): Whether the federation has to be instantiated 54 | within a Tune experiment. Defaults to False. 55 | bundle_offset (int, optional): The bundle offset. This parameter is useful 56 | whenever multiple federations need to be allocated within the same 57 | PlacementGroup. Defaults to 0. 58 | 59 | Raises: 60 | ValueError: If the topology is not a valid topology. 61 | ValueError: If the resources are not a valid resource specification. 62 | """ 63 | self._fed_id = federation_id 64 | self._name = "supervisor" 65 | self._nodes: List[VirtualNode] = nodes 66 | self._topology: Union[str, np.ndarray] = topology 67 | 68 | if not is_tune: 69 | if isinstance(resources, str): 70 | self._pg = get_resources_split( 71 | len(self._nodes), split_strategy=resources 72 | ) 73 | else: 74 | self._pg = resources 75 | else: 76 | self._pg = ray.util.get_current_placement_group() 77 | self._bundle_offset = 1 + bundle_offset if is_tune else bundle_offset 78 | 79 | self._tp_manager: TopologyManager = None 80 | self._state: Literal["IDLE", "RUNNING"] = "IDLE" 81 | self._runtime_remotes: List[ray.ObjectRef] = None 82 | self._runtime: threading.Thread = None 83 | 84 | def __getitem__(self, node_id: str): 85 | """Returns the handle of the node with the given ID.""" 86 | for node in self._nodes: 87 | if node.id == node_id: 88 | return node.handle 89 | raise ValueError(f"Identifier {node_id} not found in process.") 90 | 91 | def train(self, blocking: bool = False, **train_args): 92 | """ 93 | Trains the models in the federation. 94 | 95 | This method is responsible for dispatching the arguments of the training 96 | algorithm to the nodes. It then starts the training algorithm on the nodes, 97 | and returns the results of the training. 98 | """ 99 | raise NotImplementedError 100 | 101 | def test(self, phase: Literal["train", "eval", "test"], **kwargs) -> List: 102 | """ 103 | Tests the models in the federation. 104 | 105 | This method is responsible for dispatching the arguments of the testing 106 | algorithm to the nodes. It then starts the testing algorithm on the nodes, 107 | and returns the results of the testing. 108 | """ 109 | raise NotImplementedError 110 | 111 | def pull_version( 112 | self, node_ids: Union[str, List[str]], timeout: Optional[float] = None 113 | ) -> Union[List, Dict]: 114 | """ 115 | Pulls the version of the nodes with the given IDs. 116 | 117 | Args: 118 | node_ids (Union[str, List[str]]): The IDs of the nodes to pull the 119 | version from. 120 | timeout (Optional[float], optional): The timeout for the pull. If None, 121 | the pull is blocking. Defaults to None. 122 | Returns: 123 | Union[List, Dict]: The version of the nodes with the given IDs. 124 | """ 125 | to_pull = [node_ids] if isinstance(node_ids, str) else node_ids 126 | to_pull = [ 127 | node.handle._pull_version.remote() 128 | for node in self._nodes 129 | if node.id in to_pull 130 | ] 131 | 132 | if timeout is None: 133 | new_versions = ray.get(to_pull) 134 | return new_versions[0] if len(to_pull) == 1 else new_versions 135 | else: 136 | new_versions, _ = ray.wait(to_pull, timeout=timeout) 137 | if len(new_versions) == 0: 138 | return None 139 | else: 140 | return new_versions[0] if len(to_pull) == 1 else new_versions 141 | 142 | def send(self, header: str, body: Dict, to: Optional[Union[str, List[str]]] = None): 143 | """ 144 | Sends a message to the nodes with the given IDs. 145 | 146 | This method is useful whenever the user wishes to interact with the nodes in the 147 | federation during the training process. For example, the user can send a message 148 | to the nodes to change the learning rate of the models. 149 | 150 | Args: 151 | header (str): The header of the message. 152 | body (Dict): The body of the message. 153 | to (Optional[Union[str, List[str]]], optional): The IDs of the nodes to 154 | send the message to. If None, the message is sent to all nodes. 155 | Defaults to None. 156 | """ 157 | if isinstance(to, str): 158 | to = [to] 159 | 160 | msg = Message(header=header, sender_id=self._name, body=body) 161 | ray.get([self._tp_manager.forward.remote(msg, to)]) 162 | 163 | def stop(self) -> None: 164 | """Stops the federation.""" 165 | ray.get( 166 | [ 167 | node.handle.stop.remote() 168 | for node in self._nodes 169 | if node.built and "train" in node.role 170 | ] 171 | ) 172 | self._runtime.join() 173 | self._state = "IDLE" 174 | 175 | @property 176 | def running(self) -> bool: 177 | """Returns whether the federation is running a training process""" 178 | return ( 179 | self._state == "RUNNING" 180 | and self._runtime is not None 181 | and self._runtime.is_alive() 182 | ) 183 | 184 | @property 185 | def num_nodes(self) -> int: 186 | """Returns the number of nodes in the federation.""" 187 | return len(self._nodes) 188 | 189 | @property 190 | def node_ids(self) -> List[str]: 191 | """Returns the IDs of the nodes in the federation.""" 192 | return [node.id for node in self._nodes] 193 | 194 | @property 195 | def resources(self) -> Dict[str, Dict[str, Union[int, float]]]: 196 | """Returns the resources of the federation.""" 197 | res_arr = self._pg.bundle_specs 198 | resources = { 199 | "all": {"CPU": res_arr[0]["CPU"], "GPU": 0}, 200 | "tp_manager": {"CPU": res_arr[0]["CPU"]}, 201 | } 202 | for i, node in enumerate(self._nodes, start=1): 203 | resources[node.id] = res_arr[i] 204 | resources["all"]["CPU"] += res_arr[i]["CPU"] 205 | if "GPU" in res_arr[i]: 206 | resources["all"]["GPU"] += res_arr[i]["GPU"] 207 | 208 | return resources 209 | -------------------------------------------------------------------------------- /fedray/core/node/fedray_node.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | import copyreg 4 | import queue 5 | import threading 6 | import time 7 | from functools import cached_property 8 | from queue import Full 9 | from queue import Queue as _Queue 10 | from typing import Any 11 | from typing import Dict 12 | from typing import List 13 | from typing import Literal 14 | from typing import Optional 15 | from typing import Tuple 16 | from typing import Union 17 | 18 | import ray 19 | 20 | from fedray.core.communication.topology.manager import Message 21 | from fedray.core.communication.topology.manager import TopologyManager 22 | from fedray.util.exceptions import EndProcessException 23 | 24 | 25 | class FedRayNode(object): 26 | """Base class for a node in a federation. 27 | 28 | It provides all the base functionalities to allow a node to interact with the 29 | federation. **It is not meant to be used directly, but rather to be subclassed 30 | by the user to define a custom node**. 31 | 32 | While subclassing, the user can implement the ``build``, ``train`` and ``test`` 33 | methods. 34 | 35 | The ``train`` method is called when the federation starts the training process. It 36 | is responsible for the internal logic of the node along the training process (e.g., 37 | local training and aggregation for client and server nodes respectively). 38 | 39 | The ``test`` function is called to perform a round of evaluation on the federation. 40 | 41 | The communication between nodes happens only within the train function by using 42 | the ``send`` and ``receive`` methods. The ``send`` method is used to broadcast a 43 | message to the neighbors in the topology, or to a specific node. The ``receive`` 44 | method is used to get or wait for a message from the input queue. The ``receive`` 45 | method is **optionally** blocking, thus allowing to implement asynchronous behavior 46 | of the node. 47 | """ 48 | 49 | def __init__(self, node_id: str, role: str, federation_id: str = "", **build_args): 50 | """Creates a node in the federation. Must not be called directly or overridden. 51 | 52 | Args: 53 | node_id (str): The id of the node. 54 | role (str): The role of the node. It must be either a single role in 55 | ``["train", "eval", "test"]``, or a combination of them as a dash-separated 56 | string. For example, "train-eval" or "train-eval-test". 57 | federation_id (str): The id of the federation the node belongs to. 58 | Defaults to "". 59 | **build_args: Additional arguments to be passed to the build function. 60 | """ 61 | # Node hyperparameters 62 | self._fed_id: str = federation_id 63 | self._id: str = node_id 64 | self._role: str = role 65 | 66 | # Communication interface 67 | self._tp_manager: TopologyManager = None 68 | self._message_queue: Queue = None 69 | 70 | # Node's version 71 | self._version: int = 0 72 | self._version_buffer: Queue = None 73 | self._node_metrics: Dict[str, Any] = {} 74 | 75 | # Buildup function 76 | self._node_config = build_args 77 | self.build(**build_args) 78 | 79 | def build(self, **build_args): 80 | """ 81 | Performs the setup of the node's environment when the node is added to 82 | a federation. 83 | 84 | The build method and has a twofold purpose. 85 | 86 | **Define object-level attributes**. This encloses attributes that are independent 87 | from whether the node is executing the training method or the test method (e.g., 88 | choosing the optimizer, the loss function, etc.). 89 | 90 | **Perform all the resource-intensive operations in advance to avoid bottlenecks**. 91 | An example can be downloading the data from an external source, or instantiating 92 | a model with computationally-intensive techniques. 93 | 94 | Since it is called within the ``__init__`` method, the user can define additional 95 | class attributes. 96 | 97 | An example of build function can be the following: 98 | 99 | .. code-block:: python 100 | 101 | def build(self, dataset_name: str): 102 | self._dataset_name = dataset_name 103 | self._dataset = load_dataset(self._dataset_name) 104 | """ 105 | pass 106 | 107 | def _setup_train(self): 108 | """Prepares the node's environment for the training process.""" 109 | if self._tp_manager is None: 110 | self._tp_manager = ray.get_actor( 111 | "/".join([self._fed_id, "topology_manager"]) 112 | ) 113 | self._message_queue = Queue() 114 | self._version = 0 115 | self._version_buffer = Queue() 116 | return True 117 | 118 | def _train(self, **train_args): 119 | """Wrapper for the training function""" 120 | try: 121 | self.train(**train_args) 122 | except EndProcessException: 123 | print(f"Node {self.id} is exiting.") 124 | 125 | return self._node_metrics 126 | 127 | def train(self, **train_args) -> Dict: 128 | """Implements the core logic of a node within a training process. It is 129 | called by the federation when the training process starts. 130 | 131 | An example can be the client in the Federated Averaging algorithm: 132 | 133 | .. code-block:: python 134 | 135 | def train(self, **train_args): 136 | while True: 137 | # Get the model 138 | model = self.receive().body["model"] 139 | 140 | # Get the data 141 | data_fn = self.get_data() 142 | 143 | # Train the model 144 | 145 | model.train(self.dataset, self.optimizer, self.loss, self.metrics) 146 | 147 | # Send the model to the server 148 | self.send("model", model) 149 | """ 150 | raise NotImplementedError 151 | 152 | def test( 153 | self, phase: Literal["train", "eval", "test"], **kwargs 154 | ) -> Tuple[float, int]: 155 | """Implements the core logic of a node within a test process. It is 156 | called by the federation when the test session starts. 157 | 158 | Args: 159 | phase (Literal["train", "eval", "test"]): The phase of the test process. 160 | It can be either "train", "eval" or "test". 161 | **kwargs: Additional arguments to be passed to the test function. 162 | 163 | Returns: 164 | Tuple(float, int): A tuple containing the average loss and the number of 165 | samples used for the test. 166 | """ 167 | raise NotImplementedError 168 | 169 | def send(self, header: str, body: Dict, to: Optional[Union[str, List[str]]] = None): 170 | """Sends a message to a specific node or to the neighbor nodes in the federation. 171 | 172 | Args: 173 | header (str): The header of the message. 174 | body (Dict): The body of the message. 175 | to (Optional[Union[str, List[str]]], optional): The id of the node to which 176 | the message is sent. If None, the message is sent to the neighbor nodes. 177 | Defaults to None. 178 | """ 179 | if isinstance(to, str): 180 | to = [to] 181 | 182 | msg = Message(header=header, sender_id=self._id, body=body) 183 | ray.get([self._tp_manager.forward.remote(msg, to)]) 184 | 185 | def receive(self, timeout: Optional[float] = None) -> Message: 186 | """Receives a message from the message queue. If the timeout value is defined, 187 | it waits for a message for the specified amount of time. If no message is 188 | received within the timeout, it returns None. This allows to implement a node 189 | with an asynchronous behavior. 190 | 191 | Args: 192 | timeout (Optional[float], optional): The timeout value. Defaults to None. 193 | 194 | Returns: 195 | Message: The received message. 196 | 197 | Raises: 198 | EndProcessException: If the message received is a "STOP" message, it raises 199 | an EndProcessException to stop the process. This is handled under the 200 | hood by the training function. 201 | """ 202 | try: 203 | msg = self._message_queue.get(timeout=timeout) 204 | except Queue.Empty: 205 | msg = None 206 | 207 | if msg is not None and msg.header == "STOP": 208 | raise EndProcessException 209 | return msg 210 | 211 | def update_version(self, **kwargs): 212 | """Updates the node's version. Whenever this function is called, the version is 213 | stored in an internal queue. The version is pulled from the queue whever the 214 | federation calls the ``pull_version`` method. 215 | """ 216 | to_save = {k: copy.deepcopy(v) for k, v in kwargs.items()} 217 | version_dict = { 218 | "id": self.id, 219 | "n_version": self.version, 220 | "timestamp": time.time(), 221 | "model": to_save, 222 | } 223 | self._version_buffer.put(version_dict) 224 | self._version += 1 225 | 226 | def stop(self): 227 | """Stops the node's processes.""" 228 | self._message_queue.put(Message("STOP"), index=0) 229 | 230 | def enqueue(self, msg: ray.ObjectRef): 231 | """Enqueues a message in the node's message queue. This method is called by the 232 | topology manager when a message is sent from a neighbor. 233 | 234 | Args: 235 | msg (ray.ObjectRef): The message to be enqueued. 236 | 237 | Returns: 238 | bool: True, a dummy value for the federation. 239 | """ 240 | self._message_queue.put(msg) 241 | return True 242 | 243 | def _invalidate_neighbors(self): 244 | """ 245 | Invalidates the node's neighbors. This method is called by the topology manager 246 | when the topology changes. In future versions, this will be used to implement 247 | dynamic topologies. 248 | """ 249 | del self.neighbors 250 | 251 | def _pull_version(self): 252 | """ 253 | Pulls the version from the version buffer. This method is called under the 254 | hood by the federation when the `pull_version` method is called. 255 | """ 256 | return self._version_buffer.get(block=True) 257 | 258 | @property 259 | def id(self) -> str: 260 | """Returns the node's id.""" 261 | return self._id 262 | 263 | @property 264 | def version(self) -> int: 265 | """Returns the node's current version.""" 266 | return self._version 267 | 268 | @property 269 | def is_train_node(self) -> bool: 270 | """True if the node is a training node, False otherwise.""" 271 | return "train" in self._role.split("-") 272 | 273 | @property 274 | def is_eval_node(self) -> bool: 275 | """True if the node is an evaluation node, False otherwise.""" 276 | return "eval" in self._role.split("-") 277 | 278 | @property 279 | def is_test_node(self) -> bool: 280 | """True if the node is a test node, False otherwise.""" 281 | return "test" in self._role.split("-") 282 | 283 | @cached_property 284 | def neighbors(self) -> List[str]: 285 | """Returns the list of the node's neighbor IDs.""" 286 | return ray.get(self._tp_manager.get_neighbors.remote(self.id)) 287 | 288 | 289 | class Queue(_Queue, object): 290 | Empty = queue.Empty 291 | Full = queue.Full 292 | 293 | def put(self, item, block=True, timeout=None, index=None): 294 | """Put an item into the queue. 295 | If optional args 'block' is true and 'timeout' is None (the default), 296 | block if necessary until a free slot is available. If 'timeout' is 297 | a non-negative number, it blocks at most 'timeout' seconds and raises 298 | the Full exception if no free slot was available within that time. 299 | Otherwise ('block' is false), put an item on the queue if a free slot 300 | is immediately available, else raise the Full exception ('timeout' 301 | is ignored in that case). 302 | """ 303 | with self.not_full: 304 | if self.maxsize > 0: 305 | if not block: 306 | if self._qsize() >= self.maxsize: 307 | raise Full 308 | elif timeout is None: 309 | while self._qsize() >= self.maxsize: 310 | self.not_full.wait() 311 | elif timeout < 0: 312 | raise ValueError("'timeout' must be a non-negative number") 313 | else: 314 | endtime = time() + timeout 315 | while self._qsize() >= self.maxsize: 316 | remaining = endtime - time() 317 | if remaining <= 0.0: 318 | raise Full 319 | self.not_full.wait(remaining) 320 | self._put(item, index) 321 | self.unfinished_tasks += 1 322 | self.not_empty.notify() 323 | 324 | def _put(self, item, index) -> None: 325 | if index is None: 326 | self.queue.append(item) 327 | else: 328 | self.queue.insert(index, item) 329 | 330 | 331 | def pickle_queue(q): 332 | q_dct = q.__dict__.copy() 333 | del q_dct["mutex"] 334 | del q_dct["not_empty"] 335 | del q_dct["not_full"] 336 | del q_dct["all_tasks_done"] 337 | return Queue, (), q_dct 338 | 339 | 340 | def unpickle_queue(state): 341 | q = state[0]() 342 | q.mutex = threading.Lock() 343 | q.not_empty = threading.Condition(q.mutex) 344 | q.not_full = threading.Condition(q.mutex) 345 | q.all_tasks_done = threading.Condition(q.mutex) 346 | q.__dict__ = state[2] 347 | return q 348 | 349 | 350 | copyreg.pickle(Queue, pickle_queue, unpickle_queue) 351 | --------------------------------------------------------------------------------