├── .github └── workflows │ ├── pre-commit.yml │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── LICENSE ├── README.md ├── benchmark ├── README.md ├── benchmark_GC.py ├── benchmark_LP.py ├── benchmark_NC.py ├── configs │ ├── config_FedGAT.yaml │ ├── config_FedGCN.yaml │ ├── config_GC_FedAvg.yaml │ ├── config_GC_FedProx.yaml │ ├── config_GC_GCFL+.yaml │ ├── config_GC_GCFL+dWs.yaml │ ├── config_GC_GCFL.yaml │ ├── config_GC_SelfTrain.yaml │ ├── config_LP.yaml │ └── grafana_customized_metric_dashboard.json └── figure │ ├── GC_comm_costs │ ├── extract_GC_log.py │ ├── gc_accuracy_comparison.pdf │ ├── gc_comm_cost_comparison.pdf │ └── gc_train_time_comparison.pdf │ ├── LP_comm_costs │ ├── extract_LP_log.py │ ├── lp_auc_comparison.pdf │ ├── lp_comm_cost_comparison.pdf │ └── lp_train_time_comparison.pdf │ ├── NC_comm_costs │ ├── extract_NC_log.py │ ├── extract_global_test_acc.py │ ├── nc_accuracy_comparison_beta10.pdf │ ├── nc_accuracy_comparison_beta100.pdf │ ├── nc_accuracy_comparison_beta10000.pdf │ ├── nc_accuracy_curve_citeseer.pdf │ ├── nc_accuracy_curve_cora.pdf │ ├── nc_accuracy_curve_ogbn-arxiv.pdf │ ├── nc_accuracy_curve_pubmed.pdf │ ├── nc_comm_cost_comparison_beta10.pdf │ ├── nc_comm_cost_comparison_beta100.pdf │ ├── nc_comm_cost_comparison_beta10000.pdf │ ├── nc_train_time_comparison_beta10.pdf │ ├── nc_train_time_comparison_beta100.pdf │ └── nc_train_time_comparison_beta10000.pdf │ ├── figure_GCbyAlgo.py │ ├── figure_GCbyDataset.py │ ├── figure_GCfinal.py │ ├── figure_NNN.py │ ├── figure_batch.py │ ├── figure_batch_3d.py │ ├── figure_batch_combine.py │ ├── figure_papers100M.py │ └── figure_traintime.py ├── docker_requirements.txt ├── docs ├── Makefile ├── cite.rst ├── conf.py ├── dev_script │ ├── kuberflow_sample.py │ ├── processing_script_GC.py │ ├── processing_script_LP.py │ ├── processing_script_NC.py │ └── save_graph_node_classification.py ├── fedgraph.data_process.rst ├── fedgraph.federated_methods.rst ├── fedgraph.gnn_models.rst ├── fedgraph.monitor_class.rst ├── fedgraph.server_class.rst ├── fedgraph.setup_ray_cluster.rst ├── fedgraph.train_func.rst ├── fedgraph.trainer_class.rst ├── fedgraph.utils_gc.rst ├── fedgraph.utils_lp.rst ├── fedgraph.utils_nc.rst ├── index.rst ├── install.rst ├── make.bat ├── reference.rst ├── requirements.txt ├── sg_execution_times.rst └── zreferences.bib ├── fedgraph ├── __init__.py ├── data_process.py ├── federated_methods.py ├── gnn_models.py ├── he_context.pkl ├── he_training_context.pkl ├── monitor_class.py ├── server_class.py ├── train_func.py ├── trainer_class.py ├── utils_gc.py ├── utils_lp.py ├── utils_nc.py └── version.py ├── generate_he_context.py ├── kuberay ├── config │ ├── grafana │ │ ├── KubeRay-ApiServer-1650105351221.json │ │ ├── KubeRay-Controller-Runtime-Controllers-1650108080992.json │ │ ├── data_grafana_dashboard.json │ │ ├── default_grafana_dashboard.json │ │ ├── serve_deployment_grafana_dashboard.json │ │ └── serve_grafana_dashboard.json │ └── prometheus │ │ ├── podMonitor.yaml │ │ ├── rules │ │ └── prometheusRules.yaml │ │ └── serviceMonitor.yaml └── install │ └── prometheus │ ├── install.sh │ └── overrides.yaml ├── mypy.ini ├── pyproject.toml ├── quickstart.py ├── ray_cluster_configs ├── eks_cluster_config.yaml ├── ray_kubernetes_cluster.yaml └── ray_kubernetes_ingress.yaml ├── setup.cfg ├── setup.py ├── setup_cluster.md ├── setup_cluster.sh └── tutorials ├── FGL_GC.py ├── FGL_LP.py ├── FGL_NC.py ├── FGL_NC_HE.py └── README.txt /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | on: 3 | pull_request: 4 | push: 5 | 6 | jobs: 7 | pre-commit: 8 | runs-on: ubuntu-latest 9 | env: 10 | SKIP: no-commit-to-branch 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 15 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | #dashboard 29 | tutorials/prometheus-2.52.0.darwin-arm64/ 30 | prometheus-2.52.0.darwin-arm64/ 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | *.csv 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | .idea/ 165 | 166 | # MacOS specific 167 | # General 168 | .DS_Store 169 | .AppleDouble 170 | .LSOverride 171 | 172 | # Icon must end with two \r 173 | Icon 174 | 175 | 176 | # Thumbnails 177 | ._* 178 | 179 | # Files that might appear in the root of a volume 180 | .DocumentRevisions-V100 181 | .fseventsd 182 | .Spotlight-V100 183 | .TemporaryItems 184 | .Trashes 185 | .VolumeIcon.icns 186 | .com.apple.timemachine.donotpresent 187 | 188 | # Directories potentially created on remote AFP share 189 | .AppleDB 190 | .AppleDesktop 191 | Network Trash Folder 192 | Temporary Items 193 | .apdisk 194 | 195 | data/ 196 | docs/tutorials/ 197 | runs/ 198 | tutorials/dataset 199 | tutorials/figure_* 200 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.5.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-yaml 11 | - repo: https://github.com/PyCQA/isort 12 | rev: 5.12.0 13 | hooks: 14 | - id: isort 15 | name: isort 16 | - repo: https://github.com/pre-commit/mirrors-mypy 17 | rev: v1.7.1 18 | hooks: 19 | - id: mypy 20 | - repo: https://github.com/psf/black 21 | rev: 23.11.0 22 | hooks: 23 | - id: black 24 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.10" 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use the official Python image as a base image 2 | FROM python:3.11.9 3 | 4 | # Set the working directory 5 | WORKDIR /app 6 | 7 | # Install PyTorch early to leverage caching 8 | RUN pip install torch 9 | 10 | # # Copy the wheels directory 11 | # COPY wheels ./wheels 12 | 13 | # # Install torch-geometric related wheels from the local directory 14 | # RUN pip install --no-cache-dir --find-links=./wheels \ 15 | # torch-cluster \ 16 | # torch-scatter \ 17 | # torch-sparse \ 18 | # torch-spline-conv 19 | 20 | # Copy the requirements file (excluding torch-geometric wheels as they are pre-installed) 21 | COPY docker_requirements.txt . 22 | 23 | # Install remaining dependencies from the requirements file 24 | RUN pip install --no-cache-dir -r docker_requirements.txt 25 | 26 | # Copy the remaining application files 27 | COPY fedgraph /app/fedgraph 28 | COPY setup.py . 29 | COPY README.md . 30 | 31 | # Install the application 32 | RUN pip install . 33 | 34 | # Copy documentation and examples 35 | COPY tutorials /app/docs/examples 36 | 37 | # Specify the command to run the application 38 | # CMD ["python", "/app/docs/examples/example_LP.py"] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2024, fedgraph-team 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Federated Graph Learning [![PyPI Downloads](https://static.pepy.tech/badge/fedgraph)](https://pepy.tech/projects/fedgraph) 2 | 3 | [pypi-url]: https://pypi.python.org/pypi/fedgraph 4 | 5 | **[Documentation](https://docs.fedgraph.org)** | **[Paper](https://arxiv.org/abs/2410.06340)** | **[Slack](https://join.slack.com/t/fedgraphlibrary/shared_invite/zt-2wztvbo1v-DO81DnUD86q066mxnQuWWw)** 6 | 7 | **FedGraph** *(Federated Graph)* is a library built on top of [PyTorch Geometric (PyG)](https://www.pyg.org/), 8 | [Ray](https://docs.ray.io/), and [PyTorch](https://pytorch.org/) to easily train Graph Neural Networks 9 | under federated or distributed settings. 10 | 11 | It supports various federated training methods of graph neural networks under simulated and real federated environments and supports communication between clients and the central server for model update and information aggregation. 12 | 13 | ## Main Focus 14 | - **Federated Node Classification with Cross-Client Edges**: Our library supports communicating information stored in other clients without affecting the privacy of users. 15 | - **Federated Link Prediction on Dynamic Graphs**: Our library supports balancing temporal heterogeneity across clients with privacy preservation. 16 | - **Federated Graph Classification**: Our library supports federated graph classification with non-IID graphs. 17 | 18 | 19 | 20 | 21 | ## Cross Platform Training 22 | 23 | - We support federated training across Linux, macOS, and Windows operating systems. 24 | 25 | ## Library Highlights 26 | 27 | Whether you are a federated learning researcher or a first-time user of federated learning toolkits, here are some reasons to try out FedGraph for federated learning on graph-structured data. 28 | 29 | - **Easy-to-use and unified API**: All it takes is 10-20 lines of code to get started with training a federated GNN model. GNN models are PyTorch models provided by PyG and DGL. The federated training process is handled by Ray. We abstract away the complexity of federated graph training and provide a unified API for training and evaluating FedGraph models. 30 | 31 | - **Various FedGraph methods**: Most of the state-of-the-art federated graph training methods have been implemented by library developers or authors of research papers and are ready to be applied. 32 | 33 | - **Great flexibility**: Existing FedGraph models can easily be extended for conducting your research. Simply inherit the base class of trainers and implement your methods. 34 | 35 | - **Large-scale real-world FedGraph Training**: We focus on the need for FedGraph applications in challenging real-world scenarios with privacy preservation, and support learning on large-scale graphs across multiple clients. 36 | 37 | ## Installation 38 | ```python 39 | pip install fedgraph 40 | ``` 41 | 42 | ## Quick Start 43 | ```python 44 | from fedgraph.federated_methods import run_fedgraph 45 | 46 | import attridict 47 | 48 | config = { 49 | # Task, Method, and Dataset Settings 50 | "fedgraph_task": "NC", 51 | "dataset": "cora", 52 | "method": "FedGCN", # Federated learning method, e.g., "FedGCN" 53 | "iid_beta": 10000, # Dirichlet distribution parameter for label distribution among clients 54 | "distribution_type": "average", # Distribution type among clients 55 | # Training Configuration 56 | "global_rounds": 100, 57 | "local_step": 3, 58 | "learning_rate": 0.5, 59 | "n_trainer": 2, 60 | "batch_size": -1, # -1 indicates full batch training 61 | # Model Structure 62 | "num_layers": 2, 63 | "num_hops": 1, # Number of n-hop neighbors for client communication 64 | # Resource and Hardware Settings 65 | "gpu": False, 66 | "num_cpus_per_trainer": 1, 67 | "num_gpus_per_trainer": 0, 68 | # Logging and Output Configuration 69 | "logdir": "./runs", 70 | # Security and Privacy 71 | "use_encryption": False, # Whether to use Homomorphic Encryption for secure aggregation 72 | # Dataset Handling Options 73 | "use_huggingface": False, # Load dataset directly from Hugging Face Hub 74 | "saveto_huggingface": False, # Save partitioned dataset to Hugging Face Hub 75 | # Scalability and Cluster Configuration 76 | "use_cluster": False, # Use Kubernetes for scalability if True 77 | } 78 | 79 | 80 | config = attridict(config) 81 | run_fedgraph(config) 82 | ``` 83 | 84 | ## Set Up the Ray Cluster 85 | 86 | ```bash 87 | bash setup_cluster.sh 88 | ``` 89 | 90 | ## Delete the Ray Cluster 91 | 92 | Delete the RayCluster Custom Resource: 93 | 94 | ```bash 95 | cd ray_cluster_configs 96 | kubectl delete -f ray_kubernetes_cluster.yaml 97 | kubectl delete -f ray_kubernetes_ingress.yaml 98 | ``` 99 | 100 | Confirm that the RayCluster Pods are Terminated: 101 | 102 | ```bash 103 | kubectl get pods 104 | # Ensure the output shows no Ray pods except kuberay-operator 105 | ``` 106 | 107 | Finally, Delete the node first and then delete EKS Cluster: 108 | 109 | ```bash 110 | kubectl get nodes -o name | xargs kubectl delete 111 | eksctl delete cluster --region --name 112 | ``` 113 | 114 | ## Step to Push Data to Hugging Face Hub CLI 115 | 116 | Use the following command to login to the Hugging Face Hub CLI tool when you set "save: True" in node classification tasks if you haven't done so already: 117 | 118 | ```bash 119 | huggingface-cli login 120 | ``` 121 | 122 | ## Cite 123 | 124 | Please cite [our paper](https://arxiv.org/abs/2410.06340) (and the respective papers of the methods used) if you use this code in your own work: 125 | 126 | ``` 127 | @article{yao2024fedgraph, 128 | title={FedGraph: A Research Library and Benchmark for Federated Graph Learning}, 129 | author={Yao, Yuhang and Li, Yuan and Fan, Xinyi and Li, Junhao and Liu, Kay and Jin, Weizhao and Ravi, Srivatsan and Yu, Philip S and Joe-Wong, Carlee}, 130 | journal={arXiv preprint arXiv:2410.06340}, 131 | year={2024} 132 | } 133 | @article{yao2023fedgcn, 134 | title={FedGCN: Convergence-Communication Tradeoffs in Federated Training of Graph Convolutional Networks}, 135 | author={Yao, Yuhang and Jin, Weizhao and Ravi, Srivatsan and Joe-Wong, Carlee}, 136 | journal={Advances in Neural Information Processing Systems (NeurIPS)}, 137 | year={2023} 138 | } 139 | ``` 140 | 141 | Feel free to [email us](mailto:yuhangya@andrew.cmu.edu) if you wish your work to be listed in the [external resources](). 142 | If you notice anything unexpected, please open an [issue]() and let us know. 143 | If you have any questions or are missing a specific feature, feel free [to discuss them with us](). 144 | We are motivated to constantly make FedGraph even better. 145 | -------------------------------------------------------------------------------- /benchmark/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/README.md -------------------------------------------------------------------------------- /benchmark/benchmark_GC.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Graph Classification Benchmark 3 | ======================================= 4 | 5 | Run benchmarks for various federated graph classification algorithms using a simplified approach. 6 | 7 | (Time estimate: 30 minutes) 8 | """ 9 | 10 | import os 11 | import time 12 | 13 | import attridict 14 | import ray 15 | import torch 16 | import yaml 17 | 18 | from fedgraph.federated_methods import run_fedgraph 19 | 20 | # Datasets to benchmark 21 | datasets = [ 22 | "IMDB-BINARY", 23 | "IMDB-MULTI", 24 | "MUTAG", 25 | "BZR", 26 | "COX2", 27 | "DHFR", 28 | "AIDS", 29 | # "PTC-MR", # not found 30 | # "ENZYMES", # error with 10 clients 31 | # "DD", 32 | # "PROTEINS", 33 | # "COLLAB", 34 | # "NCI1", 35 | ] 36 | 37 | # Algorithms to benchmark 38 | algorithms = ["SelfTrain", "FedAvg", "FedProx", "GCFL", "GCFL+", "GCFL+dWs"] 39 | 40 | # Number of trainers to test 41 | trainer_numbers = [10] 42 | 43 | # Number of runs per configuration 44 | runs_per_config = 1 45 | 46 | # Define additional required parameters that might be missing from YAML 47 | required_params = { 48 | "fedgraph_task": "GC", 49 | "num_cpus_per_trainer": 20, 50 | "num_gpus_per_trainer": 1 if torch.cuda.is_available() else 0, 51 | "use_cluster": True, # Set to True to enable monitoring 52 | "gpu": torch.cuda.is_available(), 53 | } 54 | 55 | # specifying a target GPU 56 | if torch.cuda.is_available(): 57 | print("using GPU") 58 | else: 59 | print("using CPU") 60 | 61 | # Main benchmark loop 62 | for dataset_name in datasets: 63 | for algorithm in algorithms: 64 | # Load the appropriate configuration file for the algorithm 65 | config_file = os.path.join( 66 | os.path.dirname(__file__), "configs", f"config_GC_{algorithm}.yaml" 67 | ) 68 | with open(config_file, "r") as file: 69 | config = attridict(yaml.safe_load(file)) 70 | 71 | # Update the configuration with specific parameters for this run 72 | config.dataset = dataset_name 73 | 74 | # Add required parameters that might be missing 75 | for param, value in required_params.items(): 76 | if not hasattr(config, param): 77 | setattr(config, param, value) 78 | 79 | for trainer_num in trainer_numbers: 80 | # Set the number of trainers 81 | config.num_trainers = trainer_num 82 | 83 | # Run multiple times for statistical significance 84 | for i in range(runs_per_config): 85 | print(f"\n{'-'*80}") 86 | print(f"Running experiment {i+1}/{runs_per_config}:") 87 | print( 88 | f"Algorithm: {algorithm}, Dataset: {dataset_name}, Trainers: {trainer_num}" 89 | ) 90 | print(f"{'-'*80}\n") 91 | 92 | # To ensure each run uses a fresh configuration object 93 | run_config = attridict({}) 94 | for key, value in config.items(): 95 | run_config[key] = value 96 | 97 | # Ensure proper parameter naming 98 | if hasattr(run_config, "model") and not hasattr( 99 | run_config, "algorithm" 100 | ): 101 | run_config.algorithm = run_config.model 102 | elif not hasattr(run_config, "model"): 103 | run_config.model = algorithm 104 | run_config.algorithm = algorithm 105 | 106 | # Run the federated learning process with clean Ray environment 107 | try: 108 | # Make sure Ray is shut down from any previous runs 109 | if ray.is_initialized(): 110 | ray.shutdown() 111 | 112 | # Run the experiment 113 | run_fedgraph(run_config) 114 | except Exception as e: 115 | print(f"Error running experiment: {e}") 116 | print(f"Configuration: {run_config}") 117 | finally: 118 | # Always ensure Ray is shut down before the next experiment 119 | if ray.is_initialized(): 120 | ray.shutdown() 121 | 122 | # Add a short delay between runs 123 | time.sleep(5) 124 | 125 | print("Benchmark completed.") 126 | -------------------------------------------------------------------------------- /benchmark/benchmark_LP.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Link Prediction Benchmark 3 | =================================== 4 | 5 | Run benchmarks for various federated link prediction algorithms using a simplified approach. 6 | 7 | (Time estimate: 30 minutes) 8 | """ 9 | 10 | import os 11 | import time 12 | 13 | import attridict 14 | import ray 15 | import torch 16 | import yaml 17 | 18 | from fedgraph.federated_methods import run_fedgraph 19 | 20 | # Methods to benchmark 21 | methods = ["4D-FED-GNN+", "STFL", "StaticGNN", "FedLink"] 22 | 23 | # Country code combinations to test 24 | country_codes_list = [["US"], ["US", "BR"], ["US", "BR", "ID", "TR", "JP"]] 25 | 26 | # Number of runs per configuration 27 | runs_per_config = 1 28 | 29 | # Define additional required parameters that might be missing from YAML 30 | required_params = { 31 | "fedgraph_task": "LP", 32 | "num_cpus_per_trainer": 3, 33 | "num_gpus_per_trainer": 1 if torch.cuda.is_available() else 0, 34 | "use_cluster": True, 35 | "gpu": torch.cuda.is_available(), 36 | "ray_address": "auto", 37 | } 38 | 39 | # Main benchmark loop 40 | for method in methods: 41 | for country_codes in country_codes_list: 42 | # Load the base configuration file 43 | config_file = os.path.join( 44 | os.path.dirname(__file__), "configs", "config_LP.yaml" 45 | ) 46 | with open(config_file, "r") as file: 47 | config = attridict(yaml.safe_load(file)) 48 | 49 | # Update the configuration with specific parameters for this run 50 | config.method = method 51 | config.country_codes = country_codes 52 | 53 | # Add required parameters that might be missing 54 | for param, value in required_params.items(): 55 | if not hasattr(config, param): 56 | setattr(config, param, value) 57 | 58 | # Set dataset path 59 | if not hasattr(config, "dataset_path") or not config.dataset_path: 60 | config.dataset_path = os.path.join( 61 | os.path.dirname(os.path.abspath(__file__)), "data", "LPDataset" 62 | ) 63 | 64 | # Run multiple times for statistical significance 65 | for i in range(runs_per_config): 66 | print(f"\n{'-'*80}") 67 | print(f"Running experiment {i+1}/{runs_per_config}:") 68 | print(f"Method: {method}, Countries: {', '.join(country_codes)}") 69 | print(f"{'-'*80}\n") 70 | 71 | # To ensure each run uses a fresh configuration object 72 | run_config = attridict({}) 73 | for key, value in config.items(): 74 | run_config[key] = value 75 | 76 | # Run the federated learning process with clean Ray environment 77 | try: 78 | # Make sure Ray is shut down from any previous runs 79 | if ray.is_initialized(): 80 | ray.shutdown() 81 | 82 | # Run the experiment 83 | run_fedgraph(run_config) 84 | except Exception as e: 85 | print(f"Error running experiment: {e}") 86 | print(f"Configuration: {run_config}") 87 | finally: 88 | # Always ensure Ray is shut down before the next experiment 89 | if ray.is_initialized(): 90 | ray.shutdown() 91 | 92 | # Add a short delay between runs 93 | time.sleep(5) 94 | 95 | print("Benchmark completed.") 96 | -------------------------------------------------------------------------------- /benchmark/benchmark_NC.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Node Classification Benchmark 3 | ======================================= 4 | 5 | Run benchmarks for various federated node classification algorithms using a simplified approach. 6 | 7 | (Time estimate: 30 minutes) 8 | """ 9 | 10 | import os 11 | import time 12 | 13 | import attridict 14 | import ray 15 | import torch 16 | import yaml 17 | 18 | from fedgraph.federated_methods import run_fedgraph 19 | 20 | # Datasets to benchmark 21 | datasets = [ 22 | "cora", 23 | "citeseer", 24 | "pubmed", 25 | "ogbn-arxiv", 26 | ] # You can add more: ["cora", "citeseer", "ogbn-arxiv", "ogbn-products"] 27 | 28 | # Number of trainers to test 29 | n_trainers = [10] 30 | 31 | # Number of hops for neighbor aggregation 32 | num_hops_list = [0, 1] 33 | 34 | # Distribution types for node partitioning 35 | distribution_list_ogbn = ["average"] 36 | distribution_list_other = ["average"] 37 | # You can expand these: distribution_list_ogbn = ["average", "lognormal", "exponential", "powerlaw"] 38 | 39 | # IID Beta values to test (controls how IID the data distribution is) 40 | iid_betas = [10000.0, 100.0, 10.0] 41 | 42 | # Number of runs per configuration 43 | runs_per_config = 1 44 | 45 | # Define additional required parameters that might be missing from YAML 46 | required_params = { 47 | "fedgraph_task": "NC", 48 | "num_cpus_per_trainer": 4, 49 | "num_gpus_per_trainer": 1 if torch.cuda.is_available() else 0, 50 | "use_cluster": True, 51 | "global_rounds": 200, 52 | "local_step": 1, 53 | "learning_rate": 0.1, 54 | "num_layers": 2, 55 | "logdir": "./runs", 56 | "use_huggingface": False, 57 | "saveto_huggingface": False, 58 | "use_encryption": False, 59 | } 60 | 61 | # Main benchmark loop 62 | for dataset in datasets: 63 | # Determine whether to use GPU based on dataset 64 | gpu = False # Set to "ogbn" in dataset if you want to use GPU for certain datasets 65 | 66 | # Choose distribution list based on dataset and number of trainers 67 | distribution_list = ( 68 | distribution_list_other 69 | if n_trainers[0] > 10 or not gpu 70 | else distribution_list_ogbn 71 | ) 72 | 73 | # Set batch sizes based on dataset 74 | if dataset == "ogbn-arxiv": 75 | batch_sizes = [-1] 76 | elif dataset == "ogbn-products": 77 | batch_sizes = [-1] 78 | elif dataset == "ogbn-papers100M": 79 | batch_sizes = [16, 32, 64, -1] 80 | else: 81 | batch_sizes = [-1] 82 | 83 | for n_trainer in n_trainers: 84 | for num_hops in num_hops_list: 85 | for distribution_type in distribution_list: 86 | for iid_beta in iid_betas: 87 | for batch_size in batch_sizes: 88 | # Load the base configuration 89 | config = attridict({}) 90 | 91 | # Set all required parameters 92 | for param, value in required_params.items(): 93 | setattr(config, param, value) 94 | 95 | # Set experiment-specific parameters 96 | config.dataset = dataset 97 | config.method = "fedgcn" if num_hops > 0 else "FedAvg" 98 | config.batch_size = batch_size 99 | config.n_trainer = n_trainer 100 | config.num_hops = num_hops 101 | config.iid_beta = iid_beta 102 | config.distribution_type = distribution_type 103 | config.gpu = gpu 104 | 105 | # Run multiple times for statistical significance 106 | for i in range(runs_per_config): 107 | print(f"\n{'-'*80}") 108 | print(f"Running experiment {i+1}/{runs_per_config}:") 109 | print( 110 | f"Dataset: {dataset}, Trainers: {n_trainer}, Distribution: {distribution_type}, " 111 | + f"IID Beta: {iid_beta}, Hops: {num_hops}, Batch Size: {batch_size}" 112 | ) 113 | print(f"{'-'*80}\n") 114 | 115 | # Run the federated learning process with clean Ray environment 116 | try: 117 | # Make sure Ray is shut down from any previous runs 118 | if ray.is_initialized(): 119 | ray.shutdown() 120 | 121 | # Run the experiment 122 | run_fedgraph(config) 123 | except Exception as e: 124 | print(f"Error running experiment: {e}") 125 | print(f"Configuration: {config}") 126 | finally: 127 | # Always ensure Ray is shut down before the next experiment 128 | if ray.is_initialized(): 129 | ray.shutdown() 130 | 131 | # Add a short delay between runs 132 | time.sleep(5) 133 | 134 | print("Benchmark completed.") 135 | -------------------------------------------------------------------------------- /benchmark/configs/config_FedGAT.yaml: -------------------------------------------------------------------------------- 1 | dual_weight: 5.e-4 2 | aug_lagrange_rho: 6.e-4 3 | model_lr: 0.06 4 | model_regularisation: 2.e-3 5 | dual_lr: 1.e-2 6 | num_local_iters: 1 7 | train_rounds: 35 8 | global_rounds: 35 9 | gamma: 0.2 10 | attn_func_parameter: 0.2 11 | # lambda x: AttnFunction(x, 0.2) 12 | attn_func_domain: [-5, 5, 500] 13 | sample_probab: 1 14 | hidden_dim: 8 15 | num_heads: 8 16 | max_deg: 16 17 | 18 | # dataset: ogbn-arxiv 19 | dataset: cora 20 | n_trainer: 20 21 | num_layers: 2 22 | num_hops: 2 23 | gpu: false 24 | momentum: 0.0 25 | iid_beta: 10000 26 | logdir: ./runs 27 | device: cpu 28 | optim_kind: Adam 29 | glob_comm: FedAvg 30 | optim_reset: False 31 | dampening: 0.0 32 | limit_node_degree: 150 33 | # method: DistributedGAT 34 | # method: CentralizedGAT 35 | method: FedGAT 36 | batch_size: False 37 | vecgen: True 38 | communication_grad: True 39 | -------------------------------------------------------------------------------- /benchmark/configs/config_FedGCN.yaml: -------------------------------------------------------------------------------- 1 | dataset: cora 2 | fedtype: fedgcn 3 | global_rounds: 100 4 | local_step: 3 5 | learning_rate: 0.5 6 | n_trainer: 2 7 | num_layers: 2 8 | num_hops: 2 9 | gpu: false 10 | iid_beta: 10000 11 | logdir: ./runs 12 | -------------------------------------------------------------------------------- /benchmark/configs/config_GC_FedAvg.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | model: 'FedAvg' 3 | 4 | # dataset 5 | dataset: "IMDB-BINARY" 6 | is_multiple_dataset: False 7 | datapath: './data' 8 | convert_x: False 9 | overlap: False 10 | 11 | # setup 12 | device: 'cpu' 13 | seed: 10 14 | seed_split_data: 42 15 | 16 | # model_parameters 17 | num_trainers: 2 18 | num_rounds: 200 19 | local_epoch: 1 20 | lr: 0.001 21 | weight_decay: 0.0005 22 | nlayer: 3 23 | hidden: 64 24 | dropout: 0.5 25 | batch_size: 128 26 | 27 | # output 28 | outbase: './outputs' 29 | save_files: False 30 | -------------------------------------------------------------------------------- /benchmark/configs/config_GC_FedProx.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | model: 'FedAvg' 3 | 4 | # dataset: 5 | dataset: "IMDB-BINARY" 6 | is_multiple_dataset: False 7 | datapath: './data' 8 | convert_x: False 9 | overlap: False 10 | 11 | # setup: 12 | device: 'cpu' 13 | seed: 10 14 | seed_split_data: 42 15 | 16 | # model_parameters: 17 | num_trainers: 2 18 | num_rounds: 200 19 | local_epoch: 1 20 | lr: 0.001 21 | weight_decay: 0.0005 22 | nlayer: 3 23 | hidden: 64 24 | dropout: 0.5 25 | batch_size: 128 26 | mu: 0.01 27 | 28 | # output: 29 | outbase: './outputs' 30 | save_files: False 31 | -------------------------------------------------------------------------------- /benchmark/configs/config_GC_GCFL+.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | model: 'GCFL' 3 | 4 | # dataset: 5 | dataset: "IMDB-BINARY" 6 | is_multiple_dataset: False 7 | datapath: './data' 8 | convert_x: False 9 | overlap: False 10 | 11 | # setup: 12 | device: 'cpu' 13 | seed: 10 14 | seed_split_data: 42 15 | 16 | # model_parameters: 17 | num_trainers: 2 18 | num_rounds: 200 19 | local_epoch: 1 20 | lr: 0.001 21 | weight_decay: 0.0005 22 | nlayer: 3 23 | hidden: 64 24 | dropout: 0.5 25 | batch_size: 128 26 | standardize: False 27 | seq_length: 5 28 | epsilon1: 0.05 29 | epsilon2: 0.1 30 | 31 | # output: 32 | outbase: './outputs' 33 | save_files: False 34 | -------------------------------------------------------------------------------- /benchmark/configs/config_GC_GCFL+dWs.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | model: 'GCFL' 3 | 4 | # dataset: 5 | dataset: "IMDB-BINARY" 6 | is_multiple_dataset: False 7 | datapath: './data' 8 | convert_x: False 9 | overlap: False 10 | 11 | # setup: 12 | device: 'cpu' 13 | seed: 10 14 | seed_split_data: 42 15 | 16 | # model_parameters: 17 | num_trainers: 2 18 | num_rounds: 200 19 | local_epoch: 1 20 | lr: 0.001 21 | weight_decay: 0.0005 22 | nlayer: 3 23 | hidden: 64 24 | dropout: 0.5 25 | batch_size: 128 26 | standardize: False 27 | seq_length: 5 28 | epsilon1: 0.05 29 | epsilon2: 0.1 30 | 31 | # output: 32 | outbase: './outputs' 33 | save_files: False 34 | -------------------------------------------------------------------------------- /benchmark/configs/config_GC_GCFL.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | model: 'GCFL' 3 | 4 | # dataset: 5 | dataset: "IMDB-BINARY" 6 | is_multiple_dataset: False 7 | datapath: './data' 8 | convert_x: False 9 | overlap: False 10 | 11 | # setup: 12 | device: 'cpu' 13 | seed: 10 14 | seed_split_data: 42 15 | 16 | # model_parameters: 17 | num_trainers: 2 18 | num_rounds: 200 19 | local_epoch: 1 20 | lr: 0.001 21 | weight_decay: 0.0005 22 | nlayer: 3 23 | hidden: 64 24 | dropout: 0.5 25 | batch_size: 128 26 | standardize: False 27 | seq_length: 5 28 | epsilon1: 0.05 29 | epsilon2: 0.1 30 | 31 | # output: 32 | outbase: './outputs' 33 | save_files: False 34 | -------------------------------------------------------------------------------- /benchmark/configs/config_GC_SelfTrain.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | model: 'SelfTrain' 3 | 4 | # dataset 5 | dataset: "IMDB-BINARY" 6 | is_multiple_dataset: False 7 | datapath: './data' 8 | convert_x: False 9 | overlap: False 10 | 11 | # setup 12 | device: 'cpu' 13 | seed: 10 14 | seed_split_data: 42 15 | 16 | # model_parameters 17 | num_trainers: 2 18 | local_epoch: 1 19 | lr: 0.001 20 | weight_decay: 0.0005 21 | nlayer: 3 22 | hidden: 64 23 | dropout: 0.5 24 | batch_size: 128 25 | 26 | # output 27 | outbase: './outputs' 28 | save_files: False 29 | -------------------------------------------------------------------------------- /benchmark/configs/config_LP.yaml: -------------------------------------------------------------------------------- 1 | # general: 2 | method: FedLink 3 | fedgraph_task: LP 4 | # dataset: 5 | country_codes: ["US", "BR"] # country_codes = ['US', 'BR', 'ID', 'TR', 'JP'] 6 | dataset_path: data/LPDataset 7 | global_file_path: data/LPDataset/data_five_countries.txt 8 | traveled_file_path: data/LPDataset/traveled_users.txt 9 | 10 | # setup: 11 | device: gpu 12 | use_buffer: false 13 | buffer_size: 300000 14 | online_learning: false 15 | seed: 10 16 | 17 | # model_parameters: 18 | global_rounds: 8 19 | local_steps: 3 20 | repeat_time: 1 21 | hidden_channels: 64 22 | 23 | # output: 24 | record_results: false 25 | -------------------------------------------------------------------------------- /benchmark/configs/grafana_customized_metric_dashboard.json: -------------------------------------------------------------------------------- 1 | { 2 | "annotations": { 3 | "list": [] 4 | }, 5 | "editable": true, 6 | "gnetId": null, 7 | "graphTooltip": 0, 8 | "links": [], 9 | "panels": [ 10 | { 11 | "aliasColors": {}, 12 | "bars": false, 13 | "dashLength": 10, 14 | "dashes": false, 15 | "datasource": "${datasource}", 16 | "fieldConfig": { 17 | "defaults": {}, 18 | "overrides": [] 19 | }, 20 | "fill": 10, 21 | "fillGradient": 0, 22 | "gridPos": { 23 | "h": 8, 24 | "w": 12, 25 | "x": 0, 26 | "y": 0 27 | }, 28 | "hiddenSeries": false, 29 | "id": 1, 30 | "legend": { 31 | "alignAsTable": true, 32 | "current": true, 33 | "hideEmpty": false, 34 | "hideZero": true, 35 | "show": true, 36 | "sort": "current", 37 | "sortDesc": true, 38 | "values": true 39 | }, 40 | "lines": true, 41 | "linewidth": 1, 42 | "nullPointMode": "null", 43 | "options": { 44 | "alertThreshold": true 45 | }, 46 | "percentage": false, 47 | "pluginVersion": "7.5.17", 48 | "pointradius": 2, 49 | "points": false, 50 | "renderer": "flot", 51 | "seriesOverrides": [], 52 | "stack": true, 53 | "steppedLine": false, 54 | "targets": [ 55 | { 56 | "expr": "ray_train_node_network", 57 | "interval": "", 58 | "legendFormat": "{{instance}}", 59 | "refId": "A" 60 | } 61 | ], 62 | "thresholds": [], 63 | "timeFrom": null, 64 | "timeRegions": [], 65 | "timeShift": null, 66 | "title": "Node Communication (Train)", 67 | "tooltip": { 68 | "shared": true, 69 | "sort": 0, 70 | "value_type": "individual" 71 | }, 72 | "type": "graph", 73 | "xaxis": { 74 | "mode": "time", 75 | "show": true, 76 | "values": [] 77 | }, 78 | "yaxes": [ 79 | { 80 | "format": "bytes", 81 | "label": "", 82 | "logBase": 1, 83 | "show": true 84 | }, 85 | { 86 | "format": "short", 87 | "label": null, 88 | "logBase": 1, 89 | "show": true 90 | } 91 | ], 92 | "yaxis": { 93 | "align": false, 94 | "alignLevel": null 95 | } 96 | }, 97 | { 98 | "aliasColors": {}, 99 | "bars": false, 100 | "dashLength": 10, 101 | "dashes": false, 102 | "datasource": "${datasource}", 103 | "fieldConfig": { 104 | "defaults": {}, 105 | "overrides": [] 106 | }, 107 | "fill": 10, 108 | "fillGradient": 0, 109 | "gridPos": { 110 | "h": 8, 111 | "w": 12, 112 | "x": 12, 113 | "y": 0 114 | }, 115 | "hiddenSeries": false, 116 | "id": 2, 117 | "legend": { 118 | "alignAsTable": true, 119 | "current": true, 120 | "hideEmpty": false, 121 | "hideZero": true, 122 | "show": true, 123 | "sort": "current", 124 | "sortDesc": true, 125 | "values": true 126 | }, 127 | "lines": true, 128 | "linewidth": 1, 129 | "nullPointMode": "null", 130 | "options": { 131 | "alertThreshold": true 132 | }, 133 | "percentage": false, 134 | "pluginVersion": "7.5.17", 135 | "pointradius": 2, 136 | "points": false, 137 | "renderer": "flot", 138 | "seriesOverrides": [], 139 | "stack": true, 140 | "steppedLine": false, 141 | "targets": [ 142 | { 143 | "expr": "ray_pretrain_node_network", 144 | "interval": "", 145 | "legendFormat": "{{instance}}", 146 | "refId": "A" 147 | } 148 | ], 149 | "thresholds": [], 150 | "timeFrom": null, 151 | "timeRegions": [], 152 | "timeShift": null, 153 | "title": "Node Communication (Pretrain)", 154 | "tooltip": { 155 | "shared": true, 156 | "sort": 0, 157 | "value_type": "individual" 158 | }, 159 | "type": "graph", 160 | "xaxis": { 161 | "mode": "time", 162 | "show": true, 163 | "values": [] 164 | }, 165 | "yaxes": [ 166 | { 167 | "format": "bytes", 168 | "label": "", 169 | "logBase": 1, 170 | "show": true 171 | }, 172 | { 173 | "format": "short", 174 | "label": null, 175 | "logBase": 1, 176 | "show": true 177 | } 178 | ], 179 | "yaxis": { 180 | "align": false, 181 | "alignLevel": null 182 | } 183 | }, 184 | { 185 | "aliasColors": {}, 186 | "bars": false, 187 | "dashLength": 10, 188 | "dashes": false, 189 | "datasource": "${datasource}", 190 | "fieldConfig": { 191 | "defaults": {}, 192 | "overrides": [] 193 | }, 194 | "fill": 10, 195 | "fillGradient": 0, 196 | "gridPos": { 197 | "h": 8, 198 | "w": 12, 199 | "x": 0, 200 | "y": 8 201 | }, 202 | "hiddenSeries": false, 203 | "id": 3, 204 | "legend": { 205 | "alignAsTable": true, 206 | "current": true, 207 | "hideEmpty": false, 208 | "hideZero": true, 209 | "show": true, 210 | "sort": "current", 211 | "sortDesc": true, 212 | "values": true 213 | }, 214 | "lines": true, 215 | "linewidth": 1, 216 | "nullPointMode": "null", 217 | "options": { 218 | "alertThreshold": true 219 | }, 220 | "percentage": false, 221 | "pluginVersion": "7.5.17", 222 | "pointradius": 2, 223 | "points": false, 224 | "renderer": "flot", 225 | "seriesOverrides": [], 226 | "stack": true, 227 | "steppedLine": false, 228 | "targets": [ 229 | { 230 | "expr": "ray_pretrain_time_cost", 231 | "interval": "", 232 | "legendFormat": "{{instance}}", 233 | "refId": "A" 234 | } 235 | ], 236 | "thresholds": [], 237 | "timeFrom": null, 238 | "timeRegions": [], 239 | "timeShift": null, 240 | "title": "Pretrain Time Cost", 241 | "tooltip": { 242 | "shared": true, 243 | "sort": 0, 244 | "value_type": "individual" 245 | }, 246 | "type": "graph", 247 | "xaxis": { 248 | "mode": "time", 249 | "show": true, 250 | "values": [] 251 | }, 252 | "yaxes": [ 253 | { 254 | "format": "ms", 255 | "label": "", 256 | "logBase": 1, 257 | "show": true 258 | }, 259 | { 260 | "format": "short", 261 | "label": null, 262 | "logBase": 1, 263 | "show": true 264 | } 265 | ], 266 | "yaxis": { 267 | "align": false, 268 | "alignLevel": null 269 | } 270 | }, 271 | { 272 | "aliasColors": {}, 273 | "bars": false, 274 | "dashLength": 10, 275 | "dashes": false, 276 | "datasource": "${datasource}", 277 | "fieldConfig": { 278 | "defaults": {}, 279 | "overrides": [] 280 | }, 281 | "fill": 10, 282 | "fillGradient": 0, 283 | "gridPos": { 284 | "h": 8, 285 | "w": 12, 286 | "x": 12, 287 | "y": 8 288 | }, 289 | "hiddenSeries": false, 290 | "id": 4, 291 | "legend": { 292 | "alignAsTable": true, 293 | "current": true, 294 | "hideEmpty": false, 295 | "hideZero": true, 296 | "show": true, 297 | "sort": "current", 298 | "sortDesc": true, 299 | "values": true 300 | }, 301 | "lines": true, 302 | "linewidth": 1, 303 | "nullPointMode": "null", 304 | "options": { 305 | "alertThreshold": true 306 | }, 307 | "percentage": false, 308 | "pluginVersion": "7.5.17", 309 | "pointradius": 2, 310 | "points": false, 311 | "renderer": "flot", 312 | "seriesOverrides": [], 313 | "stack": true, 314 | "steppedLine": false, 315 | "targets": [ 316 | { 317 | "expr": "ray_train_time_cost", 318 | "interval": "", 319 | "legendFormat": "{{instance}}", 320 | "refId": "A" 321 | } 322 | ], 323 | "thresholds": [], 324 | "timeFrom": null, 325 | "timeRegions": [], 326 | "timeShift": null, 327 | "title": "Train Time Cost", 328 | "tooltip": { 329 | "shared": true, 330 | "sort": 0, 331 | "value_type": "individual" 332 | }, 333 | "type": "graph", 334 | "xaxis": { 335 | "mode": "time", 336 | "show": true, 337 | "values": [] 338 | }, 339 | "yaxes": [ 340 | { 341 | "format": "ms", 342 | "label": "", 343 | "logBase": 1, 344 | "show": true 345 | }, 346 | { 347 | "format": "short", 348 | "label": null, 349 | "logBase": 1, 350 | "show": true 351 | } 352 | ], 353 | "yaxis": { 354 | "align": false, 355 | "alignLevel": null 356 | } 357 | } 358 | ], 359 | "refresh": false, 360 | "schemaVersion": 27, 361 | "style": "dark", 362 | "tags": [], 363 | "templating": { 364 | "list": [ 365 | { 366 | "current": { 367 | "selected": true, 368 | "text": "All", 369 | "value": "$__all" 370 | }, 371 | "hide": 0, 372 | "includeAll": true, 373 | "label": "Instance", 374 | "multi": true, 375 | "name": "instance", 376 | "options": [], 377 | "query": "label_values(instance)", 378 | "refresh": 1, 379 | "regex": "", 380 | "skipUrlSync": false, 381 | "type": "query" 382 | } 383 | ] 384 | }, 385 | "time": { 386 | "from": "now-30m", 387 | "to": "now" 388 | }, 389 | "timepicker": {}, 390 | "timezone": "", 391 | "title": "Customized Dashboard", 392 | "uid": "customizedDashboard", 393 | "version": 1 394 | } 395 | -------------------------------------------------------------------------------- /benchmark/figure/GC_comm_costs/gc_accuracy_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/GC_comm_costs/gc_accuracy_comparison.pdf -------------------------------------------------------------------------------- /benchmark/figure/GC_comm_costs/gc_comm_cost_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/GC_comm_costs/gc_comm_cost_comparison.pdf -------------------------------------------------------------------------------- /benchmark/figure/GC_comm_costs/gc_train_time_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/GC_comm_costs/gc_train_time_comparison.pdf -------------------------------------------------------------------------------- /benchmark/figure/LP_comm_costs/lp_auc_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/LP_comm_costs/lp_auc_comparison.pdf -------------------------------------------------------------------------------- /benchmark/figure/LP_comm_costs/lp_comm_cost_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/LP_comm_costs/lp_comm_cost_comparison.pdf -------------------------------------------------------------------------------- /benchmark/figure/LP_comm_costs/lp_train_time_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/LP_comm_costs/lp_train_time_comparison.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/extract_global_test_acc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | 7 | 8 | def extract_accuracy_by_dataset_algo(logfile): 9 | """ 10 | Extract round-wise Global Test Accuracy per dataset and algorithm from a log file. 11 | 12 | Returns: 13 | dict: {(dataset, algorithm): pd.DataFrame with columns ['Round', 'Accuracy']} 14 | """ 15 | with open(logfile, "r", encoding="utf-8", errors="replace") as f: 16 | log_content = f.read() 17 | 18 | # Split log into experiment blocks 19 | experiments = re.findall( 20 | r"Running experiment \d+/\d+:.*?(?=Running experiment|\Z)", 21 | log_content, 22 | re.DOTALL, 23 | ) 24 | 25 | results = {} 26 | 27 | for exp in experiments: 28 | # Extract dataset 29 | dataset_match = re.search(r"Dataset: ([a-zA-Z0-9_-]+)", exp) 30 | if not dataset_match: 31 | continue 32 | dataset = dataset_match.group(1) 33 | 34 | # Extract algorithm 35 | algo_match = re.search(r"method': '([A-Za-z0-9+_]+)'", exp) 36 | if not algo_match: 37 | algo_match = re.search(r"Changing method to ([A-Za-z0-9+_]+)", exp) 38 | algorithm = algo_match.group(1).strip() if algo_match else "FedAvg" 39 | 40 | # Extract all round accuracies 41 | round_accs = re.findall(r"Round (\d+): Global Test Accuracy = ([\d.]+)", exp) 42 | if not round_accs: 43 | continue 44 | 45 | rounds = [int(r[0]) for r in round_accs] 46 | accs = [float(r[1]) for r in round_accs] 47 | df = pd.DataFrame({"Round": rounds, "Accuracy": accs}) 48 | results[(dataset, algorithm)] = df 49 | 50 | return results 51 | 52 | 53 | def plot_accuracy_curves_grouped(results): 54 | """ 55 | Plot accuracy curves with both FedAvg and FedGCN in the same chart per dataset. 56 | 57 | Saves 4 figures, one per dataset. 58 | """ 59 | datasets = { 60 | "cora": "Cora", 61 | "citeseer": "Citeseer", 62 | "pubmed": "Pubmed", 63 | "ogbn-arxiv": "Ogbn-Arxiv", 64 | } 65 | algos = ["FedAvg", "fedgcn"] 66 | display_names = {"FedAvg": "FedAvg", "fedgcn": "FedGCN"} 67 | colors = {"FedAvg": "#1f77b4", "fedgcn": "#ff7f0e"} 68 | 69 | for dataset_key, dataset_title in datasets.items(): 70 | plt.figure(figsize=(10, 9)) # Taller for better visual clarity 71 | for algo in algos: 72 | df = results.get((dataset_key, algo)) 73 | if df is not None and not df.empty: 74 | plt.plot( 75 | df["Round"], 76 | df["Accuracy"], 77 | label=display_names[algo], 78 | linewidth=4, 79 | color=colors[algo], 80 | ) 81 | plt.title(dataset_title, fontsize=38) 82 | plt.xlabel("Training Round", fontsize=34) 83 | plt.ylabel("Global Test Accuracy", fontsize=34) 84 | plt.grid(True, linestyle="--", alpha=0.6) 85 | plt.xticks(fontsize=30) 86 | plt.yticks(fontsize=30) 87 | plt.legend(fontsize=30, loc="lower right") 88 | plt.tight_layout() 89 | plt.savefig(f"nc_accuracy_curve_{dataset_key}.pdf", dpi=300) 90 | plt.close() 91 | 92 | 93 | if __name__ == "__main__": 94 | log_path = "NC.log" 95 | if not os.path.exists(log_path): 96 | print(f"Log file not found: {log_path}") 97 | exit(1) 98 | 99 | results = extract_accuracy_by_dataset_algo(log_path) 100 | plot_accuracy_curves_grouped(results) 101 | -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_comparison_beta10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_comparison_beta10.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_comparison_beta100.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_comparison_beta100.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_comparison_beta10000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_comparison_beta10000.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_curve_citeseer.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_curve_citeseer.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_curve_cora.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_curve_cora.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_curve_ogbn-arxiv.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_curve_ogbn-arxiv.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_accuracy_curve_pubmed.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_accuracy_curve_pubmed.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_comm_cost_comparison_beta10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_comm_cost_comparison_beta10.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_comm_cost_comparison_beta100.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_comm_cost_comparison_beta100.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_comm_cost_comparison_beta10000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_comm_cost_comparison_beta10000.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_train_time_comparison_beta10.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_train_time_comparison_beta10.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_train_time_comparison_beta100.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_train_time_comparison_beta100.pdf -------------------------------------------------------------------------------- /benchmark/figure/NC_comm_costs/nc_train_time_comparison_beta10000.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/benchmark/figure/NC_comm_costs/nc_train_time_comparison_beta10000.pdf -------------------------------------------------------------------------------- /benchmark/figure/figure_GCbyAlgo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | # Load the data from CSV file 5 | # Adjust this to the actual file path 6 | file_path = "new_memory.csv" 7 | df = pd.read_csv(file_path) 8 | 9 | # Group by algorithm and take the mean of communication cost and memory 10 | grouped_by_algo = ( 11 | df.groupby("Algorithm") 12 | .agg( 13 | { 14 | "Pretrain Network Large1": "mean", 15 | "Pretrain Network Large2": "mean", 16 | "Pretrain Network Large3": "mean", 17 | "Pretrain Network Large4": "mean", 18 | "Pretrain Network Large5": "mean", 19 | "Pretrain Network Large6": "mean", 20 | "Pretrain Network Large7": "mean", 21 | "Pretrain Network Large8": "mean", 22 | "Pretrain Network Large9": "mean", 23 | "Pretrain Network Large10": "mean", 24 | "Train Network Large1": "mean", 25 | "Train Network Large2": "mean", 26 | "Train Network Large3": "mean", 27 | "Train Network Large4": "mean", 28 | "Train Network Large5": "mean", 29 | "Train Network Large6": "mean", 30 | "Train Network Large7": "mean", 31 | "Train Network Large8": "mean", 32 | "Train Network Large9": "mean", 33 | "Train Network Large10": "mean", 34 | "Pretrain Max Trainer Memory1": "mean", 35 | "Pretrain Max Trainer Memory2": "mean", 36 | "Pretrain Max Trainer Memory3": "mean", 37 | "Pretrain Max Trainer Memory4": "mean", 38 | "Pretrain Max Trainer Memory5": "mean", 39 | "Pretrain Max Trainer Memory6": "mean", 40 | "Pretrain Max Trainer Memory7": "mean", 41 | "Pretrain Max Trainer Memory8": "mean", 42 | "Pretrain Max Trainer Memory9": "mean", 43 | "Pretrain Max Trainer Memory10": "mean", 44 | "Train Max Trainer Memory1": "mean", 45 | "Train Max Trainer Memory2": "mean", 46 | "Train Max Trainer Memory3": "mean", 47 | "Train Max Trainer Memory4": "mean", 48 | "Train Max Trainer Memory5": "mean", 49 | "Train Max Trainer Memory6": "mean", 50 | "Train Max Trainer Memory7": "mean", 51 | "Train Max Trainer Memory8": "mean", 52 | "Train Max Trainer Memory9": "mean", 53 | "Train Max Trainer Memory10": "mean", 54 | } 55 | ) 56 | .reset_index() 57 | ) 58 | 59 | # Plot Pretrain Network for each large 60 | plt.figure() 61 | for i in range(1, 11): 62 | plt.plot( 63 | grouped_by_algo["Algorithm"], 64 | grouped_by_algo[f"Pretrain Network Large{i}"], 65 | label=f"Pretrain Network Large{i}", 66 | marker="o", 67 | ) 68 | 69 | plt.xlabel("Algorithm") 70 | plt.ylabel("Communication Cost (Pretrain Network)") 71 | plt.title( 72 | "Pretrain Network Communication Cost for Different Algorithms (Large Network)" 73 | ) 74 | plt.xticks(rotation=45) 75 | plt.legend(loc="best") 76 | plt.tight_layout() 77 | plt.show() 78 | 79 | # Plot Train Network for each large 80 | plt.figure() 81 | for i in range(1, 11): 82 | plt.plot( 83 | grouped_by_algo["Algorithm"], 84 | grouped_by_algo[f"Train Network Large{i}"], 85 | label=f"Train Network Large{i}", 86 | marker="x", 87 | ) 88 | 89 | plt.xlabel("Algorithm") 90 | plt.ylabel("Communication Cost (Train Network)") 91 | plt.title("Train Network Communication Cost for Different Algorithms (Large Network)") 92 | plt.xticks(rotation=45) 93 | plt.legend(loc="best") 94 | plt.tight_layout() 95 | plt.show() 96 | 97 | # Plot Pretrain Max Trainer Memory for each trainer 98 | plt.figure() 99 | for i in range(1, 11): 100 | plt.plot( 101 | grouped_by_algo["Algorithm"], 102 | grouped_by_algo[f"Pretrain Max Trainer Memory{i}"], 103 | label=f"Pretrain Max Trainer Memory {i}", 104 | marker="o", 105 | ) 106 | 107 | plt.xlabel("Algorithm") 108 | plt.ylabel("Memory (Pretrain Max Trainer Memory)") 109 | plt.title("Pretrain Max Trainer Memory Usage for Different Algorithms") 110 | plt.xticks(rotation=45) 111 | plt.legend(loc="best") 112 | plt.tight_layout() 113 | plt.show() 114 | 115 | # Plot Train Max Trainer Memory for each trainer 116 | plt.figure() 117 | for i in range(1, 11): 118 | plt.plot( 119 | grouped_by_algo["Algorithm"], 120 | grouped_by_algo[f"Train Max Trainer Memory{i}"], 121 | label=f"Train Max Trainer Memory {i}", 122 | marker="x", 123 | ) 124 | 125 | plt.xlabel("Algorithm") 126 | plt.ylabel("Memory (Train Max Trainer Memory)") 127 | plt.title("Train Max Trainer Memory Usage for Different Algorithms") 128 | plt.xticks(rotation=45) 129 | plt.legend(loc="best") 130 | plt.tight_layout() 131 | plt.show() 132 | -------------------------------------------------------------------------------- /benchmark/figure/figure_GCbyDataset.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | # Load the data from CSV file 5 | file_path = "9.csv" # Adjust this to the actual file path 6 | df = pd.read_csv(file_path) 7 | 8 | # Filter for GCFL+dWs algorithm 9 | gcfl_dws_data = df[df["Algorithm"] == "GCFL+dWs"] 10 | 11 | # Group by dataset and take the mean of communication cost and memory 12 | grouped_by_dataset = ( 13 | gcfl_dws_data.groupby("Dataset") 14 | .agg( 15 | { 16 | "Pretrain Network Large1": "mean", 17 | "Pretrain Network Large2": "mean", 18 | "Pretrain Network Large3": "mean", 19 | "Pretrain Network Large4": "mean", 20 | "Pretrain Network Large5": "mean", 21 | "Pretrain Network Large6": "mean", 22 | "Pretrain Network Large7": "mean", 23 | "Pretrain Network Large8": "mean", 24 | "Pretrain Network Large9": "mean", 25 | "Pretrain Network Large10": "mean", 26 | "Train Network Large1": "mean", 27 | "Train Network Large2": "mean", 28 | "Train Network Large3": "mean", 29 | "Train Network Large4": "mean", 30 | "Train Network Large5": "mean", 31 | "Train Network Large6": "mean", 32 | "Train Network Large7": "mean", 33 | "Train Network Large8": "mean", 34 | "Train Network Large9": "mean", 35 | "Train Network Large10": "mean", 36 | "Pretrain Max Trainer Memory1": "mean", 37 | "Pretrain Max Trainer Memory2": "mean", 38 | "Pretrain Max Trainer Memory3": "mean", 39 | "Pretrain Max Trainer Memory4": "mean", 40 | "Pretrain Max Trainer Memory5": "mean", 41 | "Pretrain Max Trainer Memory6": "mean", 42 | "Pretrain Max Trainer Memory7": "mean", 43 | "Pretrain Max Trainer Memory8": "mean", 44 | "Pretrain Max Trainer Memory9": "mean", 45 | "Pretrain Max Trainer Memory10": "mean", 46 | "Train Max Trainer Memory1": "mean", 47 | "Train Max Trainer Memory2": "mean", 48 | "Train Max Trainer Memory3": "mean", 49 | "Train Max Trainer Memory4": "mean", 50 | "Train Max Trainer Memory5": "mean", 51 | "Train Max Trainer Memory6": "mean", 52 | "Train Max Trainer Memory7": "mean", 53 | "Train Max Trainer Memory8": "mean", 54 | "Train Max Trainer Memory9": "mean", 55 | "Train Max Trainer Memory10": "mean", 56 | } 57 | ) 58 | .reset_index() 59 | ) 60 | 61 | # Plot Pretrain Network for each large across datasets 62 | plt.figure() 63 | for i in range(1, 11): 64 | plt.plot( 65 | grouped_by_dataset["Dataset"], 66 | grouped_by_dataset[f"Pretrain Network Large{i}"], 67 | label=f"Pretrain Network Large{i}", 68 | marker="o", 69 | ) 70 | 71 | plt.xlabel("Dataset") 72 | plt.ylabel("Communication Cost (Pretrain Network)") 73 | plt.title( 74 | "Pretrain Network Communication Cost for GCFL+dWs Across Datasets (Large Network)" 75 | ) 76 | plt.xticks(rotation=45) 77 | plt.legend(loc="best") 78 | plt.tight_layout() 79 | plt.show() 80 | 81 | # Plot Train Network for each large across datasets 82 | plt.figure() 83 | for i in range(1, 11): 84 | plt.plot( 85 | grouped_by_dataset["Dataset"], 86 | grouped_by_dataset[f"Train Network Large{i}"], 87 | label=f"Train Network Large{i}", 88 | marker="x", 89 | ) 90 | 91 | plt.xlabel("Dataset") 92 | plt.ylabel("Communication Cost (Train Network)") 93 | plt.title( 94 | "Train Network Communication Cost for GCFL+dWs Across Datasets (Large Network)" 95 | ) 96 | plt.xticks(rotation=45) 97 | plt.legend(loc="best") 98 | plt.tight_layout() 99 | plt.show() 100 | 101 | # Plot Pretrain Max Trainer Memory for each trainer across datasets 102 | plt.figure() 103 | for i in range(1, 11): 104 | plt.plot( 105 | grouped_by_dataset["Dataset"], 106 | grouped_by_dataset[f"Pretrain Max Trainer Memory{i}"], 107 | label=f"Pretrain Max Trainer Memory {i}", 108 | marker="o", 109 | ) 110 | 111 | plt.xlabel("Dataset") 112 | plt.ylabel("Memory (Pretrain Max Trainer Memory)") 113 | plt.title("Pretrain Max Trainer Memory Usage for GCFL+dWs Across Datasets") 114 | plt.xticks(rotation=45) 115 | plt.legend(loc="best") 116 | plt.tight_layout() 117 | plt.show() 118 | 119 | # Plot Train Max Trainer Memory for each trainer across datasets 120 | plt.figure() 121 | for i in range(1, 11): 122 | plt.plot( 123 | grouped_by_dataset["Dataset"], 124 | grouped_by_dataset[f"Train Max Trainer Memory{i}"], 125 | label=f"Train Max Trainer Memory {i}", 126 | marker="x", 127 | ) 128 | 129 | plt.xlabel("Dataset") 130 | plt.ylabel("Memory (Train Max Trainer Memory)") 131 | plt.title("Train Max Trainer Memory Usage for GCFL+dWs Across Datasets") 132 | plt.xticks(rotation=45) 133 | plt.legend(loc="best") 134 | plt.tight_layout() 135 | plt.show() 136 | -------------------------------------------------------------------------------- /benchmark/figure/figure_GCfinal.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | # Increase font sizes for readability 6 | plt.rcParams.update( 7 | { 8 | "font.size": 14, 9 | "axes.titlesize": 16, 10 | "axes.labelsize": 14, 11 | "xtick.labelsize": 12, 12 | "ytick.labelsize": 12, 13 | "legend.fontsize": 12, 14 | } 15 | ) 16 | 17 | # 1. Load the CSV file 18 | 19 | 20 | def load_csv_file(file_path): 21 | df = pd.read_csv(file_path) 22 | return df 23 | 24 | 25 | file_path = "11.csv" 26 | df = load_csv_file(file_path) 27 | 28 | # 2. Define algorithms, datasets, and trainers 29 | algorithms = ["FedAvg", "GCFL", "GCFL+", "GCFL+dWs"] 30 | datasets = ["IMDB-BINARY", "IMDB-MULTI", "MUTAG", "BZR", "COX2"] 31 | trainers = [10] # Specify the number of trainers 32 | 33 | # Function to filter data based on Algorithm, Dataset, and Number of Trainers 34 | 35 | 36 | def filter_data(df, algorithm, dataset, trainers): 37 | filtered_df = df[ 38 | (df["Algorithm"] == algorithm) 39 | & (df["Dataset"] == dataset) 40 | & (df["Number of Trainers"].isin(trainers)) 41 | ] 42 | grouped_df = ( 43 | filtered_df.groupby(["Algorithm", "Dataset", "Number of Trainers"]) 44 | .mean() 45 | .reset_index() 46 | ) 47 | return grouped_df 48 | 49 | 50 | # 3. Plot chart for Accuracy, Train Time, and Communication Cost comparison 51 | 52 | 53 | def plot_combined_comparison(df, algorithms, datasets, trainers): 54 | width = 0.15 # Width of each bar 55 | algorithm_range = np.arange(len(algorithms)) # X positions for bars 56 | 57 | # Track min and max values for scaling y-axis for each metric 58 | min_values = { 59 | "accuracy": float("inf"), 60 | "train_time": float("inf"), 61 | "communication_cost": float("inf"), 62 | } 63 | max_values = { 64 | "accuracy": float("-inf"), 65 | "train_time": float("-inf"), 66 | "communication_cost": float("-inf"), 67 | } 68 | 69 | # Create a figure with 3 subplots in one row 70 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(17, 6)) 71 | 72 | for j, dataset in enumerate(datasets): 73 | accuracy_values = [] 74 | train_time_values = [] 75 | communication_cost_values = [] 76 | 77 | for i, algorithm in enumerate(algorithms): 78 | filtered_df = filter_data(df, algorithm, dataset, trainers) 79 | avg_accuracy = filtered_df["Average Test Accuracy"].mean() 80 | avg_train_time = filtered_df["Train Time"].mean() 81 | avg_communication_cost = ( 82 | filtered_df[[f"Train Network Large{k}" for k in range(1, 11)]] 83 | .sum(axis=1) 84 | .mean() 85 | ) # Summing `Train Network Large` columns 86 | 87 | accuracy_values.append(avg_accuracy) 88 | train_time_values.append(avg_train_time) 89 | communication_cost_values.append(avg_communication_cost) 90 | 91 | # Update min and max values for each metric 92 | min_values["accuracy"] = min(min_values["accuracy"], avg_accuracy) 93 | max_values["accuracy"] = max(max_values["accuracy"], avg_accuracy) 94 | min_values["train_time"] = min(min_values["train_time"], avg_train_time) 95 | max_values["train_time"] = max(max_values["train_time"], avg_train_time) 96 | min_values["communication_cost"] = min( 97 | min_values["communication_cost"], avg_communication_cost 98 | ) 99 | max_values["communication_cost"] = max( 100 | max_values["communication_cost"], avg_communication_cost 101 | ) 102 | 103 | # Plot the bars for each metric and dataset 104 | ax1.bar( 105 | algorithm_range + j * width, 106 | accuracy_values, 107 | width=width, 108 | label=f"{dataset}", 109 | ) 110 | ax2.bar( 111 | algorithm_range + j * width, 112 | train_time_values, 113 | width=width, 114 | label=f"{dataset}", 115 | ) 116 | ax3.bar( 117 | algorithm_range + j * width, 118 | communication_cost_values, 119 | width=width, 120 | label=f"{dataset}", 121 | ) 122 | 123 | # Set titles and labels for each subplot 124 | ax1.set_title("Accuracy Comparison") 125 | ax1.set_xlabel("Algorithms") 126 | ax1.set_ylabel("Accuracy") 127 | ax1.set_xticks(algorithm_range + width * (len(datasets) - 1) / 2) 128 | ax1.set_xticklabels(algorithms) 129 | 130 | ax2.set_title("Train Time Comparison") 131 | ax2.set_xlabel("Algorithms") 132 | ax2.set_ylabel("Train Time (ms)") 133 | ax2.set_xticks(algorithm_range + width * (len(datasets) - 1) / 2) 134 | ax2.set_xticklabels(algorithms) 135 | 136 | ax3.set_title("Communication Cost Comparison") 137 | ax3.set_xlabel("Algorithms") 138 | ax3.set_ylabel("Total Communication Cost (Bytes)") 139 | ax3.set_xticks(algorithm_range + width * (len(datasets) - 1) / 2) 140 | ax3.set_xticklabels(algorithms) 141 | 142 | # Adjust y-axis for each subplot to occupy 70% of the plot's height 143 | for ax, metric in zip( 144 | [ax1, ax2, ax3], ["accuracy", "train_time", "communication_cost"] 145 | ): 146 | diff = max_values[metric] - min_values[metric] 147 | lower_bound = 0 148 | ax.set_ylim(lower_bound, max_values[metric] * 1.01) 149 | 150 | # Display the legend 151 | ax3.legend(loc="upper left", bbox_to_anchor=(1, 1), title="Datasets") 152 | 153 | # Adjust layout to prevent overlap and add space between subplots 154 | plt.subplots_adjust(wspace=0.3) 155 | plt.tight_layout(rect=[0, 0, 1, 0.96]) 156 | plt.show() 157 | 158 | 159 | # 4. Call the plotting function 160 | plot_combined_comparison(df, algorithms, datasets, trainers) 161 | -------------------------------------------------------------------------------- /benchmark/figure/figure_NNN.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | # 设置绘图的字体和标题等样式 6 | plt.rcParams.update( 7 | { 8 | "font.size": 12, 9 | "axes.titlesize": 12, 10 | "axes.labelsize": 14, 11 | "xtick.labelsize": 12, 12 | "ytick.labelsize": 12, 13 | "legend.fontsize": 12, 14 | } 15 | ) 16 | 17 | # 假设CSV文件路径 18 | file_path = "100.csv" 19 | df = pd.read_csv(file_path) 20 | df = df[df["Number of Hops"] != 1] 21 | df["IID Beta"] = df["IID Beta"].astype(str) 22 | 23 | # 按 'IID Beta' 和 'Number of Trainers' 进行分组,计算各项的平均值 24 | numeric_columns = df.select_dtypes(include="number").columns 25 | grouped_df = ( 26 | df[["IID Beta"] + list(numeric_columns)] 27 | .groupby(["IID Beta", "Number of Trainers"]) 28 | .mean() 29 | .reset_index() 30 | ) 31 | 32 | # 设置要绘制的三个指标 33 | metrics = ["Average Test Accuracy", "Train Time", "Train Network Server"] 34 | titles = [ 35 | "Accuracy Comparison", 36 | "Train Time Comparison", 37 | "Communication Cost Comparison", 38 | ] 39 | y_labels = ["Accuracy", "Train Time (ms)", "Total Communication Cost (Bytes)"] 40 | 41 | # 设置每个柱状图的宽度 42 | bar_width = 0.3 43 | # 获取每个 IID Beta 的唯一值,方便在并排放置时设置偏移 44 | unique_betas = grouped_df["IID Beta"].unique() 45 | # 设置x轴位置 46 | num_trainers = grouped_df["Number of Trainers"].unique() 47 | x_positions = np.arange(len(num_trainers)) 48 | 49 | # 绘图 50 | fig, axes = plt.subplots(1, 3, figsize=(20, 6), sharey=False) 51 | 52 | # 绘制每个指标的柱状图 53 | for i, metric in enumerate(metrics): 54 | for j, beta in enumerate(unique_betas): 55 | beta_data = grouped_df[grouped_df["IID Beta"] == beta] 56 | # 设置位置偏移,使不同的 IID Beta 值并排放置 57 | offset_positions = ( 58 | x_positions + (j * bar_width) - (bar_width * (len(unique_betas) - 1) / 2) 59 | ) 60 | axes[i].bar( 61 | offset_positions, 62 | beta_data[metric], 63 | width=bar_width, 64 | label=f"IID Beta {beta}", 65 | ) 66 | axes[i].set_title(titles[i]) 67 | axes[i].set_xlabel("Number of Trainers") 68 | axes[i].set_ylabel(y_labels[i]) 69 | axes[i].set_xticks(x_positions) 70 | axes[i].set_xticklabels(num_trainers) 71 | 72 | # 设置图例的位置 73 | axes[2].legend(loc="upper right", title="IID Beta") 74 | plt.subplots_adjust(wspace=0.3) 75 | 76 | # 显示图形 77 | plt.show() 78 | -------------------------------------------------------------------------------- /benchmark/figure/figure_batch.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | # 1. Load the CSV file 6 | plt.rcParams.update( 7 | { 8 | "font.size": 12, 9 | "axes.titlesize": 12, 10 | "axes.labelsize": 14, 11 | "xtick.labelsize": 12, 12 | "ytick.labelsize": 12, 13 | "legend.fontsize": 12, 14 | } 15 | ) 16 | 17 | 18 | def load_csv_file(file_path): 19 | df = pd.read_csv(file_path) 20 | return df 21 | 22 | 23 | file_path = "NC_arxiv_batchsize.csv" 24 | df = load_csv_file(file_path) 25 | 26 | # 2. Define specific IID Beta, Hop values, and Batch Sizes 27 | iid_beta_values = [10000] 28 | hop_values = [1] 29 | batch_sizes = [16, 32, 64] 30 | 31 | # Function to filter data based on IID Beta and Hop 32 | 33 | 34 | def filter_data(df, iid_beta_value, hop_value): 35 | return df[(df["IID Beta"] == iid_beta_value) & (df["Number of Hops"] == hop_value)] 36 | 37 | 38 | # 3. Plot combined charts for Time, Memory, and Accuracy comparison 39 | 40 | 41 | def plot_combined_charts(df, iid_beta_value, hop_value): 42 | batch_data = df[df["Batch Size"].isin(batch_sizes)] 43 | width = 0.25 # Width of the bars 44 | batch_size_range = np.arange(len(batch_sizes)) # X positions for bars 45 | 46 | # Calculate values for each metric 47 | pretrain_values = ( 48 | batch_data.groupby("Batch Size")["Pretrain Time"].mean() 49 | if hop_value == 1 50 | else None 51 | ) 52 | train_values = batch_data.groupby("Batch Size")["Train Time"].mean() 53 | pre_columns = [f"Pretrain Network Large{i}" for i in range(1, 11)] 54 | pre_values = ( 55 | batch_data[pre_columns].sum(axis=1).groupby(batch_data["Batch Size"]).mean() 56 | ) 57 | tre_columns = [f"Train Network Large{i}" for i in range(1, 11)] 58 | tre_values = ( 59 | batch_data[tre_columns].sum(axis=1).groupby(batch_data["Batch Size"]).mean() 60 | ) 61 | accuracy_values = batch_data.groupby("Batch Size")["Average Test Accuracy"].mean() 62 | 63 | # Create a figure with 3 subplots in one row 64 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) 65 | 66 | # Plot Train Time and Pretrain Time 67 | if hop_value == 1: 68 | ax1.bar( 69 | batch_size_range - width / 2, 70 | pretrain_values.values, 71 | width=width, 72 | label="Pretrain Time", 73 | color="skyblue", 74 | ) 75 | ax1.bar( 76 | batch_size_range + (width / 2 if hop_value == 1 else 0), 77 | train_values.values, 78 | width=width, 79 | label="Train Time", 80 | color="orange", 81 | ) 82 | ax1.set_title( 83 | f"Pretrain vs Train Time (IID Beta {iid_beta_value}, Hop {hop_value})" 84 | ) 85 | ax1.set_xlabel("Batch Size") 86 | ax1.set_ylabel("Time (ms)") 87 | ax1.set_xticks(batch_size_range) 88 | ax1.set_xticklabels(batch_sizes) 89 | ax1.legend(loc="lower right") 90 | 91 | # Plot Total Train Memory 92 | if hop_value == 1: 93 | ax2.bar( 94 | batch_size_range - width / 2, 95 | pre_values.values, 96 | width=width, 97 | label="Pretrain Communication Cost", 98 | color="skyblue", 99 | ) 100 | ax2.bar( 101 | batch_size_range + (width / 2 if hop_value == 1 else 0), 102 | tre_values.values, 103 | width=width, 104 | label="Train Communication Cost", 105 | color="orange", 106 | ) 107 | ax2.set_title( 108 | f"Total Communication Cost (IID Beta {iid_beta_value}, Hop {hop_value})" 109 | ) 110 | ax2.set_xlabel("Batch Size") 111 | ax2.set_ylabel("Communication Cost (Bytes)") 112 | ax2.set_xticks(batch_size_range) 113 | ax2.set_xticklabels(batch_sizes) 114 | ax2.legend(loc="lower right") 115 | # Plot Accuracy 116 | ax3.bar(batch_size_range, accuracy_values.values, color="green", width=width) 117 | ax3.set_title(f"Test Accuracy (IID Beta {iid_beta_value}, Hop {hop_value})") 118 | ax3.set_xlabel("Batch Size") 119 | ax3.set_ylabel("Accuracy") 120 | ax3.set_xticks(batch_size_range) 121 | ax3.set_xticklabels(batch_sizes) 122 | 123 | # Adjust layout to prevent overlap 124 | plt.tight_layout(rect=[1, 0, 0, 0.96]) 125 | plt.show() 126 | 127 | 128 | # 4. Loop through the IID Beta values and Hops, and plot the combined charts 129 | for iid_beta_value in iid_beta_values: 130 | for hop_value in hop_values: 131 | filtered_df = filter_data(df, iid_beta_value, hop_value) 132 | plt.subplots_adjust(left=0.2, wspace=0.3) 133 | plot_combined_charts(filtered_df, iid_beta_value, hop_value) 134 | -------------------------------------------------------------------------------- /benchmark/figure/figure_batch_3d.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | # 1. Load the CSV file 6 | 7 | plt.rcParams.update( 8 | { 9 | "font.size": 12, 10 | "axes.titlesize": 12, 11 | "axes.labelsize": 14, 12 | "xtick.labelsize": 12, 13 | "ytick.labelsize": 12, 14 | "legend.fontsize": 12, 15 | } 16 | ) 17 | 18 | 19 | def load_csv_file(file_path): 20 | return pd.read_csv(file_path) 21 | 22 | 23 | file_path = "NC_papers100M.csv" 24 | df = load_csv_file(file_path) 25 | 26 | # 2. Filter data for specific IID Beta and Hop values 27 | iid_beta_values = [10000] 28 | hop_values = [0] 29 | batch_sizes = [16, 32, 64] 30 | 31 | # Helper function to filter data 32 | 33 | 34 | def filter_data(df, iid_beta_value, hop_value): 35 | return df.loc[ 36 | (df["IID Beta"] == iid_beta_value) & (df["Number of Hops"] == hop_value) 37 | ] 38 | 39 | 40 | # Function to add values on top of bars 41 | 42 | 43 | def add_values_on_bars(ax, values, xpos, width): 44 | for i, v in enumerate(values): 45 | ax.text( 46 | i + xpos, v + 0.01 * v, f"{v:.2f}", ha="center", va="bottom", fontsize=10 47 | ) 48 | 49 | 50 | # Function to calculate and set y-axis limits based on 70% range 51 | 52 | 53 | def set_scaled_ylim(ax, values): 54 | min_val, max_val = values.min(), values.max() 55 | if min_val == max_val: 56 | ax.set_ylim(0, max_val * 1.1) # 如果没有差异,直接从 0 到 1.1 倍的最大值 57 | else: 58 | range_val = max_val - min_val 59 | lower_bound = min_val - 0.5 * range_val # 下限比最小值稍小 60 | upper_bound = max_val + 0.5 * range_val # 上限比最大值稍大 61 | ax.set_ylim(lower_bound, upper_bound) # 设置 y 轴范围 62 | 63 | 64 | def set_scaled_ylim_1(ax, values): 65 | min_val, max_val = values.min(), values.max() 66 | if min_val == max_val: 67 | ax.set_ylim(0, max_val * 1.1) # 如果没有差异,直接从 0 到 1.1 倍的最大值 68 | else: 69 | range_val = max_val - min_val 70 | lower_bound = 0 # 下限比最小值稍小 71 | upper_bound = 0.8 # 上限比最大值稍大 72 | ax.set_ylim(lower_bound, upper_bound) # 设置 y 轴范围 73 | 74 | 75 | # 3. Plot three separate charts and combine them into one figure 76 | 77 | 78 | def plot_combined_charts(df, iid_beta_value, hop_value): 79 | batch_data = df[df["Batch Size"].isin(batch_sizes)] 80 | memory_columns = [f"Train Network Large{i}" for i in range(1, 11)] 81 | batch_data.loc[:, "Total Communication Cost"] = batch_data[memory_columns].sum( 82 | axis=1 83 | ) 84 | 85 | # Get train_time, memory, accuracy 86 | train_time = batch_data.groupby("Batch Size")["Train Time"].mean() 87 | memory = batch_data.groupby("Batch Size")["Total Communication Cost"].mean() 88 | accuracy = batch_data.groupby("Batch Size")["Average Test Accuracy"].mean() 89 | 90 | # Create a figure with 3 subplots in one row 91 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 8)) 92 | 93 | # X-axis positions for the bars 94 | x = np.arange(len(batch_sizes)) 95 | width = 0.4 96 | 97 | # Plot the bars for Train Time 98 | ax1.bar(x, train_time, width, color="orange") 99 | ax1.set_title("Train Time") 100 | ax1.set_xlabel("Batch Size") 101 | ax1.set_ylabel("Train Time (ms)") 102 | ax1.set_xticks(x) 103 | ax1.set_xticklabels([str(bs) for bs in batch_sizes]) 104 | # add_values_on_bars(ax1, train_time, 0, width) 105 | set_scaled_ylim(ax1, train_time) # 设置 y 轴范围使得差异占 70% 106 | 107 | # Plot the bars for Memory 108 | ax2.bar(x, memory, width, color="skyblue") 109 | ax2.set_title("Communication Cost") 110 | ax2.set_xlabel("Batch Size") 111 | ax2.set_ylabel("Bytes") 112 | ax2.set_xticks(x) 113 | ax2.set_xticklabels([str(bs) for bs in batch_sizes]) 114 | add_values_on_bars(ax2, memory, 0, width) 115 | set_scaled_ylim(ax2, memory) # 设置 y 轴范围使得差异占 70% 116 | 117 | # Plot the bars for Accuracy 118 | ax3.bar(x, accuracy, width, color="green") 119 | ax3.set_title("Test Accuracy") 120 | ax3.set_xlabel("Batch Size") 121 | ax3.set_ylabel("Accuracy") 122 | ax3.set_xticks(x) 123 | ax3.set_xticklabels([str(bs) for bs in batch_sizes]) 124 | add_values_on_bars(ax3, accuracy, 0, width) 125 | set_scaled_ylim_1(ax3, accuracy) # 设置 y 轴范围使得差异占 70% 126 | 127 | # Set a main title for the figure 128 | plt.suptitle(f"Combined Plot (IID Beta {iid_beta_value}, Hop {hop_value})") 129 | 130 | # Adjust layout to prevent overlap 131 | plt.tight_layout(rect=[0, 0, 1, 0.96]) 132 | plt.subplots_adjust(wspace=0.3) 133 | # Show the plot 134 | plt.show() 135 | 136 | 137 | # 4. Loop through the IID Beta values and plot the charts 138 | for iid_beta_value in iid_beta_values: 139 | for hop_value in hop_values: 140 | filtered_df = filter_data(df, iid_beta_value, hop_value) 141 | plot_combined_charts(filtered_df, iid_beta_value, hop_value) 142 | -------------------------------------------------------------------------------- /benchmark/figure/figure_batch_combine.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | # 1. Load the CSV file 6 | 7 | 8 | def load_csv_file(file_path): 9 | df = pd.read_csv(file_path) 10 | return df 11 | 12 | 13 | file_path = "NC_arxiv_batchsize.csv" 14 | df = load_csv_file(file_path) 15 | 16 | # 2. Filter data for specific Batch Sizes and Hop values 17 | iid_beta_values = [10000, 100, 10] 18 | batch_sizes = [-1, 16, 32, 64] 19 | hop_value = 1 # Set a specific hop value, or change as needed 20 | 21 | # Function to filter data based on IID Beta, Hop and Batch Size 22 | 23 | 24 | def filter_data(df, iid_beta_value, hop_value): 25 | return df[ 26 | (df["IID Beta"] == iid_beta_value) 27 | & (df["Number of Hops"] == hop_value) 28 | & (df["Batch Size"].isin(batch_sizes)) 29 | ] 30 | 31 | 32 | # 3. Plot chart for Accuracy comparison across Batch Sizes for different IID Beta values 33 | 34 | 35 | def plot_accuracy_comparison(df, hop_value): 36 | width = 0.2 # Width of each bar 37 | batch_size_range = np.arange(len(batch_sizes)) # X positions for bars 38 | 39 | min_accuracy = float("inf") # Track the minimum accuracy 40 | max_accuracy = float("-inf") # Track the maximum accuracy 41 | 42 | for i, iid_beta_value in enumerate(iid_beta_values): 43 | filtered_df = filter_data(df, iid_beta_value, hop_value) 44 | accuracy_values = filtered_df.groupby("Batch Size")[ 45 | "Average Test Accuracy" 46 | ].mean() 47 | 48 | # Update min and max accuracy 49 | min_accuracy = min(min_accuracy, accuracy_values.min()) 50 | max_accuracy = max(max_accuracy, accuracy_values.max()) 51 | 52 | # Plot the bars for each IID Beta, with slight shifts in x positions to avoid overlap 53 | plt.bar( 54 | batch_size_range + i * width, 55 | accuracy_values.values, 56 | width=width, 57 | label=f"IID Beta {iid_beta_value}", 58 | ) 59 | 60 | # Calculate diff and adjust the y-axis to make diff occupy 70% of the plot's height 61 | diff = max_accuracy - min_accuracy 62 | # Calculate lower bound to make diff occupy 70% of the plot 63 | lower_bound = max_accuracy - diff / 0.7 64 | 65 | # Set y-axis limit to make the difference more visible 66 | plt.ylim(lower_bound, max_accuracy * 1.01) 67 | 68 | # Title and labels 69 | plt.title(f"Test Accuracy Comparison (Hop {hop_value})") 70 | plt.xlabel("Batch Size") 71 | plt.ylabel("Accuracy") 72 | 73 | # Set x-axis labels to batch sizes 74 | plt.xticks(batch_size_range + width, labels=batch_sizes) 75 | 76 | plt.legend() 77 | plt.show() 78 | 79 | 80 | # 4. Call the plotting function for the given hop value 81 | plot_accuracy_comparison(df, hop_value) 82 | -------------------------------------------------------------------------------- /benchmark/figure/figure_papers100M.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | file_path = "NC_papers100M.csv" 5 | data = pd.read_csv(file_path) 6 | # data = { 7 | # 'Batch Size': [16, 32, 64, -1], 8 | # 'Train Time': [620510.904, 625067.836, 646383.4789999999, 625576.189], 9 | # 'Average Test Accuracy': [0.4148867676286986, 0.4148867676286986, 0.41487743657214304, 0.37154400992824416] 10 | # } 11 | 12 | # Create the DataFrame 13 | df = pd.DataFrame(data) 14 | 15 | # Plot Train Time vs Batch Size 16 | plt.figure() 17 | plt.plot( 18 | df["Batch Size"], df["Train Time"], marker="o", color="skyblue", label="Train Time" 19 | ) 20 | plt.xlabel("Batch Size") 21 | plt.ylabel("Train Time") 22 | plt.title("Train Time vs Batch Size") 23 | plt.grid(True) 24 | plt.tight_layout() 25 | plt.show() 26 | 27 | # Plot Accuracy vs Batch Size 28 | plt.figure() 29 | plt.plot( 30 | df["Batch Size"], 31 | df["Average Test Accuracy"], 32 | marker="x", 33 | color="orange", 34 | label="Accuracy", 35 | ) 36 | plt.xlabel("Batch Size") 37 | plt.ylabel("Test Accuracy") 38 | plt.title("Test Accuracy vs Batch Size") 39 | plt.grid(True) 40 | plt.tight_layout() 41 | plt.show() 42 | -------------------------------------------------------------------------------- /benchmark/figure/figure_traintime.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | # Load the data from CSV file 5 | file_path = "8.csv" # Adjust this to the actual file path 6 | df = pd.read_csv(file_path) 7 | 8 | # Group by algorithm and dataset, take the mean of train time and other relevant metrics 9 | grouped_by_algo = ( 10 | df.groupby("Algorithm") 11 | .agg( 12 | { 13 | "Train Time": "mean", # Aggregating by Train Time 14 | "Pretrain Network Large1": "mean", 15 | "Pretrain Network Large2": "mean", 16 | "Pretrain Network Large3": "mean", 17 | "Pretrain Network Large4": "mean", 18 | "Pretrain Network Large5": "mean", 19 | "Pretrain Network Large6": "mean", 20 | "Pretrain Network Large7": "mean", 21 | "Pretrain Network Large8": "mean", 22 | "Pretrain Network Large9": "mean", 23 | "Pretrain Network Large10": "mean", 24 | "Train Network Large1": "mean", 25 | "Train Network Large2": "mean", 26 | "Train Network Large3": "mean", 27 | "Train Network Large4": "mean", 28 | "Train Network Large5": "mean", 29 | "Train Network Large6": "mean", 30 | "Train Network Large7": "mean", 31 | "Train Network Large8": "mean", 32 | "Train Network Large9": "mean", 33 | "Train Network Large10": "mean", 34 | "Pretrain Max Trainer Memory1": "mean", 35 | "Pretrain Max Trainer Memory2": "mean", 36 | "Pretrain Max Trainer Memory3": "mean", 37 | "Pretrain Max Trainer Memory4": "mean", 38 | "Pretrain Max Trainer Memory5": "mean", 39 | "Pretrain Max Trainer Memory6": "mean", 40 | "Pretrain Max Trainer Memory7": "mean", 41 | "Pretrain Max Trainer Memory8": "mean", 42 | "Pretrain Max Trainer Memory9": "mean", 43 | "Pretrain Max Trainer Memory10": "mean", 44 | "Train Max Trainer Memory1": "mean", 45 | "Train Max Trainer Memory2": "mean", 46 | "Train Max Trainer Memory3": "mean", 47 | "Train Max Trainer Memory4": "mean", 48 | "Train Max Trainer Memory5": "mean", 49 | "Train Max Trainer Memory6": "mean", 50 | "Train Max Trainer Memory7": "mean", 51 | "Train Max Trainer Memory8": "mean", 52 | "Train Max Trainer Memory9": "mean", 53 | "Train Max Trainer Memory10": "mean", 54 | } 55 | ) 56 | .reset_index() 57 | ) 58 | 59 | # Plot Train Time 60 | plt.figure() 61 | plt.bar(grouped_by_algo["Algorithm"], grouped_by_algo["Train Time"], color="skyblue") 62 | plt.xlabel("Algorithm") 63 | plt.ylabel("Train Time") 64 | plt.title("Train Time for Different Algorithms") 65 | plt.xticks(rotation=45) 66 | plt.tight_layout() 67 | plt.show() 68 | 69 | # Plot Pretrain Network for each large 70 | plt.figure() 71 | for i in range(1, 11): 72 | plt.plot( 73 | grouped_by_algo["Algorithm"], 74 | grouped_by_algo[f"Pretrain Network Large{i}"], 75 | label=f"Pretrain Network Large{i}", 76 | marker="o", 77 | ) 78 | 79 | plt.xlabel("Algorithm") 80 | plt.ylabel("Communication Cost (Pretrain Network)") 81 | plt.title( 82 | "Pretrain Network Communication Cost for Different Algorithms (Large Network)" 83 | ) 84 | plt.xticks(rotation=45) 85 | plt.legend(loc="best") 86 | plt.tight_layout() 87 | plt.show() 88 | 89 | # Plot Train Network for each large 90 | plt.figure() 91 | for i in range(1, 11): 92 | plt.plot( 93 | grouped_by_algo["Algorithm"], 94 | grouped_by_algo[f"Train Network Large{i}"], 95 | label=f"Train Network Large{i}", 96 | marker="x", 97 | ) 98 | 99 | plt.xlabel("Algorithm") 100 | plt.ylabel("Communication Cost (Train Network)") 101 | plt.title("Train Network Communication Cost for Different Algorithms (Large Network)") 102 | plt.xticks(rotation=45) 103 | plt.legend(loc="best") 104 | plt.tight_layout() 105 | plt.show() 106 | 107 | # Plot Pretrain Max Trainer Memory for each trainer 108 | plt.figure() 109 | for i in range(1, 11): 110 | plt.plot( 111 | grouped_by_algo["Algorithm"], 112 | grouped_by_algo[f"Pretrain Max Trainer Memory{i}"], 113 | label=f"Pretrain Max Trainer Memory {i}", 114 | marker="o", 115 | ) 116 | 117 | plt.xlabel("Algorithm") 118 | plt.ylabel("Memory (Pretrain Max Trainer Memory)") 119 | plt.title("Pretrain Max Trainer Memory Usage for Different Algorithms") 120 | plt.xticks(rotation=45) 121 | plt.legend(loc="best") 122 | plt.tight_layout() 123 | plt.show() 124 | 125 | # Plot Train Max Trainer Memory for each trainer 126 | plt.figure() 127 | for i in range(1, 11): 128 | plt.plot( 129 | grouped_by_algo["Algorithm"], 130 | grouped_by_algo[f"Train Max Trainer Memory{i}"], 131 | label=f"Train Max Trainer Memory {i}", 132 | marker="x", 133 | ) 134 | 135 | plt.xlabel("Algorithm") 136 | plt.ylabel("Memory (Train Max Trainer Memory)") 137 | plt.title("Train Max Trainer Memory Usage for Different Algorithms") 138 | plt.xticks(rotation=45) 139 | plt.legend(loc="best") 140 | plt.tight_layout() 141 | plt.show() 142 | -------------------------------------------------------------------------------- /docker_requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | absl-py 3 | 4 | aiohttp 5 | aiohttp-cors 6 | aiosignal 7 | alabaster 8 | annotated-types 9 | antlr4-python3-runtime 10 | attridict 11 | attrs 12 | Babel 13 | beautifulsoup4 14 | cachetools 15 | certifi 16 | cfgv 17 | charset-normalizer 18 | click 19 | colorful 20 | contourpy 21 | cycler 22 | distlib 23 | docutils 24 | dtaidistance 25 | 26 | filelock 27 | fonttools 28 | frozenlist 29 | fsspec 30 | furo 31 | gdown 32 | google-api-core 33 | google-auth 34 | googleapis-common-protos 35 | grpcio 36 | identify 37 | idna 38 | imagesize 39 | Jinja2 40 | joblib 41 | jsonschema 42 | jsonschema-specifications 43 | kiwisolver 44 | latexcodec 45 | lightning-utilities 46 | linkify-it-py 47 | Markdown 48 | markdown-it-py 49 | MarkupSafe 50 | matplotlib 51 | mdit-py-plugins 52 | mdurl 53 | memray 54 | mpmath 55 | msgpack 56 | multidict 57 | networkx 58 | nodeenv 59 | numpy 60 | omegaconf 61 | opencensus 62 | opencensus-context 63 | packaging 64 | pandas 65 | pillow 66 | platformdirs 67 | pre-commit 68 | prometheus_client 69 | proto-plus 70 | protobuf 71 | psutil 72 | py-spy 73 | pyasn1 74 | pyasn1_modules 75 | pybtex 76 | pybtex-docutils 77 | pydantic 78 | pydantic_core 79 | Pygments 80 | pyparsing 81 | PySocks 82 | python-dateutil 83 | pytz 84 | PyYAML 85 | ray 86 | referencing 87 | requests 88 | rich 89 | rpds-py 90 | rsa 91 | scikit-learn 92 | scipy 93 | six 94 | smart-open 95 | snowballstemmer 96 | soupsieve 97 | Sphinx 98 | sphinx-basic-ng 99 | sphinx-gallery 100 | sphinxcontrib-applehelp 101 | sphinxcontrib-bibtex 102 | sphinxcontrib-devhelp 103 | sphinxcontrib-htmlhelp 104 | sphinxcontrib-jsmath 105 | sphinxcontrib-qthelp 106 | sphinxcontrib-serializinghtml 107 | sympy 108 | tensorboard 109 | tensorboard-data-server 110 | textual 111 | threadpoolctl 112 | torch-cluster 113 | torch-scatter 114 | torch-sparse 115 | torch-spline-conv 116 | torch_geometric 117 | torchmetrics 118 | tqdm 119 | typing_extensions 120 | tzdata 121 | uc-micro-py 122 | urllib3 123 | virtualenv 124 | Werkzeug 125 | wrapt 126 | yarl 127 | ogb 128 | huggingface_hub 129 | tenseal 130 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/cite.rst: -------------------------------------------------------------------------------- 1 | Cite 2 | ==== 3 | 4 | Please cite our `paper `_ (and the respective papers of the methods used) if you use this code in your own work:: 5 | 6 | @article{yao2023fedgcn, 7 | title={FedGCN: Convergence-Communication Tradeoffs in Federated Training of Graph Convolutional Networks}, 8 | author={Yao, Yuhang and Jin, Weizhao and Ravi, Srivatsan and Joe-Wong, Carlee}, 9 | journal={Advances in Neural Information Processing Systems (NeurIPS)}, 10 | year={2023} 11 | } 12 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | from os.path import abspath, dirname 16 | 17 | sys.path.insert(0, abspath("..")) 18 | root_dir = dirname(dirname(abspath(__file__))) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = "FedGraph" 23 | copyright = "2024 FedGraph Team" 24 | author = "FedGraph Team" 25 | 26 | version_path = os.path.join(root_dir, "fedgraph") 27 | # version_path = os.path.join(root_dir, "fedgraph", "version.py") 28 | sys.path.append(version_path) 29 | # exec(open(version_path).read()) 30 | from version import __version__ 31 | 32 | version = __version__ 33 | release = __version__ 34 | 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | "sphinx.ext.autodoc", 43 | "sphinx.ext.autosummary", 44 | "sphinx.ext.doctest", 45 | "sphinx.ext.intersphinx", 46 | "sphinx.ext.coverage", 47 | "sphinx.ext.mathjax", 48 | "sphinx.ext.viewcode", 49 | "sphinxcontrib.bibtex", 50 | "sphinx.ext.napoleon", 51 | "sphinx_gallery.gen_gallery", 52 | ] 53 | 54 | 55 | bibtex_bibfiles = ["zreferences.bib"] 56 | 57 | # Add any paths that contain templates here, relative to this directory. 58 | templates_path = ["_templates"] 59 | 60 | # The suffix(es) of source filenames. 61 | # You can specify multiple suffix as a list of string: 62 | # 63 | # source_suffix = ['.rst', '.md'] 64 | source_suffix = ".rst" 65 | 66 | # The master toctree document. 67 | master_doc = "index" 68 | 69 | # List of patterns, relative to source directory, that match files and 70 | # directories to ignore when looking for source files. 71 | # This pattern also affects html_static_path and html_extra_path. 72 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 73 | 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | html_theme = "sphinx_rtd_theme" 81 | 82 | # html_favicon = 'pygod.ico' 83 | 84 | # Add any paths that contain custom static files (such as style sheets) here, 85 | # relative to this directory. They are copied after the builtin static files, 86 | # so a file named "default.css" will overwrite the builtin "default.css". 87 | html_static_path = ["_static"] 88 | 89 | # Custom sidebar templates, must be a dictionary that maps document names 90 | # to template names. 91 | # 92 | # The default sidebars (for documents that don't match any pattern) are 93 | # defined by theme itself. Builtin themes are using these templates by 94 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 95 | # 'searchbox.html']``. 96 | # 97 | # html_sidebars = {} 98 | # html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', 99 | # 'searchbox.html']} 100 | 101 | # -- Options for HTMLHelp output --------------------------------------------- 102 | 103 | # Output file base name for HTML help builder. 104 | htmlhelp_basename = "fedgraphdoc" 105 | 106 | # -- Options for LaTeX output ------------------------------------------------ 107 | 108 | latex_elements: dict[str, str] = { 109 | # The paper size ('letterpaper' or 'a4paper'). 110 | # 111 | # 'papersize': 'letterpaper', 112 | # The font size ('10pt', '11pt' or '12pt'). 113 | # 114 | # 'pointsize': '10pt', 115 | # Additional stuff for the LaTeX preamble. 116 | # 117 | # 'preamble': '', 118 | # Latex figure (float) alignment 119 | # 120 | # 'figure_align': 'htbp', 121 | } 122 | 123 | # Grouping the document tree_ into LaTeX files. List of tuples 124 | # (source start file, target name, title, 125 | # author, documentclass [howto, manual, or own class]). 126 | latex_documents = [ 127 | (master_doc, "fedgraph.tex", "FedGraph Documentation", "FedGraph Team", "manual"), 128 | ] 129 | 130 | # -- Options for manual page output ------------------------------------------ 131 | 132 | # One entry per manual page. List of tuples 133 | # (source start file, name, description, authors, manual section). 134 | man_pages = [(master_doc, "fedgraph", "FedGraph Documentation", [author], 1)] 135 | 136 | # -- Options for Texinfo output ---------------------------------------------- 137 | 138 | # Grouping the document tree_ into Texinfo files. List of tuples 139 | # (source start file, target name, title, author, 140 | # dir menu entry, description, category) 141 | texinfo_documents = [ 142 | ( 143 | master_doc, 144 | "fedgraph", 145 | "FedGraph Documentation", 146 | author, 147 | "FedGraph", 148 | "One line description of project.", 149 | "Miscellaneous", 150 | ), 151 | ] 152 | 153 | # -- Extension configuration ------------------------------------------------- 154 | from sphinx_gallery.sorting import FileNameSortKey 155 | 156 | html_static_path = [] 157 | 158 | sphinx_gallery_conf = { 159 | "examples_dirs": "../tutorials/", 160 | "gallery_dirs": "tutorials/", 161 | "within_subsection_order": FileNameSortKey, 162 | "filename_pattern": ".py", 163 | "download_all_examples": False, 164 | } 165 | # -- Options for intersphinx extension --------------------------------------- 166 | 167 | # Example configuration for intersphinx: refer to the Python standard library. 168 | intersphinx_mapping = { 169 | "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), 170 | "numpy": ("https://numpy.org/doc/stable/", None), 171 | "scipy": ("https://docs.scipy.org/doc/scipy/", None), 172 | "sklearn": ("https://scikit-learn.org/stable/", None), 173 | "networkx": ("https://networkx.org/documentation/stable/", None), 174 | "torch": ("https://pytorch.org/docs/master", None), 175 | "torch_geometric": ("https://pytorch-geometric.readthedocs.io/en/latest", None), 176 | } 177 | -------------------------------------------------------------------------------- /docs/dev_script/kuberflow_sample.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import yaml 4 | from kubernetes import client, config, utils 5 | from kubernetes.client import ApiException 6 | from kubernetes.stream import portforward 7 | 8 | config.load_kube_config() 9 | 10 | api_client = client.ApiClient() 11 | v1 = client.CoreV1Api(api_client) 12 | custom_api = client.CustomObjectsApi(api_client) 13 | 14 | 15 | def create_resource_from_yaml(file_path): 16 | with open(file_path) as f: 17 | resource_yaml = yaml.safe_load(f) 18 | 19 | try: 20 | utils.create_from_yaml(api_client, file_path) 21 | print(f"Resource from {file_path} created successfully.") 22 | except ApiException as e: 23 | print(f"Exception when creating resource: {e}") 24 | 25 | 26 | def wait_for_pods(namespace, label_selector, target_phase="Running", timeout=600): 27 | start_time = time.time() 28 | while time.time() - start_time < timeout: 29 | pods = v1.list_namespaced_pod(namespace, label_selector=label_selector).items 30 | all_in_target_phase = all(pod.status.phase == target_phase for pod in pods) 31 | 32 | if all_in_target_phase: 33 | print( 34 | f"All pods with label '{label_selector}' are in {target_phase} state." 35 | ) 36 | return True 37 | else: 38 | print(f"Waiting for pods to reach {target_phase} state...") 39 | time.sleep(10) 40 | 41 | print( 42 | f"Timeout reached: Pods did not reach {target_phase} state within {timeout} seconds." 43 | ) 44 | return False 45 | -------------------------------------------------------------------------------- /docs/dev_script/processing_script_GC.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | 4 | import pandas as pd 5 | 6 | 7 | def process_log(log_content): 8 | experiments = [] 9 | current_experiment = {} 10 | 11 | for line in log_content.splitlines(): 12 | experiment_match = re.match( 13 | r"Running experiment with: Algorithm=([^,]+),\s*Dataset=([^,]+),\s*Number of Trainers=(\d+)", 14 | line, 15 | ) 16 | 17 | if experiment_match: 18 | if current_experiment: 19 | experiments.append(current_experiment) 20 | current_experiment = { 21 | "Algorithm": experiment_match.group(1), 22 | "Dataset": experiment_match.group(2), 23 | "Number of Trainers": int(experiment_match.group(3)), 24 | } 25 | pretrain_mode = True 26 | train_mode = False 27 | 28 | pretrain_time_match = re.search(r"pretrain_time: (\d+\.\d+)", line) 29 | if pretrain_time_match: 30 | pretrain_mode = True 31 | train_mode = False 32 | current_experiment["Pretrain Time"] = float(pretrain_time_match.group(1)) 33 | 34 | pretrain_max_trainer_memory_match = re.search( 35 | r"Log Max memory for Large(\d+): (\d+\.\d+)", line 36 | ) 37 | if pretrain_max_trainer_memory_match and pretrain_mode: 38 | current_experiment[ 39 | f"Pretrain Max Trainer Memory{pretrain_max_trainer_memory_match.group(1)}" 40 | ] = float(pretrain_max_trainer_memory_match.group(2)) 41 | 42 | pretrain_max_server_memory_match = re.search( 43 | r"Log Max memory for Server: (\d+\.\d+)", line 44 | ) 45 | if pretrain_max_server_memory_match and pretrain_mode: 46 | current_experiment["Pretrain Max Server Memory"] = float( 47 | pretrain_max_server_memory_match.group(1) 48 | ) 49 | 50 | pretrain_network_match = re.search(r"Log ([^,]+) network: (\d+\.\d+)", line) 51 | if pretrain_network_match and pretrain_mode: 52 | current_experiment[ 53 | f"Pretrain Network {pretrain_network_match.group(1)}" 54 | ] = float(pretrain_network_match.group(2)) 55 | if re.search("Pretrain end time recorded and duration set to gauge.", line): 56 | pretrain_mode = False 57 | train_mode = True 58 | 59 | train_time_match = re.search(r"train_time: (\d+\.\d+)", line) 60 | if train_time_match: 61 | current_experiment["Train Time"] = float(train_time_match.group(1)) 62 | 63 | train_max_trainer_memory_match = re.search( 64 | r"Log Max memory for Large(\d+): (\d+\.\d+)", line 65 | ) 66 | if train_max_trainer_memory_match and train_mode: 67 | current_experiment[ 68 | f"Train Max Trainer Memory{train_max_trainer_memory_match.group(1)}" 69 | ] = float(train_max_trainer_memory_match.group(2)) 70 | 71 | train_max_server_memory_match = re.search( 72 | r"Log Max memory for Server: (\d+\.\d+)", line 73 | ) 74 | if train_max_server_memory_match and train_mode: 75 | current_experiment["Train Max Server Memory"] = float( 76 | train_max_server_memory_match.group(1) 77 | ) 78 | 79 | train_network_match = re.search(r"Log ([^,]+) network: (\d+\.\d+)", line) 80 | if train_network_match and train_mode: 81 | current_experiment[ 82 | f"Train Network {(train_network_match.group(1))}" 83 | ] = float(train_network_match.group(2)) 84 | average_accuracy_match = re.search(r"Average test accuracy: (\d+\.\d+)", line) 85 | if average_accuracy_match: 86 | current_experiment["Average Test Accuracy"] = float( 87 | average_accuracy_match.group(1) 88 | ) 89 | 90 | if current_experiment: 91 | experiments.append(current_experiment) 92 | 93 | return pd.DataFrame(experiments) 94 | 95 | 96 | def load_log_file(file_path): 97 | with open(file_path, "r", encoding="utf-8") as file: 98 | log_content = file.read() 99 | return log_content 100 | 101 | 102 | file_path = "new_memory.log" 103 | log_content = load_log_file(file_path) 104 | df = process_log(log_content) 105 | 106 | 107 | def reorder_dataframe_columns(df): 108 | desired_columns = [ 109 | "Algorithm", 110 | "Dataset", 111 | "Number of Trainers", 112 | "Average Test Accuracy", 113 | ] 114 | 115 | new_column_order = desired_columns + [ 116 | col for col in df.columns if col not in desired_columns 117 | ] 118 | 119 | df = df[new_column_order] 120 | 121 | return df 122 | 123 | 124 | df = reorder_dataframe_columns(df) 125 | csv_file_path = "new_memory.csv" 126 | df.to_csv(csv_file_path) 127 | print(df.iloc[0, :]) 128 | -------------------------------------------------------------------------------- /docs/dev_script/processing_script_LP.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | 4 | import pandas as pd 5 | 6 | 7 | def process_log(log_content): 8 | experiments = [] 9 | current_experiment = {} 10 | 11 | for line in log_content.splitlines(): 12 | experiment_match = re.match( 13 | r"Running experiment with: Dataset=([^,]+),\s*Number of Trainers=(\d+),\s*Distribution Type=([^,]+),\s*IID Beta=([0-9.]+),\s*Number of Hops=(\d+),\s*Batch Size=([^,]+)", 14 | line, 15 | ) 16 | 17 | if experiment_match: 18 | if current_experiment: 19 | experiments.append(current_experiment) 20 | current_experiment = { 21 | "Dataset": experiment_match.group(1), 22 | "Number of Trainers": int(experiment_match.group(2)), 23 | "Distribution Type": experiment_match.group(3), 24 | "IID Beta": float(experiment_match.group(4)), 25 | "Number of Hops": int(experiment_match.group(5)), 26 | "Batch Size": int(experiment_match.group(6)), 27 | } 28 | pretrain_mode = True 29 | train_mode = False 30 | 31 | pretrain_time_match = re.search(r"pretrain_time: (\d+\.\d+)", line) 32 | if pretrain_time_match: 33 | pretrain_mode = True 34 | train_mode = False 35 | current_experiment["Pretrain Time"] = float(pretrain_time_match.group(1)) 36 | 37 | pretrain_max_trainer_memory_match = re.search( 38 | r"Log Max memory for Large(\d+): (\d+\.\d+)", line 39 | ) 40 | if pretrain_max_trainer_memory_match and pretrain_mode: 41 | current_experiment[ 42 | f"Pretrain Max Trainer Memory{pretrain_max_trainer_memory_match.group(1)}" 43 | ] = float(pretrain_max_trainer_memory_match.group(2)) 44 | 45 | pretrain_max_server_memory_match = re.search( 46 | r"Log Max memory for Server: (\d+\.\d+)", line 47 | ) 48 | if pretrain_max_server_memory_match and pretrain_mode: 49 | current_experiment["Pretrain Max Server Memory"] = float( 50 | pretrain_max_server_memory_match.group(1) 51 | ) 52 | 53 | pretrain_network_match = re.search(r"Log ([^,]+) network: (\d+\.\d+)", line) 54 | if pretrain_network_match and pretrain_mode: 55 | current_experiment[ 56 | f"Pretrain Network {pretrain_network_match.group(1)}" 57 | ] = float(pretrain_network_match.group(2)) 58 | if re.search("Pretrain end time recorded and duration set to gauge.", line): 59 | pretrain_mode = False 60 | train_mode = True 61 | 62 | train_time_match = re.search(r"train_time: (\d+\.\d+)", line) 63 | if train_time_match: 64 | current_experiment["Train Time"] = float(train_time_match.group(1)) 65 | 66 | train_max_trainer_memory_match = re.search( 67 | r"Log Max memory for Large(\d+): (\d+\.\d+)", line 68 | ) 69 | if train_max_trainer_memory_match and train_mode: 70 | current_experiment[ 71 | f"Train Max Trainer Memory{train_max_trainer_memory_match.group(1)}" 72 | ] = float(train_max_trainer_memory_match.group(2)) 73 | 74 | train_max_server_memory_match = re.search( 75 | r"Log Max memory for Server: (\d+\.\d+)", line 76 | ) 77 | if train_max_server_memory_match and train_mode: 78 | current_experiment["Train Max Server Memory"] = float( 79 | train_max_server_memory_match.group(1) 80 | ) 81 | 82 | train_network_match = re.search(r"Log ([^,]+) network: (\d+\.\d+)", line) 83 | if train_network_match and train_mode: 84 | current_experiment[ 85 | f"Train Network {(train_network_match.group(1))}" 86 | ] = float(train_network_match.group(2)) 87 | average_accuracy_match = re.search( 88 | r"Predict Day 20 average auc score: (\d+\.\d+) hit rate: (\d+\.\d+)", line 89 | ) 90 | if average_accuracy_match: 91 | print(23) 92 | current_experiment["Average Test AUC"] = float( 93 | average_accuracy_match.group(1) 94 | ) 95 | current_experiment["Hit Rate"] = float(average_accuracy_match.group(2)) 96 | 97 | if current_experiment: 98 | experiments.append(current_experiment) 99 | 100 | return pd.DataFrame(experiments) 101 | 102 | 103 | def load_log_file(file_path): 104 | with open(file_path, "r", encoding="utf-8") as file: 105 | log_content = file.read() 106 | return log_content 107 | 108 | 109 | file_path = "LP2.log" 110 | log_content = load_log_file(file_path) 111 | df = process_log(log_content) 112 | 113 | 114 | def reorder_dataframe_columns(df): 115 | desired_columns = [ 116 | "Dataset", 117 | "Number of Trainers", 118 | "Distribution Type", 119 | "IID Beta", 120 | "Number of Hops", 121 | "Batch Size", 122 | "Average Test AUC", 123 | "Hit Rate", 124 | ] 125 | 126 | new_column_order = desired_columns + [ 127 | col for col in df.columns if col not in desired_columns 128 | ] 129 | 130 | df = df[new_column_order] 131 | 132 | return df 133 | 134 | 135 | df = reorder_dataframe_columns(df) 136 | csv_file_path = "LP2.csv" 137 | df.to_csv(csv_file_path) 138 | print(df.iloc[0, :]) 139 | -------------------------------------------------------------------------------- /docs/dev_script/processing_script_NC.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | 4 | import pandas as pd 5 | 6 | 7 | def process_log(log_content): 8 | experiments = [] 9 | current_experiment = {} 10 | 11 | for line in log_content.splitlines(): 12 | experiment_match = re.match( 13 | r"Running experiment with: Dataset=([^,]+),\s*Number of Trainers=(\d+),\s*Distribution Type=([^,]+),\s*IID Beta=([0-9.]+),\s*Number of Hops=(\d+),\s*Batch Size=([^,]+)", 14 | line, 15 | ) 16 | 17 | if experiment_match: 18 | if current_experiment: 19 | experiments.append(current_experiment) 20 | current_experiment = { 21 | "Dataset": experiment_match.group(1), 22 | "Number of Trainers": int(experiment_match.group(2)), 23 | "Distribution Type": experiment_match.group(3), 24 | "IID Beta": float(experiment_match.group(4)), 25 | "Number of Hops": int(experiment_match.group(5)), 26 | "Batch Size": int(experiment_match.group(6)), 27 | } 28 | pretrain_mode = True 29 | train_mode = False 30 | 31 | pretrain_time_match = re.search(r"pretrain_time: (\d+\.\d+)", line) 32 | if pretrain_time_match: 33 | pretrain_mode = True 34 | train_mode = False 35 | current_experiment["Pretrain Time"] = float(pretrain_time_match.group(1)) 36 | 37 | pretrain_max_trainer_memory_match = re.search( 38 | r"Log Max memory for Large(\d+): (\d+\.\d+)", line 39 | ) 40 | if pretrain_max_trainer_memory_match and pretrain_mode: 41 | current_experiment[ 42 | f"Pretrain Max Trainer Memory{pretrain_max_trainer_memory_match.group(1)}" 43 | ] = float(pretrain_max_trainer_memory_match.group(2)) 44 | 45 | pretrain_max_server_memory_match = re.search( 46 | r"Log Max memory for Server: (\d+\.\d+)", line 47 | ) 48 | if pretrain_max_server_memory_match and pretrain_mode: 49 | current_experiment["Pretrain Max Server Memory"] = float( 50 | pretrain_max_server_memory_match.group(1) 51 | ) 52 | 53 | pretrain_network_match = re.search(r"Log ([^,]+) network: (\d+\.\d+)", line) 54 | if pretrain_network_match and pretrain_mode: 55 | current_experiment[ 56 | f"Pretrain Network {pretrain_network_match.group(1)}" 57 | ] = float(pretrain_network_match.group(2)) 58 | if re.search("Pretrain end time recorded and duration set to gauge.", line): 59 | pretrain_mode = False 60 | train_mode = True 61 | 62 | train_time_match = re.search(r"train_time: (\d+\.\d+)", line) 63 | if train_time_match: 64 | current_experiment["Train Time"] = float(train_time_match.group(1)) 65 | 66 | train_max_trainer_memory_match = re.search( 67 | r"Log Max memory for Large(\d+): (\d+\.\d+)", line 68 | ) 69 | if train_max_trainer_memory_match and train_mode: 70 | current_experiment[ 71 | f"Train Max Trainer Memory{train_max_trainer_memory_match.group(1)}" 72 | ] = float(train_max_trainer_memory_match.group(2)) 73 | 74 | train_max_server_memory_match = re.search( 75 | r"Log Max memory for Server: (\d+\.\d+)", line 76 | ) 77 | if train_max_server_memory_match and train_mode: 78 | current_experiment["Train Max Server Memory"] = float( 79 | train_max_server_memory_match.group(1) 80 | ) 81 | 82 | train_network_match = re.search(r"Log ([^,]+) network: (\d+\.\d+)", line) 83 | if train_network_match and train_mode: 84 | current_experiment[ 85 | f"Train Network {(train_network_match.group(1))}" 86 | ] = float(train_network_match.group(2)) 87 | average_accuracy_match = re.search(r"Average test accuracy: (\d+\.\d+)", line) 88 | if average_accuracy_match: 89 | current_experiment["Average Test Accuracy"] = float( 90 | average_accuracy_match.group(1) 91 | ) 92 | 93 | if current_experiment: 94 | experiments.append(current_experiment) 95 | 96 | return pd.DataFrame(experiments) 97 | 98 | 99 | def load_log_file(file_path): 100 | with open(file_path, "r", encoding="utf-8") as file: 101 | log_content = file.read() 102 | return log_content 103 | 104 | 105 | file_path = "1000.log" 106 | log_content = load_log_file(file_path) 107 | df = process_log(log_content) 108 | 109 | 110 | def reorder_dataframe_columns(df): 111 | desired_columns = [ 112 | "Dataset", 113 | "Number of Trainers", 114 | "Distribution Type", 115 | "IID Beta", 116 | "Number of Hops", 117 | "Batch Size", 118 | "Average Test Accuracy", 119 | ] 120 | 121 | new_column_order = desired_columns + [ 122 | col for col in df.columns if col not in desired_columns 123 | ] 124 | 125 | df = df[new_column_order] 126 | 127 | return df 128 | 129 | 130 | df = reorder_dataframe_columns(df) 131 | csv_file_path = "1000.csv" 132 | df.to_csv(csv_file_path) 133 | print(df.iloc[0, :]) 134 | -------------------------------------------------------------------------------- /docs/dev_script/save_graph_node_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | 4 | import numpy as np 5 | import torch 6 | import torch_geometric 7 | 8 | from fedgraph.utils_nc import get_in_comm_indexes 9 | 10 | # from huggingface_hub import HfApi, HfFolder 11 | 12 | 13 | def label_dirichlet_partition( 14 | labels: np.array, N: int, K: int, n_parties: int, beta: float 15 | ) -> list: 16 | min_size = 0 17 | min_require_size = 10 18 | 19 | split_data_indexes = [] 20 | 21 | # Separate the indices of nodes with label -1 22 | idx_minus_one = np.where(labels == -1)[0] 23 | np.random.shuffle(idx_minus_one) 24 | split_minus_one = np.array_split(idx_minus_one, n_parties) 25 | 26 | while min_size < min_require_size: 27 | idx_batch: list[list[int]] = [[] for _ in range(n_parties)] 28 | for k in range(K): 29 | idx_k = np.where(labels == k)[0] 30 | np.random.shuffle(idx_k) 31 | proportions = np.random.dirichlet(np.repeat(beta, n_parties)) 32 | 33 | proportions = np.array( 34 | [ 35 | p * (len(idx_j) < N / n_parties) 36 | for p, idx_j in zip(proportions, idx_batch) 37 | ] 38 | ) 39 | 40 | proportions = proportions / proportions.sum() 41 | 42 | proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] 43 | 44 | idx_batch = [ 45 | idx_j + idx.tolist() 46 | for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions)) 47 | ] 48 | min_size = min([len(idx_j) for idx_j in idx_batch]) 49 | 50 | # Append the indices of nodes with label -1 to the respective groups 51 | for j in range(n_parties): 52 | idx_batch[j] = np.concatenate((idx_batch[j], split_minus_one[j])) 53 | np.random.shuffle(idx_batch[j]) 54 | split_data_indexes.append(idx_batch[j]) 55 | return split_data_indexes 56 | 57 | 58 | def save_trainer_data_to_hugging_face( 59 | trainer_id, 60 | local_node_index, 61 | communicate_node_index, 62 | adj, 63 | train_labels, 64 | test_labels, 65 | features, 66 | idx_train, 67 | idx_test, 68 | ): 69 | repo_name = f"FedGraph/fedgraph_{args.dataset}_{args.n_trainer}trainer_{args.num_hops}hop_iid_beta_{args.iid_beta}_trainer_id_{trainer_id}" 70 | user = HfFolder.get_token() 71 | 72 | api = HfApi() 73 | try: 74 | api.create_repo( 75 | repo_id=repo_name, token=user, repo_type="dataset", exist_ok=True 76 | ) 77 | except Exception as e: 78 | print(f"Failed to create or access the repository: {str(e)}") 79 | return 80 | 81 | def save_tensor_to_hf(tensor, file_name): 82 | buffer = BytesIO() 83 | torch.save(tensor, buffer) 84 | buffer.seek(0) 85 | api.upload_file( 86 | path_or_fileobj=buffer, 87 | path_in_repo=file_name, 88 | repo_id=repo_name, 89 | repo_type="dataset", 90 | token=user, 91 | ) 92 | 93 | save_tensor_to_hf(local_node_index, "local_node_index.pt") 94 | save_tensor_to_hf(communicate_node_index, "communicate_node_index.pt") 95 | save_tensor_to_hf(adj, "adj.pt") 96 | save_tensor_to_hf(train_labels, "train_labels.pt") 97 | save_tensor_to_hf(test_labels, "test_labels.pt") 98 | save_tensor_to_hf(features, "features.pt") 99 | save_tensor_to_hf(idx_train, "idx_train.pt") 100 | save_tensor_to_hf(idx_test, "idx_test.pt") 101 | 102 | print(f"Uploaded data for trainer {trainer_id}") 103 | 104 | 105 | def save_all_trainers_data( 106 | split_node_indexes, 107 | communicate_node_indexes, 108 | edge_indexes_clients, 109 | labels, 110 | features, 111 | in_com_train_node_indexes, 112 | in_com_test_node_indexes, 113 | n_trainer, 114 | ): 115 | for i in range(n_trainer): 116 | save_trainer_data_to_hugging_face( 117 | trainer_id=i, 118 | local_node_index=split_node_indexes[i], 119 | communicate_node_index=communicate_node_indexes[i], 120 | adj=edge_indexes_clients[i], 121 | train_labels=labels[communicate_node_indexes[i]][ 122 | in_com_train_node_indexes[i] 123 | ], 124 | test_labels=labels[communicate_node_indexes[i]][ 125 | in_com_test_node_indexes[i] 126 | ], 127 | features=features[split_node_indexes[i]], 128 | idx_train=in_com_train_node_indexes[i], 129 | idx_test=in_com_test_node_indexes[i], 130 | ) 131 | 132 | 133 | def FedGCN_load_data(dataset_str: str) -> tuple: 134 | if dataset_str in [ 135 | "ogbn-arxiv", 136 | "ogbn-products", 137 | "ogbn-papers100M", 138 | ]: # 'ogbn-mag' is heteregeneous 139 | from ogb.nodeproppred import PygNodePropPredDataset 140 | 141 | # Download and process data at './dataset/.' 142 | 143 | dataset = PygNodePropPredDataset( 144 | name=dataset_str, transform=torch_geometric.transforms.ToSparseTensor() 145 | ) 146 | 147 | split_idx = dataset.get_idx_split() 148 | idx_train, idx_val, idx_test = ( 149 | split_idx["train"], 150 | split_idx["valid"], 151 | split_idx["test"], 152 | ) 153 | 154 | idx_train = torch.LongTensor(idx_train) 155 | idx_val = torch.LongTensor(idx_val) 156 | idx_test = torch.LongTensor(idx_test) 157 | data = dataset[0] 158 | 159 | features = data.x 160 | print(features.shape) 161 | labels = data.y.reshape(-1) 162 | if dataset_str == "ogbn-arxiv": 163 | adj = data.adj_t.to_symmetric() 164 | else: 165 | adj = data.adj_t 166 | return features.float(), adj, labels, idx_train, idx_val, idx_test 167 | 168 | 169 | def run(): 170 | features, adj, labels, idx_train, idx_val, idx_test = FedGCN_load_data(args.dataset) 171 | class_num = int(np.nanmax(labels)) + 1 172 | print("class_num", class_num) 173 | labels[torch.isnan(labels)] = -1 174 | labels = labels.long() 175 | 176 | row, col, edge_attr = adj.coo() 177 | edge_index = torch.stack([row, col], dim=0) 178 | 179 | print(f"gpu usage: {args.gpu}") 180 | if args.gpu: 181 | edge_index = edge_index.to("cuda") 182 | 183 | split_node_indexes = label_dirichlet_partition( 184 | labels, len(labels), class_num, args.n_trainer, beta=args.iid_beta 185 | ) 186 | 187 | for i in range(args.n_trainer): 188 | split_node_indexes[i] = np.array(split_node_indexes[i]) 189 | print(split_node_indexes[i].shape) 190 | split_node_indexes[i].sort() 191 | split_node_indexes[i] = torch.tensor(split_node_indexes[i]) 192 | 193 | ( 194 | communicate_node_indexes, 195 | in_com_train_node_indexes, 196 | in_com_test_node_indexes, 197 | edge_indexes_clients, 198 | ) = get_in_comm_indexes( 199 | edge_index, 200 | split_node_indexes, 201 | args.n_trainer, 202 | args.num_hops, 203 | idx_train, 204 | idx_test, 205 | ) 206 | save_all_trainers_data( 207 | split_node_indexes=split_node_indexes, 208 | communicate_node_indexes=communicate_node_indexes, 209 | edge_indexes_clients=edge_indexes_clients, 210 | labels=labels, 211 | features=features, 212 | in_com_train_node_indexes=in_com_train_node_indexes, 213 | in_com_test_node_indexes=in_com_test_node_indexes, 214 | n_trainer=args.n_trainer, 215 | ) 216 | 217 | 218 | np.random.seed(42) 219 | torch.manual_seed(42) 220 | 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument("-d", "--dataset", default="ogbn-arxiv", type=str) 223 | 224 | parser.add_argument("-n", "--n_trainer", default=5, type=int) 225 | parser.add_argument("-g", "--gpu", action="store_true") # if -g, use gpu 226 | parser.add_argument("-iid_b", "--iid_beta", default=10000, type=float) 227 | parser.add_argument("-nhop", "--num_hops", default=1, type=int) 228 | 229 | args = parser.parse_args() 230 | 231 | 232 | if __name__ == "__main__": 233 | run() 234 | -------------------------------------------------------------------------------- /docs/fedgraph.data_process.rst: -------------------------------------------------------------------------------- 1 | Data Process 2 | ============ 3 | 4 | .. automodule:: fedgraph.data_process 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.federated_methods.rst: -------------------------------------------------------------------------------- 1 | Federated Graph Methods 2 | ========== 3 | 4 | .. automodule:: fedgraph.federated_methods 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.gnn_models.rst: -------------------------------------------------------------------------------- 1 | GNN Models 2 | ========== 3 | 4 | .. automodule:: fedgraph.gnn_models 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.monitor_class.rst: -------------------------------------------------------------------------------- 1 | .. _fedgraph-monitor-class: 2 | 3 | Monitor Class 4 | ============= 5 | 6 | .. automodule:: fedgraph.monitor_class 7 | :members: 8 | :undoc-members: 9 | -------------------------------------------------------------------------------- /docs/fedgraph.server_class.rst: -------------------------------------------------------------------------------- 1 | Server Class 2 | ============ 3 | 4 | .. automodule:: fedgraph.server_class 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.setup_ray_cluster.rst: -------------------------------------------------------------------------------- 1 | Set Up the Ray Cluster 2 | ====================== 3 | 4 | This section provides a step-by-step guide to set up a Ray Cluster on AWS EKS. 5 | 6 | It is recommended to use the following script to set up the cluster. The script will guide you through the setup process on AWS, including Docker image building, EKS cluster creation, and deployment of Ray on Kubernetes. 7 | 8 | 9 | Components Overview 10 | ------------------- 11 | 12 | The following table outlines the key components used in setting up a Ray cluster on AWS EKS: 13 | 14 | .. list-table:: Ray Cluster Components 15 | :widths: 25 75 16 | :header-rows: 1 17 | 18 | * - Component 19 | - Purpose 20 | * - Ray 21 | - Provides distributed computing for machine learning (e.g., FedGraph tasks). 22 | * - Kubernetes 23 | - Orchestrates and manages Ray's deployment in AWS EKS. 24 | * - AWS EKS 25 | - Provides the cloud infrastructure for running Kubernetes and Ray. 26 | * - KubeRay 27 | - Automates Ray cluster setup and management in Kubernetes. 28 | * - Helm 29 | - Installs KubeRay and other Kubernetes services. 30 | * - Ray Dashboard, Prometheus, Grafana 31 | - Monitor the Ray cluster’s performance. 32 | 33 | ======= 34 | 35 | Prerequisites 36 | ------------- 37 | Before you begin, ensure you have the following: 38 | 39 | * AWS CLI installed and configured. 40 | * Docker installed and running. 41 | * Helm installed. 42 | * kubectl installed. 43 | * AWS ECR credentials. 44 | * AWS EKS access. 45 | 46 | Steps to Set Up Ray Cluster on AWS EKS 47 | -------------------------------------- 48 | 49 | 1. **Configure AWS Credentials** 50 | 51 | Run the following commands to set up AWS credentials: 52 | 53 | .. code-block:: bash 54 | 55 | aws configure set aws_access_key_id 56 | aws configure set aws_secret_access_key 57 | aws configure set region 58 | 59 | Make sure to replace ``, ``, and `` with your actual credentials and region. 60 | 61 | 2. **Log in to AWS ECR Public** 62 | 63 | To push Docker images to AWS ECR, you need to log in to the public ECR: 64 | 65 | .. code-block:: bash 66 | 67 | aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws 68 | 69 | 3. **Build and Push Docker Image to ECR** 70 | 71 | To build and push the Docker image, run the following commands: 72 | 73 | .. code-block:: bash 74 | 75 | docker buildx build --platform linux/amd64 -t public.ecr.aws/i7t1s5i1/fedgraph:img . --push 76 | 77 | 4. **Create EKS Cluster** 78 | 79 | Create a dynamic EKS cluster with the following command: 80 | 81 | .. code-block:: bash 82 | 83 | eksctl create cluster -f eks_cluster_config.yaml --timeout=60m 84 | 85 | 5. **Update Kubeconfig** 86 | 87 | Update your kubeconfig file to access the newly created EKS cluster: 88 | 89 | .. code-block:: bash 90 | 91 | aws eks --region update-kubeconfig --name 92 | 93 | 6. **Clone KubeRay Repository and Install Prometheus/Grafana** 94 | 95 | Clone the KubeRay repository and install monitoring tools: 96 | 97 | .. code-block:: bash 98 | 99 | git clone https://github.com/ray-project/kuberay.git 100 | cd kuberay 101 | ./install/prometheus/install.sh 102 | 103 | 7. **Install KubeRay Operator** 104 | 105 | To manage Ray on Kubernetes, you need to install the KubeRay operator: 106 | 107 | .. code-block:: bash 108 | 109 | helm repo add kuberay https://ray-project.github.io/kuberay-helm/ 110 | helm repo update 111 | helm install kuberay-operator kuberay/kuberay-operator --version 1.1.1 112 | 113 | 8. **Deploy Ray Kubernetes Cluster** 114 | 115 | Apply the Kubernetes configuration to deploy Ray on EKS: 116 | 117 | .. code-block:: bash 118 | 119 | kubectl apply -f ray_kubernetes_cluster.yaml 120 | kubectl apply -f ray_kubernetes_ingress.yaml 121 | 122 | 9. **Verify Pod Status** 123 | 124 | Check the status of the pods to ensure that they are running: 125 | 126 | .. code-block:: bash 127 | 128 | kubectl get pods 129 | 130 | 10. **Port Forwarding for Ray Dashboard, Prometheus, and Grafana** 131 | 132 | Forward the necessary ports for accessing the Ray dashboard and monitoring tools: 133 | 134 | .. code-block:: bash 135 | 136 | kubectl port-forward service/raycluster-autoscaler-head-svc 8265:8265 & 137 | kubectl port-forward raycluster-autoscaler-head-47mzs 8080:8080 & 138 | kubectl port-forward prometheus-prometheus-kube-prometheus-prometheus-0 -n prometheus-system 9090:9090 & 139 | kubectl port-forward deployment/prometheus-grafana -n prometheus-system 3000:3000 & 140 | 141 | 11. **Final Check** 142 | 143 | To ensure everything is set up correctly, perform a final check: 144 | 145 | .. code-block:: bash 146 | 147 | kubectl get pods --all-namespaces -o wide 148 | 149 | 12. **Submit a Ray Job (Optional)** 150 | 151 | If you want to submit a Ray job, use the following command: 152 | 153 | .. code-block:: bash 154 | 155 | ray job submit --runtime-env-json '{"working_dir": "./", "excludes": [".git"]}' --address http://localhost:8265 -- python3 run.py 156 | 157 | 13. **Stop a Ray Job (Optional)** 158 | 159 | To stop a Ray job, use: 160 | 161 | .. code-block:: bash 162 | 163 | ray job stop --address http://localhost:8265 164 | 165 | 14. **Clean Up Resources** 166 | 167 | To clean up resources, delete the RayCluster and EKS cluster: 168 | 169 | .. code-block:: bash 170 | 171 | kubectl delete -f ray_kubernetes_cluster.yaml 172 | kubectl delete -f ray_kubernetes_ingress.yaml 173 | kubectl get nodes -o name | xargs kubectl delete 174 | eksctl delete cluster --region --name 175 | 176 | Setup completed successfully! 177 | -------------------------------------------------------------------------------- /docs/fedgraph.train_func.rst: -------------------------------------------------------------------------------- 1 | Training Function 2 | ================= 3 | 4 | .. automodule:: fedgraph.train_func 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.trainer_class.rst: -------------------------------------------------------------------------------- 1 | Trainer Class 2 | ============= 3 | 4 | .. automodule:: fedgraph.trainer_class 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.utils_gc.rst: -------------------------------------------------------------------------------- 1 | Utility Functions for Graph Classification 2 | ================= 3 | 4 | .. automodule:: fedgraph.utils_gc 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.utils_lp.rst: -------------------------------------------------------------------------------- 1 | Utility Functions for Link Prediction 2 | ================= 3 | 4 | .. automodule:: fedgraph.utils_lp 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/fedgraph.utils_nc.rst: -------------------------------------------------------------------------------- 1 | Utility Functions for Node Classification 2 | ================= 3 | 4 | .. automodule:: fedgraph.utils_nc 5 | :members: 6 | :undoc-members: 7 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TDC documentation master file, created by 2 | sphinx-quickstart on Wed Jul 7 12:08:39 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | FedGraph 7 | ======== 8 | 9 | `Documentation `__ \| `Paper `__ 10 | 11 | 12 | **FedGraph** *(Federated Graph)* is a library built on top of [PyTorch Geometric (PyG)](https://www.pyg.org/), 13 | [Ray](https://docs.ray.io/), and [PyTorch](https://pytorch.org/) to easily train Graph Neural Networks 14 | under federated or distributed settings. 15 | 16 | It supports various federated training methods of graph neural networks under simulated and real federated environments and supports communication between clients and the central server for model update and information aggregation. 17 | 18 | 19 | Main Focus 20 | ---------------------- 21 | 22 | - **Federated Node Classification with Cross-Client Edges**: Our library supports communicating information stored in other clients without affecting the privacy of users. 23 | - **Federated Link Prediction on Dynamic Graphs**: Our library supports balancing temporal heterogeneity across clients with privacy preservation. 24 | - **Federated Graph Classification**: Our library supports federated graph classification with non-IID graphs. 25 | 26 | 27 | 28 | Cross Platform Training 29 | ------------------------- 30 | 31 | - We support federated training across Linux, macOS, and Windows operating systems. 32 | 33 | Library Highlights 34 | ------------------ 35 | 36 | Whether you are a federated learning researcher or a first-time user of federated learning toolkits, here are some reasons to try out FedGraph for federated learning on graph-structured data. 37 | 38 | - **Easy-to-use and unified API**: All it takes is 10-20 lines of code to get started with training a federated GNN model. GNN models are PyTorch models provided by PyG and DGL. The federated training process is handled by Ray. We abstract away the complexity of federated graph training and provide a unified API for training and evaluating FedGraph models. 39 | 40 | - **Various FedGraph methods**: Most of the state-of-the-art federated graph training methods have been implemented by library developers or authors of research papers and are ready to be applied. 41 | 42 | - **Great flexibility**: Existing FedGraph models can easily be extended for conducting your research. Simply inherit the base class of trainers and implement your methods. 43 | 44 | - **Large-scale real-world FedGraph Training**: We focus on the need for FedGraph applications in challenging real-world scenarios with privacy preservation, and support learning on large-scale graphs across multiple clients. 45 | 46 | 47 | ---- 48 | 49 | 50 | .. toctree:: 51 | :maxdepth: 2 52 | :hidden: 53 | :caption: Getting Started 54 | 55 | install 56 | tutorials/index 57 | fedgraph.setup_ray_cluster 58 | 59 | .. toctree:: 60 | :maxdepth: 2 61 | :hidden: 62 | :caption: API References 63 | 64 | fedgraph.data_process 65 | fedgraph.federated_methods 66 | fedgraph.gnn_models 67 | fedgraph.server_class 68 | fedgraph.train_func 69 | fedgraph.trainer_class 70 | fedgraph.monitor_class 71 | fedgraph.utils_gc 72 | fedgraph.utils_lp 73 | fedgraph.utils_nc 74 | 75 | .. toctree:: 76 | :maxdepth: 2 77 | :hidden: 78 | :caption: Additional Information 79 | 80 | cite 81 | reference 82 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | 5 | It is recommended to use **pip** for installation. 6 | Please make sure **the latest version** is installed, as FedGraph is updated frequently: 7 | 8 | .. code-block:: bash 9 | 10 | pip install fedgraph # normal install 11 | pip install --upgrade fedgraph # or update if needed 12 | 13 | 14 | Alternatively, you could clone and run setup.py file: 15 | 16 | .. code-block:: bash 17 | 18 | git clone https://github.com/FedGraph/fedgraph.git 19 | cd fedgraph 20 | pip install . 21 | 22 | **Required Dependencies**\ : 23 | 24 | * python>=3.8 25 | * ray 26 | * tensorboard 27 | 28 | **Note on PyG and PyTorch Installation**\ : 29 | FedGraph depends on `torch `_ and `torch_geometric (including its optional dependencies) `_. 30 | To streamline the installation, FedGraph does **NOT** install these libraries for you. 31 | Please install them from the above links for running FedGraph: 32 | 33 | * torch>=2.0.0 34 | * torch_geometric>=2.3.0 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=fedgraph 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/reference.rst: -------------------------------------------------------------------------------- 1 | Reference 2 | ========= 3 | 4 | .. bibliography:: 5 | :cited: 6 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # dependencies required for documentation 2 | sphinx_rtd_theme 3 | https://download.pytorch.org/whl/cpu/torch-2.0.0%2Bcpu-cp310-cp310-linux_x86_64.whl 4 | torch_geometric 5 | https://data.pyg.org/whl/torch-2.0.0%2Bcpu/pyg_lib-0.2.0%2Bpt20cpu-cp310-cp310-linux_x86_64.whl 6 | https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_cluster-1.6.1%2Bpt20cpu-cp310-cp310-linux_x86_64.whl 7 | https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_scatter-2.1.1%2Bpt20cpu-cp310-cp310-linux_x86_64.whl 8 | https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_sparse-0.6.17%2Bpt20cpu-cp310-cp310-linux_x86_64.whl 9 | https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_spline_conv-1.2.2%2Bpt20cpu-cp310-cp310-linux_x86_64.whl 10 | torchmetrics 11 | setuptools 12 | sphinxcontrib-bibtex 13 | scikit-learn 14 | matplotlib 15 | sphinx_gallery 16 | ray[default] 17 | tensorboard 18 | attridict 19 | dtaidistance 20 | gdown 21 | numpy==1.26 22 | pandas 23 | tenseal 24 | huggingface_hub 25 | ogb 26 | -e . 27 | -------------------------------------------------------------------------------- /docs/sg_execution_times.rst: -------------------------------------------------------------------------------- 1 | 2 | :orphan: 3 | 4 | .. _sphx_glr_sg_execution_times: 5 | 6 | 7 | Computation times 8 | ================= 9 | **02:11.421** total execution time for 4 files **from all galleries**: 10 | 11 | .. container:: 12 | 13 | .. raw:: html 14 | 15 | 19 | 20 | 21 | 22 | 27 | 28 | .. list-table:: 29 | :header-rows: 1 30 | :class: table table-striped sg-datatable 31 | 32 | * - Example 33 | - Time 34 | - Mem (MB) 35 | * - :ref:`sphx_glr_tutorials_FGL_LP.py` (``../tutorials/FGL_LP.py``) 36 | - 01:32.812 37 | - 0.0 38 | * - :ref:`sphx_glr_tutorials_FGL_NC_HE.py` (``../tutorials/FGL_NC_HE.py``) 39 | - 00:21.875 40 | - 0.0 41 | * - :ref:`sphx_glr_tutorials_FGL_GC.py` (``../tutorials/FGL_GC.py``) 42 | - 00:08.598 43 | - 0.0 44 | * - :ref:`sphx_glr_tutorials_FGL_NC.py` (``../tutorials/FGL_NC.py``) 45 | - 00:08.136 46 | - 0.0 47 | -------------------------------------------------------------------------------- /docs/zreferences.bib: -------------------------------------------------------------------------------- 1 | @article{liu2022benchmarking, 2 | author = {Liu, Kay and Dou, Yingtong and Zhao, Yue and Ding, Xueying and Hu, Xiyang and Zhang, Ruitong and Ding, Kaize and Chen, Canyu and Peng, Hao and Shu, Kai and Sun, Lichao and Li, Jundong and Chen, George H. and Jia, Zhihao and Yu, Philip S.}, 3 | title = {Benchmarking Node Outlier Detection on Graphs}, 4 | journal = {arXiv preprint arXiv:2206.10071}, 5 | year = {2022}, 6 | } 7 | -------------------------------------------------------------------------------- /fedgraph/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | data_process, 3 | federated_methods, 4 | gnn_models, 5 | monitor_class, 6 | server_class, 7 | train_func, 8 | trainer_class, 9 | utils_gc, 10 | utils_lp, 11 | utils_nc, 12 | ) 13 | from .version import __version__ 14 | -------------------------------------------------------------------------------- /fedgraph/he_context.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/fedgraph/he_context.pkl -------------------------------------------------------------------------------- /fedgraph/he_training_context.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/fedgraph/he_training_context.pkl -------------------------------------------------------------------------------- /fedgraph/train_func.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import ray 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def accuracy(output: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 11 | """ 12 | This function returns the accuracy of the output with respect to the ground truth given 13 | 14 | Parameters 15 | ---------- 16 | output: torch.Tensor 17 | the output labels predicted by the model 18 | 19 | labels: torch.Tensor 20 | ground truth labels 21 | 22 | Returns 23 | ------- 24 | (tensor): torch.Tensor 25 | Accuracy of the output with respect to the ground truth given 26 | """ 27 | 28 | preds = output.max(1)[1].type_as(labels) 29 | correct = preds.eq(labels).double() 30 | correct = correct.sum() 31 | return correct / len(labels) 32 | 33 | 34 | def gc_avg_accuracy(frame: pd.DataFrame, trainers: list) -> float: 35 | """ 36 | This function calculates the weighted average accuracy of the trainers in the frame. 37 | 38 | Parameters 39 | ---------- 40 | frame: pd.DataFrame 41 | The frame containing the accuracies of the trainers 42 | trainers: list 43 | List of trainer objects 44 | 45 | Returns 46 | ------- 47 | (float): float 48 | The average accuracy of the trainers in the frame 49 | """ 50 | 51 | # weighted average accuracy 52 | accs = frame["test_acc"] 53 | weights = [ray.get(c.get_train_size.remote()) for c in trainers] 54 | return np.average(accs, weights=weights) 55 | 56 | 57 | def test( 58 | model: torch.nn.Module, 59 | features: torch.Tensor, 60 | adj: torch.Tensor, 61 | test_labels: torch.Tensor, 62 | idx_test: torch.Tensor, 63 | ) -> tuple: 64 | """ 65 | This function tests the model and calculates the loss and accuracy 66 | 67 | Parameters 68 | ---------- 69 | model : torch.nn.Module 70 | Specific model passed 71 | features : torch.Tensor 72 | Tensor representing the input features 73 | adj : torch.Tensor 74 | Adjacency matrix 75 | labels : torch.Tensor 76 | Contains the ground truth labels for the data. 77 | idx_test : torch.Tensor 78 | Indices specifying the test data points 79 | 80 | Returns 81 | ------- 82 | loss_test.item() : float 83 | Loss of the model on the test data 84 | acc_test.item() : float 85 | Accuracy of the model on the test data 86 | 87 | """ 88 | model.eval() 89 | output = model(features, adj) 90 | loss_test = F.nll_loss(output[idx_test], test_labels) 91 | acc_test = accuracy(output[idx_test], test_labels) 92 | 93 | return loss_test.item(), acc_test.item() # , f1_test, auc_test 94 | 95 | 96 | def train( 97 | epoch: int, 98 | model: torch.nn.Module, 99 | optimizer: torch.optim.Optimizer, 100 | features: torch.Tensor, 101 | adj: torch.Tensor, 102 | train_labels: torch.Tensor, 103 | idx_train: torch.Tensor, 104 | ) -> tuple: # Centralized or new FL 105 | """ 106 | Trains the model and calculates the loss and accuracy of the model on the training data, 107 | performs backpropagation, and updates the model parameters. 108 | 109 | Parameters 110 | ---------- 111 | epoch : int 112 | Specifies the number of epoch on which the model is trained 113 | model : torch.nn.Module 114 | Specific model to be trained 115 | optimizer : optimizer 116 | Type of the optimizer used for updating the model parameters 117 | features : torch.FloatTensor 118 | Tensor representing the input features 119 | adj : torch_sparse.tensor.SparseTensor 120 | Adjacency matrix 121 | train_labels : torch.LongTensor 122 | Contains the ground truth labels for the data. 123 | idx_train : torch.LongTensor 124 | Indices specifying the test data points 125 | 126 | 127 | Returns 128 | ------- 129 | loss_train.item() : float 130 | Loss of the model on the training data 131 | acc_train.item() : float 132 | Accuracy of the model on the training data 133 | 134 | """ 135 | 136 | model.train() 137 | optimizer.zero_grad() 138 | 139 | output = model(features, adj) 140 | loss_train = F.nll_loss(output[idx_train], train_labels) 141 | acc_train = accuracy(output[idx_train], train_labels) 142 | loss_train.backward() 143 | optimizer.step() 144 | optimizer.zero_grad() 145 | 146 | return loss_train.item(), acc_train.item() 147 | -------------------------------------------------------------------------------- /fedgraph/utils_gc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | from collections import OrderedDict 4 | from typing import Any 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.nn.functional as F 10 | from sklearn.model_selection import train_test_split 11 | from torch_geometric.utils import degree, to_networkx 12 | 13 | from fedgraph.server_class import Server_GC 14 | from fedgraph.trainer_class import Trainer_GC 15 | 16 | 17 | def setup_trainers( 18 | splited_data: dict, base_model: Any, args: argparse.Namespace 19 | ) -> tuple: 20 | """ 21 | Setup trainers for graph classification. 22 | 23 | Parameters 24 | ---------- 25 | splited_data: dict 26 | The data for each trainer. 27 | base_model: Any 28 | The base model for the trainer. The base model shown in the example is GIN. 29 | args: argparse.ArgumentParser 30 | The input arguments. 31 | 32 | Returns 33 | ------- 34 | (trainers, idx_trainers): tuple(list, dict) 35 | trainers: List of trainers 36 | idx_trainers: Dictionary with the index of the trainer as the key and the dataset name as the value 37 | """ 38 | idx_trainers = {} 39 | trainers = [] 40 | for idx, dataset_trainer_name in enumerate(splited_data.keys()): 41 | idx_trainers[idx] = dataset_trainer_name 42 | """acquire data""" 43 | dataloaders, num_node_features, num_graph_labels, train_size = splited_data[ 44 | dataset_trainer_name 45 | ] 46 | 47 | """build GIN model""" 48 | cmodel_gc = base_model( 49 | nfeat=num_node_features, 50 | nhid=args.hidden, 51 | nclass=num_graph_labels, 52 | nlayer=args.nlayer, 53 | dropout=args.dropout, 54 | ) 55 | 56 | """build optimizer""" 57 | optimizer = torch.optim.Adam( 58 | params=filter(lambda p: p.requires_grad, cmodel_gc.parameters()), 59 | lr=args.lr, 60 | weight_decay=args.weight_decay, 61 | ) 62 | 63 | """build trainer""" 64 | trainer = Trainer_GC( 65 | model=cmodel_gc, # GIN model 66 | trainer_id=idx, # trainer id 67 | trainer_name=dataset_trainer_name, # trainer name 68 | train_size=train_size, # training size 69 | dataloader=dataloaders, # data loader 70 | optimizer=optimizer, # optimizer 71 | args=args, 72 | ) 73 | 74 | trainers.append(trainer) 75 | 76 | return trainers, idx_trainers 77 | 78 | 79 | def setup_server(base_model: Any, args: argparse.Namespace) -> Server_GC: 80 | """ 81 | Setup server. 82 | 83 | Parameters 84 | ---------- 85 | base_model: Any 86 | The base model for the server. The base model shown in the example is GIN_server. 87 | args: argparse.ArgumentParser 88 | The input arguments 89 | 90 | Returns 91 | ------- 92 | server: Server_GC 93 | The server object 94 | """ 95 | 96 | smodel = base_model(nlayer=args.nlayer, nhid=args.hidden) 97 | server = Server_GC(smodel, args.device, args.use_cluster) 98 | return server 99 | 100 | 101 | def get_max_degree(graphs: Any) -> int: 102 | """ 103 | Get the maximum degree of the graphs in the dataset. 104 | 105 | Parameters 106 | ---------- 107 | graphs: Any 108 | The object of graphs 109 | 110 | Returns 111 | ------- 112 | max_degree: int 113 | The maximum degree of the graphs in the dataset 114 | """ 115 | max_degree = 0 116 | for i, graph in enumerate(graphs): 117 | g = to_networkx(graph, to_undirected=True) 118 | g_degree = max(dict(g.degree).values()) 119 | max_degree = max(max_degree, g_degree) 120 | 121 | return max_degree 122 | 123 | 124 | def convert_to_node_attributes(graphs: Any) -> list: 125 | """ 126 | Use only the node attributes of the graphs. This function will treat the graphs as callable objects. 127 | 128 | Parameters 129 | ---------- 130 | graphs: Any 131 | The object of of graphs 132 | 133 | Returns 134 | ------- 135 | new_graphs: list 136 | List of graphs with only the node attributes 137 | """ 138 | num_node_attributes = graphs.num_node_attributes 139 | new_graphs = [] 140 | for _, graph in enumerate(graphs): 141 | new_graph = graph.clone() 142 | new_graph.__setitem__("x", graph.x[:, :num_node_attributes]) 143 | new_graphs.append(new_graph) 144 | return new_graphs 145 | 146 | 147 | def convert_to_node_degree_features(graphs: list) -> list: 148 | """ 149 | Convert the node attributes of the graphs to node degree features. 150 | 151 | Parameters 152 | ---------- 153 | graphs: list 154 | List of graphs 155 | 156 | Returns 157 | ------- 158 | new_graphs: list 159 | List of graphs with node degree features 160 | """ 161 | graph_infos = [] 162 | max_degree = 0 163 | for _, graph in enumerate(graphs): 164 | g = to_networkx(graph, to_undirected=True) 165 | g_degree = max(dict(g.degree).values()) 166 | max_degree = max(max_degree, g_degree) 167 | graph_infos.append( 168 | (graph, g.degree, graph.num_nodes) 169 | ) # (graph, node_degrees, num_nodes) 170 | 171 | new_graphs = [] 172 | for i, tuple in enumerate(graph_infos): 173 | idx, x = tuple[0].edge_index[0], tuple[0].x 174 | deg = degree(idx, tuple[2], dtype=torch.long) 175 | deg = F.one_hot(deg, num_classes=max_degree + 1).to(torch.float) 176 | 177 | new_graph = tuple[0].clone() 178 | new_graph.__setitem__("x", deg) 179 | new_graphs.append(new_graph) 180 | 181 | return new_graphs 182 | 183 | 184 | def split_data( 185 | graphs: list, 186 | train_size: float = 0.8, 187 | test_size: float = 0.2, 188 | shuffle: bool = True, 189 | seed: int = 42, 190 | ) -> tuple: 191 | """ 192 | Split the dataset into training and test sets. 193 | 194 | Parameters 195 | ---------- 196 | graphs: list 197 | List of graphs 198 | train_size: float 199 | The proportion (ranging from 0.0 to 1.0) of the dataset to include in the training set 200 | test_size: float 201 | The proportion (ranging from 0.0 to 1.0) of the dataset to include in the test set 202 | shuffle: bool 203 | Whether or not to shuffle the data before splitting 204 | seed: int 205 | Seed for the random number generator 206 | 207 | Returns 208 | ------- 209 | graphs_train: list 210 | List of training graphs 211 | graphs_test: list 212 | List of testing graphs 213 | 214 | Note 215 | ---- 216 | The function uses sklearn.model_selection.train_test_split to split the dataset into training and test sets. 217 | If the dataset needs to be split into training, validation, and test sets, the function should be called twice. 218 | """ 219 | y = torch.cat([graph.y for graph in graphs]) 220 | y_indices = np.unique(y, return_inverse=True)[1] 221 | class_counts = np.bincount(y_indices) 222 | if np.min(class_counts) < 2: 223 | stratify = None 224 | else: 225 | stratify = y 226 | graphs_train, graphs_test = train_test_split( 227 | graphs, 228 | train_size=train_size, 229 | test_size=test_size, 230 | stratify=stratify, 231 | shuffle=shuffle, 232 | random_state=seed, 233 | ) 234 | return graphs_train, graphs_test 235 | 236 | 237 | def get_num_graph_labels(dataset: list) -> int: 238 | """ 239 | Get the number of unique graph labels in the dataset. 240 | 241 | Parameters 242 | ---------- 243 | dataset: list 244 | List of graphs 245 | 246 | Returns 247 | ------- 248 | (labels.length): int 249 | Number of unique graph labels in the dataset 250 | """ 251 | s = set() 252 | for g in dataset: 253 | s.add(g.y.item()) 254 | return len(s) 255 | 256 | 257 | def get_avg_nodes_edges(graphs: list) -> tuple: 258 | """ 259 | Calculate the average number of nodes and edges in the dataset. 260 | 261 | Parameters 262 | ---------- 263 | graphs: list 264 | List of graphs 265 | 266 | Returns 267 | ------- 268 | avg_nodes: float 269 | The average number of nodes in the dataset 270 | avg_edges: float 271 | The average number of edges in the dataset 272 | """ 273 | num_nodes, num_edges = 0.0, 0.0 274 | num_graphs = len(graphs) 275 | for g in graphs: 276 | num_nodes += g.num_nodes 277 | num_edges += g.num_edges / 2.0 # undirected 278 | 279 | avg_nodes = num_nodes / num_graphs 280 | avg_edges = num_edges / num_graphs 281 | return avg_nodes, avg_edges 282 | 283 | 284 | def get_stats( 285 | df: pd.DataFrame, 286 | dataset: str, 287 | graphs_train: list = [], 288 | graphs_val: list = [], 289 | graphs_test: list = [], 290 | ) -> pd.DataFrame: 291 | """ 292 | Calculate and store the statistics of the dataset, including the number of graphs, average number of nodes and edges 293 | for the training, validation, and testing sets. 294 | 295 | Parameters 296 | ---------- 297 | df: pd.DataFrame 298 | An empty DataFrame to store the statistics of the dataset. 299 | dataset: str 300 | The name of the dataset. 301 | graphs_train: list 302 | List of training graphs. 303 | graphs_val: list 304 | List of validation graphs. 305 | graphs_test: list 306 | List of testing graphs. 307 | 308 | Returns 309 | ------- 310 | df: pd.DataFrame 311 | The filled statistics of the dataset. 312 | """ 313 | 314 | df.loc[dataset, "#graphs_train"] = len(graphs_train) 315 | avgNodes, avgEdges = get_avg_nodes_edges(graphs_train) 316 | df.loc[dataset, "avgNodes_train"] = avgNodes 317 | df.loc[dataset, "avgEdges_train"] = avgEdges 318 | 319 | if graphs_val: 320 | df.loc[dataset, "#graphs_val"] = len(graphs_val) 321 | avgNodes, avgEdges = get_avg_nodes_edges(graphs_val) 322 | df.loc[dataset, "avgNodes_val"] = avgNodes 323 | df.loc[dataset, "avgEdges_val"] = avgEdges 324 | 325 | if graphs_test: 326 | df.loc[dataset, "#graphs_test"] = len(graphs_test) 327 | avgNodes, avgEdges = get_avg_nodes_edges(graphs_test) 328 | df.loc[dataset, "avgNodes_test"] = avgNodes 329 | df.loc[dataset, "avgEdges_test"] = avgEdges 330 | 331 | return df 332 | 333 | 334 | def generate_context(poly_modulus_degree=8192, coeff_mod_bit_sizes=[60, 40, 40, 60]): 335 | context = ts.context( 336 | ts.SCHEME_TYPE.CKKS, 337 | poly_modulus_degree=poly_modulus_degree, 338 | coeff_mod_bit_sizes=coeff_mod_bit_sizes, 339 | ) 340 | context.global_scale = 2**40 341 | context.generate_galois_keys() 342 | return context 343 | 344 | 345 | def encryption_he(context, model_params, total_client_number): 346 | weight_factors = copy.deepcopy(model_params) 347 | for key in weight_factors.keys(): 348 | weight_factors[key] = torch.flatten( 349 | torch.full_like(weight_factors[key], 1 / total_client_number) 350 | ) 351 | 352 | enc_model_params = OrderedDict() 353 | for key in model_params.keys(): 354 | prepared_tensor = (torch.flatten(model_params[key])) * weight_factors[key] 355 | plain_tensor = ts.plain_tensor(prepared_tensor) 356 | enc_model_params[key] = ts.ckks_vector(context, plain_tensor).serialize() 357 | 358 | return enc_model_params 359 | 360 | 361 | def fedavg_he(context, list_enc_model_params): 362 | n_clients = len(list_enc_model_params) 363 | enc_global_params = copy.deepcopy(list_enc_model_params[0]) 364 | 365 | for key in enc_global_params.keys(): 366 | sum_vector = ts.ckks_vector_from(context, list_enc_model_params[0][key]) 367 | for i in range(1, n_clients): 368 | temp = ts.ckks_vector_from(context, list_enc_model_params[i][key]) 369 | sum_vector += temp 370 | enc_global_params[key] = sum_vector.serialize() 371 | 372 | return enc_global_params 373 | 374 | 375 | def decryption_he(context, template_model_params, enc_model_params): 376 | params_shape = OrderedDict() 377 | for key in template_model_params.keys(): 378 | params_shape[key] = template_model_params[key].size() 379 | 380 | params_tensor = OrderedDict() 381 | for key in enc_model_params.keys(): 382 | dec_vector = ts.ckks_vector_from(context, enc_model_params[key]) 383 | params_tensor[key] = torch.FloatTensor(dec_vector.decrypt()) 384 | params_tensor[key] = torch.reshape(params_tensor[key], tuple(params_shape[key])) 385 | 386 | return params_tensor 387 | -------------------------------------------------------------------------------- /fedgraph/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.4" 2 | -------------------------------------------------------------------------------- /generate_he_context.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import tenseal as ts 4 | 5 | 6 | def create_training_context(): 7 | scheme = ts.SCHEME_TYPE.CKKS 8 | 9 | # Keep the same settings that worked for features 10 | poly_modulus_degree = 8192 11 | coeff_mod_bit_sizes = [60, 40, 40, 60] 12 | 13 | context = ts.context( 14 | scheme=scheme, 15 | poly_modulus_degree=poly_modulus_degree, 16 | coeff_mod_bit_sizes=coeff_mod_bit_sizes, 17 | ) 18 | 19 | # Higher scale for better precision with small parameter values 20 | context.global_scale = 2**40 21 | context.generate_galois_keys() 22 | context.auto_relin = True 23 | context.auto_rescale = True 24 | 25 | return context 26 | 27 | 28 | if __name__ == "__main__": 29 | training_context = create_training_context() 30 | training_secret_context = training_context.serialize(save_secret_key=True) 31 | 32 | with open("fedgraph/he_context.pkl", "wb") as f: 33 | pickle.dump(training_secret_context, f) 34 | print("Saved HE context with secret key.") 35 | -------------------------------------------------------------------------------- /kuberay/config/prometheus/podMonitor.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: monitoring.coreos.com/v1 2 | kind: PodMonitor 3 | metadata: 4 | name: ray-workers-monitor 5 | namespace: prometheus-system 6 | labels: 7 | # `release: $HELM_RELEASE`: Prometheus can only detect PodMonitor with this label. 8 | release: prometheus 9 | spec: 10 | jobLabel: ray-workers 11 | # Only select Kubernetes Pods in the "default" namespace. 12 | namespaceSelector: 13 | matchNames: 14 | - default 15 | # Only select Kubernetes Pods with "matchLabels". 16 | selector: 17 | matchLabels: 18 | ray.io/node-type: worker 19 | # A list of endpoints allowed as part of this PodMonitor. 20 | podMetricsEndpoints: 21 | - port: metrics 22 | interval: 1s 23 | scrapeTimeout: 1s 24 | -------------------------------------------------------------------------------- /kuberay/config/prometheus/rules/prometheusRules.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: monitoring.coreos.com/v1 2 | kind: PrometheusRule 3 | metadata: 4 | name: ray-cluster-gcs-rules 5 | namespace: prometheus-system 6 | labels: 7 | # `release: $HELM_RELEASE`: Prometheus can only detect Rule with this label. 8 | release: prometheus 9 | spec: 10 | groups: 11 | - interval: 30s 12 | name: ray-cluster-main-staging-gcs.rules 13 | rules: 14 | - expr: |2 15 | ( 16 | 100 * ( 17 | sum( 18 | rate( 19 | ray_gcs_update_resource_usage_time_bucket{container="ray-head", le="20.0"}[30d] 20 | ) 21 | ) 22 | / 23 | sum( 24 | rate( 25 | ray_gcs_update_resource_usage_time_count{container="ray-head"}[30d] 26 | ) 27 | ) 28 | ) 29 | ) 30 | record: ray_gcs_availability_30d 31 | - alert: MissingMetricRayGlobalControlStore 32 | annotations: 33 | description: Ray GCS is not emitting any metrics for Resource Update requests 34 | summary: Ray GCS is not emitting metrics anymore 35 | expr: |2 36 | ( 37 | absent(ray_gcs_update_resource_usage_time_bucket) == 1 38 | ) 39 | for: 5m 40 | labels: 41 | severity: critical 42 | -------------------------------------------------------------------------------- /kuberay/config/prometheus/serviceMonitor.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: monitoring.coreos.com/v1 2 | kind: ServiceMonitor 3 | metadata: 4 | name: ray-head-monitor 5 | namespace: prometheus-system 6 | labels: 7 | # `release: $HELM_RELEASE`: Prometheus can only detect ServiceMonitor with this label. 8 | release: prometheus 9 | spec: 10 | jobLabel: ray-head 11 | # Only select Kubernetes Services in the "default" namespace. 12 | namespaceSelector: 13 | matchNames: 14 | - default 15 | # Only select Kubernetes Services with "matchLabels". 16 | selector: 17 | matchLabels: 18 | ray.io/node-type: head 19 | # A list of endpoints allowed as part of this ServiceMonitor. 20 | endpoints: 21 | - port: metrics 22 | interval: 1s # Set the scrape interval to 1 second 23 | scrapeTimeout: 1s # Set the scrape timeout to 1 second 24 | - port: as-metrics # autoscaler metrics 25 | interval: 1s 26 | scrapeTimeout: 1s 27 | - port: dash-metrics # dashboard metrics 28 | interval: 1s 29 | scrapeTimeout: 1s 30 | targetLabels: 31 | - ray.io/cluster 32 | -------------------------------------------------------------------------------- /kuberay/install/prometheus/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | set errexit 5 | 6 | helm repo add prometheus-community https://prometheus-community.github.io/helm-charts 7 | helm repo update 8 | 9 | # DIR is the absolute directory of this script (`install.sh`) 10 | DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" > /dev/null && pwd)" 11 | 12 | # Install the kube-prometheus-stack v48.2.1 helm chart with `overrides.yaml` file. 13 | # https://github.com/prometheus-community/helm-charts/tree/kube-prometheus-stack-48.2.1/charts/kube-prometheus-stack 14 | helm --namespace prometheus-system install prometheus prometheus-community/kube-prometheus-stack --create-namespace --version 48.2.1 -f ${DIR}/overrides.yaml 15 | 16 | # set the place of monitor files 17 | monitor_dir=${DIR}/../../config/prometheus 18 | 19 | # start to install monitor 20 | pushd ${monitor_dir} 21 | for file in `ls` 22 | do 23 | kubectl apply -f ${file} 24 | done 25 | popd 26 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.9 3 | ;TODO: decease python_version requirement 4 | platform = linux 5 | ;TODO: support multiple platform 6 | files = fedgraph 7 | 8 | show_column_numbers=True 9 | 10 | disallow_untyped_calls = False 11 | check_untyped_defs = False 12 | ignore_missing_imports=True 13 | disable_error_code=attr-defined,var-annotated,import-untyped 14 | 15 | [mypy-yaml.*] 16 | # ignore_missing_imports = True 17 | 18 | # be strict 19 | warn_return_any=True 20 | strict_optional=True 21 | warn_no_return=True 22 | warn_redundant_casts=True 23 | warn_unused_ignores=True 24 | 25 | # No incremental mode 26 | cache_dir=/dev/null 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | # third party tools 6 | [tool.isort] 7 | profile = "black" 8 | -------------------------------------------------------------------------------- /quickstart.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple FedGraph Example 3 | ======================= 4 | 5 | Run a simple example of FedGraph. 6 | 7 | (Time estimate: 3 minutes) 8 | """ 9 | 10 | ####################################################################### 11 | # Load libraries 12 | # -------------- 13 | 14 | import os 15 | 16 | import attridict 17 | 18 | from fedgraph.federated_methods import run_fedgraph 19 | 20 | ####################################################################### 21 | # Specify the Node Classification configuration 22 | # --------------------------------------------- 23 | config = { 24 | # Task, Method, and Dataset Settings 25 | "fedgraph_task": "NC", 26 | "dataset": "cora", 27 | "method": "FedGCN", # Federated learning method, e.g., "FedGCN" 28 | "iid_beta": 10000, # Dirichlet distribution parameter for label distribution among clients 29 | "distribution_type": "average", # Distribution type among clients 30 | # Training Configuration 31 | "global_rounds": 100, 32 | "local_step": 3, 33 | "learning_rate": 0.5, 34 | "n_trainer": 2, 35 | "batch_size": -1, # -1 indicates full batch training 36 | # Model Structure 37 | "num_layers": 2, 38 | "num_hops": 1, # Number of n-hop neighbors for client communication 39 | # Resource and Hardware Settings 40 | "gpu": False, 41 | "num_cpus_per_trainer": 1, 42 | "num_gpus_per_trainer": 0, 43 | "ray_address": "auto", # Connect to existing Ray cluster 44 | # Logging and Output Configuration 45 | "logdir": "./runs", 46 | # Security and Privacy 47 | "use_encryption": True, # Whether to use Homomorphic Encryption for secure aggregation 48 | # Dataset Handling Options 49 | "use_huggingface": False, # Load dataset directly from Hugging Face Hub 50 | "saveto_huggingface": False, # Save partitioned dataset to Hugging Face Hub 51 | # Scalability and Cluster Configuration 52 | "use_cluster": True, # Use Kubernetes for scalability if True 53 | } 54 | 55 | ####################################################################### 56 | # Run fedgraph method 57 | # ------------------- 58 | 59 | config = attridict(config) 60 | run_fedgraph(config) 61 | 62 | # ####################################################################### 63 | # # Specify the Graph Classification configuration 64 | # # ---------------------------------------------- 65 | # config = { 66 | # "fedgraph_task": "GC", 67 | # # General configuration 68 | # # algorithm options: "SelfTrain", "FedAvg", "FedProx", "GCFL", "GCFL+", "GCFL+dWs" 69 | # "algorithm": "GCFL+dWs", 70 | # # Dataset configuration 71 | # "dataset": "MUTAG", 72 | # "is_multiple_dataset": False, 73 | # "datapath": "./data", 74 | # "convert_x": False, 75 | # "overlap": False, 76 | # # Setup configuration 77 | # "device": "cpu", 78 | # "seed": 10, 79 | # "seed_split_data": 42, 80 | # # Model parameters 81 | # "num_trainers": 2, 82 | # "num_rounds": 200, # Used by "FedAvg" and "GCFL" (not used in "SelfTrain") 83 | # "local_epoch": 1, # Used by "FedAvg" and "GCFL" 84 | # # Specific for "SelfTrain" (used instead of "num_rounds" and "local_epoch") 85 | # "local_epoch_selftrain": 200, 86 | # "lr": 0.001, 87 | # "weight_decay": 0.0005, 88 | # "nlayer": 3, # Number of model layers 89 | # "hidden": 64, # Hidden layer dimension 90 | # "dropout": 0.5, # Dropout rate 91 | # "batch_size": 128, 92 | # "gpu": False, 93 | # "num_cpus_per_trainer": 1, 94 | # "num_gpus_per_trainer": 0, 95 | # # FedProx specific parameter 96 | # "mu": 0.01, # Regularization parameter, only used in "FedProx" 97 | # # GCFL specific parameters 98 | # "standardize": False, # Used only in "GCFL", "GCFL+", "GCFL+dWs" 99 | # "seq_length": 5, # Sequence length, only used in "GCFL", "GCFL+", "GCFL+dWs" 100 | # "epsilon1": 0.05, # Privacy epsilon1, specific to "GCFL", "GCFL+", "GCFL+dWs" 101 | # "epsilon2": 0.1, # Privacy epsilon2, specific to "GCFL", "GCFL+", "GCFL+dWs" 102 | # # Output configuration 103 | # "outbase": "./outputs", 104 | # "save_files": False, 105 | # # Scalability and Cluster Configuration 106 | # "use_cluster": False, # Use Kubernetes for scalability if True 107 | # } 108 | # ####################################################################### 109 | # # Run fedgraph method 110 | # # ------------------- 111 | 112 | # config = attridict(config) 113 | # run_fedgraph(config) 114 | # ####################################################################### 115 | # # Specify the Link Prediction configuration 116 | # # ---------------------------------------------- 117 | # BASE_DIR = os.path.dirname(os.path.abspath(".")) 118 | # DATASET_PATH = os.path.join( 119 | # BASE_DIR, "data", "LPDataset" 120 | # ) # Could be modified based on the user needs 121 | # config = { 122 | # "fedgraph_task": "LP", 123 | # # method = ["STFL", "StaticGNN", "4D-FED-GNN+", "FedLink"] 124 | # "method": "STFL", 125 | # # Dataset configuration 126 | # # country_codes = ['US', 'BR', 'ID', 'TR', 'JP'] 127 | # "country_codes": ["ID", "TR"], 128 | # "dataset_path": DATASET_PATH, 129 | # # Setup configuration 130 | # "device": "cpu", 131 | # "use_buffer": False, 132 | # "buffer_size": 300000, 133 | # "online_learning": False, 134 | # "seed": 10, 135 | # # Model parameters 136 | # "global_rounds": 8, 137 | # "local_steps": 3, 138 | # "hidden_channels": 64, 139 | # # Output configuration 140 | # "record_results": False, 141 | # # System configuration 142 | # "gpu": False, 143 | # "num_cpus_per_trainer": 1, 144 | # "num_gpus_per_trainer": 0, 145 | # "use_cluster": False, # whether use kubernetes for scalability or not 146 | # "distribution_type": "average", # the node number distribution among clients 147 | # "batch_size": -1, # -1 is full batch 148 | # } 149 | # ####################################################################### 150 | # # Run fedgraph method 151 | # # ------------------- 152 | 153 | # config = attridict(config) 154 | # run_fedgraph(config) 155 | -------------------------------------------------------------------------------- /ray_cluster_configs/eks_cluster_config.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: eksctl.io/v1alpha5 2 | kind: ClusterConfig 3 | 4 | metadata: 5 | name: mlarge-1739510276 6 | region: us-east-1 7 | 8 | nodeGroups: 9 | - name: head-nodes 10 | instanceType: m5.24xlarge 11 | desiredCapacity: 1 12 | minSize: 0 13 | maxSize: 1 14 | volumeSize: 256 15 | labels: 16 | ray-node-type: head 17 | 18 | - name: worker-nodes 19 | instanceType: m5.16xlarge 20 | desiredCapacity: 4 21 | minSize: 4 22 | maxSize: 4 23 | volumeSize: 1024 24 | amiFamily: Bottlerocket 25 | labels: 26 | ray-node-type: worker 27 | -------------------------------------------------------------------------------- /ray_cluster_configs/ray_kubernetes_cluster.yaml: -------------------------------------------------------------------------------- 1 | # For most use-cases, it makes sense to schedule one Ray pod per Kubernetes node. 2 | 3 | # Optimal resource allocation will depend on your Kubernetes infrastructure and might 4 | # require some experimentation. 5 | apiVersion: ray.io/v1alpha1 6 | kind: RayCluster 7 | metadata: 8 | labels: 9 | controller-tools.k8s.io: "1.0" 10 | # An unique identifier for the head node and workers of this cluster. 11 | name: raycluster-autoscaler 12 | namespace: default 13 | spec: 14 | rayVersion: "1.13.0" 15 | enableInTreeAutoscaling: True 16 | ######################headGroupSpecs################################# 17 | # head group template and specs, (perhaps 'group' is not needed in the name) 18 | headGroupSpec: 19 | # Kubernetes Service Type, valid values are 'ClusterIP', 'NodePort' and 'LoadBalancer' 20 | serviceType: ClusterIP 21 | # for the head group, replicas should always be 1. 22 | # headGroupSpec.replicas is deprecated in KubeRay >= 0.3.0. 23 | # logical group name, for this called head-group, also can be functional 24 | # pod type head or worker 25 | # rayNodeType: head # Not needed since it is under the headgroup 26 | # the following params are used to complete the ray start: ray start --head --block --redis-port=6379 ... 27 | rayStartParams: 28 | port: "6379" 29 | dashboard-host: "0.0.0.0" 30 | block: "true" 31 | 32 | #pod template 33 | template: 34 | metadata: 35 | labels: 36 | # custom labels. NOTE: do not define custom labels start with `raycluster.`, they may be used in controller. 37 | # Refer to https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/ 38 | rayCluster: raycluster-sample # will be injected if missing 39 | rayNodeType: head # will be injected if missing, must be head or wroker 40 | groupName: headgroup # will be injected if missing 41 | # annotations for pod 42 | annotations: 43 | key: value 44 | spec: 45 | containers: 46 | - name: ray-head 47 | image: public.ecr.aws/i7t1s5i1/fedgraph:new 48 | imagePullPolicy: Always 49 | # Optimal resource allocation will depend on your Kubernetes infrastructure and might 50 | # require some experimentation. 51 | # Setting requests=limits is recommended with Ray. K8s limits are used for Ray-internal 52 | # resource accounting. K8s requests are not used by Ray. 53 | resources: 54 | limits: 55 | cpu: "2" 56 | memory: "220Gi" 57 | # nvidia.com/gpu: "1" 58 | 59 | requests: 60 | cpu: "2" 61 | memory: "220Gi" 62 | # nvidia.com/gpu: "1" 63 | env: 64 | - name: CPU_REQUEST 65 | valueFrom: 66 | resourceFieldRef: 67 | containerName: ray-head 68 | resource: requests.cpu 69 | - name: CPU_LIMITS 70 | valueFrom: 71 | resourceFieldRef: 72 | containerName: ray-head 73 | resource: limits.cpu 74 | - name: MEMORY_LIMITS 75 | valueFrom: 76 | resourceFieldRef: 77 | containerName: ray-head 78 | resource: limits.memory 79 | - name: MEMORY_REQUESTS 80 | valueFrom: 81 | resourceFieldRef: 82 | containerName: ray-head 83 | resource: requests.memory 84 | - name: MY_POD_IP 85 | valueFrom: 86 | fieldRef: 87 | fieldPath: status.podIP 88 | - name: RAY_GRAFANA_IFRAME_HOST 89 | value: http://127.0.0.1:3000 90 | - name: RAY_GRAFANA_HOST 91 | value: http://prometheus-grafana.prometheus-system.svc:80 92 | - name: RAY_PROMETHEUS_HOST 93 | value: http://prometheus-kube-prometheus-prometheus.prometheus-system.svc:9090 94 | ports: 95 | - containerPort: 6379 96 | name: gcs 97 | - containerPort: 8265 98 | name: dashboard 99 | - containerPort: 10001 100 | name: client 101 | - containerPort: 8080 102 | name: metrics 103 | - containerPort: 8000 104 | name: serve 105 | - containerPort: 44217 106 | name: as-metrics # autoscaler 107 | - containerPort: 44227 108 | name: dash-metrics # dashboard 109 | lifecycle: 110 | preStop: 111 | exec: 112 | command: ["/bin/sh", "-c", "ray stop"] 113 | workerGroupSpecs: 114 | # the pod replicas in this group typed worker 115 | - replicas: 4 116 | minReplicas: 4 117 | maxReplicas: 4 118 | # logical group name, for this called large-group, also can be functional 119 | groupName: large-group 120 | # if worker pods need to be added, we can simply increment the replicas 121 | # if worker pods need to be removed, we decrement the replicas, and populate the podsToDelete list 122 | # the operator will remove pods from the list until the number of replicas is satisfied 123 | # when a pod is confirmed to be deleted, its name will be removed from the list below 124 | #scaleStrategy: 125 | # workersToDelete: 126 | # - raycluster-complete-worker-small-group-bdtwh 127 | # - raycluster-complete-worker-small-group-hv457 128 | # - raycluster-complete-worker-small-group-k8tj7 129 | # the following params are used to complete the ray start: ray start --block --node-ip-address= ... 130 | rayStartParams: 131 | block: "true" 132 | #pod template 133 | template: 134 | metadata: 135 | labels: 136 | rayCluster: raycluster-autoscaler # will be injected if missing 137 | rayNodeType: worker # will be injected if missing 138 | groupName: large-group # will be injected if missing 139 | # annotations for pod 140 | annotations: 141 | key: value 142 | spec: 143 | containers: 144 | - name: machine-learning # must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc' 145 | image: public.ecr.aws/i7t1s5i1/fedgraph:new 146 | imagePullPolicy: Always 147 | # Setting requests=limits is recommended with Ray. K8s limits are used for Ray-internal 148 | # resource accounting. K8s requests are not used by Ray. 149 | resources: 150 | limits: 151 | cpu: "60" 152 | memory: "200Gi" 153 | # nvidia.com/gpu: "1" 154 | requests: 155 | cpu: "60" 156 | memory: "200Gi" 157 | # nvidia.com/gpu: "1" 158 | # environment variables to set in the container.Optional. 159 | # Refer to https://kubernetes.io/docs/tasks/inject-data-application/define-environment-variable-container/ 160 | env: 161 | - name: RAY_DISABLE_DOCKER_CPU_WARNING 162 | value: "1" 163 | - name: TYPE 164 | value: "worker" 165 | - name: CPU_REQUEST 166 | valueFrom: 167 | resourceFieldRef: 168 | containerName: machine-learning 169 | resource: requests.cpu 170 | - name: CPU_LIMITS 171 | valueFrom: 172 | resourceFieldRef: 173 | containerName: machine-learning 174 | resource: limits.cpu 175 | - name: MEMORY_LIMITS 176 | valueFrom: 177 | resourceFieldRef: 178 | containerName: machine-learning 179 | resource: limits.memory 180 | - name: MEMORY_REQUESTS 181 | valueFrom: 182 | resourceFieldRef: 183 | containerName: machine-learning 184 | resource: requests.memory 185 | - name: MY_POD_NAME 186 | valueFrom: 187 | fieldRef: 188 | fieldPath: metadata.name 189 | - name: MY_POD_IP 190 | valueFrom: 191 | fieldRef: 192 | fieldPath: status.podIP 193 | ports: 194 | - containerPort: 80 195 | lifecycle: 196 | preStop: 197 | exec: 198 | command: ["/bin/sh", "-c", "ray stop"] 199 | # use volumeMounts.Optional. 200 | # Refer to https://kubernetes.io/docs/concepts/storage/volumes/ 201 | volumeMounts: 202 | - mountPath: /var/log 203 | name: log-volume 204 | initContainers: 205 | # the env var $RAY_IP is set by the operator if missing, with the value of the head service name 206 | - name: init-myservice 207 | image: busybox:1.28 208 | # Change the cluster postfix if you don't have a default setting 209 | command: 210 | [ 211 | "sh", 212 | "-c", 213 | "until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done", 214 | ] 215 | # use volumes 216 | # Refer to https://kubernetes.io/docs/concepts/storage/volumes/ 217 | volumes: 218 | - name: log-volume 219 | emptyDir: {} 220 | -------------------------------------------------------------------------------- /ray_cluster_configs/ray_kubernetes_ingress.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: networking.k8s.io/v1 2 | kind: Ingress 3 | metadata: 4 | name: ray-cluster-ingress 5 | namespace: default 6 | annotations: 7 | kubernetes.io/ingress.class: "nginx" 8 | nginx.ingress.kubernetes.io/rewrite-target: /$2 9 | spec: 10 | rules: 11 | - http: 12 | paths: 13 | - path: /dashboard(/|$)(.*) 14 | pathType: Prefix 15 | backend: 16 | service: 17 | name: raycluster-autoscaler-head-svc 18 | port: 19 | number: 8265 20 | - path: /serve(/|$)(.*) 21 | pathType: Prefix 22 | backend: 23 | service: 24 | name: raycluster-autoscaler-head-svc 25 | port: 26 | number: 8000 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FedGraph/fedgraph/96ed93d1a1f764e5f3eb196b53716e791cf003ed/setup.cfg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | __version__ = "" 4 | 5 | with open("fedgraph/version.py", "r") as f: 6 | exec(f.read(), globals()) 7 | 8 | 9 | with open("README.md", "r") as f: 10 | README = f.read() 11 | 12 | setup( 13 | name="fedgraph", 14 | version=__version__, 15 | packages=["fedgraph"], 16 | author="Yuhang Yao", 17 | author_email="yuhangya@andrew.cmu.edu", 18 | description="Federated Graph Learning", 19 | long_description=README, 20 | long_description_content_type="text/markdown", 21 | url="https://github.com/FedGraph/fedgraph", 22 | classifiers=[ 23 | "Programming Language :: Python :: 3", 24 | # "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | ], 27 | keywords=["Graph Neural Networks", "Federated Learning"], 28 | python_requires=">=3.9", 29 | install_requires=[ 30 | "torch>=2.0.0", 31 | "torch-scatter>=2.0.9", 32 | "torch-sparse>=0.6.15", 33 | "torch-cluster>=1.6.0", 34 | "torch-spline-conv>=1.2.1", 35 | "torch-geometric>=2.1.0.post1", 36 | "omegaconf>=2.3.0", 37 | "ray[default]>=2.6.3", 38 | "PyYAML>=5.4.0", 39 | "attridict", 40 | "torchmetrics", 41 | "setuptools", 42 | "sphinx_rtd_theme", 43 | "sphinxcontrib-bibtex", 44 | "matplotlib", 45 | "sphinx_gallery", 46 | "tensorboard", 47 | "dtaidistance", 48 | "gdown", 49 | "pandas", 50 | "twine==5.0.0", 51 | "scikit-learn", 52 | "tenseal", 53 | "huggingface_hub", 54 | "ogb", 55 | ], 56 | extras_require={"dev": ["build", "mypy", "pre-commit", "pytest"]}, 57 | include_package_data=True, 58 | package_data={ 59 | "fedgraph": ["he_context.pkl"], 60 | }, 61 | ) 62 | -------------------------------------------------------------------------------- /setup_cluster.md: -------------------------------------------------------------------------------- 1 | # Instructions for Setting Up a Ray Cluster on AWS EKS 2 | 3 | ## Step-by-Step Guide to Push customized Docker ECR image 4 | 5 | Configure AWS: 6 | 7 | ```bash 8 | aws configure 9 | ``` 10 | 11 | Login to ECR (Only for pushing public image, FedGraph already provided public docker image that includes all of the environmental dependencies) 12 | 13 | ```bash 14 | aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws 15 | ``` 16 | 17 | Build Docker with amd64 architecture on the cloud and push to ECR 18 | 19 | ```bash 20 | # You can modify the cloud builder using the CLI, with the docker buildx create command. 21 | docker buildx create --driver cloud ryanli3/fedgraph 22 | # Set your new cloud builder as default on your local machine. 23 | docker buildx use cloud-ryanli3-fedgraph --global 24 | # Build and push image to ECR 25 | docker buildx build --platform linux/amd64 -t public.ecr.aws/i7t1s5i1/fedgraph:img . --push 26 | ``` 27 | 28 | ## Step-by-Step Guide to Set Up the Ray Cluster 29 | 30 | Create an EKS Cluster with eksctl: 31 | 32 | ```bash 33 | eksctl create cluster -f eks_cluster_config.yaml --timeout=60m 34 | ``` 35 | 36 | After waiting the cluster setup, update kubeconfig for AWS EKS to config the cluster using kubectl: 37 | 38 | ```bash 39 | # --region and --name can config in the eks_cluster_config.yaml 40 | # metadata: 41 | # name: user 42 | # region: us-west-2 43 | aws eks --region us-west-2 update-kubeconfig --name mlarge 44 | 45 | ``` 46 | Optional: Check or switch current cluster only if we have multiple clusters running at the same time: 47 | 48 | ```bash 49 | 50 | kubectl config current-context 51 | kubectl config use-context arn:aws:eks:us-west-2:312849146674:cluster/large 52 | 53 | 54 | ``` 55 | Clone the KubeRay Repository, Install Prometheus and Grafana Server 56 | 57 | ```bash 58 | git clone https://github.com/ray-project/kuberay.git 59 | cd kuberay 60 | ./install/prometheus/install.sh 61 | ``` 62 | 63 | Add the KubeRay Helm Repository, Install KubeRay Operator: 64 | 65 | ```bash 66 | helm repo add kuberay https://ray-project.github.io/kuberay-helm/ 67 | helm repo update 68 | helm install kuberay-operator kuberay/kuberay-operator --version 1.1.1 69 | ``` 70 | 71 | Navigate to the Example Configurations Directory: 72 | 73 | ```bash 74 | cd docs/examples/configs 75 | ``` 76 | 77 | Apply Ray Kubernetes Cluster and Ingress Configurations: 78 | 79 | ```bash 80 | kubectl apply -f ray_kubernetes_cluster.yaml 81 | kubectl apply -f ray_kubernetes_ingress.yaml 82 | ``` 83 | Check every pod is running correctly: 84 | ```bash 85 | kubectl get pods 86 | # NAME READY STATUS RESTARTS AGE 87 | # kuberay-operator-7d7998bcdb-bzpkj 1/1 Running 0 35m 88 | # raycluster-autoscaler-head-47mzs 2/2 Running 0 35m 89 | # raycluster-autoscaler-worker-large-group-grw8w 1/1 Running 0 35m 90 | ``` 91 | 92 | If a pod status is Pending, it means the ray_kubernetes_cluster.yaml requests too many resources than the cluster can provide, delete the ray_kubernetes_cluster, modify the config and restart the kubernetes 93 | ```bash 94 | kubectl delete -f ray_kubernetes_cluster.yaml 95 | kubectl apply -f ray_kubernetes_cluster.yaml 96 | ``` 97 | 98 | Forward Port for Ray Dashboard, Prometheus, and Grafana 99 | 100 | ```bash 101 | kubectl port-forward service/raycluster-autoscaler-head-svc 8265:8265 102 | # raycluster-autoscaler-head-xxx is the pod name 103 | kubectl port-forward raycluster-autoscaler-head-47mzs 8080:8080 104 | kubectl port-forward prometheus-prometheus-kube-prometheus-prometheus-0 -n prometheus-system 9090:9090 105 | kubectl port-forward deployment/prometheus-grafana -n prometheus-system 3000:3000 106 | ``` 107 | 108 | Final Check 109 | 110 | ```bash 111 | kubectl get pods --all-namespaces -o wide 112 | ``` 113 | 114 | Submit a Ray Job: 115 | 116 | ```bash 117 | cd fedgraph 118 | ray job submit --runtime-env-json '{ 119 | "working_dir": "./", 120 | "excludes": [".git"] 121 | }' --address http://localhost:8265 -- python3 run.py 122 | 123 | 124 | ``` 125 | 126 | Stop a Ray Job: 127 | 128 | ```bash 129 | # raysubmit_xxx is the job name that can be found via 130 | ray job stop raysubmit_m5PN9xqV6drJQ8k2 --address http://localhost:8265 131 | ``` 132 | 133 | ## How to Delete the Ray Cluster 134 | 135 | Delete the RayCluster Custom Resource: 136 | 137 | ```bash 138 | cd docs/examples/configs 139 | kubectl delete -f ray_kubernetes_cluster.yaml 140 | kubectl delete -f ray_kubernetes_ingress.yaml 141 | ``` 142 | 143 | Confirm that the RayCluster Pods are Terminated: 144 | 145 | ```bash 146 | kubectl get pods 147 | # Ensure the output shows no Ray pods except kuberay-operator 148 | ``` 149 | 150 | Finally, Delete the node first and then delete EKS Cluster: 151 | 152 | ```bash 153 | kubectl get nodes -o name | xargs kubectl delete 154 | eksctl delete cluster --region us-west-2 --name user 155 | ``` 156 | 157 | ## Step to Push Data to Hugging Face Hub CLI 158 | 159 | Use the following command to login to the Hugging Face Hub CLI tool when you set "save: True" in node classification tasks if you haven't done so already: 160 | 161 | ```bash 162 | huggingface-cli login 163 | ``` 164 | -------------------------------------------------------------------------------- /setup_cluster.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ======================================= 4 | # Script to Set Up a Ray Cluster on AWS EKS 5 | # ======================================= 6 | 7 | # Function to check command success 8 | check_command() { 9 | if [ $? -ne 0 ]; then 10 | echo "Error: $1 failed. Exiting." 11 | exit 1 12 | fi 13 | } 14 | 15 | # Step 1: Configure AWS credentials 16 | echo "Configuring AWS credentials..." 17 | read -p "Enter AWS Access Key ID: " aws_access_key 18 | read -p "Enter AWS Secret Access Key: " aws_secret_key 19 | read -p "Enter AWS Default Region (e.g., us-east-1): " aws_region 20 | 21 | aws configure set aws_access_key_id $aws_access_key 22 | check_command "AWS Access Key configuration" 23 | aws configure set aws_secret_access_key $aws_secret_key 24 | check_command "AWS Secret Key configuration" 25 | aws configure set region $aws_region 26 | check_command "AWS Region configuration" 27 | 28 | # Step 2: Login to AWS ECR Public 29 | # Note: You do NOT need to rebuild and push the Docker image every time. 30 | # Only rebuild if you have added new dependencies or made changes to the Dockerfile. 31 | 32 | # echo "Logging in to AWS ECR Public..." 33 | # aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws 34 | # check_command "AWS ECR login" 35 | 36 | # # Step 3: Build and push Docker image to ECR 37 | # echo "Building and pushing Docker image to ECR..." 38 | 39 | # # Define the builder name 40 | # BUILDER_NAME="fedgraph-builder" 41 | 42 | # # Check if the builder already exists 43 | # if docker buildx ls | grep -q $BUILDER_NAME; then 44 | # echo "Builder $BUILDER_NAME already exists. Using the existing builder." 45 | # docker buildx use $BUILDER_NAME --global 46 | # else 47 | # echo "Creating a new builder: $BUILDER_NAME" 48 | # docker buildx create --driver docker-container --name $BUILDER_NAME 49 | # check_command "Docker buildx create" 50 | # docker buildx use $BUILDER_NAME --global 51 | # check_command "Docker buildx use" 52 | # fi 53 | 54 | # # Build and push the Docker image 55 | # docker buildx build --platform linux/amd64 -t public.ecr.aws/i7t1s5i1/fedgraph:img . --push 56 | # check_command "Docker build and push" 57 | 58 | # Step 4: Check if EKS Cluster exists 59 | CLUSTER_NAME="mlarge-1739510276" # You can keep a fixed name or change it dynamically 60 | echo "Checking if the EKS cluster '$CLUSTER_NAME' exists..." 61 | 62 | eksctl get cluster --name $CLUSTER_NAME --region $aws_region > /dev/null 2>&1 63 | if [ $? -eq 0 ]; then 64 | echo "Cluster '$CLUSTER_NAME' already exists. Skipping cluster creation." 65 | else 66 | echo "Cluster '$CLUSTER_NAME' does not exist. Creating EKS cluster..." 67 | 68 | if [ ! -f "ray_cluster_configs/eks_cluster_config.yaml" ]; then 69 | echo "Error: eks_cluster_config.yaml not found in the ray_cluster_configs folder." 70 | exit 1 71 | fi 72 | 73 | # Modify the configuration file to include the dynamic cluster name 74 | sed -i.bak "s/^ name: .*/ name: $CLUSTER_NAME/" ray_cluster_configs/eks_cluster_config.yaml 75 | 76 | # Create the cluster using the modified configuration file 77 | eksctl create cluster -f ray_cluster_configs/eks_cluster_config.yaml --timeout=60m 78 | check_command "EKS cluster creation" 79 | fi 80 | 81 | # Step 5: Update kubeconfig for AWS EKS 82 | echo "Updating kubeconfig for AWS EKS..." 83 | aws eks --region $aws_region update-kubeconfig --name $CLUSTER_NAME 84 | check_command "Kubeconfig update" 85 | 86 | # Step 6: Clone KubeRay Repository and Install Prometheus/Grafana 87 | echo "Cloning KubeRay repository and installing Prometheus and Grafana..." 88 | if [ ! -d "kuberay" ]; then 89 | git clone https://github.com/ray-project/kuberay.git 90 | fi 91 | cd kuberay 92 | ./install/prometheus/install.sh 93 | check_command "Prometheus and Grafana installation" 94 | 95 | # Step 7: Install KubeRay Operator via Helm 96 | echo "Installing KubeRay Operator..." 97 | helm repo add kuberay https://ray-project.github.io/kuberay-helm/ 98 | helm repo update 99 | helm install kuberay-operator kuberay/kuberay-operator --version 1.1.1 100 | check_command "KubeRay Operator installation" 101 | 102 | # Step 8: Deploy Ray Kubernetes Cluster and Ingress 103 | echo "Deploying Ray Kubernetes Cluster and Ingress..."Forwarding ports for Ray Dashboard, Prometheus, and Grafana 104 | # Ensure the script starts from the root directory of the project 105 | cd "$(dirname "$0")/.." 106 | # Apply the Ray Kubernetes cluster and ingress YAML files from the correct path 107 | kubectl apply -f ray_cluster_configs/ray_kubernetes_cluster.yaml 108 | check_command "Ray Kubernetes Cluster deployment" 109 | kubectl apply -f ray_cluster_configs/ray_kubernetes_ingress.yaml 110 | check_command "Ray Kubernetes Ingress deployment" 111 | 112 | # Step 9: Verify Pod Status 113 | echo "Checking pod status..." 114 | kubectl get pods 115 | echo "If any pod status is Pending, modify ray_kubernetes_cluster.yaml and reapply." 116 | 117 | # Step 10: Handle Pending Pod Issues (Optional) 118 | echo "To handle Pending pods, delete the cluster and reapply:" 119 | echo "kubectl delete -f ray_cluster_configs/ray_kubernetes_cluster.yaml" 120 | echo "kubectl apply -f ray_cluster_configs/ray_kubernetes_cluster.yaml" 121 | 122 | # Step 11: Forward Ports for Ray Dashboard, Prometheus, and Grafana 123 | # Note: You must open separate terminal windows for each port forwarding command below. 124 | # Do NOT run them all in one terminal with background (&) processes, as that may cause issues. 125 | echo "Open a new terminal and run the following commands one by one in separate terminals:" 126 | echo "kubectl port-forward service/raycluster-autoscaler-head-svc 8265:8265" 127 | # To get , run `kubectl get pods` 128 | echo "kubectl port-forward 8080:8080" 129 | echo "kubectl port-forward prometheus-prometheus-kube-prometheus-prometheus-0 -n prometheus-system 9090:9090" 130 | # To get the default username and password for Grafana,check https://docs.ray.io/en/latest/cluster/kubernetes/k8s-ecosystem/prometheus-grafana.html 131 | echo "kubectl port-forward deployment/prometheus-grafana -n prometheus-system 3000:3000" 132 | 133 | # Step 12: Final Check 134 | echo "Final check for all pods across namespaces:" 135 | kubectl get pods --all-namespaces -o wide 136 | 137 | # Step 13: Submit a Ray Job 138 | echo "To submit a Ray job, run:" 139 | echo "cd fedgraph" 140 | echo "ray job submit \ 141 | --address http://localhost:8265 \ 142 | --runtime-env-json '{ 143 | "working_dir": ".", 144 | "excludes": [".git", "__pycache__", "outputs", "fedgraph/he_training_context.pkl"], 145 | "pip": ["fsspec", "huggingface_hub", "tenseal"] 146 | }' \ 147 | -- python benchmark/benchmark_GC.py" 148 | 149 | # Step 14: Stop a Ray Job (Optional) 150 | echo "To stop a Ray job, use:" 151 | echo "ray job stop --address http://localhost:8265" 152 | 153 | # Step 15: Clean Up Resources 154 | echo "To clean up resources, delete the RayCluster Custom Resource and EKS cluster:" 155 | echo "kubectl delete -f ray_cluster_configs/ray_kubernetes_cluster.yaml" 156 | echo "kubectl delete -f ray_cluster_configs/ray_kubernetes_ingress.yaml" 157 | echo "kubectl get nodes -o name | xargs kubectl delete" 158 | echo "eksctl delete cluster --region $aws_region --name $CLUSTER_NAME" 159 | # eksctl delete cluster --region us-east-1 --name mlarge-1739510276 160 | 161 | echo "Setup completed successfully!" 162 | -------------------------------------------------------------------------------- /tutorials/FGL_GC.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Graph Classification Example 3 | ====================================== 4 | 5 | Federated Graph Classification with GCFL+dWs on the MUTAG dataset. 6 | 7 | (Time estimate: 3 minutes) 8 | """ 9 | 10 | ####################################################################### 11 | # Load libraries 12 | # -------------- 13 | 14 | import attridict 15 | 16 | from fedgraph.federated_methods import run_fedgraph 17 | 18 | ####################################################################### 19 | # Specify the Graph Classification configuration 20 | # ---------------------------------------------- 21 | config = { 22 | "fedgraph_task": "GC", 23 | # General configuration 24 | # algorithm options: "SelfTrain", "FedAvg", "FedProx", "GCFL", "GCFL+", "GCFL+dWs" 25 | "algorithm": "GCFL+dWs", 26 | # Dataset configuration 27 | "dataset": "MUTAG", 28 | "is_multiple_dataset": False, 29 | "datapath": "./data", 30 | "convert_x": False, 31 | "overlap": False, 32 | # Setup configuration 33 | "device": "cpu", 34 | "seed": 10, 35 | "seed_split_data": 42, 36 | # Model parameters 37 | "num_trainers": 2, 38 | "num_rounds": 200, # Used by "FedAvg" and "GCFL" (not used in "SelfTrain") 39 | "local_epoch": 1, # Used by "FedAvg" and "GCFL" 40 | # Specific for "SelfTrain" (used instead of "num_rounds" and "local_epoch") 41 | "local_epoch_selftrain": 200, 42 | "lr": 0.001, 43 | "weight_decay": 0.0005, 44 | "nlayer": 3, # Number of model layers 45 | "hidden": 64, # Hidden layer dimension 46 | "dropout": 0.5, # Dropout rate 47 | "batch_size": 128, 48 | "gpu": False, 49 | "num_cpus_per_trainer": 1, 50 | "num_gpus_per_trainer": 0, 51 | # FedProx specific parameter 52 | "mu": 0.01, # Regularization parameter, only used in "FedProx" 53 | # GCFL specific parameters 54 | "standardize": False, # Used only in "GCFL", "GCFL+", "GCFL+dWs" 55 | "seq_length": 5, # Sequence length, only used in "GCFL", "GCFL+", "GCFL+dWs" 56 | "epsilon1": 0.05, # Privacy epsilon1, specific to "GCFL", "GCFL+", "GCFL+dWs" 57 | "epsilon2": 0.1, # Privacy epsilon2, specific to "GCFL", "GCFL+", "GCFL+dWs" 58 | # Output configuration 59 | "outbase": "./outputs", 60 | "save_files": False, 61 | # Scalability and Cluster Configuration 62 | "use_cluster": False, # Use Kubernetes for scalability if True 63 | } 64 | ####################################################################### 65 | # Run fedgraph method 66 | # ------------------- 67 | 68 | config = attridict(config) 69 | run_fedgraph(config) 70 | -------------------------------------------------------------------------------- /tutorials/FGL_LP.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Link Prediction Example 3 | ================================= 4 | 5 | Federated Link Prediction with STFL on the Link Prediction dataset. 6 | 7 | (Time estimate: 3 minutes) 8 | """ 9 | 10 | ####################################################################### 11 | # Load libraries 12 | # -------------- 13 | 14 | import os 15 | 16 | import attridict 17 | 18 | from fedgraph.federated_methods import run_fedgraph 19 | 20 | ####################################################################### 21 | # Specify the Link Prediction configuration 22 | # ---------------------------------------------- 23 | BASE_DIR = os.path.dirname(os.path.abspath(".")) 24 | DATASET_PATH = os.path.join( 25 | BASE_DIR, "data", "LPDataset" 26 | ) # Could be modified based on the user needs 27 | config = { 28 | "fedgraph_task": "LP", 29 | # method = ["STFL", "StaticGNN", "4D-FED-GNN+", "FedLink"] 30 | "method": "STFL", 31 | # Dataset configuration 32 | # country_codes = ['US', 'BR', 'ID', 'TR', 'JP'] 33 | "country_codes": ["JP"], 34 | "dataset_path": DATASET_PATH, 35 | # Setup configuration 36 | "device": "cpu", 37 | "use_buffer": False, 38 | "buffer_size": 300000, 39 | "online_learning": False, 40 | "seed": 10, 41 | # Model parameters 42 | "global_rounds": 8, 43 | "local_steps": 3, 44 | "hidden_channels": 64, 45 | # Output configuration 46 | "record_results": False, 47 | # System configuration 48 | "gpu": False, 49 | "num_cpus_per_trainer": 1, 50 | "num_gpus_per_trainer": 0, 51 | "use_cluster": False, # whether use kubernetes for scalability or not 52 | "distribution_type": "average", # the node number distribution among clients 53 | "batch_size": -1, # -1 is full batch 54 | } 55 | ####################################################################### 56 | # Run fedgraph method 57 | # ------------------- 58 | 59 | config = attridict(config) 60 | run_fedgraph(config) 61 | -------------------------------------------------------------------------------- /tutorials/FGL_NC.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Node Classification Example 3 | ======================= 4 | 5 | Federated Node Classification with FedGCN on the Cora dataset. 6 | 7 | (Time estimate: 3 minutes) 8 | """ 9 | 10 | ####################################################################### 11 | # Load libraries 12 | # -------------- 13 | 14 | import attridict 15 | 16 | from fedgraph.federated_methods import run_fedgraph 17 | 18 | ####################################################################### 19 | # Specify the Node Classification configuration 20 | # --------------------------------------------- 21 | config = { 22 | # Task, Method, and Dataset Settings 23 | "fedgraph_task": "NC", 24 | "dataset": "cora", 25 | "method": "FedGCN", # Federated learning method, e.g., "FedGCN" 26 | "iid_beta": 10000, # Dirichlet distribution parameter for label distribution among clients 27 | "distribution_type": "average", # Distribution type among clients 28 | # Training Configuration 29 | "global_rounds": 100, 30 | "local_step": 3, 31 | "learning_rate": 0.5, 32 | "n_trainer": 2, 33 | "batch_size": -1, # -1 indicates full batch training 34 | # Model Structure 35 | "num_layers": 2, 36 | "num_hops": 1, # Number of n-hop neighbors for client communication 37 | # Resource and Hardware Settings 38 | "gpu": False, 39 | "num_cpus_per_trainer": 1, 40 | "num_gpus_per_trainer": 0, 41 | # Logging and Output Configuration 42 | "logdir": "./runs", 43 | # Security and Privacy 44 | "use_encryption": False, # Whether to use Homomorphic Encryption for secure aggregation 45 | # Dataset Handling Options 46 | "use_huggingface": False, # Load dataset directly from Hugging Face Hub 47 | "saveto_huggingface": False, # Save partitioned dataset to Hugging Face Hub 48 | # Scalability and Cluster Configuration 49 | "use_cluster": False, # Use Kubernetes for scalability if True 50 | } 51 | 52 | ####################################################################### 53 | # Run fedgraph method 54 | # ------------------- 55 | 56 | config = attridict(config) 57 | run_fedgraph(config) 58 | -------------------------------------------------------------------------------- /tutorials/FGL_NC_HE.py: -------------------------------------------------------------------------------- 1 | """ 2 | Federated Node Classification with Homomorphic Encryption Example 3 | ====================================== 4 | 5 | Federated Node Classification with FedGCN and Homomorphic Encryption on the Cora dataset. 6 | 7 | (Time estimate: 3 minutes) 8 | """ 9 | 10 | ####################################################################### 11 | # Load libraries 12 | # -------------- 13 | 14 | import attridict 15 | 16 | from fedgraph.federated_methods import run_fedgraph 17 | 18 | ####################################################################### 19 | # Specify the Node Classification with Homomorphic Encryption configuration 20 | # --------------------------------------------- 21 | config = { 22 | # Task, Method, and Dataset Settings 23 | "fedgraph_task": "NC", 24 | "dataset": "cora", 25 | "method": "FedGCN", # Federated learning method, e.g., "FedGCN" 26 | "iid_beta": 10000, # Dirichlet distribution parameter for label distribution among clients 27 | "distribution_type": "average", # Distribution type among clients 28 | # Training Configuration 29 | "global_rounds": 100, 30 | "local_step": 3, 31 | "learning_rate": 0.5, 32 | "n_trainer": 2, 33 | "batch_size": -1, # -1 indicates full batch training 34 | # Model Structure 35 | "num_layers": 2, 36 | "num_hops": 1, # Number of n-hop neighbors for client communication 37 | # Resource and Hardware Settings 38 | "gpu": False, 39 | "num_cpus_per_trainer": 1, 40 | "num_gpus_per_trainer": 0, 41 | # Logging and Output Configuration 42 | "logdir": "./runs", 43 | # Security and Privacy 44 | "use_encryption": True, # Whether to use Homomorphic Encryption for secure aggregation 45 | # Dataset Handling Options 46 | "use_huggingface": False, # Load dataset directly from Hugging Face Hub 47 | "saveto_huggingface": False, # Save partitioned dataset to Hugging Face Hub 48 | # Scalability and Cluster Configuration 49 | "use_cluster": False, # Use Kubernetes for scalability if True 50 | } 51 | 52 | ####################################################################### 53 | # Run fedgraph method 54 | # ------------------- 55 | 56 | config = attridict(config) 57 | run_fedgraph(config) 58 | -------------------------------------------------------------------------------- /tutorials/README.txt: -------------------------------------------------------------------------------- 1 | Federated Graph Learning: A Tutorial 2 | ===================================== 3 | --------------------------------------------------------------------------------