├── .github ├── ISSUE_TEMPLATE │ ├── blank-issue.md │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── check-style.yml │ ├── push-docker-image.yml │ ├── run-benchmarks.yml │ ├── run-tests-on-modal.yml │ └── run-tests.yml ├── .gitignore ├── .readthedocs.yml ├── CITATION.cff ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── benchmarks ├── benchmark_averaging.py ├── benchmark_dht.py ├── benchmark_optimizer.py ├── benchmark_tensor_compression.py └── benchmark_throughput.py ├── codecov.yml ├── docs ├── Makefile ├── _static │ ├── bug.gif │ ├── bug.odp │ ├── bug_preview.gif │ ├── dht.odp │ ├── dht.png │ ├── favicon.png │ └── fix_rtd.css ├── conf.py ├── index.rst ├── make.bat ├── modules │ ├── averaging.rst │ ├── client.rst │ ├── dht.rst │ ├── index.rst │ ├── optim.rst │ └── server.rst └── user │ ├── acknowledgements.md │ ├── benchmarks.md │ ├── contributing.md │ ├── dht.md │ ├── moe.md │ └── quickstart.md ├── examples └── albert │ ├── README.md │ ├── arguments.py │ ├── requirements.txt │ ├── run_trainer.py │ ├── run_training_monitor.py │ ├── tokenize_wikitext103.py │ └── utils.py ├── hivemind ├── __init__.py ├── averaging │ ├── __init__.py │ ├── allreduce.py │ ├── averager.py │ ├── control.py │ ├── group_info.py │ ├── key_manager.py │ ├── load_balancing.py │ ├── matchmaking.py │ └── partition.py ├── compression │ ├── __init__.py │ ├── adaptive.py │ ├── base.py │ ├── floating.py │ ├── quantization.py │ └── serialization.py ├── dht │ ├── __init__.py │ ├── crypto.py │ ├── dht.py │ ├── node.py │ ├── protocol.py │ ├── routing.py │ ├── schema.py │ ├── storage.py │ ├── traverse.py │ └── validation.py ├── hivemind_cli │ ├── __init__.py │ ├── config.yml │ ├── run_dht.py │ └── run_server.py ├── moe │ ├── __init__.py │ ├── client │ │ ├── __init__.py │ │ ├── beam_search.py │ │ ├── expert.py │ │ ├── moe.py │ │ ├── remote_expert_worker.py │ │ └── switch_moe.py │ ├── expert_uid.py │ └── server │ │ ├── __init__.py │ │ ├── checkpoints.py │ │ ├── connection_handler.py │ │ ├── dht_handler.py │ │ ├── layers │ │ ├── __init__.py │ │ ├── common.py │ │ ├── custom_experts.py │ │ ├── dropout.py │ │ ├── lr_schedule.py │ │ └── optim.py │ │ ├── module_backend.py │ │ ├── runtime.py │ │ ├── server.py │ │ └── task_pool.py ├── optim │ ├── __init__.py │ ├── grad_averager.py │ ├── grad_scaler.py │ ├── optimizer.py │ ├── power_sgd_averager.py │ ├── progress_tracker.py │ ├── state_averager.py │ └── training_averager.py ├── p2p │ ├── __init__.py │ ├── p2p_daemon.py │ ├── p2p_daemon_bindings │ │ ├── __init__.py │ │ ├── control.py │ │ ├── datastructures.py │ │ ├── p2pclient.py │ │ └── utils.py │ └── servicer.py ├── proto │ ├── auth.proto │ ├── averaging.proto │ ├── crypto.proto │ ├── dht.proto │ ├── p2pd.proto │ ├── runtime.proto │ └── test.proto └── utils │ ├── __init__.py │ ├── asyncio.py │ ├── auth.py │ ├── crypto.py │ ├── limits.py │ ├── logging.py │ ├── math.py │ ├── mpfuture.py │ ├── multiaddr │ ├── __init__.py │ ├── codecs │ │ ├── __init__.py │ │ ├── cid.py │ │ ├── domain.py │ │ ├── fspath.py │ │ ├── ip4.py │ │ ├── ip6.py │ │ ├── onion.py │ │ ├── onion3.py │ │ ├── uint16be.py │ │ └── utf8.py │ ├── exceptions.py │ ├── multiaddr.py │ ├── protocols.py │ └── transforms.py │ ├── nested.py │ ├── networking.py │ ├── performance_ema.py │ ├── serializer.py │ ├── streaming.py │ ├── tensor_descr.py │ └── timed_storage.py ├── modal_ci.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements-docs.txt ├── requirements.txt ├── setup.py └── tests ├── conftest.py ├── test_allreduce.py ├── test_allreduce_fault_tolerance.py ├── test_auth.py ├── test_averaging.py ├── test_cli_scripts.py ├── test_compression.py ├── test_connection_handler.py ├── test_custom_experts.py ├── test_dht.py ├── test_dht_crypto.py ├── test_dht_experts.py ├── test_dht_node.py ├── test_dht_protocol.py ├── test_dht_schema.py ├── test_dht_storage.py ├── test_dht_validation.py ├── test_expert_backend.py ├── test_moe.py ├── test_multiaddr.py ├── test_multiaddr_protocols.py ├── test_multiaddr_transforms.py ├── test_optimizer.py ├── test_p2p_daemon.py ├── test_p2p_daemon_bindings.py ├── test_p2p_servicer.py ├── test_relays.py ├── test_routing.py ├── test_start_server.py ├── test_training.py ├── test_util_modules.py └── test_utils ├── custom_networks.py ├── dht_swarms.py ├── networking.py └── p2p_daemon.py /.github/ISSUE_TEMPLATE/blank-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Blank issue 3 | about: An issue that doesn't fit into the templates. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report a problem or undesired behavior 4 | title: "[BUG] YOUR_TITLE_HERE" 5 | labels: bug 6 | assignees: justheuristic 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | If applicable, please create a minimal script that reproduces the problem for you. It would be great to include script outputs as well. 15 | 16 | **Environment** 17 | Please list: 18 | * python version (e.g. 3.8.1); 19 | * hivemind.__version__; branch or commit id 20 | * Please copy and paste the output from pytorch [environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) 21 | 22 | If the script doesn't work, please report pytorch and numpy versions manually. We also encourage you to include any additional information that you believe can help us solve the issue. 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/check-style.yml: -------------------------------------------------------------------------------- 1 | name: Check style 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | codespell: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: codespell-project/actions-codespell@v1 18 | with: 19 | only_warn: 1 20 | ignore_words_list: ibrary,nd,Buss 21 | 22 | ruff: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v4 26 | - uses: astral-sh/ruff-action@v3 27 | with: 28 | version-file: "requirements-dev.txt" 29 | - uses: astral-sh/ruff-action@v3 30 | with: 31 | args: "format --check --diff" 32 | version-file: "requirements-dev.txt" 33 | -------------------------------------------------------------------------------- /.github/workflows/push-docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Push to Docker Hub 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | tags: 7 | - "*.*.*" 8 | pull_request: 9 | branches: [ master ] 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v4 22 | 23 | - name: Docker meta 24 | id: meta 25 | uses: crazy-max/ghaction-docker-meta@v2 26 | with: 27 | # list of Docker images to use as base name for tags 28 | images: | 29 | learningathome/hivemind 30 | # generate Docker tags based on the following events/attributes 31 | tags: | 32 | type=ref,event=branch 33 | type=ref,event=pr 34 | type=semver,pattern={{version}} 35 | type=semver,pattern={{major}}.{{minor}} 36 | type=semver,pattern={{major}} 37 | 38 | - name: Set up Docker Buildx 39 | id: buildx 40 | uses: docker/setup-buildx-action@v1 41 | 42 | - name: Login to Docker Hub 43 | if: github.event_name != 'pull_request' 44 | uses: docker/login-action@v1 45 | with: 46 | username: ${{ secrets.DOCKER_HUB_USERNAME }} 47 | password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} 48 | 49 | - name: Build and push 50 | id: docker_build 51 | uses: docker/build-push-action@v2 52 | with: 53 | context: . 54 | push: ${{ github.event_name != 'pull_request' }} 55 | tags: ${{ steps.meta.outputs.tags }} 56 | 57 | - name: Image digest 58 | run: echo ${{ steps.docker_build.outputs.digest }} 59 | -------------------------------------------------------------------------------- /.github/workflows/run-benchmarks.yml: -------------------------------------------------------------------------------- 1 | name: Benchmarks 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | run_benchmarks: 14 | 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 10 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.11 23 | - name: Cache dependencies 24 | uses: actions/cache@v4 25 | with: 26 | path: ~/.cache/pip 27 | key: Key-v1-3.11-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -r requirements.txt 32 | pip install -r requirements-dev.txt 33 | - name: Build bitsandbytes 34 | run: | 35 | pip install bitsandbytes==0.45.2 36 | - name: Build hivemind 37 | run: | 38 | pip install . 39 | - name: Benchmark 40 | run: | 41 | cd benchmarks 42 | python benchmark_throughput.py --preset minimalistic 43 | python benchmark_tensor_compression.py 44 | python benchmark_dht.py 45 | -------------------------------------------------------------------------------- /.github/workflows/run-tests-on-modal.yml: -------------------------------------------------------------------------------- 1 | name: Modal tests 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | run_tests: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.9", "3.10", "3.11", "3.12"] 18 | fail-fast: false 19 | env: 20 | MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} 21 | MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} 22 | PYTHON_VERSION: ${{ matrix.python-version }} 23 | timeout-minutes: 15 24 | steps: 25 | - name: Checkout Repository 26 | uses: actions/checkout@v4 27 | 28 | - name: Install Python 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: "3.12" 32 | 33 | - name: Cache dependencies 34 | uses: actions/cache@v4 35 | with: 36 | path: ~/.cache/pip 37 | key: Key-v1-3.12-modal 38 | 39 | - name: Install build dependencies 40 | run: | 41 | python -m pip install --upgrade pip 42 | pip install modal==0.73.32 43 | 44 | - name: Run tests 45 | run: | 46 | modal run modal_ci.py::run_tests 47 | 48 | measure_coverage: 49 | runs-on: ubuntu-latest 50 | env: 51 | MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} 52 | MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} 53 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 54 | GITHUB_EVENT_NAME: ${{ github.event_name }} 55 | GITHUB_EVENT_NUMBER: ${{ github.event.number }} 56 | GITHUB_EVENT_PULL_REQUEST_HEAD_SHA: ${{ github.event.pull_request.head.sha }} 57 | PYTHON_VERSION: "3.11" 58 | timeout-minutes: 15 59 | steps: 60 | - name: Checkout Repository 61 | uses: actions/checkout@v4 62 | 63 | - name: Install Python 64 | uses: actions/setup-python@v5 65 | with: 66 | python-version: "3.12" 67 | 68 | - name: Cache dependencies 69 | uses: actions/cache@v4 70 | with: 71 | path: ~/.cache/pip 72 | key: Key-v1-3.12-modal 73 | 74 | - name: Install build dependencies 75 | run: | 76 | python -m pip install --upgrade pip 77 | pip install modal==0.73.32 78 | 79 | - name: Measure and upload coverage 80 | run: | 81 | modal run modal_ci.py::run_codecov 82 | 83 | build_and_test_p2pd: 84 | runs-on: ubuntu-latest 85 | env: 86 | MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} 87 | MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} 88 | PYTHON_VERSION: "3.11" 89 | timeout-minutes: 10 90 | steps: 91 | - name: Checkout Repository 92 | uses: actions/checkout@v4 93 | 94 | - name: Install Python 95 | uses: actions/setup-python@v5 96 | with: 97 | python-version: "3.12" 98 | 99 | - name: Cache dependencies 100 | uses: actions/cache@v4 101 | with: 102 | path: ~/.cache/pip 103 | key: Key-v1-3.12-modal 104 | 105 | - name: Install build dependencies 106 | run: | 107 | python -m pip install --upgrade pip 108 | pip install modal==0.73.32 109 | 110 | - name: Run p2pd tests 111 | run: | 112 | modal run modal_ci.py::build_and_test_p2pd 113 | -------------------------------------------------------------------------------- /.github/workflows/run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | run_tests: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [ "3.9", "3.10", "3.11", "3.12" ] 17 | fail-fast: false 18 | timeout-minutes: 15 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Cache dependencies 26 | uses: actions/cache@v4 27 | with: 28 | path: ~/.cache/pip 29 | key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install -r requirements.txt 34 | pip install -r requirements-dev.txt 35 | - name: Build bitsandbytes 36 | run: | 37 | pip install bitsandbytes==0.45.2 38 | - name: Build hivemind 39 | run: | 40 | pip install . 41 | - name: Test 42 | run: | 43 | cd tests 44 | export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor 45 | pytest --durations=0 --durations-min=1.0 -v 46 | build_and_test_p2pd: 47 | runs-on: ubuntu-latest 48 | timeout-minutes: 10 49 | steps: 50 | - uses: actions/checkout@v3 51 | - uses: actions/setup-go@v3 52 | with: 53 | go-version: '1.20.11' 54 | check-latest: true 55 | - name: Set up Python 56 | uses: actions/setup-python@v3 57 | with: 58 | python-version: '3.11' 59 | - name: Cache dependencies 60 | uses: actions/cache@v3 61 | with: 62 | path: ~/.cache/pip 63 | key: Key-v1-3.11-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} 64 | - name: Install dependencies 65 | run: | 66 | python -m pip install --upgrade pip setuptools wheel 67 | pip install -r requirements.txt 68 | pip install -r requirements-dev.txt 69 | - name: Build hivemind 70 | run: | 71 | pip install . --global-option=build_py --global-option="--buildgo" --no-use-pep517 72 | - name: Test 73 | run: | 74 | cd tests 75 | export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor 76 | pytest -k "p2p" -v 77 | codecov_in_develop_mode: 78 | runs-on: ubuntu-latest 79 | timeout-minutes: 20 80 | steps: 81 | - uses: actions/checkout@v3 82 | - name: Set up Python 83 | uses: actions/setup-python@v3 84 | with: 85 | python-version: '3.11' 86 | - name: Cache dependencies 87 | uses: actions/cache@v3 88 | with: 89 | path: ~/.cache/pip 90 | key: Key-v1-3.11-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }} 91 | - name: Install dependencies 92 | run: | 93 | python -m pip install --upgrade pip setuptools wheel 94 | pip install -r requirements.txt 95 | pip install -r requirements-dev.txt 96 | - name: Build bitsandbytes 97 | run: | 98 | pip install bitsandbytes==0.45.2 99 | - name: Build hivemind 100 | run: | 101 | pip install -e . --no-use-pep517 102 | - name: Test 103 | run: | 104 | export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor 105 | pytest --cov hivemind --cov-config=pyproject.toml -v tests 106 | - name: Upload coverage to Codecov 107 | uses: codecov/codecov-action@v3 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # node and NPM 2 | npm-debug.log 3 | node_modules 4 | 5 | # swap files 6 | *~ 7 | *.swp 8 | 9 | examples/data/* 10 | examples/runs/* 11 | examples/.ipynb_checkpoints/* 12 | 13 | env.sh 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | bin/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg/ 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Mr Developer 53 | .mr.developer.cfg 54 | .project 55 | .pydevproject 56 | .idea 57 | .vscode 58 | .ipynb_checkpoints 59 | 60 | # Rope 61 | .ropeproject 62 | 63 | # Django stuff: 64 | *.log 65 | *.pot 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | docs/tmp* 70 | 71 | # OS X garbage 72 | .DS_Store 73 | 74 | # Debian things 75 | debian/reproducible-experiment-platform 76 | debian/files 77 | *.substvars 78 | *.debhelper.log 79 | 80 | # protobuf stuff 81 | hivemind/proto/*_pb2* 82 | 83 | # libp2p-daemon binary 84 | hivemind/hivemind_cli/p2pd 85 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | fail_on_warning: true 5 | configuration: docs/conf.py 6 | 7 | python: 8 | install: 9 | - requirements: requirements.txt 10 | - requirements: requirements-docs.txt 11 | - method: pip 12 | path: . 13 | 14 | build: 15 | os: ubuntu-22.04 16 | tools: 17 | python: "3.11" 18 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: "1.2.0" 2 | date-released: 2020-04 3 | message: "If you use this software, please cite it as below." 4 | title: "Hivemind: A Library For Decentralized Deep Learning" 5 | url: "https://github.com/learning-at-home/hivemind" 6 | authors: 7 | - family-names: Ryabinin 8 | given-names: Max 9 | - family-names: Borzunov 10 | given-names: Alexander 11 | - family-names: Diskin 12 | given-names: Michael 13 | - family-names: Gusev 14 | given-names: Anton 15 | - family-names: Mazur 16 | given-names: Denis 17 | - family-names: Plokhotnyuk 18 | given-names: Vsevolod 19 | - family-names: Bukhtiyarov 20 | given-names: Alexey 21 | - family-names: Samygin 22 | given-names: Pavel 23 | - family-names: Sinitsin 24 | given-names: Anton 25 | - family-names: Chumachenko 26 | given-names: Artem 27 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 2 | LABEL maintainer="Learning@home" 3 | LABEL repository="hivemind" 4 | 5 | WORKDIR /home 6 | # Set en_US.UTF-8 locale by default 7 | RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment 8 | 9 | # Install packages 10 | RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \ 11 | build-essential \ 12 | curl \ 13 | wget \ 14 | git \ 15 | vim \ 16 | && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* 17 | 18 | RUN curl https://sh.rustup.rs -sSf | sh -s -- -y 19 | ENV PATH="/root/.cargo/bin:${PATH}" 20 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \ 21 | bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh 22 | ENV PATH="/opt/conda/bin:${PATH}" 23 | 24 | RUN conda install python~=3.11.0 pip && \ 25 | pip install --no-cache-dir torch torchvision torchaudio && \ 26 | conda clean --all 27 | 28 | COPY requirements.txt hivemind/requirements.txt 29 | COPY requirements-dev.txt hivemind/requirements-dev.txt 30 | COPY examples/albert/requirements.txt hivemind/examples/albert/requirements.txt 31 | RUN pip install --no-cache-dir -r hivemind/requirements.txt && \ 32 | pip install --no-cache-dir -r hivemind/requirements-dev.txt && \ 33 | pip install --no-cache-dir -r hivemind/examples/albert/requirements.txt && \ 34 | rm -rf ~/.cache/pip 35 | 36 | COPY . hivemind/ 37 | RUN cd hivemind && \ 38 | pip install --no-cache-dir .[dev] && \ 39 | conda clean --all && rm -rf ~/.cache/pip 40 | 41 | CMD bash 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Learning@home authors and collaborators 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /benchmarks/benchmark_averaging.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import threading 4 | import time 5 | 6 | import torch 7 | 8 | import hivemind 9 | from hivemind.compression import Float16Compression 10 | from hivemind.utils.limits import increase_file_limit 11 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 12 | 13 | use_hivemind_log_handler("in_root_logger") 14 | logger = get_logger(__name__) 15 | 16 | 17 | def sample_tensors(hid_size, num_layers): 18 | tensors = [] 19 | for i in range(num_layers): 20 | tensors.append(torch.randn(hid_size, 3 * hid_size)) 21 | tensors.append(torch.randn(3 * hid_size)) 22 | tensors.append(torch.randn(3 * hid_size)) 23 | tensors.append(torch.randn(hid_size, hid_size)) 24 | tensors.append(torch.ones(hid_size)) 25 | tensors.append(torch.zeros(hid_size)) 26 | tensors.append(torch.randn(hid_size, 4 * hid_size)) 27 | tensors.append(torch.randn(4 * hid_size)) 28 | tensors.append(torch.ones(4 * hid_size)) 29 | tensors.append(torch.randn(2, hid_size, hid_size, 2)) 30 | tensors.append(torch.randn(hid_size)) 31 | tensors.append(torch.randn(hid_size)) 32 | tensors.append(torch.randn(hid_size)) 33 | return tuple(tensors) 34 | 35 | 36 | def benchmark_averaging( 37 | num_peers: int, 38 | target_group_size: int, 39 | num_rounds: int, 40 | min_matchmaking_time: float, 41 | request_timeout: float, 42 | round_timeout: float, 43 | hid_size: int, 44 | num_layers: int, 45 | spawn_dtime: float, 46 | ): 47 | dht_root = hivemind.DHT(start=True) 48 | initial_peers = dht_root.get_visible_maddrs() 49 | 50 | num_groups = 2 ** int(round(math.log2(num_peers / target_group_size))) 51 | nbits = int(round(math.log2(num_groups))) 52 | peer_tensors = [sample_tensors(hid_size, num_layers) for _ in range(num_peers)] 53 | processes = {dht_root} 54 | lock_stats = threading.Lock() 55 | successful_steps = total_steps = 0 56 | 57 | def run_averager(index): 58 | nonlocal successful_steps, total_steps, lock_stats 59 | dht = hivemind.DHT(initial_peers=initial_peers, start=True) 60 | initial_bits = bin(index % num_groups)[2:].rjust(nbits, "0") 61 | averager = hivemind.averaging.DecentralizedAverager( 62 | peer_tensors[index], 63 | dht, 64 | prefix="my_tensor", 65 | initial_group_bits=initial_bits, 66 | compression=Float16Compression(), 67 | target_group_size=target_group_size, 68 | min_matchmaking_time=min_matchmaking_time, 69 | request_timeout=request_timeout, 70 | start=True, 71 | ) 72 | processes.update({dht, averager}) 73 | 74 | logger.info( 75 | f"Averager {index}: started with peer id {averager.peer_id}, group_bits: {averager.get_group_bits()}" 76 | ) 77 | for step in range(num_rounds): 78 | try: 79 | success = averager.step(timeout=round_timeout) is not None 80 | except: # noqa: E722 81 | success = False 82 | with lock_stats: 83 | successful_steps += int(success) 84 | total_steps += 1 85 | logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}") 86 | logger.info(f"Averager {index}: done.") 87 | 88 | threads = [] 89 | for i in range(num_peers): 90 | thread = threading.Thread(target=run_averager, args=[i]) 91 | threads.append(thread) 92 | thread.start() 93 | time.sleep(spawn_dtime) 94 | 95 | t = time.time() 96 | for thread in threads: 97 | thread.join() 98 | 99 | logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.") 100 | logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)") 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--num_peers", type=int, default=16, required=False) 106 | parser.add_argument("--target_group_size", type=int, default=4, required=False) 107 | parser.add_argument("--num_rounds", type=int, default=5, required=False) 108 | parser.add_argument("--hid_size", type=int, default=256, required=False) 109 | parser.add_argument("--num_layers", type=int, default=3, required=False) 110 | parser.add_argument("--min_matchmaking_time", type=float, default=5, required=False) 111 | parser.add_argument("--round_timeout", type=float, default=15, required=False) 112 | parser.add_argument("--request_timeout", type=float, default=1, required=False) 113 | parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False) 114 | parser.add_argument("--increase_file_limit", action="store_true") 115 | args = vars(parser.parse_args()) 116 | 117 | if args.pop("increase_file_limit", False): 118 | increase_file_limit() 119 | 120 | benchmark_averaging(**args) 121 | -------------------------------------------------------------------------------- /benchmarks/benchmark_tensor_compression.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch 5 | 6 | from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor 7 | from hivemind.proto.runtime_pb2 import CompressionType 8 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 9 | 10 | use_hivemind_log_handler("in_root_logger") 11 | logger = get_logger(__name__) 12 | 13 | 14 | def benchmark_compression(tensor: torch.Tensor, compression_type: CompressionType) -> [float, float, int]: 15 | t = time.time() 16 | serialized = serialize_torch_tensor(tensor, compression_type) 17 | result = deserialize_torch_tensor(serialized) 18 | return time.time() - t, (tensor - result).square().mean(), serialized.ByteSize() 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--size", type=int, default=10_000_000, required=False) 24 | parser.add_argument("--seed", type=int, default=7348, required=False) 25 | parser.add_argument("--num_iters", type=int, default=30, required=False) 26 | 27 | args = parser.parse_args() 28 | 29 | torch.manual_seed(args.seed) 30 | X = torch.randn(args.size, dtype=torch.float32) 31 | 32 | for name, compression_type in CompressionType.items(): 33 | total_time = 0 34 | compression_error = 0 35 | total_size = 0 36 | for i in range(args.num_iters): 37 | iter_time, iter_distortion, size = benchmark_compression(X, compression_type) 38 | total_time += iter_time 39 | compression_error += iter_distortion 40 | total_size += size 41 | total_time /= args.num_iters 42 | compression_error /= args.num_iters 43 | total_size /= args.num_iters 44 | logger.info( 45 | f"Compression type: {name}, time: {total_time:.5f}, compression error: {compression_error:.5f}, " 46 | f"size: {int(total_size):d}" 47 | ) 48 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: 2 | layout: "diff, files" 3 | behavior: default 4 | require_changes: true 5 | coverage: 6 | status: 7 | patch: 8 | default: 9 | informational: true 10 | project: 11 | default: 12 | threshold: 1% 13 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_static/bug.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/docs/_static/bug.gif -------------------------------------------------------------------------------- /docs/_static/bug.odp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/docs/_static/bug.odp -------------------------------------------------------------------------------- /docs/_static/bug_preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/docs/_static/bug_preview.gif -------------------------------------------------------------------------------- /docs/_static/dht.odp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/docs/_static/dht.odp -------------------------------------------------------------------------------- /docs/_static/dht.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/docs/_static/dht.png -------------------------------------------------------------------------------- /docs/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/docs/_static/favicon.png -------------------------------------------------------------------------------- /docs/_static/fix_rtd.css: -------------------------------------------------------------------------------- 1 | /* work around https://github.com/snide/sphinx_rtd_theme/issues/149 */ 2 | .rst-content table.field-list .field-body { 3 | padding-top: 8px; 4 | } 5 | /* unlimited page width */ 6 | .wy-nav-content { 7 | max-width: none; 8 | } -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | |logo| **Hivemind docs & tutorials** 2 | ==================================== 3 | 4 | .. |logo| image:: _static/favicon.png 5 | :scale: 48 6 | 7 | Hivemind is a library for decentralized deep learning computations. It allows you to train large neural networks using vast numbers 8 | of computers, whether you're running a very capable computer or a less reliable one. 9 | Learn how to create or join a Hivemind run in the `quickstart tutorial <./user/quickstart.html>`__ or browse the API 10 | documentation below. 11 | 12 | | Hivemind is currently in active development, so expect some adventures. If you have any questions, feel free to ask them 13 | in `our Discord chat `_ or 14 | file an `issue `__. 15 | 16 | **Table of contents:** 17 | ~~~~~~~~~~~~~~~~~~~~~~ 18 | .. toctree:: 19 | :maxdepth: 2 20 | :glob: 21 | 22 | user/quickstart 23 | modules/index 24 | user/dht 25 | user/moe 26 | user/contributing 27 | user/benchmarks 28 | user/acknowledgements 29 | 30 | Indices and tables 31 | ~~~~~~~~~~~~~~~~~~ 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | 36 | .. _GitHub: https://github.com/learning-at-home/hivemind 37 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/modules/averaging.rst: -------------------------------------------------------------------------------- 1 | **hivemind.averaging** 2 | ====================== 3 | 4 | .. automodule:: hivemind.averaging 5 | 6 | .. currentmodule:: hivemind.averaging 7 | .. raw:: html 8 | 9 | This module lets you average tensors in a decentralized manner. 10 |

11 | 12 | .. autoclass:: DecentralizedAverager 13 | :members: 14 | :member-order: bysource 15 | :exclude-members: get_tensors, get_tensors_async, update_tensors, rpc_join_group, rpc_aggregate_part, register_allreduce_group 16 | -------------------------------------------------------------------------------- /docs/modules/client.rst: -------------------------------------------------------------------------------- 1 | **hivemind.moe.client** 2 | ======================= 3 | 4 | .. automodule:: hivemind.moe.client 5 | 6 | .. currentmodule:: hivemind.moe.client 7 | 8 | .. raw:: html 9 | 10 | This module lets you connect to distributed Mixture-of-Experts or individual experts hosted 11 | in the cloud cloud on someone else's computer. 12 |

13 | 14 | .. autoclass:: RemoteExpert 15 | :members: forward 16 | 17 | .. autoclass:: RemoteMixtureOfExperts 18 | :members: 19 | :member-order: bysource 20 | 21 | .. autoclass:: RemoteSwitchMixtureOfExperts 22 | :members: 23 | :member-order: bysource -------------------------------------------------------------------------------- /docs/modules/dht.rst: -------------------------------------------------------------------------------- 1 | **hivemind.dht** 2 | ==================== 3 | 4 | .. automodule:: hivemind.dht 5 | .. currentmodule:: hivemind.dht 6 | 7 | Here's a high level scheme of how these components interact with one another: 8 | 9 | .. image:: ../_static/dht.png 10 | :width: 640 11 | :align: center 12 | 13 | 14 | DHT and DHTNode 15 | ############### 16 | 17 | .. autoclass:: DHT 18 | :members: 19 | :exclude-members: make_key 20 | :member-order: bysource 21 | 22 | .. autoclass:: DHTNode 23 | :members: 24 | :member-order: bysource 25 | 26 | DHT communication protocol 27 | ########################## 28 | .. automodule:: hivemind.dht.protocol 29 | .. currentmodule:: hivemind.dht.protocol 30 | 31 | .. autoclass:: DHTProtocol 32 | :members: 33 | :member-order: bysource 34 | 35 | .. currentmodule:: hivemind.dht.routing 36 | 37 | .. autoclass:: RoutingTable 38 | :members: 39 | :member-order: bysource 40 | 41 | .. autoclass:: KBucket 42 | :members: 43 | :member-order: bysource 44 | 45 | .. autoclass:: DHTID 46 | :members: 47 | :exclude-members: HASH_FUNC 48 | :member-order: bysource 49 | 50 | Traverse (crawl) DHT 51 | #################### 52 | 53 | .. automodule:: hivemind.dht.traverse 54 | .. currentmodule:: hivemind.dht.traverse 55 | 56 | .. autofunction:: simple_traverse_dht 57 | 58 | .. autofunction:: traverse_dht -------------------------------------------------------------------------------- /docs/modules/index.rst: -------------------------------------------------------------------------------- 1 | #################### 2 | API documentation 3 | #################### 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | optim 9 | averaging 10 | dht 11 | client 12 | server -------------------------------------------------------------------------------- /docs/modules/optim.rst: -------------------------------------------------------------------------------- 1 | **hivemind.optim** 2 | ================== 3 | 4 | .. raw:: html 5 | 6 | This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers. 7 | Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent, 8 | or perform asynchronous local updates and average model parameters. 9 | 10 |

11 | 12 | .. automodule:: hivemind.optim.optimizer 13 | .. currentmodule:: hivemind.optim.optimizer 14 | 15 | **hivemind.Optimizer** 16 | ---------------------- 17 | 18 | .. autoclass:: Optimizer 19 | :members: step, local_epoch, zero_grad, load_state_from_peers, param_groups, shutdown 20 | :member-order: bysource 21 | 22 | .. currentmodule:: hivemind.optim.grad_scaler 23 | .. autoclass:: GradScaler 24 | :member-order: bysource -------------------------------------------------------------------------------- /docs/modules/server.rst: -------------------------------------------------------------------------------- 1 | **hivemind.moe.server** 2 | ======================================== 3 | 4 | A hivemind server hosts one or several experts and processes incoming requests to those experts. It periodically 5 | re-publishes these experts to the dht via a dedicated **hivemind.dht.DHT** peer that runs in background. 6 | The experts can be accessed directly as **hivemind.moe.client.RemoteExpert("addr:port", "expert.uid.here")** 7 | or as a part of **hivemind.moe.client.RemoteMixtureOfExperts** that finds the most suitable experts across the DHT. 8 | 9 | The hivemind.moe.server module is organized as follows: 10 | 11 | - Server_ is the main class that publishes experts, accepts incoming requests, and passes them to Runtime_ for compute. 12 | - ModuleBackend_ is a wrapper for `torch.nn.Module `_ \ 13 | that can be accessed by remote clients. It has two TaskPool_ s for forward and backward requests. 14 | - Runtime_ balances the device (GPU) usage between several ModuleBackend_ instances that each service one expert. 15 | - TaskPool_ stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches \ 16 | and offers those batches to Runtime_ for processing. 17 | 18 | 19 | .. automodule:: hivemind.moe.server 20 | 21 | .. currentmodule:: hivemind.moe.server 22 | 23 | .. _Server: 24 | .. autoclass:: Server 25 | :members: 26 | :member-order: bysource 27 | 28 | .. _ModuleBackend: 29 | .. autoclass:: ModuleBackend 30 | :members: forward, backward, on_backward, get_info, get_pools 31 | :member-order: bysource 32 | 33 | .. currentmodule:: hivemind.moe.server.runtime 34 | 35 | .. _Runtime: 36 | .. autoclass:: Runtime 37 | :members: 38 | :member-order: bysource 39 | 40 | .. currentmodule:: hivemind.moe.server.task_pool 41 | 42 | .. _TaskPool: 43 | .. autoclass:: TaskPool 44 | :members: submit_task, iterate_minibatches, load_batch_to_runtime, send_outputs_from_runtime, get_task_size, empty 45 | :member-order: bysource -------------------------------------------------------------------------------- /docs/user/acknowledgements.md: -------------------------------------------------------------------------------- 1 | # Acknowledgements 2 | 3 | We kindly thank (in no particular order) 4 | 5 | * [Artem Babenko](https://research.yandex.com/people/102794) and 6 | [Vladimir Aliev](https://ru.linkedin.com/in/vladimir-aliev-19b93282) for helpful discussions and editorial review of 7 | the paper, 8 | * [Jacob R. Steeves](https://github.com/unconst) for discussions on RPC frameworks and NAT traversal and peer-to-peer 9 | technologies. 10 | * [Dmitry Afanasiev](https://www.linkedin.com/in/dmitry-afanasiev-295a231/) for his guidance on networking and 11 | communication technologies, 12 | * [Lidi Zheng](https://github.com/lidizheng) and grpc-aio contributors for their awesome framework 13 | and [this PR](https://github.com/grpc/grpc/pull/23265) 14 | * [Brian Muller](https://github.com/bmuller/kademlia) for his implementations 15 | of [kademlia](https://github.com/bmuller/kademlia) and [rpcudp](https://github.com/bmuller/rpcudp) 16 | * Alexander Sherbakov for helpful discussions on PC and server component architecture, 17 | * [Yandex School of Data Analysis](https://yandexdataschool.com) students, for helping us run first truly collaborative experiments. 18 | * The [Neuropark community](https://neuropark.co/), for hosting early collaborative training experiments of sahajBERT with hivemind. 19 | * Our early adopters, [contributors](https://github.com/learning-at-home/hivemind/graphs/contributors), and conference reviewers. 20 | 21 | # Related projects 22 | 23 | In this section, we list several organizations and research projects that bring humanity closer to the dream of world-scale deep learning with volunteer computing. 24 | * [Hugging Face](https://huggingface.co) — an AI community with world-leading NLP research that builds collaborative hub training using hivemind. 25 | * [EYDLE](https://www.eydle.com) — a start-up that works towards distributed deep learning on volunteer hardware using centralized infrastructure. 26 | * [BitTensor](https://github.com/opentensor/BitTensor) — a decentralized deep learning ecosystem with incentive 27 | mechanism. Each peer trains for its own objective and rewards others for useful features. 28 | * Also building collaborative deep learning? Let us know! `hivemind-team hotmail.com` 29 | -------------------------------------------------------------------------------- /docs/user/benchmarks.md: -------------------------------------------------------------------------------- 1 | # Benchmarking 2 | 3 | This page describes the benchmark scripts that can be used to measure the performance impact of different changes to 4 | hivemind. 5 | 6 | ### Server throughput 7 | 8 | You can use [this benchmark](https://github.com/learning-at-home/hivemind/blob/master/benchmarks/benchmark_throughput.py) to 9 | check the performance impact of your changes to hivemind.moe. The benchmark will start one server without 10 | DHT with several experts, and then spawn trainer processes that load the server with requests. The two main statistics 11 | in this benchmark samples/s and startup time. 12 | 13 | `python benchmark_throughput.py --preset default` (aka `ffn_forward_backward`) 14 | 15 |
16 | Console outputs 17 | 18 | ```sh 19 | Benchmark finished, status:Success 20 | Server parameters: num_experts=16, num_handlers=64, max_batch_size=8192, expert_cls=ffn, hid_dim=1024, device=cuda 21 | Client parameters: num_clients=128, num_batches_per_client=16, batch_size=2048, backprop=True 22 | Results: 23 | Server startup took 10.965 s. (3.075 s. experts + 7.889 s. networking) 24 | Processed 4194304 examples in 146.750 25 | Throughput for forward + backward passes: 28581.213 samples / s. 26 | Benchmarking took 157.948 s. 27 | Using device: cuda 28 | GeForce GTX 1080 Ti 29 | Memory Usage: 30 | Allocated: 6.0 GB 31 | Cached: 7.7 GB 32 | 33 | ``` 34 | 35 |
36 | 37 | `python benchmark_throughput.py --preset ffn_forward` 38 | 39 |
40 | Console outputs 41 | 42 | ```sh 43 | Benchmark finished, status:Success 44 | Server parameters: num_experts=16, num_handlers=64, max_batch_size=8192, expert_cls=ffn, hid_dim=1024, device=cuda 45 | Client parameters: num_clients=128, num_batches_per_client=16, batch_size=2048, backprop=False 46 | Results: 47 | Server startup took 19.941 s. (3.065 s. experts + 16.877 s. networking) 48 | Processed 4194304 examples in 42.973 49 | Throughput for forward passes: 97604.282 samples / s. 50 | Benchmarking took 63.167 s. 51 | Using device: cuda 52 | GeForce GTX 1080 Ti 53 | Memory Usage: 54 | Allocated: 1.5 GB 55 | Cached: 3.2 GB 56 | ``` 57 | 58 |
59 | 60 | ### DHT performance 61 | 62 | In turn, [this benchmark](https://github.com/learning-at-home/hivemind/blob/master/benchmarks/benchmark_dht.py) can be used 63 | to measure performance impact of changes to hivemind.dht. It spawns a DHT with `num_peers` participants, then chooses 64 | one peer that will declare `num_experts` total experts in batches of `expert_batch_size`. Then, another peer will 65 | consecutively get all peers and check if they are there. 66 | 67 | Here's a run with 1024 participants on the same machine that was used for benchmark_throughput: 68 | 69 | `python benchmark_dht.py --num_peers 1024 --num_experts 16384 --expert_batch_size 64 --expiration 99999 --increase_file_limit` 70 |
71 | Console outputs 72 | 73 | ```sh 74 | Increasing file limit - soft 1024=>32768, hard 1048576=>32768 75 | Creating peers... 76 | 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [01:45<00:00, 9.74it/s] 77 | Sampled 16384 unique ids (after deduplication) 78 | Storing peers to dht in batches of 64... 79 | 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [12:07<00:00, 2.84s/it] 80 | Store success rate: 100.0% (48920 / 48920) 81 | Mean store time: 0.01487, Total: 727.46 82 | 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [01:48<00:00, 2.35it/s] 83 | Get success rate: 100.0 (16384 / 16384) 84 | Mean get time: 0.00664, Total: 108.73952 85 | Node survival rate: 100.000% 86 | ``` 87 | 88 |
89 | 90 | The three main statistics in this benchmark are total store time, total get time and get success rate. Please also note 91 | that this benchmark does not emulate node failure, latency and does not benefit from caching. If one wants to account 92 | for these factors, one must introduce them manually by changing the code. -------------------------------------------------------------------------------- /docs/user/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing to hivemind 2 | 3 | This section describes the ways to contribute to the hivemind library. For technical details of developing this library 4 | and getting towards merging your code in the master branch, read 5 | the [guidelines](https://github.com/learning-at-home/hivemind/blob/master/CONTRIBUTING.md#) in our GitHub repository. In 6 | any case, please follow the [Contributor Covenant](https://www.contributor-covenant.org/version/2/0/code_of_conduct/) 7 | code of conduct when discussing the library and the changes with other community members. 8 | 9 | ## Ways to contribute 10 | 11 | ### Reporting issues 12 | 13 | ### Proposing new features 14 | 15 | ### Implementing new features 16 | 17 | ### Fixing bugs and improving performance 18 | 19 | ### Improving tests 20 | 21 | ### Improving code readability 22 | 23 | ### Adding tutorials 24 | 25 | ### Improving documentation 26 | 27 | ### Reviewing pull requests 28 | 29 | -------------------------------------------------------------------------------- /examples/albert/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers~=4.6 2 | datasets~=1.5 3 | torch_optimizer==0.1.0 4 | wandb==0.10.26 5 | sentencepiece 6 | requests 7 | nltk==3.6.7 8 | -------------------------------------------------------------------------------- /examples/albert/tokenize_wikitext103.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """This script builds a pre-tokenized compressed representation of WikiText-103 using huggingface/datasets""" 3 | 4 | import random 5 | from functools import partial 6 | 7 | import nltk 8 | from datasets import load_dataset 9 | from transformers import AlbertTokenizerFast 10 | 11 | COLUMN_NAMES = ("attention_mask", "input_ids", "sentence_order_label", "special_tokens_mask", "token_type_ids") 12 | 13 | 14 | def create_instances_from_document(tokenizer, document, max_seq_length): 15 | """ 16 | Creates training instances from a single document. 17 | Reuses code from the original ALBERT implementation (Google AI, 2018) 18 | https://github.com/google-research/albert/blob/master/create_pretraining_data.py#L267 19 | """ 20 | # We DON'T just concatenate all of the tokens from a document into a long 21 | # sequence and choose an arbitrary split point because this would make the 22 | # next sentence prediction task too easy. Instead, we split the input into 23 | # segments "A" and "B" based on the actual "sentences" provided by the user 24 | # input. 25 | instances = [] 26 | current_chunk = [] 27 | current_length = 0 28 | 29 | segmented_sents = list(nltk.sent_tokenize(document)) 30 | 31 | for i, sent in enumerate(segmented_sents): 32 | current_chunk.append(sent) 33 | current_length += len(tokenizer.tokenize(sent)) 34 | if i == len(segmented_sents) - 1 or current_length >= max_seq_length: 35 | if len(current_chunk) > 1: 36 | # `a_end` is how many segments from `current_chunk` go into the `A` 37 | # (first) sentence. 38 | a_end = random.randint(1, len(current_chunk) - 1) 39 | 40 | tokens_a = [] 41 | for j in range(a_end): 42 | tokens_a.append(current_chunk[j]) 43 | 44 | tokens_b = [] 45 | 46 | for j in range(a_end, len(current_chunk)): 47 | tokens_b.append(current_chunk[j]) 48 | 49 | if random.random() < 0.5: 50 | # Random next 51 | is_random_next = True 52 | # Note(mingdachen): in this case, we just swap tokens_a and tokens_b 53 | tokens_a, tokens_b = tokens_b, tokens_a 54 | else: 55 | # Actual next 56 | is_random_next = False 57 | 58 | assert len(tokens_a) >= 1 59 | assert len(tokens_b) >= 1 60 | 61 | instance = tokenizer( 62 | " ".join(tokens_a), 63 | " ".join(tokens_b), 64 | truncation="longest_first", 65 | max_length=max_seq_length, 66 | # We use this option because DataCollatorForLanguageModeling 67 | # is more efficient when it receives the `special_tokens_mask`. 68 | return_special_tokens_mask=True, 69 | ) 70 | assert len(instance["input_ids"]) <= max_seq_length 71 | instance["sentence_order_label"] = 1 if is_random_next else 0 72 | instances.append(instance) 73 | 74 | current_chunk = [] 75 | current_length = 0 76 | 77 | return instances 78 | 79 | 80 | def tokenize_function(tokenizer, examples): 81 | # Remove empty texts 82 | texts = (text for text in examples["text"] if len(text) > 0 and not text.isspace()) 83 | 84 | new_examples = {col: [] for col in COLUMN_NAMES} 85 | 86 | for text in texts: 87 | instances = create_instances_from_document(tokenizer, text, max_seq_length=512) 88 | for instance in instances: 89 | for key, value in instance.items(): 90 | new_examples[key].append(value) 91 | 92 | return new_examples 93 | 94 | 95 | if __name__ == "__main__": 96 | random.seed(0) 97 | nltk.download("punkt") 98 | tokenizer = AlbertTokenizerFast.from_pretrained("albert-large-v2") 99 | wikitext = load_dataset("wikitext", "wikitext-103-v1", cache_dir="./data/cache") 100 | 101 | tokenized_datasets = wikitext.map( 102 | partial(tokenize_function, tokenizer), 103 | batched=True, 104 | num_proc=8, 105 | remove_columns=["text"], 106 | ) 107 | 108 | tokenized_datasets.save_to_disk("./data/albert_tokenized_wikitext") 109 | tokenizer.save_pretrained("./data/tokenizer") 110 | -------------------------------------------------------------------------------- /examples/albert/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from pydantic.v1 import BaseModel, StrictFloat, confloat, conint 4 | 5 | from hivemind.dht.crypto import RSASignatureValidator 6 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator 7 | from hivemind.dht.validation import RecordValidatorBase 8 | from hivemind.utils.logging import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class LocalMetrics(BaseModel): 14 | step: conint(ge=0, strict=True) 15 | samples_per_second: confloat(ge=0.0, strict=True) 16 | samples_accumulated: conint(ge=0, strict=True) 17 | loss: StrictFloat 18 | mini_steps: conint(ge=0, strict=True) 19 | 20 | 21 | class MetricSchema(BaseModel): 22 | metrics: Dict[BytesWithPublicKey, LocalMetrics] 23 | 24 | 25 | def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]: 26 | signature_validator = RSASignatureValidator() 27 | validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator] 28 | return validators, signature_validator.local_public_key 29 | -------------------------------------------------------------------------------- /hivemind/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.averaging import DecentralizedAverager 2 | from hivemind.compression import * 3 | from hivemind.dht import DHT 4 | from hivemind.moe import ( 5 | ModuleBackend, 6 | RemoteExpert, 7 | RemoteMixtureOfExperts, 8 | RemoteSwitchMixtureOfExperts, 9 | Server, 10 | register_expert_class, 11 | ) 12 | from hivemind.optim import GradScaler, Optimizer, TrainingAverager 13 | from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo 14 | from hivemind.utils import * 15 | 16 | __version__ = "1.2.0.dev0" 17 | -------------------------------------------------------------------------------- /hivemind/averaging/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.averaging.averager import DecentralizedAverager 2 | -------------------------------------------------------------------------------- /hivemind/averaging/group_info.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | from hivemind.p2p import PeerID 5 | 6 | 7 | @dataclass(frozen=True) 8 | class GroupInfo: 9 | """A group of peers assembled through decentralized matchmaking""" 10 | 11 | group_id: bytes # random unique bytestring that describes the current group, generated by group leader 12 | peer_ids: Tuple[PeerID, ...] # an ordered sequence of peer_ids of each groupmate 13 | gathered: Tuple[bytes, ...] # binary metadata gathered from all peers by leader, same order as peer_ids 14 | 15 | @property 16 | def group_size(self): 17 | return len(self.peer_ids) 18 | 19 | def __contains__(self, peer_id: PeerID): 20 | return peer_id in self.peer_ids 21 | -------------------------------------------------------------------------------- /hivemind/compression/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compression strategies that reduce the network communication in .averaging, .optim and .moe 3 | """ 4 | 5 | from hivemind.compression.adaptive import PerTensorCompression, RoleAdaptiveCompression, SizeAdaptiveCompression 6 | from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression, TensorRole 7 | from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression 8 | from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization 9 | from hivemind.compression.serialization import ( 10 | deserialize_tensor_stream, 11 | deserialize_torch_tensor, 12 | serialize_torch_tensor, 13 | ) 14 | -------------------------------------------------------------------------------- /hivemind/compression/adaptive.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Mapping, Sequence, Union 3 | 4 | import torch 5 | 6 | from hivemind.compression.base import CompressionBase, CompressionInfo, Key, NoCompression, TensorRole 7 | from hivemind.compression.serialization import deserialize_torch_tensor 8 | from hivemind.proto import runtime_pb2 9 | 10 | 11 | class AdaptiveCompressionBase(CompressionBase, ABC): 12 | @abstractmethod 13 | def choose_compression(self, info: CompressionInfo) -> CompressionBase: ... 14 | 15 | def estimate_compression_ratio(self, info: CompressionInfo) -> float: 16 | return self.choose_compression(info).estimate_compression_ratio(info) 17 | 18 | def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: 19 | return self.choose_compression(info).compress(tensor, info=info, allow_inplace=allow_inplace) 20 | 21 | def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: 22 | return deserialize_torch_tensor(serialized_tensor) 23 | 24 | 25 | class SizeAdaptiveCompression(AdaptiveCompressionBase): 26 | """Apply compression strategy 1 if tensor has more than :threshold: elements and strategy 2 otherwise""" 27 | 28 | def __init__(self, threshold: int, less: CompressionBase, greater_equal: CompressionBase): 29 | self.threshold, self.less, self.greater_equal = threshold, less, greater_equal 30 | 31 | def choose_compression(self, info: CompressionInfo) -> CompressionBase: 32 | return self.greater_equal if info.descriptor.numel() >= self.threshold else self.less 33 | 34 | 35 | class RoleAdaptiveCompression(AdaptiveCompressionBase): 36 | """Compress a tensor based on its role in training. Any non-specified compressions will use the "default" option""" 37 | 38 | def __init__( 39 | self, 40 | *, 41 | activation: CompressionBase = None, 42 | parameter: CompressionBase = None, 43 | gradient: CompressionBase = None, 44 | optimizer: CompressionBase = None, 45 | default: CompressionBase = NoCompression(), 46 | ): 47 | self.role_compressions = { 48 | TensorRole.ACTIVATION: activation or default, 49 | TensorRole.PARAMETER: parameter or default, 50 | TensorRole.GRADIENT: gradient or default, 51 | TensorRole.OPTIMIZER: optimizer or default, 52 | TensorRole.UNSPECIFIED: default, 53 | } 54 | 55 | def choose_compression(self, info: CompressionInfo) -> CompressionBase: 56 | return self.role_compressions[info.role] 57 | 58 | 59 | class PerTensorCompression(AdaptiveCompressionBase): 60 | """Manually specify the compression strategy depending on tensor key""" 61 | 62 | def __init__(self, tensor_compressions: Union[Sequence[CompressionBase], Mapping[Key, CompressionBase]]): 63 | self.tensor_compressions = tensor_compressions 64 | 65 | def choose_compression(self, info: CompressionInfo) -> CompressionBase: 66 | return self.tensor_compressions[info.key] 67 | -------------------------------------------------------------------------------- /hivemind/compression/floating.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from hivemind.compression.base import CompressionBase, CompressionInfo 7 | from hivemind.proto import runtime_pb2 8 | 9 | 10 | class Float16Compression(CompressionBase): 11 | compression_type = runtime_pb2.CompressionType.FLOAT16 12 | FP16_MIN, FP16_MAX = torch.finfo(torch.float16).min, torch.finfo(torch.float16).max 13 | 14 | def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: 15 | if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16: 16 | raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors") 17 | requires_grad = tensor.requires_grad 18 | tensor = tensor.detach().cpu() 19 | dtype_name = tensor.numpy().dtype.name 20 | tensor = tensor.to(torch.float32, copy=not allow_inplace) 21 | tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16) 22 | return runtime_pb2.Tensor( 23 | compression=self.compression_type, 24 | buffer=tensor.numpy().tobytes(), 25 | size=tensor.shape, 26 | dtype=dtype_name, 27 | requires_grad=requires_grad, 28 | ) 29 | 30 | def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: 31 | original_dtype = np.dtype(serialized_tensor.dtype) 32 | array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16) 33 | return ( 34 | torch.as_tensor(np.asarray(array, dtype=original_dtype)) 35 | .reshape(tuple(serialized_tensor.size)) 36 | .requires_grad_(serialized_tensor.requires_grad) 37 | ) 38 | 39 | def estimate_compression_ratio(self, info: CompressionInfo) -> float: 40 | return 16.0 / get_num_bits(info.descriptor.dtype) 41 | 42 | 43 | class ScaledFloat16Compression(Float16Compression): 44 | """A compression strategy that applies mean-std scaling over last axis before casting to float16""" 45 | 46 | compression_type = runtime_pb2.CompressionType.MEANSTD_16BIT 47 | FP32_BYTES = torch.finfo(torch.float32).bits // 8 48 | FP32_EPS = torch.finfo(torch.float32).eps 49 | 50 | def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: 51 | if not torch.is_floating_point(tensor) or tensor.dtype == torch.bfloat16: 52 | raise ValueError(f"{self.__class__.__name__} does not support {tensor.dtype} tensors") 53 | requires_grad = tensor.requires_grad 54 | tensor = tensor.detach().cpu() 55 | dtype_name = tensor.numpy().dtype.name 56 | tensor = tensor.to(dtype=torch.float32, copy=not allow_inplace) 57 | means = torch.mean(tensor, dim=-1, keepdim=True) 58 | tensor.sub_(means) 59 | stds = tensor.norm(dim=-1, keepdim=True) / math.sqrt(tensor.shape[-1]) 60 | stds.clamp_min_(self.FP32_EPS) 61 | tensor.div_(stds) 62 | tensor = tensor.clamp_(self.FP16_MIN, self.FP16_MAX).to(torch.float16) 63 | 64 | data = b"".join((tensor.numpy().tobytes(), means.float().numpy().tobytes(), stds.float().numpy().tobytes())) 65 | 66 | return runtime_pb2.Tensor( 67 | compression=self.compression_type, 68 | buffer=data, 69 | size=tensor.shape, 70 | dtype=dtype_name, 71 | requires_grad=requires_grad, 72 | ) 73 | 74 | def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: 75 | stats_shape = list(serialized_tensor.size) 76 | stats_shape[-1] = 1 77 | stats_count = np.prod(stats_shape) 78 | means_offset = len(serialized_tensor.buffer) - 2 * stats_count * self.FP32_BYTES 79 | stds_offset = len(serialized_tensor.buffer) - stats_count * self.FP32_BYTES 80 | 81 | array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16, count=np.prod(serialized_tensor.size)) 82 | means = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=means_offset, count=stats_count) 83 | stds = np.frombuffer(serialized_tensor.buffer, dtype=np.float32, offset=stds_offset, count=stats_count) 84 | 85 | means = torch.as_tensor(means).reshape(stats_shape) 86 | stds = torch.as_tensor(stds).reshape(stats_shape) 87 | tensor = torch.as_tensor(np.asarray(array, dtype=serialized_tensor.dtype)).reshape( 88 | list(serialized_tensor.size) 89 | ) 90 | dtype = getattr(torch, serialized_tensor.dtype) 91 | return tensor.mul_(stds).add_(means).to(dtype).requires_grad_(serialized_tensor.requires_grad) 92 | 93 | 94 | def get_num_bits(dtype: torch.dtype) -> int: 95 | if dtype == torch.bool: 96 | return 8 # see https://github.com/pytorch/pytorch/issues/41571 97 | elif dtype.is_floating_point: 98 | return torch.finfo(dtype).bits 99 | else: 100 | try: 101 | return torch.iinfo(dtype).bits 102 | except TypeError: 103 | raise TypeError(f"Could not infer size for tensor type {dtype}") 104 | -------------------------------------------------------------------------------- /hivemind/compression/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import AsyncIterator, Dict, Iterable, List, Optional 4 | 5 | import torch 6 | 7 | from hivemind.compression.base import CompressionBase, CompressionInfo, NoCompression 8 | from hivemind.compression.floating import Float16Compression, ScaledFloat16Compression 9 | from hivemind.compression.quantization import BlockwiseQuantization, Quantile8BitQuantization, Uniform8BitQuantization 10 | from hivemind.proto import runtime_pb2 11 | from hivemind.utils.streaming import combine_from_streaming 12 | 13 | _BASE_COMPRESSION_TYPES: Dict[str, CompressionBase] = dict( 14 | NONE=NoCompression(), 15 | FLOAT16=Float16Compression(), 16 | MEANSTD_16BIT=ScaledFloat16Compression(), 17 | QUANTILE_8BIT=Quantile8BitQuantization(), 18 | UNIFORM_8BIT=Uniform8BitQuantization(), 19 | BLOCKWISE_8BIT=BlockwiseQuantization(), 20 | ) 21 | 22 | for key in runtime_pb2.CompressionType.keys(): 23 | assert key in _BASE_COMPRESSION_TYPES, f"Compression type {key} does not have a registered deserializer" 24 | actual_compression_type = _BASE_COMPRESSION_TYPES[key].compression_type 25 | assert runtime_pb2.CompressionType.Name(actual_compression_type) == key, ( 26 | f"Compression strategy for {key} has inconsistent type" 27 | ) 28 | 29 | 30 | def serialize_torch_tensor( 31 | tensor: torch.Tensor, 32 | compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE, 33 | info: Optional[CompressionInfo] = None, 34 | allow_inplace: bool = False, 35 | **kwargs, 36 | ) -> runtime_pb2.Tensor: 37 | """Serialize a given tensor into a protobuf message using the specified compression strategy""" 38 | assert tensor.device == torch.device("cpu") 39 | compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(compression_type)] 40 | info = info or CompressionInfo.from_tensor(tensor, **kwargs) 41 | return compression.compress(tensor, info, allow_inplace) 42 | 43 | 44 | def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: 45 | """Restore a pytorch tensor from a protobuf message""" 46 | compression = _BASE_COMPRESSION_TYPES[runtime_pb2.CompressionType.Name(serialized_tensor.compression)] 47 | return compression.extract(serialized_tensor).requires_grad_(serialized_tensor.requires_grad) 48 | 49 | 50 | async def deserialize_tensor_stream( 51 | stream: AsyncIterator[Iterable[runtime_pb2.Tensor]], 52 | ) -> List[torch.Tensor]: 53 | """Async wrapper of combine_from_streaming that combines tensors from a stream of parts and deserializes them""" 54 | 55 | tensors = [] 56 | tensor_parts = [] 57 | 58 | async for parts in stream: 59 | for part in parts: 60 | if part.dtype and tensor_parts: 61 | tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts))) 62 | tensor_parts = [] 63 | 64 | tensor_parts.append(part) 65 | if tensor_parts: 66 | tensors.append(deserialize_torch_tensor(combine_from_streaming(tensor_parts))) 67 | 68 | return tensors 69 | -------------------------------------------------------------------------------- /hivemind/dht/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a Distributed Hash Table optimized for rapidly accessing a lot of lightweight metadata. 3 | Hivemind DHT is based on Kademlia [1] with added support for improved bulk store/get operations and caching. 4 | 5 | The code is organized as follows: 6 | 7 | * **class DHT (dht.py)** - high-level class for model training. Runs DHTNode in a background process. 8 | * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys. 9 | * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes. 10 | * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers. 11 | 12 | - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric. 13 | - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :) 14 | """ 15 | 16 | from hivemind.dht.dht import DHT 17 | from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode 18 | from hivemind.dht.routing import DHTID, DHTValue 19 | from hivemind.dht.validation import CompositeValidator, RecordValidatorBase 20 | -------------------------------------------------------------------------------- /hivemind/dht/crypto.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import re 3 | from typing import Optional 4 | 5 | from hivemind.dht.validation import DHTRecord, RecordValidatorBase 6 | from hivemind.utils import MSGPackSerializer, get_logger 7 | from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | class RSASignatureValidator(RecordValidatorBase): 13 | """ 14 | Introduces a notion of *protected records* whose key/subkey contains substring 15 | "[owner:ssh-rsa ...]" with an RSA public key of the owner. 16 | 17 | If this validator is used, changes to such records always must be signed with 18 | the corresponding private key (so only the owner can change them). 19 | """ 20 | 21 | PUBLIC_KEY_FORMAT = b"[owner:_key_]" 22 | SIGNATURE_FORMAT = b"[signature:_value_]" 23 | 24 | PUBLIC_KEY_REGEX = re.escape(PUBLIC_KEY_FORMAT).replace(b"_key_", rb"(.+?)") 25 | _PUBLIC_KEY_RE = re.compile(PUBLIC_KEY_REGEX) 26 | _SIGNATURE_RE = re.compile(re.escape(SIGNATURE_FORMAT).replace(b"_value_", rb"(.+?)")) 27 | 28 | _cached_private_key = None 29 | 30 | def __init__(self, private_key: Optional[RSAPrivateKey] = None): 31 | if private_key is None: 32 | private_key = RSAPrivateKey.process_wide() 33 | self._private_key = private_key 34 | 35 | serialized_public_key = private_key.get_public_key().to_bytes() 36 | self._local_public_key = self.PUBLIC_KEY_FORMAT.replace(b"_key_", serialized_public_key) 37 | 38 | @property 39 | def local_public_key(self) -> bytes: 40 | return self._local_public_key 41 | 42 | def validate(self, record: DHTRecord) -> bool: 43 | public_keys = self._PUBLIC_KEY_RE.findall(record.key) 44 | if record.subkey is not None: 45 | public_keys += self._PUBLIC_KEY_RE.findall(record.subkey) 46 | if not public_keys: 47 | return True # The record is not protected with a public key 48 | 49 | if len(set(public_keys)) > 1: 50 | logger.debug(f"Key and subkey can't contain different public keys in {record}") 51 | return False 52 | public_key = RSAPublicKey.from_bytes(public_keys[0]) 53 | 54 | signatures = self._SIGNATURE_RE.findall(record.value) 55 | if len(signatures) != 1: 56 | logger.debug(f"Record should have exactly one signature in {record}") 57 | return False 58 | signature = signatures[0] 59 | 60 | stripped_record = dataclasses.replace(record, value=self.strip_value(record)) 61 | if not public_key.verify(self._serialize_record(stripped_record), signature): 62 | logger.debug(f"Signature is invalid in {record}") 63 | return False 64 | return True 65 | 66 | def sign_value(self, record: DHTRecord) -> bytes: 67 | if self._local_public_key not in record.key and self._local_public_key not in record.subkey: 68 | return record.value 69 | 70 | signature = self._private_key.sign(self._serialize_record(record)) 71 | return record.value + self.SIGNATURE_FORMAT.replace(b"_value_", signature) 72 | 73 | def strip_value(self, record: DHTRecord) -> bytes: 74 | return self._SIGNATURE_RE.sub(b"", record.value) 75 | 76 | def _serialize_record(self, record: DHTRecord) -> bytes: 77 | return MSGPackSerializer.dumps(dataclasses.astuple(record)) 78 | 79 | @property 80 | def priority(self) -> int: 81 | # On validation, this validator must be executed before validators 82 | # that deserialize the record 83 | return 10 84 | 85 | def merge_with(self, other: RecordValidatorBase) -> bool: 86 | if not isinstance(other, RSASignatureValidator): 87 | return False 88 | 89 | # Ignore another RSASignatureValidator instance (it doesn't make sense to have several 90 | # instances of this class) and report successful merge 91 | return True 92 | -------------------------------------------------------------------------------- /hivemind/dht/storage.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, Union 4 | 5 | from hivemind.dht.routing import DHTID, BinaryDHTValue, Subkey 6 | from hivemind.utils.serializer import MSGPackSerializer 7 | from hivemind.utils.timed_storage import DHTExpiration, KeyType, TimedStorage, ValueType 8 | 9 | 10 | @MSGPackSerializer.ext_serializable(0x50) 11 | class DictionaryDHTValue(TimedStorage[Subkey, BinaryDHTValue]): 12 | """a dictionary-like DHT value type that maps sub-keys to values with individual expirations""" 13 | 14 | latest_expiration_time = float("-inf") 15 | 16 | def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool: 17 | self.latest_expiration_time = max(self.latest_expiration_time, expiration_time) 18 | return super().store(key, value, expiration_time) 19 | 20 | def packb(self) -> bytes: 21 | """custom behavior for MSGPackSerializer.dumps""" 22 | packed_items = [[key, value, expiration_time] for key, (value, expiration_time) in self.items()] 23 | return MSGPackSerializer.dumps([self.maxsize, self.latest_expiration_time, packed_items]) 24 | 25 | @classmethod 26 | def unpackb(cls, raw: bytes) -> DictionaryDHTValue: 27 | maxsize, latest_expiration_time, items = MSGPackSerializer.loads(raw) 28 | with DictionaryDHTValue(maxsize).freeze() as new_dict: 29 | for key, value, expiration_time in items: 30 | new_dict.store(key, value, expiration_time) 31 | new_dict.latest_expiration_time = latest_expiration_time 32 | return new_dict 33 | 34 | 35 | class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTValue]]): 36 | """A dictionary-like storage that can store binary values and/or nested dictionaries until expiration""" 37 | 38 | def store( 39 | self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration, subkey: Optional[Subkey] = None 40 | ) -> bool: 41 | """ 42 | Store a (key, value) pair locally at least until expiration_time. See class docstring for details. 43 | If subkey is not None, adds a subkey-value pair to a dictionary associated with :key: (see store_subkey below) 44 | :returns: True if new value was stored, False it was rejected (current value is newer) 45 | """ 46 | if subkey is not None: # add one sub-key 47 | return self.store_subkey(key, subkey, value, expiration_time) 48 | else: # store regular key 49 | return super().store(key, value, expiration_time) 50 | 51 | def store_subkey(self, key: DHTID, subkey: Subkey, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool: 52 | """ 53 | Save a (sub-key, value) into a dictionary associated with a given key. 54 | 1) if self[key] is empty, create a new dictionary and add sub-key there 55 | 2) if self[key] is a dictionary (DictionaryDHTValue), store {sub-key: value, expiration} to that storage 56 | 3) if self[key] is a normal value with smaller expiration time, overwrite it with a dictionary and add sub-key 57 | :returns: True if new entry was stored, False it was rejected (current value is newer) 58 | """ 59 | previous_value, previous_expiration_time = self.get(key) or (b"", -float("inf")) 60 | if isinstance(previous_value, BinaryDHTValue) and expiration_time > previous_expiration_time: 61 | new_storage = DictionaryDHTValue() 62 | new_storage.store(subkey, value, expiration_time) 63 | return super().store(key, new_storage, new_storage.latest_expiration_time) 64 | elif isinstance(previous_value, DictionaryDHTValue): 65 | if expiration_time > previous_value.latest_expiration_time: 66 | super().store(key, previous_value, expiration_time) # refresh expiration time 67 | return previous_value.store(subkey, value, expiration_time) 68 | else: 69 | return False 70 | -------------------------------------------------------------------------------- /hivemind/dht/validation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from abc import ABC, abstractmethod 3 | from typing import Iterable 4 | 5 | 6 | @dataclasses.dataclass(init=True, repr=True, frozen=True) 7 | class DHTRecord: 8 | key: bytes 9 | subkey: bytes 10 | value: bytes 11 | expiration_time: float 12 | 13 | 14 | class RecordValidatorBase(ABC): 15 | """ 16 | Record validators are a generic mechanism for checking the DHT records including: 17 | - Enforcing a data schema (e.g. checking content types) 18 | - Enforcing security requirements (e.g. allowing only the owner to update the record) 19 | """ 20 | 21 | @abstractmethod 22 | def validate(self, record: DHTRecord) -> bool: 23 | """ 24 | Should return whether the `record` is valid. 25 | The valid records should have been extended with sign_value(). 26 | 27 | validate() is called when another DHT peer: 28 | - Asks us to store the record 29 | - Returns the record by our request 30 | """ 31 | 32 | pass 33 | 34 | def sign_value(self, record: DHTRecord) -> bytes: 35 | """ 36 | Should return `record.value` extended with the record's signature. 37 | 38 | Note: there's no need to overwrite this method if a validator doesn't use a signature. 39 | 40 | sign_value() is called after the application asks the DHT to store the record. 41 | """ 42 | 43 | return record.value 44 | 45 | def strip_value(self, record: DHTRecord) -> bytes: 46 | """ 47 | Should return `record.value` stripped of the record's signature. 48 | strip_value() is only called if validate() was successful. 49 | 50 | Note: there's no need to overwrite this method if a validator doesn't use a signature. 51 | 52 | strip_value() is called before the DHT returns the record by the application's request. 53 | """ 54 | 55 | return record.value 56 | 57 | @property 58 | def priority(self) -> int: 59 | """ 60 | Defines the order of applying this validator with respect to other validators. 61 | 62 | The validators are applied: 63 | - In order of increasing priority for signing a record 64 | - In order of decreasing priority for validating and stripping a record 65 | """ 66 | 67 | return 0 68 | 69 | def merge_with(self, other: "RecordValidatorBase") -> bool: 70 | """ 71 | By default, all validators are applied sequentially (i.e. we require all validate() calls 72 | to return True for a record to be validated successfully). 73 | 74 | However, you may want to define another policy for combining your validator classes 75 | (e.g. for schema validators, we want to require only one validate() call to return True 76 | because each validator bears a part of the schema). 77 | 78 | This can be achieved with overriding merge_with(). It should: 79 | 80 | - Return True if it has successfully merged the `other` validator to `self`, 81 | so that `self` became a validator that combines the old `self` and `other` using 82 | the necessary policy. In this case, `other` should remain unchanged. 83 | 84 | - Return False if the merging has not happened. In this case, both `self` and `other` 85 | should remain unchanged. The DHT will try merging `other` to another validator or 86 | add it as a separate validator (to be applied sequentially). 87 | """ 88 | 89 | return False 90 | 91 | 92 | class CompositeValidator(RecordValidatorBase): 93 | def __init__(self, validators: Iterable[RecordValidatorBase] = ()): 94 | self._validators = [] 95 | self.extend(validators) 96 | 97 | def extend(self, validators: Iterable[RecordValidatorBase]) -> None: 98 | for new_validator in validators: 99 | for existing_validator in self._validators: 100 | if existing_validator.merge_with(new_validator): 101 | break 102 | else: 103 | self._validators.append(new_validator) 104 | self._validators.sort(key=lambda item: item.priority) 105 | 106 | def validate(self, record: DHTRecord) -> bool: 107 | for i, validator in enumerate(reversed(self._validators)): 108 | if not validator.validate(record): 109 | return False 110 | if i < len(self._validators) - 1: 111 | record = dataclasses.replace(record, value=validator.strip_value(record)) 112 | return True 113 | 114 | def sign_value(self, record: DHTRecord) -> bytes: 115 | for validator in self._validators: 116 | record = dataclasses.replace(record, value=validator.sign_value(record)) 117 | return record.value 118 | 119 | def strip_value(self, record: DHTRecord) -> bytes: 120 | for validator in reversed(self._validators): 121 | record = dataclasses.replace(record, value=validator.strip_value(record)) 122 | return record.value 123 | -------------------------------------------------------------------------------- /hivemind/hivemind_cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learning-at-home/hivemind/4d5c41495be082490ea44cce4e9dd58f9926bb4e/hivemind/hivemind_cli/__init__.py -------------------------------------------------------------------------------- /hivemind/hivemind_cli/config.yml: -------------------------------------------------------------------------------- 1 | num_experts: 16 2 | expert_cls: ffn 3 | hidden_dim: 1024 4 | expert_pattern: expert.[0:4].[0:4] 5 | max_batch_size: 16384 6 | optimizer: adam 7 | no_dht: True 8 | initial_peers: "[]" 9 | increase_file_limit: True 10 | -------------------------------------------------------------------------------- /hivemind/hivemind_cli/run_dht.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from secrets import token_hex 3 | from signal import SIGINT, SIGTERM, signal, strsignal 4 | from threading import Event 5 | 6 | from hivemind.dht import DHT, DHTNode 7 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 8 | from hivemind.utils.networking import log_visible_maddrs 9 | 10 | use_hivemind_log_handler("in_root_logger") 11 | logger = get_logger(__name__) 12 | 13 | 14 | async def report_status(dht: DHT, node: DHTNode): 15 | logger.info( 16 | f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) " 17 | f"are in the local routing table " 18 | ) 19 | logger.debug(f"Routing table contents: {node.protocol.routing_table}") 20 | logger.info(f"Local storage contains {len(node.protocol.storage)} keys") 21 | logger.debug(f"Local storage contents: {node.protocol.storage}") 22 | 23 | # Contact peers and keep the routing table healthy (remove stale PeerIDs) 24 | await node.get(f"heartbeat_{token_hex(16)}", latest=True) 25 | 26 | 27 | def main(): 28 | parser = ArgumentParser() 29 | parser.add_argument( 30 | "--initial_peers", 31 | nargs="*", 32 | help="Multiaddrs of the peers that will welcome you into the existing DHT. " 33 | "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY", 34 | ) 35 | parser.add_argument( 36 | "--host_maddrs", 37 | nargs="*", 38 | default=["/ip4/0.0.0.0/tcp/0"], 39 | help="Multiaddrs to listen for external connections from other DHT instances. " 40 | "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0", 41 | ) 42 | parser.add_argument( 43 | "--announce_maddrs", 44 | nargs="*", 45 | help="Visible multiaddrs the host announces for external connections from other DHT instances", 46 | ) 47 | parser.add_argument( 48 | "--use_ipfs", 49 | action="store_true", 50 | help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" ' 51 | "part of the multiaddrs for the initial_peers " 52 | "(no need to specify a particular IPv4/IPv6 host and port)", 53 | ) 54 | parser.add_argument( 55 | "--identity_path", 56 | help="Path to a private key file. If defined, makes the peer ID deterministic. " 57 | "If the file does not exist, writes a new private key to this file.", 58 | ) 59 | parser.add_argument( 60 | "--no_relay", 61 | action="store_false", 62 | dest="use_relay", 63 | help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)", 64 | ) 65 | parser.add_argument( 66 | "--use_auto_relay", 67 | action="store_true", 68 | help="Look for libp2p relays to become reachable if we are behind NAT/firewall", 69 | ) 70 | parser.add_argument( 71 | "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT" 72 | ) 73 | 74 | args = parser.parse_args() 75 | 76 | dht = DHT( 77 | start=True, 78 | initial_peers=args.initial_peers, 79 | host_maddrs=args.host_maddrs, 80 | announce_maddrs=args.announce_maddrs, 81 | use_ipfs=args.use_ipfs, 82 | identity_path=args.identity_path, 83 | use_relay=args.use_relay, 84 | use_auto_relay=args.use_auto_relay, 85 | ) 86 | log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) 87 | 88 | exit_event = Event() 89 | 90 | def signal_handler(signal_number: int, _) -> None: 91 | logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down") 92 | exit_event.set() 93 | 94 | signal(SIGTERM, signal_handler) 95 | signal(SIGINT, signal_handler) 96 | 97 | try: 98 | while not exit_event.is_set(): 99 | dht.run_coroutine(report_status, return_future=False) 100 | exit_event.wait(args.refresh_period) 101 | finally: 102 | dht.shutdown() 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /hivemind/moe/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.moe.client import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts 2 | from hivemind.moe.server import ( 3 | ModuleBackend, 4 | Server, 5 | background_server, 6 | declare_experts, 7 | get_experts, 8 | register_expert_class, 9 | ) 10 | -------------------------------------------------------------------------------- /hivemind/moe/client/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.moe.client.expert import RemoteExpert 2 | from hivemind.moe.client.moe import RemoteMixtureOfExperts 3 | from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts 4 | -------------------------------------------------------------------------------- /hivemind/moe/client/remote_expert_worker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from concurrent.futures import Future 4 | from threading import Thread 5 | from typing import Awaitable 6 | 7 | from hivemind.utils import switch_to_uvloop 8 | 9 | 10 | class RemoteExpertWorker: 11 | """Local thread for managing async tasks related to RemoteExpert""" 12 | 13 | _event_thread = None 14 | _event_loop_fut = None 15 | _pid = None 16 | 17 | @classmethod 18 | def _run_event_loop(cls): 19 | try: 20 | loop = switch_to_uvloop() 21 | cls._event_loop_fut.set_result(loop) 22 | except Exception as e: 23 | cls._event_loop_fut.set_exception(e) 24 | loop.run_forever() 25 | 26 | @classmethod 27 | def run_coroutine(cls, coro: Awaitable, return_future: bool = False): 28 | if cls._event_thread is None or cls._pid != os.getpid(): 29 | cls._pid = os.getpid() 30 | cls._event_loop_fut = Future() 31 | cls._event_thread = Thread(target=cls._run_event_loop, daemon=True) 32 | cls._event_thread.start() 33 | 34 | loop = cls._event_loop_fut.result() 35 | future = asyncio.run_coroutine_threadsafe(coro, loop) 36 | return future if return_future else future.result() 37 | -------------------------------------------------------------------------------- /hivemind/moe/expert_uid.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from typing import NamedTuple, Tuple, Union 5 | 6 | from hivemind.p2p import PeerID 7 | 8 | ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float 9 | 10 | 11 | class ExpertInfo(NamedTuple): 12 | uid: ExpertUID 13 | peer_id: PeerID 14 | 15 | 16 | UID_DELIMITER = "." # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix 17 | FLAT_EXPERT = -1 # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case. 18 | UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$") # e.g. ffn_expert.98.76.54 - prefix + some dims 19 | PREFIX_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))*[.]$") # e.g. expert. or ffn.45. (ends with ".") 20 | # formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)} 21 | 22 | 23 | def is_valid_uid(maybe_uid: str) -> bool: 24 | """An uid must contain a string expert type, followed by one or more .-separated numeric indices""" 25 | return bool(UID_PATTERN.fullmatch(maybe_uid)) 26 | 27 | 28 | def is_valid_prefix(maybe_prefix: str) -> bool: 29 | """An uid prefix must contain a string expert type, followed by optional numeric indices and a trailing period""" 30 | return bool(PREFIX_PATTERN.fullmatch(maybe_prefix)) 31 | 32 | 33 | def split_uid(uid_or_prefix: Union[ExpertUID, ExpertPrefix]) -> Tuple[ExpertPrefix, Coordinate]: 34 | """Separate an expert UID or prefix into a new ExpertPrefix and integer for the last coordinate""" 35 | uid_or_prefix = uid_or_prefix.rstrip(UID_DELIMITER) 36 | pivot = uid_or_prefix.rindex(UID_DELIMITER) + 1 37 | return uid_or_prefix[:pivot], int(uid_or_prefix[pivot:]) 38 | -------------------------------------------------------------------------------- /hivemind/moe/server/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.moe.server.dht_handler import declare_experts, get_experts 2 | from hivemind.moe.server.layers import register_expert_class 3 | from hivemind.moe.server.module_backend import ModuleBackend 4 | from hivemind.moe.server.server import Server, background_server 5 | -------------------------------------------------------------------------------- /hivemind/moe/server/checkpoints.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | from datetime import datetime 4 | from pathlib import Path 5 | from shutil import copy2 6 | from tempfile import TemporaryDirectory 7 | from typing import Dict 8 | 9 | import torch 10 | 11 | from hivemind.moe.server.module_backend import ModuleBackend 12 | from hivemind.utils.logging import get_logger 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def is_directory(directory: Path): 18 | assert directory is not None 19 | assert directory.exists() 20 | assert directory.is_dir() 21 | return True 22 | 23 | 24 | def copy_tree(src: str, dst: str): 25 | if not os.path.exists(dst): 26 | os.makedirs(dst) 27 | for item in os.listdir(src): 28 | src_entry = os.path.join(src, item) 29 | dst_entry = os.path.join(dst, item) 30 | if os.path.isdir(src_entry): 31 | copy_tree(src_entry, dst_entry) 32 | else: 33 | copy2(src_entry, dst_entry) 34 | 35 | 36 | class CheckpointSaver(threading.Thread): 37 | def __init__(self, module_backends: Dict[str, ModuleBackend], checkpoint_dir: Path, update_period: float): 38 | super().__init__() 39 | assert is_directory(checkpoint_dir) 40 | self.module_backends = module_backends 41 | self.update_period = update_period 42 | self.checkpoint_dir = checkpoint_dir 43 | self.stop = threading.Event() 44 | 45 | # create expert directories to ensure that the directory is writable and checkpoints can be loaded 46 | store_experts(self.module_backends, self.checkpoint_dir) 47 | 48 | def run(self) -> None: 49 | while not self.stop.wait(self.update_period): 50 | store_experts(self.module_backends, self.checkpoint_dir) 51 | 52 | 53 | def store_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): 54 | logger.debug(f"Storing experts at {checkpoint_dir.absolute()}") 55 | assert is_directory(checkpoint_dir) 56 | timestamp = datetime.now().isoformat(sep="_") 57 | with TemporaryDirectory() as tmpdirname: 58 | for expert_name, expert_backend in experts.items(): 59 | expert_dir = Path(tmpdirname) / expert_name 60 | expert_dir.mkdir() 61 | checkpoint_name = expert_dir / f"checkpoint_{timestamp}.pt" 62 | torch.save(expert_backend.state_dict(), checkpoint_name) 63 | os.symlink(checkpoint_name, expert_dir / "checkpoint_last.pt") 64 | copy_tree(tmpdirname, str(checkpoint_dir)) 65 | 66 | 67 | def load_experts(experts: Dict[str, ModuleBackend], checkpoint_dir: Path): 68 | assert is_directory(checkpoint_dir) 69 | for expert_name, expert in experts.items(): 70 | checkpoints_folder = checkpoint_dir / expert_name 71 | latest_checkpoint = checkpoints_folder / "checkpoint_last.pt" 72 | if latest_checkpoint.exists(): 73 | expert.load_state_dict(torch.load(latest_checkpoint)) 74 | else: 75 | logger.warning(f"Failed to load checkpoint for expert {expert_name}") 76 | -------------------------------------------------------------------------------- /hivemind/moe/server/dht_handler.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from functools import partial 3 | from typing import Dict, List, Optional, Sequence, Tuple, Union 4 | 5 | from hivemind.dht import DHT, DHTNode, DHTValue 6 | from hivemind.moe.client.expert import RemoteExpert, create_remote_experts 7 | from hivemind.moe.expert_uid import ( 8 | FLAT_EXPERT, 9 | UID_DELIMITER, 10 | UID_PATTERN, 11 | Coordinate, 12 | ExpertInfo, 13 | ExpertPrefix, 14 | ExpertUID, 15 | is_valid_uid, 16 | split_uid, 17 | ) 18 | from hivemind.p2p import PeerID 19 | from hivemind.utils import MAX_DHT_TIME_DISCREPANCY_SECONDS, DHTExpiration, MPFuture, get_dht_time 20 | 21 | 22 | class DHTHandlerThread(threading.Thread): 23 | def __init__( 24 | self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs 25 | ): 26 | super().__init__(**kwargs) 27 | if expiration is None: 28 | expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS) 29 | self.module_backends = module_backends 30 | self.dht = dht 31 | self.update_period = update_period 32 | self.expiration = expiration 33 | self.stop = threading.Event() 34 | 35 | def run(self) -> None: 36 | declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration) 37 | while not self.stop.wait(self.update_period): 38 | declare_experts(self.dht, self.module_backends.keys(), expiration_time=get_dht_time() + self.expiration) 39 | 40 | 41 | def declare_experts( 42 | dht: DHT, uids: Sequence[ExpertUID], expiration_time: DHTExpiration, wait: bool = True 43 | ) -> Union[Dict[ExpertUID, bool], MPFuture[Dict[ExpertUID, bool]]]: 44 | """ 45 | Make experts visible to all DHT peers; update timestamps if declared previously. 46 | 47 | :param uids: a list of expert ids to update 48 | :param wait: if True, awaits for declaration to finish, otherwise runs in background 49 | :param expiration_time: experts will be visible for this many seconds 50 | :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected) 51 | """ 52 | assert not isinstance(uids, str), "Please send a list / tuple of expert uids." 53 | if not isinstance(uids, list): 54 | uids = list(uids) 55 | for uid in uids: 56 | assert is_valid_uid(uid), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}" 57 | return dht.run_coroutine( 58 | partial(_declare_experts, uids=uids, expiration_time=expiration_time), return_future=not wait 59 | ) 60 | 61 | 62 | async def _declare_experts( 63 | dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: DHTExpiration 64 | ) -> Dict[ExpertUID, bool]: 65 | num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) 66 | data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {} 67 | peer_id_base58 = dht.peer_id.to_base58() 68 | 69 | for uid in uids: 70 | data_to_store[uid, None] = peer_id_base58 71 | prefix = uid if uid.count(UID_DELIMITER) > 1 else f"{uid}{UID_DELIMITER}{FLAT_EXPERT}" 72 | for i in range(prefix.count(UID_DELIMITER) - 1): 73 | prefix, last_coord = split_uid(prefix) 74 | data_to_store[prefix, last_coord] = (uid, peer_id_base58) 75 | 76 | keys, maybe_subkeys, values = zip(*((key, subkey, value) for (key, subkey), value in data_to_store.items())) 77 | store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers) 78 | return store_ok 79 | 80 | 81 | def get_experts( 82 | dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False 83 | ) -> Union[List[Optional[RemoteExpert]], MPFuture[List[Optional[RemoteExpert]]]]: 84 | """ 85 | :param uids: find experts with these ids from across the DHT 86 | :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time) 87 | :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background. 88 | :returns: a list of [RemoteExpert if found else None] 89 | """ 90 | assert not isinstance(uids, str), "Please send a list / tuple of expert uids." 91 | result = dht.run_coroutine(partial(_get_experts, uids=list(uids), expiration_time=expiration_time), return_future) 92 | return create_remote_experts(result, dht, return_future) 93 | 94 | 95 | async def _get_experts( 96 | dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] 97 | ) -> List[Optional[ExpertInfo]]: 98 | if expiration_time is None: 99 | expiration_time = get_dht_time() 100 | num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) 101 | found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) 102 | 103 | experts: List[Optional[ExpertInfo]] = [None] * len(uids) 104 | for i, uid in enumerate(uids): 105 | server_peer_id = found[uid] 106 | if server_peer_id is not None and isinstance(server_peer_id.value, str): 107 | experts[i] = ExpertInfo(uid, PeerID.from_base58(server_peer_id.value)) 108 | return experts 109 | -------------------------------------------------------------------------------- /hivemind/moe/server/layers/__init__.py: -------------------------------------------------------------------------------- 1 | name_to_block = {} 2 | name_to_input = {} 3 | 4 | import hivemind.moe.server.layers.common 5 | import hivemind.moe.server.layers.dropout 6 | from hivemind.moe.server.layers.custom_experts import add_custom_models_from_file, register_expert_class 7 | from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup 8 | 9 | schedule_name_to_scheduler = {"linear": get_linear_schedule_with_warmup, "none": None} 10 | -------------------------------------------------------------------------------- /hivemind/moe/server/layers/common.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | from hivemind.moe.server.layers.custom_experts import register_expert_class 7 | 8 | 9 | # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py 10 | @torch.jit.script 11 | def gelu_fast(x): 12 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 13 | 14 | 15 | ffn_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)) 16 | 17 | 18 | @register_expert_class("ffn", ffn_sample_input) 19 | class FeedforwardBlock(nn.Module): 20 | def __init__(self, hid_dim): 21 | super().__init__() 22 | self.ffn = nn.Linear(hid_dim, 4 * hid_dim) 23 | self.ffn_output = nn.Linear(4 * hid_dim, hid_dim) 24 | self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12) 25 | 26 | def forward(self, x): 27 | ffn_output = self.ffn(x) 28 | ffn_output = gelu_fast(ffn_output) 29 | ffn_output = self.ffn_output(ffn_output) 30 | return self.layer_norm(x + ffn_output) 31 | 32 | 33 | class TransformerEncoderLayer(nn.Module): 34 | """ 35 | A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting 36 | """ 37 | 38 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): 39 | super().__init__() 40 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 41 | # Implementation of Feedforward model 42 | self.linear1 = nn.Linear(d_model, dim_feedforward) 43 | self.dropout = nn.Dropout(dropout) 44 | self.linear2 = nn.Linear(dim_feedforward, d_model) 45 | 46 | self.norm1 = nn.LayerNorm(d_model) 47 | self.norm2 = nn.LayerNorm(d_model) 48 | self.dropout1 = nn.Dropout(dropout) 49 | self.dropout2 = nn.Dropout(dropout) 50 | 51 | self.activation = gelu_fast 52 | 53 | def forward(self, src, src_key_padding_mask=None): 54 | # (N, S, E) -> (S, N, E) 55 | src = src.transpose(0, 1) 56 | 57 | src2 = self.self_attn(src, src, src, key_padding_mask=src_key_padding_mask)[0] 58 | src = src + self.dropout1(src2) 59 | src = self.norm1(src) 60 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 61 | src = src + self.dropout2(src2) 62 | src = self.norm2(src) 63 | 64 | # (S, N, E) -> (N, S, E) 65 | src = src.transpose(0, 1) 66 | return src 67 | 68 | 69 | transformer_sample_input = lambda batch_size, hid_dim: ( 70 | torch.empty((batch_size, 128, hid_dim)), 71 | torch.empty((batch_size, 128), dtype=torch.bool), 72 | ) 73 | 74 | 75 | @register_expert_class("transformer", transformer_sample_input) 76 | class TunedTransformer(TransformerEncoderLayer): 77 | def __init__(self, hid_dim): 78 | super().__init__(hid_dim, dim_feedforward=4 * hid_dim, nhead=16) 79 | 80 | 81 | nop_sample_input = lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)) 82 | 83 | 84 | @register_expert_class("nop", nop_sample_input) 85 | class NopExpert(nn.Sequential): 86 | def __init__(self, hid_dim): 87 | super().__init__() 88 | self.w = nn.Parameter(torch.zeros(0), requires_grad=True) 89 | 90 | def forward(self, x): 91 | return x.clone() 92 | 93 | 94 | @register_expert_class("nop_delay", nop_sample_input) 95 | class DelayedNopExpert(nn.Sequential): 96 | def __init__(self, hid_dim, delay=0.5): 97 | super().__init__() 98 | self.w = nn.Parameter(torch.zeros(0), requires_grad=True) 99 | self.delay = delay 100 | 101 | def forward(self, x): 102 | time.sleep(self.delay) 103 | return x.clone() 104 | -------------------------------------------------------------------------------- /hivemind/moe/server/layers/custom_experts.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from typing import Callable, Type 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from hivemind.moe.server.layers import name_to_block, name_to_input 9 | 10 | 11 | def add_custom_models_from_file(path: str): 12 | spec = importlib.util.spec_from_file_location("custom_module", os.path.abspath(path)) 13 | foo = importlib.util.module_from_spec(spec) 14 | spec.loader.exec_module(foo) 15 | 16 | 17 | def register_expert_class(name: str, sample_input: Callable[[int, int], torch.tensor]): 18 | """ 19 | Adds a custom user expert to hivemind server. 20 | :param name: the name of the expert. It shouldn't coincide with existing modules\ 21 | ('ffn', 'transformer', 'nop', 'det_dropout') 22 | :param sample_input: a function which gets batch_size and hid_dim and outputs a \ 23 | sample of an input in the module 24 | :unchanged module 25 | """ 26 | 27 | def _register_expert_class(custom_class: Type[nn.Module]): 28 | if name in name_to_block or name in name_to_input: 29 | raise RuntimeError("The class might already exist or be added twice") 30 | name_to_block[name] = custom_class 31 | name_to_input[name] = sample_input 32 | 33 | return custom_class 34 | 35 | return _register_expert_class 36 | -------------------------------------------------------------------------------- /hivemind/moe/server/layers/dropout.py: -------------------------------------------------------------------------------- 1 | import torch.autograd 2 | from torch import nn as nn 3 | 4 | from hivemind.moe.server.layers.custom_experts import register_expert_class 5 | 6 | 7 | class DeterministicDropoutFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x, keep_prob, mask): 10 | ctx.keep_prob = keep_prob 11 | ctx.save_for_backward(mask) 12 | return x * mask / keep_prob 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | return ctx.saved_tensors[0] * grad_output / ctx.keep_prob, None, None 17 | 18 | 19 | class DeterministicDropout(nn.Module): 20 | """ 21 | Custom dropout layer which accepts dropout mask as an input (drop_prob is only used for scaling input activations). 22 | Can be used with RemoteExpert/ModuleBackend to ensure that dropout mask is the same at forward and backward steps 23 | """ 24 | 25 | def __init__(self, drop_prob): 26 | super().__init__() 27 | self.keep_prob = 1 - drop_prob 28 | 29 | def forward(self, x, mask): 30 | if self.training: 31 | return DeterministicDropoutFunction.apply(x, self.keep_prob, mask) 32 | else: 33 | return x 34 | 35 | 36 | dropout_sample_input = lambda batch_size, hid_dim: ( 37 | torch.empty((batch_size, hid_dim)), 38 | torch.randint(0, 1, (batch_size, hid_dim)), 39 | ) 40 | 41 | 42 | @register_expert_class("det_dropout", dropout_sample_input) 43 | class DeterministicDropoutNetwork(nn.Module): 44 | def __init__(self, hid_dim, dropout_prob=0.2): 45 | super().__init__() 46 | self.linear_in = nn.Linear(hid_dim, 2 * hid_dim) 47 | self.activation = nn.ReLU() 48 | self.dropout = DeterministicDropout(dropout_prob) 49 | self.linear_out = nn.Linear(2 * hid_dim, hid_dim) 50 | 51 | def forward(self, x, mask): 52 | x = self.linear_in(self.dropout(x, mask)) 53 | return self.linear_out(self.activation(x)) 54 | -------------------------------------------------------------------------------- /hivemind/moe/server/layers/lr_schedule.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | 3 | 4 | # https://github.com/huggingface/transformers/blob/master/src/transformers/optimization.py 5 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): 6 | """ 7 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 8 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 9 | Args: 10 | optimizer (:class:`~torch.optim.Optimizer`): 11 | The optimizer for which to schedule the learning rate. 12 | num_warmup_steps (:obj:`int`): 13 | The number of steps for the warmup phase. 14 | num_training_steps (:obj:`int`): 15 | The total number of training steps. 16 | Return: 17 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 18 | """ 19 | 20 | def lr_lambda(current_step: int): 21 | if current_step < num_warmup_steps: 22 | return float(current_step) / float(max(1, num_warmup_steps)) 23 | return max( 24 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 25 | ) 26 | 27 | return LambdaLR(optimizer, lr_lambda) 28 | -------------------------------------------------------------------------------- /hivemind/moe/server/layers/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class OptimizerWrapper: 5 | """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" 6 | 7 | def __init__(self, optim: torch.optim.Optimizer): 8 | self.optim = optim 9 | 10 | @property 11 | def defaults(self): 12 | return self.optim.defaults 13 | 14 | @property 15 | def state(self): 16 | return self.optim.state 17 | 18 | def __getstate__(self): 19 | return self.optim.__getstate__() 20 | 21 | def __setstate__(self, state): 22 | self.optim.__setstate__(state) 23 | 24 | def __repr__(self): 25 | return f"{self.__class__.__name__}({repr(self.optim)})" 26 | 27 | def state_dict(self): 28 | return self.optim.state_dict() 29 | 30 | def load_state_dict(self, state_dict: dict) -> None: 31 | return self.optim.load_state_dict(state_dict) 32 | 33 | def step(self, *args, **kwargs): 34 | return self.optim.step(*args, **kwargs) 35 | 36 | def zero_grad(self, *args, **kwargs): 37 | return self.optim.zero_grad(*args, **kwargs) 38 | 39 | @property 40 | def param_groups(self): 41 | return self.optim.param_groups 42 | 43 | def add_param_group(self, param_group: dict) -> None: 44 | return self.optim.add_param_group(param_group) 45 | 46 | 47 | class ClippingWrapper(OptimizerWrapper): 48 | """A wrapper of torch.Optimizer that clips gradients by global norm before each step""" 49 | 50 | def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float): 51 | super().__init__(optim) 52 | self.clip_grad_norm = clip_grad_norm 53 | 54 | def step(self, *args, **kwargs): 55 | parameters = tuple(param for group in self.param_groups for param in group["params"]) 56 | torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm) 57 | return super().step(*args, **kwargs) 58 | -------------------------------------------------------------------------------- /hivemind/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.optim.grad_scaler import GradScaler 2 | from hivemind.optim.optimizer import Optimizer 3 | from hivemind.optim.training_averager import TrainingAverager 4 | -------------------------------------------------------------------------------- /hivemind/p2p/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.p2p.p2p_daemon import P2P, P2PContext 2 | from hivemind.p2p.p2p_daemon_bindings import P2PDaemonError, P2PHandlerError, PeerID, PeerInfo 3 | from hivemind.p2p.servicer import ServicerBase, StubBase 4 | -------------------------------------------------------------------------------- /hivemind/p2p/p2p_daemon_bindings/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo 2 | from hivemind.p2p.p2p_daemon_bindings.utils import P2PDaemonError, P2PHandlerError 3 | -------------------------------------------------------------------------------- /hivemind/p2p/p2p_daemon_bindings/datastructures.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings 3 | Licence: MIT 4 | Author: Kevin Mai-Husan Chia 5 | """ 6 | 7 | import hashlib 8 | from typing import Any, Sequence, Union 9 | 10 | import base58 11 | import multihash 12 | from cryptography.hazmat.primitives import serialization 13 | 14 | from hivemind.proto import crypto_pb2, p2pd_pb2 15 | from hivemind.utils.multiaddr import Multiaddr 16 | 17 | 18 | class PeerID: 19 | def __init__(self, peer_id_bytes: bytes) -> None: 20 | self._bytes = peer_id_bytes 21 | self._b58_str = base58.b58encode(self._bytes).decode() 22 | 23 | def to_bytes(self) -> bytes: 24 | return self._bytes 25 | 26 | def to_base58(self) -> str: 27 | return self._b58_str 28 | 29 | def __repr__(self) -> str: 30 | return f"" 31 | 32 | def __str__(self): 33 | return self.to_base58() 34 | 35 | def pretty(self): 36 | return self.to_base58() 37 | 38 | def to_string(self): 39 | return self.to_base58() 40 | 41 | def __eq__(self, other: object) -> bool: 42 | if isinstance(other, str): 43 | return self.to_base58() == other 44 | elif isinstance(other, bytes): 45 | return self._bytes == other 46 | elif isinstance(other, PeerID): 47 | return self._bytes == other._bytes 48 | else: 49 | return False 50 | 51 | def __lt__(self, other: object) -> bool: 52 | if not isinstance(other, PeerID): 53 | raise TypeError(f"'<' not supported between instances of 'PeerID' and '{type(other)}'") 54 | 55 | return self.to_base58() < other.to_base58() 56 | 57 | def __hash__(self) -> int: 58 | return hash(self._bytes) 59 | 60 | @classmethod 61 | def from_base58(cls, base58_id: str) -> "PeerID": 62 | peer_id_bytes = base58.b58decode(base58_id) 63 | return cls(peer_id_bytes) 64 | 65 | @classmethod 66 | def from_identity(cls, data: bytes) -> "PeerID": 67 | """ 68 | See [1] for the specification of how this conversion should happen. 69 | 70 | [1] https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md#peer-ids 71 | """ 72 | key_data = crypto_pb2.PrivateKey.FromString(data).data 73 | private_key = serialization.load_der_private_key(key_data, password=None) 74 | 75 | encoded_public_key = private_key.public_key().public_bytes( 76 | encoding=serialization.Encoding.DER, 77 | format=serialization.PublicFormat.SubjectPublicKeyInfo, 78 | ) 79 | encoded_public_key = crypto_pb2.PublicKey( 80 | key_type=crypto_pb2.RSA, 81 | data=encoded_public_key, 82 | ).SerializeToString() 83 | 84 | encoded_digest = multihash.encode( 85 | hashlib.sha256(encoded_public_key).digest(), 86 | multihash.coerce_code("sha2-256"), 87 | ) 88 | return cls(encoded_digest) 89 | 90 | 91 | def sha256_digest(data: Union[str, bytes]) -> bytes: 92 | if isinstance(data, str): 93 | data = data.encode("utf8") 94 | return hashlib.sha256(data).digest() 95 | 96 | 97 | class StreamInfo: 98 | def __init__(self, peer_id: PeerID, addr: Multiaddr, proto: str) -> None: 99 | self.peer_id = peer_id 100 | self.addr = addr 101 | self.proto = proto 102 | 103 | def __repr__(self) -> str: 104 | return f"" 105 | 106 | def to_protobuf(self) -> p2pd_pb2.StreamInfo: 107 | pb_msg = p2pd_pb2.StreamInfo(peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto) 108 | return pb_msg 109 | 110 | @classmethod 111 | def from_protobuf(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo": 112 | stream_info = cls(peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto) 113 | return stream_info 114 | 115 | 116 | class PeerInfo: 117 | def __init__(self, peer_id: PeerID, addrs: Sequence[Multiaddr]) -> None: 118 | self.peer_id = peer_id 119 | self.addrs = list(addrs) 120 | 121 | def __eq__(self, other: Any) -> bool: 122 | return isinstance(other, PeerInfo) and self.peer_id == other.peer_id and self.addrs == other.addrs 123 | 124 | @classmethod 125 | def from_protobuf(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> "PeerInfo": 126 | peer_id = PeerID(peer_info_pb.id) 127 | addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs] 128 | return PeerInfo(peer_id, addrs) 129 | 130 | def __str__(self): 131 | return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}" 132 | 133 | def __repr__(self): 134 | return f"PeerInfo(peer_id={repr(self.peer_id)}, addrs={repr(self.addrs)})" 135 | -------------------------------------------------------------------------------- /hivemind/p2p/p2p_daemon_bindings/p2pclient.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings 3 | Licence: MIT 4 | Author: Kevin Mai-Husan Chia 5 | """ 6 | 7 | import asyncio 8 | from contextlib import asynccontextmanager 9 | from typing import AsyncIterator, Iterable, Sequence, Tuple 10 | 11 | from hivemind.p2p.p2p_daemon_bindings.control import ( 12 | DEFAULT_MAX_MSG_SIZE, 13 | ControlClient, 14 | DaemonConnector, 15 | StreamHandler, 16 | TUnaryHandler, 17 | ) 18 | from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo 19 | from hivemind.utils.multiaddr import Multiaddr 20 | 21 | 22 | class Client: 23 | control: ControlClient 24 | 25 | def __init__(self, *, _initialized_with_create=False) -> None: 26 | assert _initialized_with_create, "Please use Client.create coroutine to spawn new client instances" 27 | self.control = None 28 | 29 | @classmethod 30 | async def create( 31 | cls, 32 | control_maddr: Multiaddr = None, 33 | listen_maddr: Multiaddr = None, 34 | *, 35 | persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE, 36 | ) -> "Client": 37 | client = cls(_initialized_with_create=True) 38 | 39 | daemon_connector = DaemonConnector(control_maddr=control_maddr) 40 | client.control = await ControlClient.create( 41 | daemon_connector=daemon_connector, 42 | listen_maddr=listen_maddr, 43 | persistent_conn_max_msg_size=persistent_conn_max_msg_size, 44 | ) 45 | 46 | return client 47 | 48 | def close(self) -> None: 49 | if self.control is not None: 50 | self.control.close() 51 | 52 | def __del__(self): 53 | self.close() 54 | 55 | @asynccontextmanager 56 | async def listen(self) -> AsyncIterator["Client"]: 57 | """ 58 | Starts to listen incoming connections for handlers registered via stream_handler. 59 | :return: 60 | """ 61 | async with self.control.listen(): 62 | yield self 63 | 64 | async def add_unary_handler(self, proto: str, handler: TUnaryHandler, balanced: bool = False) -> None: 65 | await self.control.add_unary_handler(proto, handler, balanced=balanced) 66 | 67 | async def remove_unary_handler(self, proto: str) -> None: 68 | await self.control.remove_unary_handler(proto) 69 | 70 | async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes: 71 | return await self.control.call_unary_handler(peer_id, proto, data) 72 | 73 | async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]: 74 | """ 75 | Get current node peer id and list of addresses 76 | """ 77 | return await self.control.identify() 78 | 79 | async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None: 80 | """ 81 | Connect to p2p node with specified addresses and peer id. 82 | :peer_id: node peer id you want connect to 83 | :maddrs: node multiaddresses you want connect to. Of course, it must be reachable. 84 | """ 85 | await self.control.connect(peer_id=peer_id, maddrs=maddrs) 86 | 87 | async def list_peers(self) -> Tuple[PeerInfo, ...]: 88 | """ 89 | Get list of peers that node connect to 90 | """ 91 | return await self.control.list_peers() 92 | 93 | async def disconnect(self, peer_id: PeerID) -> None: 94 | """ 95 | Disconnect from node with specified peer id 96 | :peer_id: node peer id you want disconnect from 97 | """ 98 | await self.control.disconnect(peer_id=peer_id) 99 | 100 | async def stream_open( 101 | self, peer_id: PeerID, protocols: Sequence[str] 102 | ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]: 103 | """ 104 | Open a stream to call other peer (with peer_id) handler for specified protocols 105 | :peer_id: other peer id 106 | :protocols: list of protocols for other peer handling 107 | :return: Returns tuple of stream info (info about connection to second peer) and reader/writer 108 | """ 109 | return await self.control.stream_open(peer_id=peer_id, protocols=protocols) 110 | 111 | async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: bool = False) -> None: 112 | """ 113 | Register a stream handler 114 | :param proto: protocols that handler serves 115 | :param handler_cb: handler callback 116 | :param balanced: flag if stream handler should be balanced on p2pd side. Default: False. 117 | :return: 118 | """ 119 | await self.control.stream_handler(proto=proto, handler_cb=handler_cb, balanced=balanced) 120 | 121 | async def remove_stream_handler(self, proto: str) -> None: 122 | await self.control.remove_stream_handler(proto=proto) 123 | -------------------------------------------------------------------------------- /hivemind/p2p/p2p_daemon_bindings/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings 3 | Licence: MIT 4 | Author: Kevin Mai-Husan Chia 5 | """ 6 | 7 | import asyncio 8 | 9 | from google.protobuf.message import Message as PBMessage 10 | 11 | from hivemind.proto import p2pd_pb2 as p2pd_pb 12 | 13 | DEFAULT_MAX_BITS: int = 64 14 | 15 | 16 | class P2PHandlerError(Exception): 17 | """ 18 | Raised if remote handled a request with an exception 19 | """ 20 | 21 | 22 | class P2PDaemonError(Exception): 23 | """ 24 | Raised if daemon failed to handle request 25 | """ 26 | 27 | 28 | class ControlFailure(P2PDaemonError): 29 | pass 30 | 31 | 32 | class DispatchFailure(P2PDaemonError): 33 | pass 34 | 35 | 36 | async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS) -> None: 37 | max_int = 1 << max_bits 38 | if integer < 0: 39 | raise ValueError(f"negative integer: {integer}") 40 | if integer >= max_int: 41 | raise ValueError(f"integer too large: {integer}") 42 | while True: 43 | value = integer & 0x7F 44 | integer >>= 7 45 | if integer != 0: 46 | value |= 0x80 47 | byte = value.to_bytes(1, "big") 48 | stream.write(byte) 49 | await stream.drain() 50 | if integer == 0: 51 | break 52 | 53 | 54 | async def read_unsigned_varint(stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS) -> int: 55 | max_int = 1 << max_bits 56 | iteration = 0 57 | result = 0 58 | has_next = True 59 | while has_next: 60 | data = await stream.readexactly(1) 61 | c = data[0] 62 | value = c & 0x7F 63 | result |= value << (iteration * 7) 64 | has_next = (c & 0x80) != 0 65 | iteration += 1 66 | if result >= max_int: 67 | raise ValueError(f"Varint overflowed: {result}") 68 | return result 69 | 70 | 71 | def raise_if_failed(response: p2pd_pb.Response) -> None: 72 | if response.type == p2pd_pb.Response.ERROR: 73 | raise ControlFailure(f"Connect failed. msg={response.error.msg}") 74 | 75 | 76 | async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None: 77 | size = pbmsg.ByteSize() 78 | await write_unsigned_varint(stream, size) 79 | msg_bytes: bytes = pbmsg.SerializeToString() 80 | stream.write(msg_bytes) 81 | await stream.drain() 82 | 83 | 84 | async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None: 85 | len_msg_bytes = await read_unsigned_varint(stream) 86 | msg_bytes = await stream.readexactly(len_msg_bytes) 87 | pbmsg.ParseFromString(msg_bytes) 88 | -------------------------------------------------------------------------------- /hivemind/proto/auth.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message AccessToken { 4 | string username = 1; 5 | bytes public_key = 2; 6 | string expiration_time = 3; 7 | bytes signature = 4; 8 | } 9 | 10 | message RequestAuthInfo { 11 | AccessToken client_access_token = 1; 12 | bytes service_public_key = 2; 13 | double time = 3; 14 | bytes nonce = 4; 15 | bytes signature = 5; 16 | } 17 | 18 | message ResponseAuthInfo { 19 | AccessToken service_access_token = 1; 20 | bytes nonce = 2; 21 | bytes signature = 3; 22 | } 23 | -------------------------------------------------------------------------------- /hivemind/proto/averaging.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | import "runtime.proto"; 3 | 4 | 5 | enum MessageCode { 6 | NO_CODE = 0; // Default value that should not be used explicitly 7 | REQUEST_JOIN = 1; // "Dear maybe leader, will you have me in your group as a follower?" 8 | ACCEPTED = 2; // "I accept you in my group, you now commit to responding to me" 9 | BEGIN_ALLREDUCE = 3; // "We can begin allreduce now. These are your peers." 10 | PART_FOR_AVERAGING = 4; // "I am running allreduce with you, here's a part of my tensor that you should aggregate" 11 | AVERAGED_PART = 5; // "I aggregated your part with others and here's the average for that part" 12 | NOT_DECLARED = 6; // "I have not declared my group id yet, how the heck did you even find me? Go away." 13 | NOT_A_LEADER = 7; // "I am not a group a leader. Go ask my leader instead." 14 | BAD_EXPIRATION_TIME = 8; // "I will not accept you. I cannot guarantee that we begin before you expire." 15 | BAD_SCHEMA_HASH = 9; // "I will not accept you. I am not averaging the samy type of tensors as you." 16 | BAD_GROUP_ID = 10; // "I will not accept your request, your group id does not match with any groups i'm in." 17 | DUPLICATE_PEER_ID = 11; // "I will not accept you, i already have exactly the same peer id in my current group." 18 | GROUP_IS_FULL = 12; // "I will not accept you, my group already contains too many peers." 19 | NOT_LOOKING_FOR_GROUP = 13;// "I'm not available at the moment. Please, get lost." 20 | PROTOCOL_VIOLATION = 14; // "You did something so unspeakable that i don't have a special code for that." 21 | INTERNAL_ERROR = 15; // "I messed up, we will have to stop allreduce because of that." 22 | CANCELLED = 16; // "[from peer during allreduce] I no longer want to participate in AllReduce." 23 | GROUP_DISBANDED = 17; // "[from leader] The group is closed. Go find another group." 24 | BAD_GROUP_KEY = 18; // "I will not accept you. My current group key differs (maybe you used my older key)." 25 | } 26 | 27 | message JoinRequest { 28 | bytes schema_hash = 2; // A hash that describes follower's tensors (shapes, num tensors, etc) 29 | double expiration = 3; // Follower would like to **begin** all_reduce by this point in time 30 | bytes gather = 4; // optional metadata that is gathered from all peers (e.g. batch size or current loss) 31 | bool client_mode = 5; // if True, the incoming averager is a client with no capacity for averaging 32 | string group_key = 6; // group key identifying an All-Reduce bucket, e.g my_averager.0b011011101 33 | } 34 | 35 | message MessageFromLeader { 36 | MessageCode code = 1; 37 | bytes group_id = 2; // a unique identifier of this group, only valid until allreduce is finished/failed 38 | bytes suggested_leader = 3; // if peer is already in a group, it'll provide us with a peer id of its leader 39 | repeated bytes ordered_peer_ids = 4; // a sequence of peers, each responsible for one shard during averaging 40 | repeated bytes gathered = 5; // metadata (gather) from all groupmates in the same order as their peer ids 41 | } 42 | 43 | message AveragingData { 44 | MessageCode code = 1; // in case of a protocol violation, this will be the error message 45 | bytes group_id = 2; // a unique group identifier, same as in MessageFromLeader 46 | bytes peer_id = 3; // sender's rpc peer_id, used for coordination 47 | Tensor tensor_part = 4; // either peer's local tensor part (rpc input) or group average of this part (rpc output) 48 | double weight = 5; // peers will be averaged in proportion to these weights 49 | } 50 | 51 | message DownloadRequest {} 52 | 53 | message DownloadData { 54 | bytes metadata = 1; 55 | Tensor tensor_part = 2; 56 | } 57 | -------------------------------------------------------------------------------- /hivemind/proto/crypto.proto: -------------------------------------------------------------------------------- 1 | // Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings 2 | // Licence: MIT 3 | // Author: Kevin Mai-Husan Chia 4 | 5 | syntax = "proto2"; 6 | 7 | package crypto.pb; 8 | 9 | enum KeyType { 10 | RSA = 0; 11 | Ed25519 = 1; 12 | Secp256k1 = 2; 13 | ECDSA = 3; 14 | } 15 | 16 | message PublicKey { 17 | required KeyType key_type = 1; 18 | required bytes data = 2; 19 | } 20 | 21 | message PrivateKey { 22 | required KeyType key_type = 1; 23 | required bytes data = 2; 24 | } 25 | -------------------------------------------------------------------------------- /hivemind/proto/dht.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | import "auth.proto"; 3 | 4 | // this protocol defines how Hivemind nodes form a distributed hash table. 5 | // For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode) 6 | 7 | message NodeInfo { 8 | // note: both node_id and port are optional: if specified, ask peer to add you to its routing table; 9 | // if either node_id or port is absent, simply request recipient info (for client-only mode) 10 | bytes node_id = 1; // sender's own node id serialized with DHTID.to_bytes() 11 | } 12 | 13 | message PingRequest { 14 | RequestAuthInfo auth = 1; 15 | NodeInfo peer = 2; // (optional) sender's own node info, same behavior as in DHT.rpc_ping 16 | bool validate = 3; // set to True if sender wants to validate that he is accessible and synchronized 17 | } 18 | 19 | message PingResponse { 20 | ResponseAuthInfo auth = 1; 21 | NodeInfo peer = 2; // respondent's node id, for you to update routing table 22 | double dht_time = 4; // recipient's local DHT time - used to soft-synchronize peers 23 | bool available = 5; // if validate = True, this flag asserts that the sender is available for ping 24 | } 25 | 26 | message StoreRequest { 27 | RequestAuthInfo auth = 1; 28 | // three lists of the same length representing dht keys, dht values and expiration 29 | repeated bytes keys = 2; // keys in the form of DHTID.generate(raw_key).to_bytes() 30 | repeated bytes subkeys = 3; // serialized subkeys for DictionaryDHTValue type. None means no subkey 31 | repeated bytes values = 4; // serialized value for i-th key 32 | repeated double expiration_time = 5; // expirations for i-th key (type = DHTExpiration) 33 | repeated bool in_cache = 6; // if in_cache[i], store i-th key in cache, else store normally 34 | NodeInfo peer = 7; // (optional) sender's own node info, same behavior as in DHT.rpc_ping 35 | } 36 | 37 | message StoreResponse { 38 | ResponseAuthInfo auth = 1; 39 | repeated bool store_ok = 2; // for every key, True means store accepted, False means store rejected/failed 40 | NodeInfo peer = 3; // respondent's node id, for you to update routing table 41 | } 42 | 43 | message FindRequest { 44 | RequestAuthInfo auth = 1; 45 | repeated bytes keys = 2; // a list of DHTID search keys encoded as bytes 46 | NodeInfo peer = 3; // optional, same behavior as in DHT.ping 47 | } 48 | 49 | enum ResultType {NOT_FOUND = 0; FOUND_REGULAR = 1; FOUND_DICTIONARY = 2;} 50 | 51 | message FindResult { 52 | ResultType type = 1; // NONE | REGULAR | DICTIONARY 53 | bytes value = 2; // n/a | serialized value | serialized DictionaryDHTValue with serialized fields 54 | double expiration_time = 3; // n/a | expiration time | DictionaryDHTValue.latest_expiration_time 55 | 56 | // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance) 57 | repeated bytes nearest_node_ids = 4; // DHTIDs of the nearest peers serialized with node_id.to_bytes() 58 | repeated bytes nearest_peer_ids = 5; // libp2p PeerIDs of the nearest peers 59 | } 60 | 61 | message FindResponse { 62 | ResponseAuthInfo auth = 1; 63 | repeated FindResult results = 2; // for each item, return value/expiration (if found) and nearest peers 64 | NodeInfo peer = 3; // respondent's node id, for you to update routing table 65 | } 66 | -------------------------------------------------------------------------------- /hivemind/proto/p2pd.proto: -------------------------------------------------------------------------------- 1 | // Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings 2 | // Licence: MIT 3 | // Author: Kevin Mai-Husan Chia 4 | 5 | syntax = "proto2"; 6 | 7 | package p2pclient.p2pd.pb; 8 | 9 | message Request { 10 | enum Type { 11 | IDENTIFY = 0; 12 | CONNECT = 1; 13 | STREAM_OPEN = 2; 14 | STREAM_HANDLER = 3; 15 | REMOVE_STREAM_HANDLER = 10; 16 | DHT = 4; 17 | LIST_PEERS = 5; 18 | CONNMANAGER = 6; 19 | DISCONNECT = 7; 20 | PUBSUB = 8; 21 | PERSISTENT_CONN_UPGRADE = 9; 22 | } 23 | 24 | required Type type = 1; 25 | 26 | optional ConnectRequest connect = 2; 27 | optional StreamOpenRequest streamOpen = 3; 28 | optional StreamHandlerRequest streamHandler = 4; 29 | optional RemoveStreamHandlerRequest removeStreamHandler = 9; 30 | optional DHTRequest dht = 5; 31 | optional ConnManagerRequest connManager = 6; 32 | optional DisconnectRequest disconnect = 7; 33 | optional PSRequest pubsub = 8; 34 | } 35 | 36 | message Response { 37 | enum Type { 38 | OK = 0; 39 | ERROR = 1; 40 | } 41 | 42 | required Type type = 1; 43 | optional ErrorResponse error = 2; 44 | optional StreamInfo streamInfo = 3; 45 | optional IdentifyResponse identify = 4; 46 | optional DHTResponse dht = 5; 47 | repeated PeerInfo peers = 6; 48 | optional PSResponse pubsub = 7; 49 | } 50 | 51 | message PersistentConnectionRequest { 52 | required bytes callId = 1; 53 | 54 | oneof message { 55 | AddUnaryHandlerRequest addUnaryHandler = 2; 56 | RemoveUnaryHandlerRequest removeUnaryHandler = 6; 57 | CallUnaryRequest callUnary = 3; 58 | CallUnaryResponse unaryResponse = 4; 59 | Cancel cancel = 5; 60 | } 61 | } 62 | 63 | message PersistentConnectionResponse { 64 | required bytes callId = 1; 65 | 66 | oneof message { 67 | CallUnaryResponse callUnaryResponse = 2; 68 | CallUnaryRequest requestHandling = 3; 69 | DaemonError daemonError = 4; 70 | Cancel cancel = 5; 71 | } 72 | } 73 | 74 | 75 | message IdentifyResponse { 76 | required bytes id = 1; 77 | repeated bytes addrs = 2; 78 | } 79 | 80 | message ConnectRequest { 81 | required bytes peer = 1; 82 | repeated bytes addrs = 2; 83 | optional int64 timeout = 3; 84 | } 85 | 86 | message StreamOpenRequest { 87 | required bytes peer = 1; 88 | repeated string proto = 2; 89 | optional int64 timeout = 3; 90 | } 91 | 92 | message StreamHandlerRequest { 93 | required bytes addr = 1; 94 | repeated string proto = 2; 95 | required bool balanced = 3; 96 | } 97 | 98 | message RemoveStreamHandlerRequest { 99 | required bytes addr = 1; 100 | repeated string proto = 2; 101 | } 102 | 103 | message ErrorResponse { 104 | required string msg = 1; 105 | } 106 | 107 | message StreamInfo { 108 | required bytes peer = 1; 109 | required bytes addr = 2; 110 | required string proto = 3; 111 | } 112 | 113 | message DHTRequest { 114 | enum Type { 115 | FIND_PEER = 0; 116 | FIND_PEERS_CONNECTED_TO_PEER = 1; 117 | FIND_PROVIDERS = 2; 118 | GET_CLOSEST_PEERS = 3; 119 | GET_PUBLIC_KEY = 4; 120 | GET_VALUE = 5; 121 | SEARCH_VALUE = 6; 122 | PUT_VALUE = 7; 123 | PROVIDE = 8; 124 | } 125 | 126 | required Type type = 1; 127 | optional bytes peer = 2; 128 | optional bytes cid = 3; 129 | optional bytes key = 4; 130 | optional bytes value = 5; 131 | optional int32 count = 6; 132 | optional int64 timeout = 7; 133 | } 134 | 135 | message DHTResponse { 136 | enum Type { 137 | BEGIN = 0; 138 | VALUE = 1; 139 | END = 2; 140 | } 141 | 142 | required Type type = 1; 143 | optional PeerInfo peer = 2; 144 | optional bytes value = 3; 145 | } 146 | 147 | message PeerInfo { 148 | required bytes id = 1; 149 | repeated bytes addrs = 2; 150 | } 151 | 152 | message ConnManagerRequest { 153 | enum Type { 154 | TAG_PEER = 0; 155 | UNTAG_PEER = 1; 156 | TRIM = 2; 157 | } 158 | 159 | required Type type = 1; 160 | 161 | optional bytes peer = 2; 162 | optional string tag = 3; 163 | optional int64 weight = 4; 164 | } 165 | 166 | message DisconnectRequest { 167 | required bytes peer = 1; 168 | } 169 | 170 | message PSRequest { 171 | enum Type { 172 | GET_TOPICS = 0; 173 | LIST_PEERS = 1; 174 | PUBLISH = 2; 175 | SUBSCRIBE = 3; 176 | } 177 | 178 | required Type type = 1; 179 | optional string topic = 2; 180 | optional bytes data = 3; 181 | } 182 | 183 | message PSMessage { 184 | optional bytes from = 1; 185 | optional bytes data = 2; 186 | optional bytes seqno = 3; 187 | repeated string topicIDs = 4; 188 | optional bytes signature = 5; 189 | optional bytes key = 6; 190 | } 191 | 192 | message PSResponse { 193 | repeated string topics = 1; 194 | repeated bytes peerIDs = 2; 195 | } 196 | 197 | message CallUnaryRequest { 198 | required bytes peer = 1; 199 | required string proto = 2; 200 | required bytes data = 3; 201 | } 202 | 203 | message CallUnaryResponse { 204 | oneof result { 205 | bytes response = 1; 206 | bytes error = 2; 207 | } 208 | } 209 | 210 | message AddUnaryHandlerRequest { 211 | required string proto = 1; 212 | required bool balanced = 2; 213 | } 214 | 215 | message RemoveUnaryHandlerRequest { 216 | required string proto = 1; 217 | } 218 | 219 | message DaemonError { 220 | optional string message = 1; 221 | } 222 | 223 | message Cancel { 224 | } 225 | 226 | message RPCError { 227 | optional string message = 1; 228 | } 229 | -------------------------------------------------------------------------------- /hivemind/proto/runtime.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | 4 | message ExpertUID { 5 | string uid = 1; 6 | } 7 | 8 | message ExpertInfo { 9 | bytes serialized_info = 1; 10 | } 11 | 12 | message ExpertRequest { 13 | string uid = 1; 14 | repeated Tensor tensors = 2; 15 | bytes metadata = 3; 16 | } 17 | 18 | message ExpertResponse { 19 | repeated Tensor tensors = 2; 20 | bytes metadata = 3; 21 | } 22 | 23 | enum CompressionType{ 24 | NONE = 0; 25 | MEANSTD_16BIT = 1; 26 | FLOAT16 = 2; 27 | QUANTILE_8BIT = 3; 28 | UNIFORM_8BIT = 4; 29 | BLOCKWISE_8BIT = 5; 30 | } 31 | 32 | message Tensor { 33 | bytes buffer = 1; 34 | repeated uint32 size = 2; 35 | bool requires_grad = 3; 36 | string dtype = 4; 37 | CompressionType compression = 5; 38 | int32 chunks = 6; 39 | } 40 | 41 | -------------------------------------------------------------------------------- /hivemind/proto/test.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | message TestRequest { 4 | int32 number = 1; 5 | } 6 | 7 | message TestResponse { 8 | int32 number = 1; 9 | } 10 | -------------------------------------------------------------------------------- /hivemind/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from hivemind.utils.asyncio import * 2 | from hivemind.utils.limits import increase_file_limit 3 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 4 | from hivemind.utils.mpfuture import * 5 | from hivemind.utils.nested import * 6 | from hivemind.utils.networking import log_visible_maddrs 7 | from hivemind.utils.performance_ema import PerformanceEMA 8 | from hivemind.utils.serializer import MSGPackSerializer, SerializerBase 9 | from hivemind.utils.streaming import combine_from_streaming, split_for_streaming 10 | from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor 11 | from hivemind.utils.timed_storage import * 12 | -------------------------------------------------------------------------------- /hivemind/utils/crypto.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import base64 4 | import threading 5 | from abc import ABC, abstractmethod 6 | 7 | from cryptography import exceptions 8 | from cryptography.hazmat.primitives import hashes, serialization 9 | from cryptography.hazmat.primitives.asymmetric import padding, rsa 10 | 11 | 12 | class PrivateKey(ABC): 13 | @abstractmethod 14 | def sign(self, data: bytes) -> bytes: ... 15 | 16 | @abstractmethod 17 | def get_public_key(self) -> PublicKey: ... 18 | 19 | 20 | class PublicKey(ABC): 21 | @abstractmethod 22 | def verify(self, data: bytes, signature: bytes) -> bool: ... 23 | 24 | @abstractmethod 25 | def to_bytes(self) -> bytes: ... 26 | 27 | @classmethod 28 | @abstractmethod 29 | def from_bytes(cls, key: bytes) -> bytes: ... 30 | 31 | 32 | _RSA_PADDING = padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH) 33 | _RSA_HASH_ALGORITHM = hashes.SHA256() 34 | 35 | 36 | class RSAPrivateKey(PrivateKey): 37 | def __init__(self): 38 | self._private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) 39 | 40 | _process_wide_key = None 41 | _process_wide_key_lock = threading.RLock() 42 | 43 | @classmethod 44 | def process_wide(cls) -> RSAPrivateKey: 45 | if cls._process_wide_key is None: 46 | with cls._process_wide_key_lock: 47 | if cls._process_wide_key is None: 48 | cls._process_wide_key = cls() 49 | return cls._process_wide_key 50 | 51 | def sign(self, data: bytes) -> bytes: 52 | signature = self._private_key.sign(data, _RSA_PADDING, _RSA_HASH_ALGORITHM) 53 | return base64.b64encode(signature) 54 | 55 | def get_public_key(self) -> RSAPublicKey: 56 | return RSAPublicKey(self._private_key.public_key()) 57 | 58 | def to_bytes(self) -> bytes: 59 | return self._private_key.private_bytes( 60 | encoding=serialization.Encoding.DER, 61 | format=serialization.PrivateFormat.TraditionalOpenSSL, 62 | encryption_algorithm=serialization.NoEncryption(), 63 | ) 64 | 65 | def __getstate__(self): 66 | state = self.__dict__.copy() 67 | # Serializes the private key to make the class instances picklable 68 | state["_private_key"] = self.to_bytes() 69 | return state 70 | 71 | def __setstate__(self, state): 72 | self.__dict__.update(state) 73 | self._private_key = serialization.load_der_private_key(self._private_key, password=None) 74 | 75 | 76 | class RSAPublicKey(PublicKey): 77 | def __init__(self, public_key: rsa.RSAPublicKey): 78 | self._public_key = public_key 79 | 80 | def verify(self, data: bytes, signature: bytes) -> bool: 81 | try: 82 | signature = base64.b64decode(signature) 83 | 84 | # Returns None if the signature is correct, raises an exception otherwise 85 | self._public_key.verify(signature, data, _RSA_PADDING, _RSA_HASH_ALGORITHM) 86 | 87 | return True 88 | except (ValueError, exceptions.InvalidSignature): 89 | return False 90 | 91 | def to_bytes(self) -> bytes: 92 | return self._public_key.public_bytes( 93 | encoding=serialization.Encoding.OpenSSH, format=serialization.PublicFormat.OpenSSH 94 | ) 95 | 96 | @classmethod 97 | def from_bytes(cls, key: bytes) -> RSAPublicKey: 98 | key = serialization.load_ssh_public_key(key) 99 | if not isinstance(key, rsa.RSAPublicKey): 100 | raise ValueError(f"Expected an RSA public key, got {key}") 101 | return cls(key) 102 | -------------------------------------------------------------------------------- /hivemind/utils/limits.py: -------------------------------------------------------------------------------- 1 | from hivemind.utils.logging import get_logger 2 | 3 | logger = get_logger(__name__) 4 | 5 | 6 | def increase_file_limit(new_soft=2**15, new_hard=2**15): 7 | """Increase the maximum number of open files. On Linux, this allows spawning more processes/threads.""" 8 | try: 9 | import resource # local import to avoid ImportError for Windows users 10 | 11 | soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) 12 | new_soft = max(soft, new_soft) 13 | new_hard = max(hard, new_hard) 14 | logger.info(f"Increasing file limit: soft {soft}=>{new_soft}, hard {hard}=>{new_hard}") 15 | return resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft, new_hard)) 16 | except Exception as e: 17 | logger.warning(f"Failed to increase file limit: {e}") 18 | -------------------------------------------------------------------------------- /hivemind/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | @torch.jit.script 6 | def orthogonalize_(matrix, eps: float = 1e-8): 7 | """Orthogonalize a 2d tensor in-place over the last dimension""" 8 | n, m = matrix.shape 9 | for i in range(m): 10 | col = matrix[:, i] 11 | F.normalize(col, dim=0, eps=eps, out=col) 12 | if i + 1 < m: 13 | rest = matrix[:, i + 1 :] 14 | rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1) 15 | 16 | 17 | def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2): 18 | """get dims to flatten tensor up to max_ndim dimensions by merging small axes together""" 19 | dims = list(tensor.shape) 20 | while len(dims) > max_ndim: 21 | squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1]) 22 | squeezed_dim = dims.pop(squeeze_ix) 23 | dims[squeeze_ix] *= squeezed_dim 24 | return dims 25 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/__init__.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | from .multiaddr import Multiaddr # NOQA 26 | 27 | __author__ = "Steven Buss" 28 | __email__ = "steven.buss@gmail.com" 29 | __version__ = "0.0.9" 30 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/__init__.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import importlib 26 | 27 | # These are special sizes 28 | LENGTH_PREFIXED_VAR_SIZE = -1 29 | 30 | 31 | class NoneCodec: 32 | SIZE = 0 33 | IS_PATH = False 34 | 35 | 36 | CODEC_CACHE = {} 37 | 38 | 39 | def codec_by_name(name): 40 | if name is None: # Special “do nothing – expect nothing” pseudo-codec 41 | return NoneCodec 42 | codec = CODEC_CACHE.get(name) 43 | if not codec: 44 | codec = CODEC_CACHE[name] = importlib.import_module(".{0}".format(name), __name__) 45 | return codec 46 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/domain.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import idna 26 | 27 | from . import LENGTH_PREFIXED_VAR_SIZE 28 | 29 | SIZE = LENGTH_PREFIXED_VAR_SIZE 30 | IS_PATH = False 31 | 32 | 33 | def to_bytes(proto, string): 34 | return idna.uts46_remap(string).encode("utf-8") 35 | 36 | 37 | def to_string(proto, buf): 38 | string = buf.decode("utf-8") 39 | for label in string.split("."): 40 | idna.check_label(label) 41 | return string 42 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/fspath.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import os 26 | 27 | from . import LENGTH_PREFIXED_VAR_SIZE 28 | 29 | SIZE = LENGTH_PREFIXED_VAR_SIZE 30 | IS_PATH = True 31 | 32 | 33 | def to_bytes(proto, string): 34 | return os.fsencode(string) 35 | 36 | 37 | def to_string(proto, buf): 38 | return os.fsdecode(buf) 39 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/ip4.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import netaddr 26 | 27 | SIZE = 32 28 | IS_PATH = False 29 | 30 | 31 | def to_bytes(proto, string): 32 | return netaddr.IPAddress(string, version=4).packed 33 | 34 | 35 | def to_string(proto, buf): 36 | return str(netaddr.IPAddress(int.from_bytes(buf, byteorder="big"), version=4)) 37 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/ip6.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import netaddr 26 | 27 | SIZE = 128 28 | IS_PATH = False 29 | 30 | 31 | def to_bytes(proto, string): 32 | return netaddr.IPAddress(string, version=6).packed 33 | 34 | 35 | def to_string(proto, buf): 36 | return str(netaddr.IPAddress(int.from_bytes(buf, byteorder="big"), version=6)) 37 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/onion.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import base64 26 | import struct 27 | 28 | SIZE = 96 29 | IS_PATH = False 30 | 31 | 32 | def to_bytes(proto, string): 33 | addr = string.split(":") 34 | if len(addr) != 2: 35 | raise ValueError("Does not contain a port number") 36 | 37 | # onion address without the ".onion" substring 38 | if len(addr[0]) != 16: 39 | raise ValueError("Invalid onion host address length (must be 16 characters)") 40 | try: 41 | onion_host_bytes = base64.b32decode(addr[0].upper()) 42 | except Exception as exc: 43 | raise ValueError("Cannot decode {0!r} as base32: {1}".format(addr[0], exc)) from exc 44 | 45 | # onion port number 46 | try: 47 | port = int(addr[1], 10) 48 | except ValueError as exc: 49 | raise ValueError("Port number is not a base 10 integer") from exc 50 | if port not in range(1, 65536): 51 | raise ValueError("Port number is not in range(1, 65536)") 52 | 53 | return b"".join((onion_host_bytes, struct.pack(">H", port))) 54 | 55 | 56 | def to_string(proto, buf): 57 | addr_bytes, port_bytes = (buf[:-2], buf[-2:]) 58 | addr = base64.b32encode(addr_bytes).decode("ascii").lower() 59 | port = str(struct.unpack(">H", port_bytes)[0]) 60 | return ":".join([addr, port]) 61 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/onion3.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import base64 26 | import struct 27 | 28 | SIZE = 296 29 | IS_PATH = False 30 | 31 | 32 | def to_bytes(proto, string): 33 | addr = string.split(":") 34 | if len(addr) != 2: 35 | raise ValueError("Does not contain a port number") 36 | 37 | # onion3 address without the ".onion" substring 38 | if len(addr[0]) != 56: 39 | raise ValueError("Invalid onion3 host address length (must be 56 characters)") 40 | try: 41 | onion3_host_bytes = base64.b32decode(addr[0].upper()) 42 | except Exception as exc: 43 | raise ValueError("Cannot decode {0!r} as base32: {1}".format(addr[0], exc)) from exc 44 | 45 | # onion3 port number 46 | try: 47 | port = int(addr[1], 10) 48 | except ValueError as exc: 49 | raise ValueError("Port number is not a base 10 integer") from exc 50 | if port not in range(1, 65536): 51 | raise ValueError("Port number is not in range(1, 65536)") 52 | 53 | return b"".join((onion3_host_bytes, struct.pack(">H", port))) 54 | 55 | 56 | def to_string(proto, buf): 57 | addr_bytes, port_bytes = (buf[:-2], buf[-2:]) 58 | addr = base64.b32encode(addr_bytes).decode("ascii").lower() 59 | port = str(struct.unpack(">H", port_bytes)[0]) 60 | return ":".join([addr, port]) 61 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/uint16be.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import struct 26 | 27 | SIZE = 16 28 | IS_PATH = False 29 | 30 | 31 | def to_bytes(proto, string): 32 | try: 33 | return struct.pack(">H", int(string, 10)) 34 | except ValueError as exc: 35 | raise ValueError("Not a base 10 integer") from exc 36 | except struct.error as exc: 37 | raise ValueError("Integer not in range(65536)") from exc 38 | 39 | 40 | def to_string(proto, buf): 41 | if len(buf) != 2: 42 | raise ValueError("Invalid integer length (must be 2 bytes / 16 bits)") 43 | return str(struct.unpack(">H", buf)[0]) 44 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/codecs/utf8.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | from __future__ import absolute_import 26 | 27 | from . import LENGTH_PREFIXED_VAR_SIZE 28 | 29 | SIZE = LENGTH_PREFIXED_VAR_SIZE 30 | IS_PATH = False 31 | 32 | 33 | def to_bytes(proto, string): 34 | if len(string) == 0: 35 | raise ValueError("{0} value must not be empty".format(proto.name)) 36 | return string.encode("utf-8") 37 | 38 | 39 | def to_string(proto, buf): 40 | if len(buf) == 0: 41 | raise ValueError("invalid length (should be > 0)") 42 | return buf.decode("utf-8") 43 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/exceptions.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | class Error(Exception): 26 | pass 27 | 28 | 29 | class LookupError(LookupError, Error): 30 | pass 31 | 32 | 33 | class ProtocolLookupError(LookupError): 34 | """ 35 | MultiAddr did not contain a protocol with the requested code 36 | """ 37 | 38 | def __init__(self, proto, string): 39 | self.proto = proto 40 | self.string = string 41 | 42 | super().__init__("MultiAddr {0!r} does not contain protocol {1}".format(string, proto)) 43 | 44 | 45 | class ParseError(ValueError, Error): 46 | pass 47 | 48 | 49 | class StringParseError(ParseError): 50 | """ 51 | MultiAddr string representation could not be parsed 52 | """ 53 | 54 | def __init__(self, message, string, protocol=None, original=None): 55 | self.message = message 56 | self.string = string 57 | self.protocol = protocol 58 | self.original = original 59 | 60 | if protocol: 61 | message = "Invalid MultiAddr {0!r} protocol {1}: {2}".format(string, protocol, message) 62 | else: 63 | message = "Invalid MultiAddr {0!r}: {1}".format(string, message) 64 | 65 | super().__init__(message) 66 | 67 | 68 | class BinaryParseError(ParseError): 69 | """ 70 | MultiAddr binary representation could not be parsed 71 | """ 72 | 73 | def __init__(self, message, binary, protocol, original=None): 74 | self.message = message 75 | self.binary = binary 76 | self.protocol = protocol 77 | self.original = original 78 | 79 | message = "Invalid binary MultiAddr protocol {0}: {1}".format(protocol, message) 80 | 81 | super().__init__(message) 82 | 83 | 84 | class ProtocolRegistryError(Error): 85 | pass 86 | 87 | 88 | ProtocolManagerError = ProtocolRegistryError 89 | 90 | 91 | class ProtocolRegistryLocked(Error): 92 | """Protocol registry was locked and doesn't allow any further additions""" 93 | 94 | def __init__(self): 95 | super().__init__("Protocol registry is locked and does not accept any new values") 96 | 97 | 98 | class ProtocolExistsError(ProtocolRegistryError): 99 | """Protocol with the given name or code already exists""" 100 | 101 | def __init__(self, proto, kind="name"): 102 | self.proto = proto 103 | self.kind = kind 104 | 105 | super().__init__("Protocol with {0} {1!r} already exists".format(kind, getattr(proto, kind))) 106 | 107 | 108 | class ProtocolNotFoundError(ProtocolRegistryError): 109 | """No protocol with the given name or code found""" 110 | 111 | def __init__(self, value, kind="name"): 112 | self.value = value 113 | self.kind = kind 114 | 115 | super().__init__("No protocol with {0} {1!r} found".format(kind, value)) 116 | -------------------------------------------------------------------------------- /hivemind/utils/multiaddr/transforms.py: -------------------------------------------------------------------------------- 1 | # This code is originally taken from https://github.com/multiformats/py-multiaddr 2 | # 3 | # The MIT License (MIT) 4 | # 5 | # Copyright (c) 2014-2015 Steven Buss 6 | # Copyright (c) 2019-2020 Alexander Schlarb 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | import io 26 | 27 | import varint 28 | 29 | from . import exceptions 30 | from .codecs import LENGTH_PREFIXED_VAR_SIZE, codec_by_name 31 | from .protocols import protocol_with_code, protocol_with_name 32 | 33 | 34 | def string_to_bytes(string): 35 | bs = [] 36 | for proto, codec, value in string_iter(string): 37 | bs.append(varint.encode(proto.code)) 38 | if value is not None: 39 | try: 40 | buf = codec.to_bytes(proto, value) 41 | except Exception as exc: 42 | raise exceptions.StringParseError(str(exc), string, proto.name, exc) from exc 43 | if codec.SIZE == LENGTH_PREFIXED_VAR_SIZE: 44 | bs.append(varint.encode(len(buf))) 45 | bs.append(buf) 46 | return b"".join(bs) 47 | 48 | 49 | def bytes_to_string(buf): 50 | st = [""] # start with empty string so we get a leading slash on join() 51 | for _, proto, codec, part in bytes_iter(buf): 52 | st.append(proto.name) 53 | if codec.SIZE != 0: 54 | try: 55 | value = codec.to_string(proto, part) 56 | except Exception as exc: 57 | raise exceptions.BinaryParseError(str(exc), buf, proto.name, exc) from exc 58 | if codec.IS_PATH and value[0] == "/": 59 | st.append(value[1:]) 60 | else: 61 | st.append(value) 62 | return "/".join(st) 63 | 64 | 65 | def size_for_addr(codec, buf_io): 66 | if codec.SIZE >= 0: 67 | return codec.SIZE // 8 68 | else: 69 | return varint.decode_stream(buf_io) 70 | 71 | 72 | def string_iter(string): 73 | if not string.startswith("/"): 74 | raise exceptions.StringParseError("Must begin with /", string) 75 | # consume trailing slashes 76 | string = string.rstrip("/") 77 | sp = string.split("/") 78 | 79 | # skip the first element, since it starts with / 80 | sp.pop(0) 81 | while sp: 82 | element = sp.pop(0) 83 | try: 84 | proto = protocol_with_name(element) 85 | codec = codec_by_name(proto.codec) 86 | except (ImportError, exceptions.ProtocolNotFoundError) as exc: 87 | raise exceptions.StringParseError("Unknown Protocol", string, element) from exc 88 | value = None 89 | if codec.SIZE != 0: 90 | if len(sp) < 1: 91 | raise exceptions.StringParseError("Protocol requires address", string, proto.name) 92 | if codec.IS_PATH: 93 | value = "/" + "/".join(sp) 94 | sp.clear() 95 | else: 96 | value = sp.pop(0) 97 | yield proto, codec, value 98 | 99 | 100 | def bytes_iter(buf): 101 | buf_io = io.BytesIO(buf) 102 | while buf_io.tell() < len(buf): 103 | offset = buf_io.tell() 104 | code = varint.decode_stream(buf_io) 105 | proto = None 106 | try: 107 | proto = protocol_with_code(code) 108 | codec = codec_by_name(proto.codec) 109 | except (ImportError, exceptions.ProtocolNotFoundError) as exc: 110 | raise exceptions.BinaryParseError( 111 | "Unknown Protocol", 112 | buf, 113 | proto.name if proto else code, 114 | ) from exc 115 | 116 | size = size_for_addr(codec, buf_io) 117 | yield offset, proto, codec, buf_io.read(size) 118 | -------------------------------------------------------------------------------- /hivemind/utils/nested.py: -------------------------------------------------------------------------------- 1 | """utility functions that help you process nested dicts, tuples, lists and namedtuples""" 2 | 3 | 4 | def nested_compare(t, u): 5 | """ 6 | Return whether nested structure of t1 and t2 matches. 7 | """ 8 | if isinstance(t, (list, tuple)): 9 | if not isinstance(u, type(t)): 10 | return False 11 | if len(t) != len(u): 12 | return False 13 | for a, b in zip(t, u): 14 | if not nested_compare(a, b): 15 | return False 16 | return True 17 | 18 | if isinstance(t, dict): 19 | if not isinstance(u, dict): 20 | return False 21 | if set(t.keys()) != set(u.keys()): 22 | return False 23 | for k in t: 24 | if not nested_compare(t[k], u[k]): 25 | return False 26 | return True 27 | 28 | else: 29 | return True 30 | 31 | 32 | def nested_flatten(t): 33 | """ 34 | Turn nested list/tuple/dict into a flat iterator. 35 | """ 36 | if isinstance(t, (list, tuple)): 37 | for x in t: 38 | yield from nested_flatten(x) 39 | elif isinstance(t, dict): 40 | for k, v in sorted(t.items()): 41 | yield from nested_flatten(v) 42 | else: 43 | yield t 44 | 45 | 46 | def nested_pack(flat, structure): 47 | """ 48 | Restore nested structure from flattened state 49 | :param flat: result of nested_flatten 50 | :param structure: used as example when recovering structure 51 | :returns: nested structure like :structure: filled with elements of :flat: 52 | """ 53 | return _nested_pack(iter(flat), structure) 54 | 55 | 56 | def _nested_pack(flat_iter, structure): 57 | if is_namedtuple(structure): 58 | return type(structure)(*[_nested_pack(flat_iter, x) for x in structure]) 59 | elif isinstance(structure, (list, tuple)): 60 | return type(structure)(_nested_pack(flat_iter, x) for x in structure) 61 | elif isinstance(structure, dict): 62 | return {k: _nested_pack(flat_iter, v) for k, v in sorted(structure.items())} 63 | else: 64 | return next(flat_iter) 65 | 66 | 67 | def is_namedtuple(x): 68 | """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 .""" 69 | t = type(x) 70 | b = t.__bases__ 71 | if len(b) != 1 or b[0] is not tuple: 72 | return False 73 | f = getattr(t, "_fields", None) 74 | if not isinstance(f, tuple): 75 | return False 76 | return all(n is str for n in f) 77 | 78 | 79 | def nested_map(fn, *t): 80 | # Check arguments. 81 | if not t: 82 | raise ValueError("Expected 2+ arguments, got 1") 83 | for i in range(1, len(t)): 84 | if not nested_compare(t[0], t[i]): 85 | msg = "Nested structure of %r and %r differs" 86 | raise ValueError(msg % (t[0], t[i])) 87 | 88 | # Map. 89 | flat = map(nested_flatten, t) 90 | return nested_pack(map(fn, *flat), t[0]) 91 | -------------------------------------------------------------------------------- /hivemind/utils/networking.py: -------------------------------------------------------------------------------- 1 | from ipaddress import ip_address 2 | from typing import List, Sequence 3 | 4 | from hivemind.utils.logging import TextStyle, get_logger 5 | from hivemind.utils.multiaddr import Multiaddr 6 | 7 | LOCALHOST = "127.0.0.1" 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def choose_ip_address( 13 | maddrs: Sequence[Multiaddr], prefer_global: bool = True, protocol_priority: Sequence[str] = ("ip4", "ip6") 14 | ) -> str: 15 | """ 16 | Currently, some components of hivemind are not converted to work over libp2p and use classical networking. 17 | To allow other peers reach a server when needed, these components announce a machine's IP address. 18 | 19 | This function automatically selects the best IP address to announce among publicly visible multiaddrs 20 | of this machine identified by libp2p (typically, using the ``P2P.get_visible_maddrs()`` method), 21 | so a user does not need to define this address manually (unless the user wants to). 22 | 23 | The best IP address is chosen using the following logic: 24 | - Prefer IP addresses from global address blocks 25 | (in terms of https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv4Address.is_global) 26 | - Among the IP addresses of the same globality status, prefer IPv4 addresses over IPv6 27 | 28 | If the default logic does not suit you, it is recommended to set the announced IP address manually. 29 | """ 30 | 31 | for need_global in [prefer_global, not prefer_global]: 32 | for protocol in protocol_priority: 33 | for addr in maddrs: 34 | if protocol in addr.protocols(): 35 | value_for_protocol = addr[protocol] 36 | if ip_address(value_for_protocol).is_global == need_global: 37 | return value_for_protocol 38 | 39 | raise ValueError(f"No IP address found among given multiaddrs: {maddrs}") 40 | 41 | 42 | def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None: 43 | if only_p2p: 44 | unique_addrs = {addr["p2p"] for addr in visible_maddrs} 45 | initial_peers = " ".join(f"/p2p/{addr}" for addr in unique_addrs) 46 | else: 47 | available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr] 48 | if available_ips: 49 | preferred_ip = choose_ip_address(available_ips) 50 | selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)] 51 | else: 52 | selected_maddrs = visible_maddrs 53 | initial_peers = " ".join(str(addr) for addr in selected_maddrs) 54 | 55 | logger.info( 56 | f"Running a DHT instance. To connect other peers to this one, use " 57 | f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers}{TextStyle.RESET}" 58 | ) 59 | logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}") 60 | -------------------------------------------------------------------------------- /hivemind/utils/performance_ema.py: -------------------------------------------------------------------------------- 1 | import time 2 | from contextlib import contextmanager 3 | from threading import Lock 4 | from typing import Optional 5 | 6 | 7 | class PerformanceEMA: 8 | """ 9 | A running estimate of performance (operations/sec) using adjusted exponential moving average 10 | :param alpha: Smoothing factor in range [0, 1], [default: 0.1]. 11 | """ 12 | 13 | def __init__(self, alpha: float = 0.1, eps: float = 1e-20, paused: bool = False): 14 | self.alpha, self.eps, self.num_updates = alpha, eps, 0 15 | self.ema_seconds_per_sample, self.samples_per_second = 0, eps 16 | self.timestamp = time.perf_counter() 17 | self.paused = paused 18 | self.lock = Lock() 19 | 20 | def update(self, task_size: float, interval: Optional[float] = None) -> float: 21 | """ 22 | :param task_size: how many items were processed since last call 23 | :param interval: optionally provide the time delta it took to process this task 24 | :returns: current estimate of performance (samples per second), but at most 25 | """ 26 | assert task_size > 0, f"Can't register processing {task_size} samples" 27 | if not self.paused: 28 | self.timestamp, old_timestamp = time.perf_counter(), self.timestamp 29 | interval = interval if interval is not None else self.timestamp - old_timestamp 30 | else: 31 | assert interval is not None, "If PerformanceEMA is paused, please specify the time interval" 32 | self.ema_seconds_per_sample = ( 33 | self.alpha * interval / task_size + (1 - self.alpha) * self.ema_seconds_per_sample 34 | ) 35 | self.num_updates += 1 36 | adjusted_seconds_per_sample = self.ema_seconds_per_sample / (1 - (1 - self.alpha) ** self.num_updates) 37 | self.samples_per_second = 1 / max(adjusted_seconds_per_sample, self.eps) 38 | return self.samples_per_second 39 | 40 | def reset_timer(self): 41 | """Reset the time since the last update so that the next task performance is counted from current time""" 42 | self.timestamp = time.perf_counter() 43 | 44 | @contextmanager 45 | def pause(self): 46 | """While inside this context, EMA will not count the time passed towards the performance estimate""" 47 | self.paused, was_paused = True, self.paused 48 | try: 49 | yield 50 | finally: 51 | self.paused = was_paused 52 | self.reset_timer() 53 | 54 | def __repr__(self): 55 | return f"{self.__class__.__name__}(ema={self.samples_per_second:.5f}, num_updates={self.num_updates})" 56 | 57 | @contextmanager 58 | def update_threadsafe(self, task_size: float): 59 | """ 60 | Update the EMA throughput of a code that runs inside the context manager, supports multiple concurrent threads. 61 | 62 | :param task_size: how many items were processed since last call 63 | """ 64 | start_timestamp = time.perf_counter() 65 | yield 66 | with self.lock: 67 | self.update(task_size, interval=time.perf_counter() - max(start_timestamp, self.timestamp)) 68 | # note: we define interval as such to support two distinct scenarios: 69 | # (1) if this is the first call to measure_threadsafe after a pause, count time from entering this context 70 | # (2) if there are concurrent calls to measure_threadsafe, respect the timestamp updates from these calls 71 | -------------------------------------------------------------------------------- /hivemind/utils/serializer.py: -------------------------------------------------------------------------------- 1 | """A unified interface for several common serialization methods""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict 5 | 6 | import msgpack 7 | 8 | from hivemind.utils.logging import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class SerializerBase(ABC): 14 | @staticmethod 15 | @abstractmethod 16 | def dumps(obj: object) -> bytes: 17 | pass 18 | 19 | @staticmethod 20 | @abstractmethod 21 | def loads(buf: bytes) -> object: 22 | pass 23 | 24 | 25 | class MSGPackSerializer(SerializerBase): 26 | _ext_types: Dict[Any, int] = {} 27 | _ext_type_codes: Dict[int, Any] = {} 28 | _TUPLE_EXT_TYPE_CODE = 0x40 29 | 30 | @classmethod 31 | def ext_serializable(cls, type_code: int): 32 | assert isinstance(type_code, int), "Please specify a (unique) int type code" 33 | 34 | def wrap(wrapped_type: type): 35 | assert callable(getattr(wrapped_type, "packb", None)) and callable( 36 | getattr(wrapped_type, "unpackb", None) 37 | ), "Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)" 38 | if type_code in cls._ext_type_codes: 39 | logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting") 40 | cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code 41 | return wrapped_type 42 | 43 | return wrap 44 | 45 | @classmethod 46 | def _encode_ext_types(cls, obj): 47 | type_code = cls._ext_types.get(type(obj)) 48 | if type_code is not None: 49 | return msgpack.ExtType(type_code, obj.packb()) 50 | elif isinstance(obj, tuple): 51 | # Tuples need to be handled separately to ensure that 52 | # 1. tuple serialization works and 2. tuples serialized not as lists 53 | data = msgpack.packb(list(obj), strict_types=True, use_bin_type=True, default=cls._encode_ext_types) 54 | return msgpack.ExtType(cls._TUPLE_EXT_TYPE_CODE, data) 55 | return obj 56 | 57 | @classmethod 58 | def _decode_ext_types(cls, type_code: int, data: bytes): 59 | if type_code in cls._ext_type_codes: 60 | return cls._ext_type_codes[type_code].unpackb(data) 61 | elif type_code == cls._TUPLE_EXT_TYPE_CODE: 62 | return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False)) 63 | 64 | logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is") 65 | return data 66 | 67 | @classmethod 68 | def dumps(cls, obj: object) -> bytes: 69 | return msgpack.dumps(obj, use_bin_type=True, default=cls._encode_ext_types, strict_types=True) 70 | 71 | @classmethod 72 | def loads(cls, buf: bytes) -> object: 73 | return msgpack.loads(buf, ext_hook=cls._decode_ext_types, raw=False) 74 | -------------------------------------------------------------------------------- /hivemind/utils/streaming.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for streaming tensors 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Iterable, Iterator 8 | 9 | from hivemind.proto import runtime_pb2 10 | from hivemind.utils.logging import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | STREAMING_CHUNK_SIZE_BYTES = 2**16 15 | 16 | 17 | def split_for_streaming( 18 | serialized_tensor: runtime_pb2.Tensor, 19 | chunk_size_bytes: int = STREAMING_CHUNK_SIZE_BYTES, 20 | ) -> Iterator[runtime_pb2.Tensor]: 21 | """Split serialized_tensor into multiple chunks for streaming""" 22 | buffer = memoryview(serialized_tensor.buffer) 23 | num_chunks = len(range(0, len(buffer), chunk_size_bytes)) 24 | yield runtime_pb2.Tensor( 25 | compression=serialized_tensor.compression, 26 | buffer=buffer[:chunk_size_bytes].tobytes(), 27 | chunks=num_chunks, 28 | size=serialized_tensor.size, 29 | dtype=serialized_tensor.dtype, 30 | requires_grad=serialized_tensor.requires_grad, 31 | ) 32 | for chunk_start in range(chunk_size_bytes, len(buffer), chunk_size_bytes): 33 | yield runtime_pb2.Tensor(buffer=buffer[chunk_start : chunk_start + chunk_size_bytes].tobytes()) 34 | 35 | 36 | def combine_from_streaming(stream: Iterable[runtime_pb2.Tensor]) -> runtime_pb2.Tensor: 37 | """Restore a result of split_into_chunks into a single serialized tensor""" 38 | stream = iter(stream) 39 | first_chunk = next(stream) 40 | serialized_tensor = runtime_pb2.Tensor() 41 | serialized_tensor.CopyFrom(first_chunk) 42 | buffer_chunks = [first_chunk.buffer] 43 | for tensor_part in stream: 44 | buffer_chunks.append(tensor_part.buffer) 45 | serialized_tensor.buffer = b"".join(buffer_chunks) 46 | return serialized_tensor 47 | -------------------------------------------------------------------------------- /hivemind/utils/tensor_descr.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from dataclasses import asdict, dataclass 5 | from typing import Tuple 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from hivemind.proto.runtime_pb2 import CompressionType 11 | from hivemind.utils.serializer import MSGPackSerializer 12 | 13 | DUMMY_BATCH_SIZE = 3 # used for dummy runs only 14 | 15 | warnings.filterwarnings("ignore", "CUDA initialization*", category=UserWarning) 16 | 17 | 18 | # ^-- cures https://github.com/pytorch/pytorch/issues/47038 19 | 20 | 21 | @dataclass(init=True, repr=True, frozen=True) 22 | class DescriptorBase: 23 | pass 24 | 25 | 26 | @dataclass(init=True, repr=True, frozen=True) 27 | class TensorDescriptor(DescriptorBase): 28 | size: tuple 29 | dtype: torch.dtype = None 30 | layout: torch.layout = torch.strided 31 | device: torch.device = None 32 | requires_grad: bool = False 33 | pin_memory: bool = False 34 | compression: CompressionType = CompressionType.NONE 35 | 36 | @property 37 | def shape(self) -> Tuple[int, ...]: 38 | return self.size 39 | 40 | def numel(self) -> int: 41 | return int(np.prod(self.size)) 42 | 43 | @classmethod 44 | def from_tensor(cls, tensor: torch.Tensor) -> TensorDescriptor: 45 | return cls( 46 | tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, _safe_check_pinned(tensor) 47 | ) 48 | 49 | def make_zeros(self, **kwargs): 50 | properties = asdict(self) 51 | properties.update(kwargs) 52 | properties.pop("compression") 53 | return torch.zeros(**properties) 54 | 55 | 56 | def _str_to_torch_type(name: str, torch_type: type): 57 | try: 58 | value = getattr(torch, name.split(".")[-1]) 59 | except AttributeError: 60 | raise ValueError(f"Invalid dtype: torch has no attribute {name}") 61 | if not isinstance(value, torch_type): 62 | raise ValueError(f"Invalid dtype: expected {torch_type}, got: {type(value)}") 63 | 64 | return value 65 | 66 | 67 | @MSGPackSerializer.ext_serializable(0x51) 68 | @dataclass(repr=True, frozen=True) 69 | class BatchTensorDescriptor(TensorDescriptor): 70 | """torch.Tensor with a variable 0-th dimension, used to describe batched data""" 71 | 72 | def __init__(self, *instance_size, **kwargs): # compatibility: allow initializing with *size 73 | if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)): 74 | instance_size = instance_size[0] # we were given size as the only parameter instead of *parameters 75 | super().__init__((None, *instance_size), **kwargs) 76 | 77 | @classmethod 78 | def from_tensor(cls, tensor: torch.Tensor, compression=CompressionType.NONE) -> BatchTensorDescriptor: 79 | return cls( 80 | *tensor.shape[1:], 81 | dtype=tensor.dtype, 82 | layout=tensor.layout, 83 | device=tensor.device, 84 | requires_grad=tensor.requires_grad, 85 | pin_memory=_safe_check_pinned(tensor), 86 | compression=compression if tensor.is_floating_point() else CompressionType.NONE, 87 | ) 88 | 89 | def make_zeros(self, *batch_size: int, **kwargs) -> torch.Tensor: 90 | assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)" 91 | return super().make_zeros(size=(*batch_size, *self.shape[1:]), **kwargs) 92 | 93 | def packb(self) -> bytes: 94 | obj_dict = asdict(self) 95 | 96 | obj_dict["dtype"] = str(self.dtype) if self.dtype is not None else None 97 | obj_dict["layout"] = str(self.layout) if self.layout is not None else None 98 | 99 | device = obj_dict.pop("device") 100 | device_type, device_index = (device.type, device.index) if device is not None else (None, None) 101 | obj_dict.update( 102 | device_type=device_type, 103 | device_index=device_index, 104 | ) 105 | 106 | return MSGPackSerializer.dumps(obj_dict) 107 | 108 | @classmethod 109 | def unpackb(cls, raw: bytes) -> BatchTensorDescriptor: 110 | obj_dict = MSGPackSerializer.loads(raw) 111 | 112 | if obj_dict["dtype"] is not None: 113 | obj_dict["dtype"] = _str_to_torch_type(obj_dict["dtype"], torch.dtype) 114 | 115 | if obj_dict["layout"] is not None: 116 | obj_dict["layout"] = _str_to_torch_type(obj_dict["layout"], torch.layout) 117 | 118 | if obj_dict["device_type"] is not None: 119 | obj_dict["device"] = torch.device(obj_dict["device_type"], obj_dict["device_index"]) 120 | else: 121 | obj_dict["device"] = None 122 | 123 | del obj_dict["device_type"], obj_dict["device_index"] 124 | 125 | size = obj_dict.pop("size")[1:] 126 | 127 | return BatchTensorDescriptor(*size, **obj_dict) 128 | 129 | 130 | def _safe_check_pinned(tensor: torch.Tensor) -> bool: 131 | """Check whether or not a tensor is pinned. If torch cannot initialize cuda, returns False instead of error.""" 132 | try: 133 | return torch.cuda.is_available() and tensor.is_pinned() 134 | except RuntimeError: 135 | return False 136 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.coverage.run] 2 | concurrency = ["thread", "multiprocessing"] 3 | omit = ["hivemind/proto/*"] 4 | source = ["hivemind"] 5 | parallel = true 6 | sigterm = true 7 | 8 | [tool.ruff] 9 | exclude = [ 10 | "__init__.py", 11 | "*_pb2.py" 12 | ] 13 | line-length = 119 14 | required-version = "==0.11.2" 15 | target-version = "py39" 16 | 17 | [tool.ruff.lint] 18 | select = ["E", "F", "W", "I", "YTT", "LOG"] 19 | ignore = ["E501", "E702", "E731"] 20 | dummy-variable-rgx = "^_$" 21 | 22 | [tool.ruff.lint.isort] 23 | known-local-folder = ["arguments", "test_utils", "tests", "utils"] 24 | 25 | [tool.pytest.ini_options] 26 | asyncio_default_fixture_loop_scope = "function" 27 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest==8.3.5 2 | pytest-forked 3 | pytest-asyncio==0.26.0 4 | pytest-cov 5 | pytest-timeout 6 | coverage 7 | tqdm 8 | scikit-learn 9 | codespell==2.2.2 10 | psutil 11 | pytest-xdist 12 | ruff==0.11.2 13 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | recommonmark==0.5.0 2 | sphinx_rtd_theme==0.4.3 3 | docutils==0.16 4 | sphinx==5.0.0 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML 2 | torch>=1.9.0 3 | numpy>=1.17 4 | scipy>=1.2.1 5 | prefetch_generator>=1.0.1 6 | msgpack>=0.5.6 7 | sortedcontainers 8 | uvloop>=0.14.0 9 | grpcio-tools>=1.33.2 10 | protobuf>=5.29.0 11 | configargparse>=1.2.3 12 | py-multihash>=0.2.3 13 | cryptography>=3.4.6 14 | pydantic>=2.0.0 15 | packaging>=20.9 16 | varint>=1.0.2 17 | base58>=1.0.2 18 | netaddr>=1.3.0 19 | idna>=3.10 20 | py-cid>=0.3.0 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import psutil 4 | import pytest 5 | 6 | from hivemind.utils.crypto import RSAPrivateKey 7 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 8 | from hivemind.utils.mpfuture import MPFuture 9 | 10 | use_hivemind_log_handler("in_root_logger") 11 | logger = get_logger(__name__) 12 | 13 | 14 | @pytest.fixture(autouse=True, scope="session") 15 | def cleanup_children(): 16 | yield 17 | 18 | with RSAPrivateKey._process_wide_key_lock: 19 | RSAPrivateKey._process_wide_key = None 20 | 21 | gc.collect() # Call .__del__() for removed objects 22 | 23 | MPFuture.reset_backend() 24 | 25 | children = psutil.Process().children(recursive=True) 26 | if children: 27 | _gone, alive = psutil.wait_procs(children, timeout=1) 28 | logger.debug(f"Cleaning up {len(alive)} leftover child processes") 29 | for child in alive: 30 | child.terminate() 31 | _gone, alive = psutil.wait_procs(alive, timeout=1) 32 | for child in alive: 33 | child.kill() 34 | -------------------------------------------------------------------------------- /tests/test_cli_scripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from subprocess import PIPE, Popen 4 | from time import sleep 5 | 6 | _DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") 7 | 8 | 9 | def test_dht_connection_successful(): 10 | dht_refresh_period = 1 11 | 12 | cloned_env = os.environ.copy() 13 | # overriding the loglevel to prevent debug print statements 14 | cloned_env["HIVEMIND_LOGLEVEL"] = "INFO" 15 | 16 | dht_proc = Popen( 17 | ["hivemind-dht", "--host_maddrs", "/ip4/127.0.0.1/tcp/0", "--refresh_period", str(dht_refresh_period)], 18 | stderr=PIPE, 19 | text=True, 20 | encoding="utf-8", 21 | env=cloned_env, 22 | ) 23 | 24 | first_line = dht_proc.stderr.readline() 25 | second_line = dht_proc.stderr.readline() 26 | dht_pattern_match = _DHT_START_PATTERN.search(first_line) 27 | assert dht_pattern_match is not None, first_line 28 | assert "Full list of visible multiaddresses:" in second_line, second_line 29 | 30 | initial_peers = dht_pattern_match.group(1).split(" ") 31 | 32 | dht_client_proc = Popen( 33 | [ 34 | "hivemind-dht", 35 | *initial_peers, 36 | "--host_maddrs", 37 | "/ip4/127.0.0.1/tcp/0", 38 | "--refresh_period", 39 | str(dht_refresh_period), 40 | ], 41 | stderr=PIPE, 42 | text=True, 43 | encoding="utf-8", 44 | env=cloned_env, 45 | ) 46 | 47 | # ensure we get the output of dht_proc after the start of dht_client_proc 48 | sleep(5 * dht_refresh_period) 49 | 50 | # skip first two lines with connectivity info 51 | for _ in range(2): 52 | dht_client_proc.stderr.readline() 53 | first_report_msg = dht_client_proc.stderr.readline() 54 | 55 | assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg 56 | 57 | # expect that one of the next logging outputs from the first peer shows a new connection 58 | for _ in range(20): 59 | first_report_msg = dht_proc.stderr.readline() 60 | second_report_msg = dht_proc.stderr.readline() 61 | 62 | if ( 63 | "2 DHT nodes (including this one) are in the local routing table" in first_report_msg 64 | and "Local storage contains 0 keys" in second_report_msg 65 | ): 66 | break 67 | else: 68 | assert ( 69 | "2 DHT nodes (including this one) are in the local routing table" in first_report_msg 70 | and "Local storage contains 0 keys" in second_report_msg 71 | ) 72 | 73 | dht_proc.stderr.close() 74 | dht_client_proc.stderr.close() 75 | 76 | dht_proc.terminate() 77 | dht_client_proc.terminate() 78 | 79 | dht_proc.wait() 80 | dht_client_proc.wait() 81 | -------------------------------------------------------------------------------- /tests/test_custom_experts.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | 6 | from hivemind.dht import DHT 7 | from hivemind.moe.client.expert import create_remote_experts 8 | from hivemind.moe.expert_uid import ExpertInfo 9 | from hivemind.moe.server import background_server 10 | 11 | CUSTOM_EXPERTS_PATH = os.path.join(os.path.dirname(__file__), "test_utils", "custom_networks.py") 12 | 13 | 14 | @pytest.mark.forked 15 | def test_custom_expert(hid_dim=16): 16 | with background_server( 17 | expert_cls="perceptron", 18 | num_experts=2, 19 | device="cpu", 20 | hidden_dim=hid_dim, 21 | num_handlers=2, 22 | custom_module_path=CUSTOM_EXPERTS_PATH, 23 | ) as server_peer_info: 24 | dht = DHT(initial_peers=server_peer_info.addrs, start=True) 25 | expert0, expert1 = create_remote_experts( 26 | [ 27 | ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id), 28 | ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id), 29 | ], 30 | dht=dht, 31 | ) 32 | 33 | for batch_size in (1, 4): 34 | batch = torch.randn(batch_size, hid_dim) 35 | 36 | output0 = expert0(batch) 37 | output1 = expert1(batch) 38 | 39 | loss = output0.sum() 40 | loss.backward() 41 | loss = output1.sum() 42 | loss.backward() 43 | 44 | 45 | @pytest.mark.forked 46 | def test_multihead_expert(hid_dim=16): 47 | with background_server( 48 | expert_cls="multihead", 49 | num_experts=2, 50 | device="cpu", 51 | hidden_dim=hid_dim, 52 | num_handlers=2, 53 | custom_module_path=CUSTOM_EXPERTS_PATH, 54 | ) as server_peer_info: 55 | dht = DHT(initial_peers=server_peer_info.addrs, start=True) 56 | expert0, expert1 = create_remote_experts( 57 | [ 58 | ExpertInfo(uid="expert.0", peer_id=server_peer_info.peer_id), 59 | ExpertInfo(uid="expert.1", peer_id=server_peer_info.peer_id), 60 | ], 61 | dht=dht, 62 | ) 63 | 64 | for batch_size in (1, 4): 65 | batch = ( 66 | torch.randn(batch_size, hid_dim), 67 | torch.randn(batch_size, 2 * hid_dim), 68 | torch.randn(batch_size, 3 * hid_dim), 69 | ) 70 | 71 | output0 = expert0(*batch) 72 | output1 = expert1(*batch) 73 | 74 | loss = output0.sum() 75 | loss.backward() 76 | loss = output1.sum() 77 | loss.backward() 78 | -------------------------------------------------------------------------------- /tests/test_dht.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures 3 | import random 4 | import time 5 | 6 | import pytest 7 | 8 | import hivemind 9 | from hivemind.utils.multiaddr import Multiaddr 10 | 11 | from test_utils.dht_swarms import launch_dht_instances 12 | from test_utils.networking import get_free_port 13 | 14 | 15 | @pytest.mark.forked 16 | def test_get_store(n_peers=10): 17 | peers = launch_dht_instances(n_peers) 18 | 19 | node1, node2 = random.sample(peers, 2) 20 | assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30) 21 | assert node1.get("key1").value == "value1" 22 | assert node2.get("key1").value == "value1" 23 | assert node2.get("key2") is None 24 | 25 | future = node1.get("foo", return_future=True) 26 | assert future.result() is None 27 | 28 | future = node1.get("foo", return_future=True) 29 | future.cancel() 30 | 31 | assert node2.store("key1", 123, expiration_time=hivemind.get_dht_time() + 31) 32 | assert node2.store("key2", 456, expiration_time=hivemind.get_dht_time() + 32) 33 | assert node1.get("key1", latest=True).value == 123 34 | assert node1.get("key2").value == 456 35 | 36 | assert node1.store("key2", subkey="subkey1", value=789, expiration_time=hivemind.get_dht_time() + 32) 37 | assert node2.store("key2", subkey="subkey2", value="pew", expiration_time=hivemind.get_dht_time() + 32) 38 | found_dict = node1.get("key2", latest=True).value 39 | assert isinstance(found_dict, dict) and len(found_dict) == 2 40 | assert found_dict["subkey1"].value == 789 and found_dict["subkey2"].value == "pew" 41 | 42 | for peer in peers: 43 | peer.shutdown() 44 | 45 | 46 | async def dummy_dht_coro(self, node): 47 | return "pew" 48 | 49 | 50 | async def dummy_dht_coro_error(self, node): 51 | raise ValueError("Oops, i did it again...") 52 | 53 | 54 | async def dummy_dht_coro_stateful(self, node): 55 | self._x_dummy = getattr(self, "_x_dummy", 123) + 1 56 | return self._x_dummy 57 | 58 | 59 | async def dummy_dht_coro_long(self, node): 60 | await asyncio.sleep(0.25) 61 | return self._x_dummy**2 62 | 63 | 64 | async def dummy_dht_coro_for_cancel(self, node): 65 | self._x_dummy = -100 66 | await asyncio.sleep(0.5) 67 | self._x_dummy = 999 68 | 69 | 70 | @pytest.mark.forked 71 | def test_run_coroutine(): 72 | dht = hivemind.DHT(start=True) 73 | assert dht.run_coroutine(dummy_dht_coro) == "pew" 74 | 75 | with pytest.raises(ValueError): 76 | dht.run_coroutine(dummy_dht_coro_error) 77 | 78 | bg_task = dht.run_coroutine(dummy_dht_coro_long, return_future=True) 79 | assert dht.run_coroutine(dummy_dht_coro_stateful) == 124 80 | assert dht.run_coroutine(dummy_dht_coro_stateful) == 125 81 | assert dht.run_coroutine(dummy_dht_coro_stateful) == 126 82 | assert not hasattr(dht, "_x_dummy") 83 | assert bg_task.result() == 126**2 84 | 85 | future = dht.run_coroutine(dummy_dht_coro_for_cancel, return_future=True) 86 | time.sleep(0.25) 87 | future.cancel() 88 | assert dht.run_coroutine(dummy_dht_coro_stateful) == -99 89 | 90 | dht.shutdown() 91 | 92 | 93 | @pytest.mark.forked 94 | @pytest.mark.asyncio 95 | async def test_dht_get_visible_maddrs(): 96 | # test 1: IPv4 localhost multiaddr is visible by default 97 | 98 | dht = hivemind.DHT(start=True) 99 | 100 | assert any(str(maddr).startswith("/ip4/127.0.0.1") for maddr in dht.get_visible_maddrs()) 101 | dht.shutdown() 102 | 103 | # test 2: announce_maddrs are the single visible multiaddrs if defined 104 | 105 | dummy_endpoint = Multiaddr("/ip4/123.45.67.89/tcp/31337") 106 | p2p = await hivemind.p2p.P2P.create(announce_maddrs=[dummy_endpoint]) 107 | dht = hivemind.DHT(start=True, p2p=p2p) 108 | 109 | assert dht.get_visible_maddrs() == [dummy_endpoint.encapsulate(f"/p2p/{p2p.peer_id}")] 110 | dht.shutdown() 111 | 112 | 113 | @pytest.mark.asyncio 114 | async def test_startup_error(): 115 | with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"(?i)Failed to connect to bootstrap peers"): 116 | hivemind.DHT( 117 | initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"], 118 | start=True, 119 | ) 120 | 121 | dht = hivemind.DHT(start=True, await_ready=False) 122 | with pytest.raises(concurrent.futures.TimeoutError): 123 | dht.wait_until_ready(timeout=0.01) 124 | dht.shutdown() 125 | -------------------------------------------------------------------------------- /tests/test_dht_validation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Dict 3 | 4 | import pytest 5 | from pydantic.v1 import BaseModel, StrictInt 6 | 7 | import hivemind 8 | from hivemind.dht.crypto import RSASignatureValidator 9 | from hivemind.dht.protocol import DHTProtocol 10 | from hivemind.dht.routing import DHTID 11 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator 12 | from hivemind.dht.validation import CompositeValidator, DHTRecord 13 | 14 | 15 | class SchemaA(BaseModel): 16 | field_a: bytes 17 | 18 | 19 | class SchemaB(BaseModel): 20 | field_b: Dict[BytesWithPublicKey, StrictInt] 21 | 22 | 23 | @pytest.fixture 24 | def validators_for_app(): 25 | # Each application may add its own validator set 26 | return { 27 | "A": [RSASignatureValidator(), SchemaValidator(SchemaA, allow_extra_keys=False)], 28 | "B": [SchemaValidator(SchemaB, allow_extra_keys=False), RSASignatureValidator()], 29 | } 30 | 31 | 32 | @pytest.mark.forked 33 | def test_dht_add_validators(validators_for_app): 34 | # One app may create a DHT with its validators 35 | dht = hivemind.DHT(start=False, record_validators=validators_for_app["A"]) 36 | 37 | # While the DHT process is not started, you can't send a command to append new validators 38 | with pytest.raises(RuntimeError): 39 | dht.add_validators(validators_for_app["B"]) 40 | dht.run_in_background(await_ready=True) 41 | 42 | # After starting the process, other apps may add new validators to the existing DHT 43 | dht.add_validators(validators_for_app["B"]) 44 | 45 | assert dht.store("field_a", b"bytes_value", hivemind.get_dht_time() + 10) 46 | assert dht.get("field_a", latest=True).value == b"bytes_value" 47 | 48 | assert not dht.store("field_a", 666, hivemind.get_dht_time() + 10) 49 | assert dht.get("field_a", latest=True).value == b"bytes_value" 50 | 51 | local_public_key = validators_for_app["A"][0].local_public_key 52 | assert dht.store("field_b", 777, hivemind.get_dht_time() + 10, subkey=local_public_key) 53 | dictionary = dht.get("field_b", latest=True).value 54 | assert len(dictionary) == 1 and dictionary[local_public_key].value == 777 55 | 56 | assert not dht.store("unknown_key", 666, hivemind.get_dht_time() + 10) 57 | assert dht.get("unknown_key", latest=True) is None 58 | 59 | 60 | def test_composite_validator(validators_for_app): 61 | validator = CompositeValidator(validators_for_app["A"]) 62 | assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator] 63 | 64 | validator.extend(validators_for_app["B"]) 65 | assert [type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator] 66 | assert len(validator._validators[0]._schemas) == 2 67 | 68 | local_public_key = validators_for_app["A"][0].local_public_key 69 | record = DHTRecord( 70 | key=DHTID.generate(source="field_b").to_bytes(), 71 | subkey=DHTProtocol.serializer.dumps(local_public_key), 72 | value=DHTProtocol.serializer.dumps(777), 73 | expiration_time=hivemind.get_dht_time() + 10, 74 | ) 75 | 76 | signed_record = dataclasses.replace(record, value=validator.sign_value(record)) 77 | # Expect only one signature since two RSASignatureValidatos have been merged 78 | assert signed_record.value.count(b"[signature:") == 1 79 | # Expect successful validation since the second SchemaValidator has been merged to the first 80 | assert validator.validate(signed_record) 81 | assert validator.strip_value(signed_record) == record.value 82 | 83 | record = DHTRecord( 84 | key=DHTID.generate(source="unknown_key").to_bytes(), 85 | subkey=DHTProtocol.IS_REGULAR_VALUE, 86 | value=DHTProtocol.serializer.dumps(777), 87 | expiration_time=hivemind.get_dht_time() + 10, 88 | ) 89 | 90 | signed_record = dataclasses.replace(record, value=validator.sign_value(record)) 91 | assert signed_record.value.count(b"[signature:") == 0 92 | # Expect failed validation since `unknown_key` is not a part of any schema 93 | assert not validator.validate(signed_record) 94 | -------------------------------------------------------------------------------- /tests/test_expert_backend.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tempfile import TemporaryDirectory 3 | 4 | import pytest 5 | import torch 6 | from torch.nn import Linear 7 | 8 | from hivemind import BatchTensorDescriptor, ModuleBackend 9 | from hivemind.moe.server.checkpoints import load_experts, store_experts 10 | from hivemind.moe.server.layers.lr_schedule import get_linear_schedule_with_warmup 11 | 12 | EXPERT_WEIGHT_UPDATES = 3 13 | BACKWARD_PASSES_BEFORE_SAVE = 2 14 | BACKWARD_PASSES_AFTER_SAVE = 2 15 | EXPERT_NAME = "test_expert" 16 | PEAK_LR = 1.0 17 | 18 | 19 | @pytest.fixture 20 | def example_experts(): 21 | expert = Linear(1, 1) 22 | opt = torch.optim.SGD(expert.parameters(), PEAK_LR) 23 | 24 | args_schema = (BatchTensorDescriptor(1),) 25 | expert_backend = ModuleBackend( 26 | name=EXPERT_NAME, 27 | module=expert, 28 | optimizer=opt, 29 | scheduler=get_linear_schedule_with_warmup( 30 | opt, 31 | num_warmup_steps=BACKWARD_PASSES_BEFORE_SAVE, 32 | num_training_steps=BACKWARD_PASSES_BEFORE_SAVE + BACKWARD_PASSES_AFTER_SAVE, 33 | ), 34 | args_schema=args_schema, 35 | outputs_schema=BatchTensorDescriptor(1), 36 | max_batch_size=1, 37 | ) 38 | experts = {EXPERT_NAME: expert_backend} 39 | yield experts 40 | 41 | 42 | @pytest.mark.forked 43 | def test_save_load_checkpoints(example_experts): 44 | expert = example_experts[EXPERT_NAME].module 45 | 46 | with TemporaryDirectory() as tmpdir: 47 | tmp_path = Path(tmpdir) 48 | 49 | for i in range(1, EXPERT_WEIGHT_UPDATES + 1): 50 | expert.weight.data[0] = i 51 | store_experts(example_experts, tmp_path) 52 | 53 | checkpoints_dir = tmp_path / EXPERT_NAME 54 | 55 | assert checkpoints_dir.exists() 56 | # include checkpoint_last.pt 57 | assert len(list(checkpoints_dir.iterdir())) == EXPERT_WEIGHT_UPDATES + 1 58 | 59 | expert.weight.data[0] = 0 60 | 61 | load_experts(example_experts, tmp_path) 62 | assert expert.weight.data[0] == EXPERT_WEIGHT_UPDATES 63 | 64 | 65 | @pytest.mark.forked 66 | def test_restore_update_count(example_experts): 67 | expert_backend = example_experts[EXPERT_NAME] 68 | 69 | batch = torch.randn(1, 1) 70 | loss_grad = torch.randn(1, 1) 71 | 72 | with TemporaryDirectory() as tmpdir: 73 | tmp_path = Path(tmpdir) 74 | 75 | for _ in range(BACKWARD_PASSES_BEFORE_SAVE): 76 | expert_backend.backward(batch, loss_grad) 77 | 78 | store_experts(example_experts, tmp_path) 79 | 80 | for _ in range(BACKWARD_PASSES_AFTER_SAVE): 81 | expert_backend.backward(batch, loss_grad) 82 | 83 | load_experts(example_experts, tmp_path) 84 | assert expert_backend.scheduler._step_count == BACKWARD_PASSES_BEFORE_SAVE + 1 85 | 86 | 87 | @pytest.mark.forked 88 | def test_lr_schedule(example_experts): 89 | expert_backend = example_experts[EXPERT_NAME] 90 | optimizer = expert_backend.optimizer 91 | 92 | batch = torch.randn(1, 1) 93 | loss_grad = torch.randn(1, 1) 94 | 95 | with TemporaryDirectory() as tmpdir: 96 | tmp_path = Path(tmpdir) 97 | 98 | assert optimizer.param_groups[0]["lr"] == 0.0 99 | 100 | for i in range(BACKWARD_PASSES_BEFORE_SAVE): 101 | assert optimizer.param_groups[0]["lr"] == PEAK_LR * i / BACKWARD_PASSES_BEFORE_SAVE 102 | expert_backend.backward(batch, loss_grad) 103 | 104 | assert optimizer.param_groups[0]["lr"] == PEAK_LR 105 | store_experts(example_experts, tmp_path) 106 | 107 | for i in range(BACKWARD_PASSES_AFTER_SAVE): 108 | assert optimizer.param_groups[0]["lr"] == PEAK_LR * (1 - (i / BACKWARD_PASSES_AFTER_SAVE)) 109 | expert_backend.backward(batch, loss_grad) 110 | 111 | assert optimizer.param_groups[0]["lr"] == 0.0 112 | load_experts(example_experts, tmp_path) 113 | assert optimizer.param_groups[0]["lr"] == PEAK_LR 114 | -------------------------------------------------------------------------------- /tests/test_relays.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import pytest 5 | 6 | import hivemind 7 | 8 | 9 | async def ping_to_client(dht, node, peer_id: str): 10 | return await node.protocol.call_ping(hivemind.PeerID.from_base58(str(peer_id))) 11 | 12 | 13 | @pytest.mark.forked 14 | @pytest.mark.parametrize( 15 | "use_auto_relay,use_relay", 16 | [ 17 | (True, True), 18 | (False, False), 19 | ], 20 | ) 21 | def test_autorelay(use_auto_relay: bool, use_relay: bool): 22 | dht_first_peer = hivemind.DHT( 23 | start=True, 24 | use_auto_relay=use_auto_relay, 25 | use_relay=use_relay, 26 | force_reachability="public", 27 | ) 28 | dht_first_peer_id = dht_first_peer.peer_id 29 | initial_peers = dht_first_peer.get_visible_maddrs() 30 | assert dht_first_peer_id is not None 31 | 32 | dht_third_peer = hivemind.DHT( 33 | initial_peers=initial_peers, 34 | host_maddrs=[], 35 | start=True, 36 | no_listen=True, 37 | use_relay=use_relay, 38 | client_mode=False, 39 | use_auto_relay=use_auto_relay, 40 | ) 41 | time.sleep(5) 42 | dht_second_peer = hivemind.DHT( 43 | initial_peers=initial_peers, 44 | start=True, 45 | client_mode=False, 46 | no_listen=False, 47 | use_relay=use_relay, 48 | use_auto_relay=use_auto_relay, 49 | ) 50 | 51 | assert dht_first_peer.is_alive() and dht_second_peer.is_alive() and dht_third_peer.is_alive() 52 | 53 | time_start = time.perf_counter() 54 | while time.perf_counter() - time_start < 30: 55 | reached_ip = dht_second_peer.run_coroutine(partial(ping_to_client, peer_id=dht_third_peer.peer_id)) 56 | if reached_ip: 57 | assert use_relay 58 | break 59 | time.sleep(2) 60 | else: 61 | assert not use_relay 62 | 63 | for peer in dht_first_peer, dht_second_peer, dht_third_peer: 64 | peer.shutdown() 65 | -------------------------------------------------------------------------------- /tests/test_start_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from functools import partial 4 | from subprocess import PIPE, Popen 5 | from tempfile import TemporaryDirectory 6 | 7 | import pytest 8 | 9 | from hivemind.moe.server import background_server 10 | 11 | 12 | def cleanup_process(process, timeout=5): 13 | try: 14 | process.terminate() 15 | process.wait(timeout=timeout) # Add timeout to wait 16 | except: # noqa: E722 17 | process.kill() 18 | process.wait(timeout=timeout) 19 | 20 | 21 | @pytest.mark.xfail(reason="Flaky test", strict=False) 22 | def test_background_server_identity_path(): 23 | with TemporaryDirectory() as tempdir: 24 | id_path = os.path.join(tempdir, "id") 25 | 26 | server_runner = partial(background_server, num_experts=1, device="cpu", hidden_dim=1) 27 | 28 | with ( 29 | server_runner(identity_path=id_path) as server_info_1, 30 | server_runner(identity_path=id_path) as server_info_2, 31 | server_runner(identity_path=None) as server_info_3, 32 | ): 33 | assert server_info_1.peer_id == server_info_2.peer_id 34 | assert server_info_1.peer_id != server_info_3.peer_id 35 | assert server_info_3.peer_id == server_info_3.peer_id 36 | 37 | 38 | @pytest.mark.xfail(reason="Flaky test", strict=False) 39 | def test_cli_run_server_identity_path(): 40 | pattern = r"Running DHT node on \[(.+)\]," 41 | 42 | with TemporaryDirectory() as tempdir: 43 | id_path = os.path.join(tempdir, "id") 44 | 45 | cloned_env = os.environ.copy() 46 | # overriding the loglevel to prevent debug print statements 47 | cloned_env["HIVEMIND_LOGLEVEL"] = "INFO" 48 | 49 | common_server_args = ["--hidden_dim", "4", "--num_handlers", "1"] 50 | 51 | server_1_proc = Popen( 52 | ["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args, 53 | stderr=PIPE, 54 | text=True, 55 | encoding="utf-8", 56 | env=cloned_env, 57 | ) 58 | 59 | line = server_1_proc.stderr.readline() 60 | assert "Generating new identity" in line 61 | 62 | line = server_1_proc.stderr.readline() 63 | addrs_pattern_result = re.search(pattern, line) 64 | assert addrs_pattern_result is not None, line 65 | addrs_1 = set(addrs_pattern_result.group(1).split(", ")) 66 | ids_1 = set(a.split("/")[-1] for a in addrs_1) 67 | 68 | assert len(ids_1) == 1 69 | 70 | server_2_proc = Popen( 71 | ["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args, 72 | stderr=PIPE, 73 | text=True, 74 | encoding="utf-8", 75 | env=cloned_env, 76 | ) 77 | 78 | line = server_2_proc.stderr.readline() 79 | addrs_pattern_result = re.search(pattern, line) 80 | assert addrs_pattern_result is not None, line 81 | addrs_2 = set(addrs_pattern_result.group(1).split(", ")) 82 | ids_2 = set(a.split("/")[-1] for a in addrs_2) 83 | 84 | assert len(ids_2) == 1 85 | 86 | server_3_proc = Popen( 87 | ["hivemind-server", "--num_experts", "1"] + common_server_args, 88 | stderr=PIPE, 89 | text=True, 90 | encoding="utf-8", 91 | env=cloned_env, 92 | ) 93 | 94 | line = server_3_proc.stderr.readline() 95 | addrs_pattern_result = re.search(pattern, line) 96 | assert addrs_pattern_result is not None, line 97 | addrs_3 = set(addrs_pattern_result.group(1).split(", ")) 98 | ids_3 = set(a.split("/")[-1] for a in addrs_3) 99 | 100 | assert len(ids_3) == 1 101 | 102 | assert ids_1 == ids_2 103 | assert ids_1 != ids_3 104 | assert ids_2 != ids_3 105 | 106 | assert addrs_1 != addrs_2 107 | assert addrs_1 != addrs_3 108 | assert addrs_2 != addrs_3 109 | 110 | for p in [server_1_proc, server_2_proc, server_3_proc]: 111 | cleanup_process(p) 112 | -------------------------------------------------------------------------------- /tests/test_utils/custom_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from hivemind.moe import register_expert_class 6 | 7 | sample_input = lambda batch_size, hidden_dim: torch.empty((batch_size, hidden_dim)) 8 | 9 | 10 | @register_expert_class("perceptron", sample_input) 11 | class MultilayerPerceptron(nn.Module): 12 | def __init__(self, hidden_dim, num_classes=10): 13 | super().__init__() 14 | self.layer1 = nn.Linear(hidden_dim, 2 * hidden_dim) 15 | self.layer2 = nn.Linear(2 * hidden_dim, 2 * hidden_dim) 16 | self.layer3 = nn.Linear(2 * hidden_dim, num_classes) 17 | 18 | def forward(self, x): 19 | x = F.relu(self.layer1(x)) 20 | x = F.relu(self.layer2(x)) 21 | x = self.layer3(x) 22 | return x 23 | 24 | 25 | multihead_sample_input = lambda batch_size, hidden_dim: ( 26 | torch.empty((batch_size, hidden_dim)), 27 | torch.empty((batch_size, 2 * hidden_dim)), 28 | torch.empty((batch_size, 3 * hidden_dim)), 29 | ) 30 | 31 | 32 | @register_expert_class("multihead", multihead_sample_input) 33 | class MultiheadNetwork(nn.Module): 34 | def __init__(self, hidden_dim, num_classes=10): 35 | super().__init__() 36 | self.layer1 = nn.Linear(hidden_dim, num_classes) 37 | self.layer2 = nn.Linear(2 * hidden_dim, num_classes) 38 | self.layer3 = nn.Linear(3 * hidden_dim, num_classes) 39 | 40 | def forward(self, x1, x2, x3): 41 | x = self.layer1(x1) + self.layer2(x2) + self.layer3(x3) 42 | return x 43 | -------------------------------------------------------------------------------- /tests/test_utils/dht_swarms.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import multiprocessing as mp 3 | import random 4 | import signal 5 | import threading 6 | from typing import Dict, List, Tuple 7 | 8 | from hivemind.dht import DHT 9 | from hivemind.dht.node import DHTID, DHTNode 10 | from hivemind.p2p import PeerID 11 | from hivemind.utils.multiaddr import Multiaddr 12 | 13 | 14 | def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue, **kwargs): 15 | if asyncio.get_event_loop().is_running(): 16 | asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop 17 | asyncio.set_event_loop(asyncio.new_event_loop()) 18 | loop = asyncio.get_event_loop() 19 | 20 | node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, **kwargs)) 21 | maddrs = loop.run_until_complete(node.get_visible_maddrs()) 22 | 23 | info_queue.put((node.node_id, node.peer_id, maddrs)) 24 | 25 | async def shutdown(): 26 | await node.shutdown() 27 | loop.stop() 28 | 29 | loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown())) 30 | loop.run_forever() 31 | 32 | 33 | def launch_swarm_in_separate_processes( 34 | n_peers: int, n_sequential_peers: int, **kwargs 35 | ) -> Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]: 36 | assert n_sequential_peers < n_peers, ( 37 | "Parameters imply that first n_sequential_peers of n_peers will be run sequentially" 38 | ) 39 | 40 | processes = [] 41 | dht = {} 42 | swarm_maddrs = [] 43 | 44 | info_queue = mp.Queue() 45 | info_lock = mp.RLock() 46 | 47 | for _ in range(n_sequential_peers): 48 | initial_peers = random.choice(swarm_maddrs) if swarm_maddrs else [] 49 | 50 | proc = mp.Process(target=run_node, args=(initial_peers, info_queue), kwargs=kwargs, daemon=True) 51 | proc.start() 52 | processes.append(proc) 53 | 54 | node_id, peer_id, peer_maddrs = info_queue.get() 55 | dht[peer_id] = node_id 56 | swarm_maddrs.append(peer_maddrs) 57 | 58 | def collect_info(): 59 | while True: 60 | node_id, peer_id, peer_maddrs = info_queue.get() 61 | with info_lock: 62 | dht[peer_id] = node_id 63 | swarm_maddrs.append(peer_maddrs) 64 | 65 | if len(dht) == n_peers: 66 | break 67 | 68 | collect_thread = threading.Thread(target=collect_info) 69 | collect_thread.start() 70 | 71 | for _ in range(n_peers - n_sequential_peers): 72 | with info_lock: 73 | initial_peers = random.choice(swarm_maddrs) 74 | 75 | proc = mp.Process(target=run_node, args=(initial_peers, info_queue), kwargs=kwargs, daemon=True) 76 | proc.start() 77 | processes.append(proc) 78 | 79 | collect_thread.join() 80 | 81 | return processes, dht, swarm_maddrs 82 | 83 | 84 | async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]: 85 | nodes = [await DHTNode.create(**kwargs)] 86 | initial_peers = await nodes[0].get_visible_maddrs() 87 | nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)]) 88 | return nodes 89 | 90 | 91 | def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]: 92 | dhts = [DHT(start=True, **kwargs)] 93 | initial_peers = dhts[0].get_visible_maddrs() 94 | 95 | dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1)) 96 | for process in dhts[1:]: 97 | process.wait_until_ready() 98 | 99 | return dhts 100 | -------------------------------------------------------------------------------- /tests/test_utils/networking.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from contextlib import closing 3 | 4 | 5 | def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)): 6 | """ 7 | Finds a tcp port that can be occupied with a socket with *params and use *opt options. 8 | 9 | :note: Using this function is discouraged since it often leads to a race condition 10 | with the "Address is already in use" error if the code is run in parallel. 11 | """ 12 | try: 13 | with closing(socket.socket(*params)) as sock: 14 | sock.bind(("", 0)) 15 | sock.setsockopt(*opt) 16 | return sock.getsockname()[1] 17 | except Exception as e: 18 | raise e 19 | --------------------------------------------------------------------------------