├── src
└── node0
│ ├── __init__.py
│ ├── security
│ ├── integrity_check.cpython-311-aarch64-linux-gnu.so
│ ├── integrity_check.cpython-311-x86_64-linux-gnu.so
│ ├── validation.py
│ └── authorization.py
│ ├── utils
│ ├── __init__.py
│ ├── flops.py
│ ├── dht_monitor.py
│ ├── dht_partition.py
│ ├── mem_monitor.py
│ ├── logging.py
│ ├── get_parameters.py
│ ├── connection_test_server.py
│ ├── network_throughput.py
│ ├── common.py
│ └── node_info.py
│ ├── models
│ ├── llama
│ │ └── arguments.py
│ ├── arguments.py
│ └── lr_schedule.py
│ ├── configs
│ └── llama_8B_C.yaml
│ ├── server
│ ├── module_collab.py
│ ├── matchmaking.py
│ ├── state_averager_wrap.py
│ ├── power_sgd_averager.py
│ ├── HM_gradient_averager.py
│ ├── node0_server.py
│ └── optim.py
│ └── run_server.py
├── .github
├── CODEOWNERS
├── assets
│ └── dashboard-button.svg
└── ISSUE_TEMPLATE
│ └── bug_report.md
├── images
├── node0-logo-black.png
├── node0-logo-white.png
├── runpod_edit_pod.png
├── aws_inbound_rules.png
├── gcp_inbound_rules.png
├── lambda_inbound_rules.png
├── runpod_external_port.png
├── runpod_inbound_rules.png
├── tensordock_external_port.png
└── tensordock_forwarded_ports.png
├── run.json
├── Dockerfile
├── pyproject.toml
├── THIRD_PARTY_LICENSES.md
├── docs
├── useful_links.md
├── network.md
└── cloud.md
├── .gitignore
└── LICENSE
/src/node0/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @violettapluralis @gilpluralis @yanpluralis
--------------------------------------------------------------------------------
/images/node0-logo-black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/node0-logo-black.png
--------------------------------------------------------------------------------
/images/node0-logo-white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/node0-logo-white.png
--------------------------------------------------------------------------------
/images/runpod_edit_pod.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/runpod_edit_pod.png
--------------------------------------------------------------------------------
/images/aws_inbound_rules.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/aws_inbound_rules.png
--------------------------------------------------------------------------------
/images/gcp_inbound_rules.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/gcp_inbound_rules.png
--------------------------------------------------------------------------------
/images/lambda_inbound_rules.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/lambda_inbound_rules.png
--------------------------------------------------------------------------------
/images/runpod_external_port.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/runpod_external_port.png
--------------------------------------------------------------------------------
/images/runpod_inbound_rules.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/runpod_inbound_rules.png
--------------------------------------------------------------------------------
/images/tensordock_external_port.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/tensordock_external_port.png
--------------------------------------------------------------------------------
/images/tensordock_forwarded_ports.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/tensordock_forwarded_ports.png
--------------------------------------------------------------------------------
/src/node0/security/integrity_check.cpython-311-aarch64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/src/node0/security/integrity_check.cpython-311-aarch64-linux-gnu.so
--------------------------------------------------------------------------------
/src/node0/security/integrity_check.cpython-311-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/src/node0/security/integrity_check.cpython-311-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/run.json:
--------------------------------------------------------------------------------
1 | {
2 | "run_config": "llama_8B_C",
3 | "auth_server": "https://auth.pluralis.ai",
4 | "seeds": ["/ip4/34.215.69.49/tcp/49200/p2p/QmQQc9r9pKF5hw3MbXHqRhGAHh4JZKiuv1hCenmGKV9e2Y", "/ip4/54.70.227.132/tcp/49200/p2p/QmSzC8yq5yqhiTjP2gRaTaUZbjik2ZrdbCc8rP7dXRdP21", "/ip4/34.208.159.142/tcp/49200/p2p/QmNY87dq6tmNXZJhWYddyA5WbVLY6FEktU1QTXdraYCSa3", "/ip4/35.89.245.43/tcp/49200/p2p/QmcRSpPVX3itcEedowDdLrSdSN7kMTC8jkkbTN5UYsoRH4"]
5 | }
--------------------------------------------------------------------------------
/.github/assets/dashboard-button.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/node0/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .common import build_cls, clean_tmp, infer_expert_params
16 | from .logging import Node0Logger
17 | from .monitor import MonitorWorker
18 | from .dht_partition import update_initial_peers
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04
2 |
3 | WORKDIR /home
4 | # Set en_US.UTF-8 locale by default
5 | RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment
6 |
7 | # Install packages
8 | RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \
9 | build-essential \
10 | rsync \
11 | openssh-client \
12 | curl \
13 | wget \
14 | git \
15 | vim \
16 | && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/*
17 |
18 | # Install conda
19 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \
20 | bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh
21 | ENV PATH="/opt/conda/bin:${PATH}"
22 |
23 | # Accept conda TOS
24 | RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \
25 | conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r
26 |
27 | # Install torch
28 | RUN conda install python=3.11 pip
29 |
30 | # Install node0 lib
31 | COPY . node0/
32 | RUN cd node0 && pip install .
33 |
34 | CMD ["bash"]
35 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "node0"
3 | version = "0.1.0"
4 | readme = "README.md"
5 | requires-python = ">=3.11,<3.12"
6 | license = "Apache-2.0"
7 | license-files = ["*LICEN[CS]E*"]
8 | dependencies = [
9 | "PyYAML",
10 | "prometheus_client",
11 | "torch==2.7.0",
12 | "numpy>=1.17",
13 | "scipy>=1.2.1",
14 | "prefetch_generator>=1.0.1",
15 | "msgpack>=0.5.6",
16 | "sortedcontainers",
17 | "uvloop>=0.14.0",
18 | "grpcio-tools>=1.33.2",
19 | "protobuf>=5.29.0",
20 | "configargparse>=1.2.3",
21 | "py-multihash>=0.2.3",
22 | "cryptography>=3.4.6",
23 | "pydantic>=2.0.0",
24 | "packaging>=20.9",
25 | "varint>=1.0.2",
26 | "base58>=1.0.2",
27 | "netaddr>=1.3.0",
28 | "idna>=3.10",
29 | "py-cid>=0.3.0",
30 | "requests>=2.32.3",
31 | "speedtest-cli>=2.1.3",
32 | "psutil>=7.0.0",
33 | "hivemind @ git+https://github.com/learning-at-home/hivemind.git@4d5c41495be082490ea44cce4e9dd58f9926bb4e"
34 | ]
35 |
36 | [build-system]
37 | requires = ["hatchling"]
38 | build-backend = "hatchling.build"
39 |
40 | [tool.hatch.metadata]
41 | allow-direct-references = true
42 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug Report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels:
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Bug Description
11 |
12 | **Describe the bug**
13 |
14 | A clear and concise description of what the bug is and what you expected to happen.
15 |
16 | **Steps to Reproduce**
17 |
18 | Describe the steps to reproduce the issue, including the specific script/command you ran and any input parameters.
19 |
20 | ## Environment Information
21 |
22 | **Operating System:**
23 | - [ ] Windows + WSL (specify version: )
24 | - [ ] Linux (specify distribution and version: )
25 |
26 | **Python Version:**
27 |
28 | **Deployment Method:**
29 | - [ ] Docker
30 | - [ ] Local installation (pip/conda/venv)
31 |
32 | **Runtime Environment:**
33 | - [ ] Personal computer/laptop
34 | - [ ] Cloud service (specify which one):
35 | - [ ] AWS
36 | - [ ] Google Cloud Platform
37 | - [ ] Azure
38 | - [ ] Lambda Labs
39 | - [ ] Other: ___________
40 |
41 | ## Error Details
42 |
43 | **Error Message:**
44 | ```
45 | Paste the full error message/traceback here
46 | ```
47 |
48 | **Log Output:**
49 | ```
50 | Paste relevant log output here
51 | ```
52 |
53 | **Any additional context or information that might be helpful:**
54 |
--------------------------------------------------------------------------------
/src/node0/models/llama/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from node0.models.arguments import ModelArguments
16 |
17 |
18 | class LlamaArguments(ModelArguments):
19 | hidden_dim: int = 4096
20 | n_heads: int = 32
21 | n_kv_heads: int | None = None
22 | vocab_size: int = 50265 # Using AutoTokenizer.from_pretrained("facebook/opt-2.7b")
23 | # vocab_size: int = 50280 # Using AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
24 | multiple_of: int = 256
25 | ffn_dim_multiplier: float | None = None
26 | norm_eps: float = 1e-5
27 | rope_theta: float = 10000
28 | max_seq_len: int = 512
29 | depth_init: bool = True
30 | constant_init: bool = False
31 | norm_type: str = "rmsnorm"
32 |
--------------------------------------------------------------------------------
/src/node0/configs/llama_8B_C.yaml:
--------------------------------------------------------------------------------
1 | model_config:
2 | class_name: node0.models.llama.arguments.LlamaArguments
3 | init_args:
4 | hidden_dim: 4096
5 | n_heads: 32
6 | n_kv_heads: 8
7 | ffn_dim_multiplier: 1.3
8 | multiple_of: 1024
9 | rope_theta: 500000
10 | num_hidden_layers: 1
11 | qk_norm: True
12 | norm_reorder: True
13 | trainable_rmsnorm: False
14 | compression_rate: 100
15 | use_compression: True
16 | max_seq_len: 4096
17 | ss_component: 'https://d2exiwjpgw0bxb.cloudfront.net/subspace_compression/8B_C/subspace_comp.pt'
18 |
19 | optim_config:
20 | class_name: torch.optim.AdamW
21 | init_args:
22 | lr: 0.0003
23 | weight_decay: 0.1
24 |
25 | grad_avg_config:
26 | class_name: node0.server.power_sgd_averager.PowerSGDGradientAverager
27 | init_args:
28 | averager_rank: 64
29 |
30 | # Scheduler
31 | scheduler: linear
32 | num_warmup_steps: 4000
33 | num_training_steps: 100000
34 |
35 | # Training
36 | num_stages: 32
37 | clip_grad_norm: 1.0
38 | weight_decay: 0.0
39 | compression: NONE
40 | min_batch_size: 1
41 | max_batch_size: 1
42 | averaging_target_batch_size: 1024
43 | matchmaking_time: 45
44 | averaging_timeout: 120
45 | request_timeout: 3
46 | sparse_avg: 0.05
47 | average_state_every: 5
48 | load_state_timeout: 150
49 | max_allowed_stale: 5
50 | num_dht: 5
--------------------------------------------------------------------------------
/src/node0/models/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Self
16 |
17 | from pydantic import BaseModel, model_validator
18 |
19 |
20 | class ModelArguments(BaseModel):
21 | # Model parameters
22 | hidden_dim: int
23 | n_heads: int
24 | num_hidden_layers: int
25 | n_layers: int = 0
26 | stage: int | None = None
27 |
28 | # Attention projection parameters
29 | attn_proj: bool = False
30 |
31 | # QK norm
32 | qk_norm: bool = False
33 | norm_reorder: bool = False
34 | trainable_rmsnorm: bool = True
35 |
36 | # Compression parameters
37 | compression_rate: int | None = None
38 | use_compression: bool = False
39 | ss_component: str | None
40 |
41 | @model_validator(mode="after")
42 | def set_n_layers(self) -> Self:
43 | self.n_layers = self.num_hidden_layers
44 | return self
45 |
--------------------------------------------------------------------------------
/src/node0/utils/flops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from node0.models.arguments import ModelArguments
18 |
19 |
20 | def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
21 | """Compute number of model parameters."""
22 | num_params = sum(p.numel() for p in model.parameters())
23 | if exclude_embedding and hasattr(model, "tok_embeddings"):
24 | num_params -= model.tok_embeddings.weight.numel()
25 |
26 | if exclude_embedding and hasattr(model, "fixed_tok_embeddings"):
27 | num_params -= model.fixed_tok_embeddings.weight.numel()
28 |
29 | return num_params
30 |
31 |
32 | def get_num_flop_per_token_fwd(num_params: int, model_config: ModelArguments, seq_len: int) -> int:
33 | """Compute number of flop per token in forward pass."""
34 | layers, h, q, t = (
35 | model_config.n_layers,
36 | model_config.n_heads,
37 | model_config.hidden_dim // model_config.n_heads,
38 | seq_len,
39 | )
40 |
41 | flop_per_token = 2 * num_params + 4 * layers * h * q * t
42 |
43 | return flop_per_token
44 |
45 |
46 | def get_num_flop_per_token_bwd(num_params: int, model_config: ModelArguments, seq_len: int) -> int:
47 | """Compute number of flop per token in backward pass."""
48 | layers, h, q, t = (
49 | model_config.n_layers,
50 | model_config.n_heads,
51 | model_config.hidden_dim // model_config.n_heads,
52 | seq_len,
53 | )
54 |
55 | flop_per_token = 4 * num_params + 8 * layers * h * q * t
56 |
57 | return flop_per_token
58 |
--------------------------------------------------------------------------------
/THIRD_PARTY_LICENSES.md:
--------------------------------------------------------------------------------
1 | # Third-Party Dependencies
2 |
3 | This project uses the following open-source libraries:
4 |
5 | | Name | Version | License |
6 | |--------------------------|------------|--------------------------------------|
7 | | ConfigArgParse | 1.7.1 | MIT License |
8 | | PyYAML | 6.0.2 | MIT License |
9 | | base58 | 1.0.3 | MIT License |
10 | | cryptography | 45.0.7 | Apache-2.0 OR BSD-3-Clause |
11 | | grpcio-tools | 1.74.0 | Apache Software License |
12 | | hivemind | 1.2.0.dev0 | MIT License |
13 | | idna | 3.10 | BSD License |
14 | | msgpack | 1.1.1 | Apache 2.0 |
15 | | netaddr | 1.3.0 | BSD License |
16 | | numpy | 2.3.3 | BSD License |
17 | | packaging | 25.0 | Apache Software License; BSD License |
18 | | prefetch_generator | 1.0.3 | The Unlicense (Unlicense) |
19 | | prometheus_client | 0.22.1 | Apache-2.0 AND BSD-2-Clause |
20 | | protobuf | 6.32.1 | 3-Clause BSD License |
21 | | psutil | 7.0.0 | BSD License |
22 | | py-cid | 0.3.1 | MIT |
23 | | py-multihash | 2.0.1 | MIT License |
24 | | pydantic | 2.11.9 | MIT License |
25 | | requests | 2.32.5 | Apache Software License |
26 | | scipy | 1.16.2 | BSD License |
27 | | sortedcontainers | 2.4.0 | Apache Software License |
28 | | speedtest-cli | 2.1.3 | Apache Software License |
29 | | torch | 2.7.0 | BSD License |
30 | | uvloop | 0.21.0 | Apache Software License; MIT License |
31 | | varint | 1.0.2 | MIT License |
32 |
--------------------------------------------------------------------------------
/src/node0/security/validation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from hivemind.dht.crypto import RSASignatureValidator
18 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
19 | from hivemind.dht.validation import RecordValidatorBase
20 | from pydantic.v1 import BaseModel, StrictFloat, StrictInt, StrictStr
21 |
22 |
23 | if TYPE_CHECKING:
24 | # For type checkers, treat it as bytes
25 | BytesWithPublicKeyType = bytes
26 | else:
27 | # At runtime, use the actual validator
28 | BytesWithPublicKeyType = BytesWithPublicKey
29 |
30 |
31 | class WorkerMetricsV1(BaseModel):
32 | """Schema for worker metrics."""
33 |
34 | peer_id: str
35 | num_flop: StrictFloat
36 | active_time: StrictFloat
37 |
38 |
39 | class WorkerPortV1(BaseModel):
40 | """Schema for worker port reachability."""
41 |
42 | peer_id: str
43 | is_open: bool
44 |
45 |
46 | class RunParameters(BaseModel):
47 | peer_id: bytes
48 | averaging_target_batch_size: StrictInt
49 | scheduler: StrictStr
50 | num_warmup_steps: StrictInt
51 | num_training_steps: StrictInt
52 | averaging_timeout: StrictFloat
53 | matchmaking_time: StrictFloat
54 | request_timeout: StrictFloat
55 | load_state_timeout: StrictFloat
56 | time: StrictFloat
57 |
58 |
59 | class MetricSchema(BaseModel):
60 | """Force metrics keys to have signed subkeys."""
61 |
62 | worker_metrics: dict[BytesWithPublicKeyType, WorkerMetricsV1]
63 |
64 |
65 | class PortSchema(BaseModel):
66 | """Force port keys to have signed subkeys."""
67 |
68 | worker_ports: dict[BytesWithPublicKeyType, WorkerPortV1]
69 |
70 |
71 | class RunParametersSchema(BaseModel):
72 | paramstore: dict[BytesWithPublicKeyType, RunParameters | None]
73 |
74 |
75 | def make_validators(experiment_prefix: str, peer_id: str, stage: str) -> tuple[list[RecordValidatorBase], bytes]:
76 | """Create all validators"""
77 | metric_validator = SchemaValidator(MetricSchema, prefix=f"{experiment_prefix}_{stage}")
78 | port_validator = SchemaValidator(PortSchema, prefix=f"{experiment_prefix}_{peer_id}")
79 | param_validator = SchemaValidator(RunParametersSchema, prefix=stage.split(".")[0])
80 | signature_validator = RSASignatureValidator()
81 |
82 | validators = [metric_validator, port_validator, param_validator, signature_validator]
83 | return validators, signature_validator.local_public_key
84 |
--------------------------------------------------------------------------------
/src/node0/utils/dht_monitor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import threading
16 | import time
17 |
18 | from hivemind.dht.protocol import DHTProtocol
19 | from hivemind.utils.logging import get_logger
20 |
21 |
22 | logger = get_logger(__name__)
23 |
24 | _rpc_store_count = 0
25 | _rpc_find_count = 0
26 | _last_log_time = time.time()
27 | _counter_lock = threading.Lock()
28 |
29 |
30 | def patch_dht_protocol_logging():
31 | """Monkey patch DHTProtocol to add RPC call counting"""
32 | # Import here to avoid issues with module loading order
33 | from hivemind.p2p import P2PContext
34 | from hivemind.proto import dht_pb2
35 |
36 | global _rpc_store_count, _rpc_find_count, _last_log_time
37 |
38 | # Store original methods
39 | original_rpc_store = DHTProtocol.rpc_store
40 | original_rpc_find = DHTProtocol.rpc_find
41 |
42 | logger.extra("[DHT Monitor] Patching DHTProtocol to monitor RPC calls")
43 |
44 | # Create wrapped methods
45 | async def counted_rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
46 | global _rpc_store_count, _rpc_find_count, _last_log_time
47 |
48 | with _counter_lock:
49 | _rpc_store_count += 1
50 | current_time = time.time()
51 |
52 | # Log every 60 seconds
53 | if current_time - _last_log_time >= 60:
54 | logger.extra(f"[DHT RPC Stats] Last 60s - rpc_store: {_rpc_store_count}, rpc_find: {_rpc_find_count}")
55 | _rpc_store_count = 0
56 | _rpc_find_count = 0
57 | _last_log_time = current_time
58 |
59 | return await original_rpc_store(self, request, context)
60 |
61 | async def counted_rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
62 | global _rpc_store_count, _rpc_find_count, _last_log_time
63 |
64 | with _counter_lock:
65 | _rpc_find_count += 1
66 | current_time = time.time()
67 |
68 | # Log every 60 seconds
69 | if current_time - _last_log_time >= 60:
70 | logger.extra(f"[DHT RPC Stats] Last 60s - rpc_store: {_rpc_store_count}, rpc_find: {_rpc_find_count}")
71 | _rpc_store_count = 0
72 | _rpc_find_count = 0
73 | _last_log_time = current_time
74 |
75 | return await original_rpc_find(self, request, context)
76 |
77 | # Copy over important attributes from original methods
78 | counted_rpc_store.__name__ = "rpc_store"
79 | counted_rpc_store.__qualname__ = "DHTProtocol.rpc_store"
80 | counted_rpc_find.__name__ = "rpc_find"
81 | counted_rpc_find.__qualname__ = "DHTProtocol.rpc_find"
82 |
83 | # Patch with the new methods
84 | DHTProtocol.rpc_store = counted_rpc_store
85 | DHTProtocol.rpc_find = counted_rpc_find
86 |
87 | logger.extra("[DHT Monitor] RPC monitoring active - will log stats every 60 seconds")
88 |
89 | return None
90 |
--------------------------------------------------------------------------------
/src/node0/utils/dht_partition.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def partition_array(x: int, y: int) -> list[list[int]]:
17 | """
18 | Partition array [0, 1, 2, ..., x-1] into y partitions as equally as possible.
19 |
20 | Args:
21 | x (int): Size of array (values 0 to x-1)
22 | y (int): Number of partitions
23 |
24 | Returns:
25 | list[list[int]]: list of lists representing the partitions
26 | """
27 | if x == 0:
28 | return []
29 |
30 | if x <= y:
31 | # Each value gets its own partition
32 | return [[i] for i in range(y) if i < x]
33 |
34 | base_size = x // y
35 | remainder = x % y
36 |
37 | partitions = []
38 | start = 0
39 |
40 | for i in range(y):
41 | # First 'remainder' partitions get an extra element
42 | size = base_size + (1 if i < remainder else 0)
43 | partitions.append(list(range(start, start + size)))
44 | start += size
45 |
46 | return partitions
47 |
48 |
49 | def stage_to_dht_map(dht_partition: list[list[int]]) -> list[int]:
50 | """Map stage index to dht index
51 |
52 | Args:
53 | dht_partition (list[list[int]]): list of lists representing the partitions
54 |
55 | Returns:
56 | list[int]: stage to dht mapping
57 | """
58 | stage_to_dht = [part for part, sublist in enumerate(dht_partition) for _ in sublist]
59 | return stage_to_dht
60 |
61 |
62 | def update_initial_peers(
63 | initial_peers: list[str],
64 | pipeline_stage: str,
65 | num_stages: int,
66 | num_dht: int,
67 | ) -> list[str]:
68 | """Update the list of initial peers with correct ports that match the given stage
69 |
70 | Args:
71 | initial_peers (list[str]): list of multiaddress
72 | pipeline_stage (str): stage type in the format: head-X, body-X, tail-X (X is int)
73 | num_stages (int): total number of stages
74 | num_dht (int): number of worker DHTs
75 |
76 | Raises:
77 | ValueError: wrong stage type
78 |
79 | Returns:
80 | list[str]: initial_peers
81 | """
82 |
83 | # Calculate port offset according to stage
84 | stage_idx = int(pipeline_stage.split("-")[1])
85 | dht_worker_partitions = partition_array(num_stages, num_dht)
86 | stage_to_dht = stage_to_dht_map(dht_worker_partitions)
87 | port_offset = stage_to_dht[stage_idx]
88 |
89 | # Update initial peers ports
90 | for i, peeri in enumerate(initial_peers):
91 | try:
92 | # Extract baseline port
93 | parts = peeri.split("/")
94 | port_index = parts.index("tcp") + 1
95 | base_port = int(parts[port_index])
96 | parts[port_index] = str(base_port + port_offset)
97 | initial_peers[i] = "/".join(parts)
98 | except (ValueError, IndexError) as e:
99 | raise ValueError(f"Invalid multiaddress format in peer {i}: {peeri}. Error: {e}") from e
100 |
101 | return initial_peers
102 |
--------------------------------------------------------------------------------
/src/node0/models/lr_schedule.py:
--------------------------------------------------------------------------------
1 | # This file contains code originally from Hivemind under MIT License
2 | # Original: Copyright 2020 Learning@home authors and collaborators
3 | # Modified by: Pluralis Research 2025
4 | #
5 | # Original code: MIT License (see THIRD_PARTY_LICENSES)
6 | # Modifications: Apache 2.0 License (see LICENSE)
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only;
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | import math
13 |
14 | from torch.optim.lr_scheduler import LambdaLR
15 |
16 |
17 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
18 | """
19 | Create a schedule with a learning rate that decreases following a cosine curve from the initial lr set in the
20 | optimizer to min_lr_ratio * initial_lr, after a warmup period during which it increases linearly from 0 to
21 | the initial lr set in the optimizer.
22 |
23 | Args:
24 | optimizer (:class:`~torch.optim.Optimizer`):
25 | The optimizer for which to schedule the learning rate.
26 | num_warmup_steps (:obj:`int`):
27 | The number of steps for the warmup phase.
28 | num_training_steps (:obj:`int`):
29 | The total number of training steps.
30 | min_lr_ratio (:obj:`float`, optional):
31 | The minimum learning rate as a ratio of the initial learning rate. Default: 0.1
32 | Return:
33 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
34 | """
35 |
36 | def lr_lambda(current_step: int):
37 | if current_step < num_warmup_steps:
38 | # Linear warmup: 0 to 1.0
39 | return float(current_step) / float(max(1, num_warmup_steps))
40 |
41 | # Cosine decay: 1.0 to min_lr_ratio
42 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
43 | progress = min(progress, 1.0) # Clamp to [0, 1]
44 |
45 | # Cosine annealing formula
46 | cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
47 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_factor
48 |
49 | return LambdaLR(optimizer, lr_lambda)
50 |
51 |
52 | # https://github.com/huggingface/transformers/blob/master/src/transformers/optimization.py
53 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
54 | """
55 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0.1x, after
56 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
57 | Args:
58 | optimizer (:class:`~torch.optim.Optimizer`):
59 | The optimizer for which to schedule the learning rate.
60 | num_warmup_steps (:obj:`int`):
61 | The number of steps for the warmup phase.
62 | num_training_steps (:obj:`int`):
63 | The total number of training steps.
64 | min_lr_ratio (:obj:`float`, optional):
65 | The minimum learning rate as a ratio of the initial learning rate. Default: 0.1
66 | Return:
67 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
68 | """
69 |
70 | def lr_lambda(current_step: int):
71 | if current_step < num_warmup_steps:
72 | return float(current_step) / float(max(1, num_warmup_steps))
73 | return max(
74 | min_lr_ratio,
75 | float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)),
76 | )
77 |
78 | return LambdaLR(optimizer, lr_lambda)
79 |
80 |
81 | schedule_name_to_scheduler = {
82 | "linear": get_linear_schedule_with_warmup,
83 | "cosine": get_cosine_schedule_with_warmup,
84 | "none": None,
85 | }
86 |
--------------------------------------------------------------------------------
/src/node0/utils/mem_monitor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import threading
16 | import time
17 |
18 | import numpy as np
19 | import psutil
20 | import torch
21 |
22 | from hivemind.utils.logging import get_logger
23 |
24 |
25 | logger = get_logger(__name__)
26 |
27 |
28 | class MemoryTracker(threading.Thread):
29 | """
30 | Monitor memory
31 |
32 | """
33 |
34 | def __init__(
35 | self,
36 | interval: int = 60,
37 | ):
38 | super().__init__()
39 |
40 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
41 |
42 | self.log_gpumem = True
43 | if self.device == "cpu":
44 | logger.info("GPU memory monitor skipped. Device is cpu.")
45 | self.log_gpumem = False
46 |
47 | self.interval = interval
48 | if self.log_gpumem:
49 | self.total_gpu_vram = torch.cuda.get_device_properties(self.device).total_memory / 1024**3 # Gb
50 |
51 | self.total_ram = psutil.virtual_memory().total / 1024**3 # Gb
52 |
53 | self.start()
54 |
55 | def get_gpu_memory_usage(self, device=0):
56 | """
57 | Returns the current GPU memory usage
58 | """
59 |
60 | allocated = torch.cuda.memory_allocated(device) / 1024**3 # Gb
61 | reserved = torch.cuda.memory_reserved(device) / 1024**3 # Gb
62 |
63 | return {
64 | "allocated": allocated,
65 | "reserved": reserved,
66 | }
67 |
68 | def get_memory_usage(self):
69 | """
70 | Returns the current RAM usage
71 | """
72 |
73 | ram = psutil.virtual_memory()
74 | used_ram_gb = ram.used / 1024**3 # Gb
75 |
76 | return {
77 | "used": used_ram_gb,
78 | }
79 |
80 | def run(self):
81 | # Store initial network counters
82 | logger.info("Running memory monitor")
83 | start_time = time.time()
84 | alloc = []
85 | reserv = []
86 | ram_used = []
87 | try:
88 | while True:
89 | if time.time() - start_time > self.interval:
90 | if self.log_gpumem:
91 | logger.info(f"GPU mem size is {round(self.total_gpu_vram)}Gb")
92 | logger.info(f"Allocated GPU mem is {np.mean(alloc) / self.total_gpu_vram:.2f}")
93 | logger.info(f"Reserved GPU mem is {np.mean(reserv) / self.total_gpu_vram:.2f}")
94 | logger.info(f"Total RAM mem is {round(self.total_ram)}Gb")
95 | logger.info(f"Used RAM mem is {np.mean(ram_used) / self.total_ram:.2f}")
96 | start_time = time.time()
97 | alloc = []
98 | reserv = []
99 | ram_used = []
100 |
101 | if self.log_gpumem:
102 | vram_usage = self.get_gpu_memory_usage()
103 | alloc.append(vram_usage["allocated"])
104 | reserv.append(vram_usage["reserved"])
105 | ram_usage = self.get_memory_usage()
106 | ram_used.append(ram_usage["used"])
107 | time.sleep(0.5)
108 |
109 | except KeyboardInterrupt:
110 | print("\n GPU monitoring stopped by user.")
111 |
--------------------------------------------------------------------------------
/src/node0/utils/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 |
17 | from logging.handlers import RotatingFileHandler
18 |
19 | from hivemind.utils import get_logger, use_hivemind_log_handler
20 | from hivemind.utils.logging import TextStyle, always_log_caller
21 |
22 |
23 | class CustomFormatter(logging.Formatter):
24 | """
25 | A formatter that allows a log time and caller info to be overridden via
26 | ``logger.log(level, message, extra={"origin_created": ..., "caller": ...})``.
27 | """
28 |
29 | _LEVEL_TO_COLOR = {
30 | logging.DEBUG: TextStyle.PURPLE,
31 | logging.INFO: TextStyle.BLUE,
32 | logging.WARNING: TextStyle.ORANGE,
33 | logging.ERROR: TextStyle.RED,
34 | logging.CRITICAL: TextStyle.RED,
35 | }
36 |
37 | def format(self, record: logging.LogRecord) -> str:
38 | if hasattr(record, "origin_created"):
39 | record.created = record.origin_created
40 | record.msecs = (record.created - int(record.created)) * 1000
41 |
42 | if record.levelno > logging.INFO or always_log_caller:
43 | if not hasattr(record, "caller"):
44 | record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
45 | record.caller_block = f" [{TextStyle.BOLD}{record.caller}{TextStyle.RESET}]"
46 | else:
47 | record.caller_block = ""
48 |
49 | # Aliases for the format argument
50 | record.levelcolor = (
51 | self._LEVEL_TO_COLOR[record.levelno] if record.levelno in self._LEVEL_TO_COLOR else TextStyle.BLUE
52 | )
53 | record.bold = TextStyle.BOLD
54 | record.reset = TextStyle.RESET
55 |
56 | return super().format(record)
57 |
58 |
59 | class Node0Logger:
60 | def __init__(self, log_level: str = "INFO"):
61 | """Instantiate logger.
62 |
63 | Args:
64 | log_level (str, optional): logging level. Defaults to "INFO".
65 | log_file (str | None, optional): file to save logs. Defaults to None.
66 | """
67 | # Add extra log level
68 | EXTRA_LEVEL = 15 # between INFO (20) and DEBUG (10)
69 | logging.addLevelName(EXTRA_LEVEL, "EXTRA")
70 |
71 | # Create a custom method for the logger
72 | def extra(self, message, *args, **kwargs):
73 | if self.isEnabledFor(EXTRA_LEVEL):
74 | self._log(EXTRA_LEVEL, message, args, **kwargs)
75 |
76 | logging.Logger.extra = extra
77 | use_hivemind_log_handler("in_root_logger")
78 |
79 | # Convert log level string to logging constant
80 | numeric_level = getattr(logging, log_level.upper(), log_level)
81 |
82 | # Configure root logger
83 | self.root_logger = get_logger()
84 | self.root_logger.setLevel(numeric_level)
85 |
86 | formatter = CustomFormatter(
87 | fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}]{caller_block} {message}",
88 | style="{",
89 | datefmt="%b %d %H:%M:%S",
90 | )
91 |
92 | for handler in list(self.root_logger.handlers):
93 | if isinstance(handler, logging.StreamHandler) and not isinstance(handler, RotatingFileHandler):
94 | handler.setFormatter(formatter)
95 |
96 | def add_monitor_handler(self, monitor_handler: logging.Handler):
97 | """Attach monitor handler to report logs."""
98 | self.root_logger.addHandler(monitor_handler)
99 |
--------------------------------------------------------------------------------
/docs/useful_links.md:
--------------------------------------------------------------------------------
1 | # Useful Links
2 |
3 | ### Hugging Face Access Tokens
4 |
5 | * **Creating and managing a HF "Access Token"**
6 | [https://huggingface.co/docs/hub/security-tokens](https://huggingface.co/docs/hub/security-tokens)
7 |
8 |
9 | ### SSH Keys & Remote Login
10 |
11 | * **How to generate an SSH key pair**
12 | [https://www.digitalocean.com/community/tutorials/how-to-create-ssh-keys-with-openssh-on-macos-or-linux](https://www.digitalocean.com/community/tutorials/how-to-create-ssh-keys-with-openssh-on-macos-or-linux)
13 | * **How to connect to a server using SSH**
14 | [https://www.digitalocean.com/community/tutorials/how-to-use-ssh-to-connect-to-a-remote-server](https://www.digitalocean.com/community/tutorials/how-to-use-ssh-to-connect-to-a-remote-server)
15 |
16 |
17 | ### File Permissions & Ownership
18 |
19 | * **`chmod`, `chown`, and UNIX file permission basics**
20 | [https://www.digitalocean.com/community/tutorials/how-to-set-permissions-linux](https://www.digitalocean.com/community/tutorials/how-to-set-permissions-linux)
21 | * **Why “UNPROTECTED PRIVATE KEY FILE!” and how to fix it**
22 | [https://www.cyberciti.biz/faq/warning-unprotected-private-key-file-ssh-linux-unix-error/](https://www.cyberciti.biz/faq/warning-unprotected-private-key-file-ssh-linux-unix-error/)
23 |
24 |
25 | ### Basic Linux Package Management & Tools
26 |
27 | * **`apt` (Debian/Ubuntu) package manager intro**
28 | [https://help.ubuntu.com/community/AptGet/Howto](https://help.ubuntu.com/community/AptGet/Howto)
29 | * **Installing and using `lsof` to see open ports / GPU processes**
30 | [https://linux.die.net/man/8/lsof](https://linux.die.net/man/8/lsof)
31 |
32 |
33 | ### Installing Docker
34 |
35 | * **Official Docker Engine install guide (Ubuntu)**
36 | [https://docs.docker.com/engine/install/ubuntu/](https://docs.docker.com/engine/install/ubuntu/)
37 | * **Post-install steps (add your user to the `docker` group)**
38 | [https://docs.docker.com/engine/install/linux-postinstall/](https://docs.docker.com/engine/install/linux-postinstall/)
39 |
40 |
41 | ### Installing Conda (Miniconda)
42 |
43 | * **Miniconda install instructions (Linux)**
44 | [https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html)
45 | * **Conda cheat sheet (creating/activating envs, installing packages)**
46 | [https://docs.conda.io/projects/conda/en/latest/user-guide/cheatsheet.html](https://docs.conda.io/projects/conda/en/latest/user-guide/cheatsheet.html)
47 |
48 |
49 | ### Checking NVIDIA GPU & Drivers
50 |
51 | * **Installing NVIDIA drivers on Ubuntu**
52 | [https://docs.nvidia.com/datacenter/tesla/tesla-installation-notes/index.html](https://docs.nvidia.com/datacenter/tesla/tesla-installation-notes/index.html)
53 | * **Using `nvidia-smi` to verify GPU availability**
54 | [https://www.gpu-mart.com/blog/monitor-gpu-utilization-with-nvidia-smi](https://www.gpu-mart.com/blog/monitor-gpu-utilization-with-nvidia-smi)
55 |
56 |
57 | ### Git Basics
58 |
59 | * **Cloning a repository**
60 | [https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository)
61 | * **Generating an SSH key for GitHub (if needed)**
62 | [https://docs.github.com/en/authentication/connecting-to-github-with-ssh](https://docs.github.com/en/authentication/connecting-to-github-with-ssh)
63 |
64 |
65 | ### Firewall / Opening Ports
66 |
67 | * **Opening a Port on Linux to Allow TCP Connections**
68 | [https://www.digitalocean.com/community/tutorials/opening-a-port-on-linux](https://www.digitalocean.com/community/tutorials/opening-a-port-on-linux)
69 | * **General Linux firewall basics (iptables / firewalld)**
70 | [https://www.digitalocean.com/community/tutorials/iptables-essentials-common-firewall-rules-and-commands](https://www.digitalocean.com/community/tutorials/iptables-essentials-common-firewall-rules-and-commands)
71 |
72 |
73 | ### Opening a port on Cloud Providers
74 |
75 | * Check out the [network guide](network.md).
76 |
77 |
--------------------------------------------------------------------------------
/src/node0/server/module_collab.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Tuple
16 |
17 | import torch
18 |
19 | from hivemind.moe.server import ModuleBackend
20 | from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack
21 |
22 |
23 | class ModuleCollab(ModuleBackend):
24 | def __init__(self, optimizer_lock, *args, **kwargs):
25 | super().__init__(*args, **kwargs)
26 |
27 | self.optimizer_lock = optimizer_lock
28 |
29 | def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
30 | """
31 | Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually
32 | To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``.
33 |
34 | Subclassing:
35 | This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``;
36 |
37 | It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``;
38 |
39 | Runtime doesn't guarantee that backward will be performed in the same order and for the same data
40 | as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
41 |
42 | Please make sure to call ``ModuleBackend.on_backward`` after each call to backward
43 | """
44 | (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema)
45 |
46 | with torch.enable_grad():
47 | with self.optimizer_lock:
48 | args = [
49 | tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
50 | for tensor in args
51 | ]
52 | kwargs = {
53 | input_key: (
54 | tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach()
55 | )
56 | for input_key, tensor in kwargs.items()
57 | }
58 |
59 | batch_size = args[0].size(0)
60 |
61 | outputs = self.module(*args, **kwargs)
62 | assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure"
63 |
64 | outputs_flat = tuple(nested_flatten(outputs))
65 |
66 | grad_outputs_flat = tuple(
67 | map(
68 | lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True),
69 | nested_flatten(grad_outputs),
70 | outputs_flat,
71 | )
72 | )
73 | torch.autograd.backward(
74 | outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False
75 | )
76 | self.on_backward(batch_size)
77 |
78 | return tuple(
79 | x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs))
80 | )
81 |
82 | def on_backward(self, batch_size: int) -> None:
83 | """
84 | Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients.
85 | """
86 | if self.optimizer is not None:
87 | self.optimizer.step(batch_size=batch_size)
88 | self.optimizer.zero_grad()
89 |
90 | if self.scheduler is not None:
91 | self.scheduler.step()
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | logs*/
7 |
8 | # C extensions
9 | *.so
10 | !integrity_check.*.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # UV
101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | #uv.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
119 | .pdm.toml
120 | .pdm-python
121 | .pdm-build/
122 |
123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124 | __pypackages__/
125 |
126 | # Celery stuff
127 | celerybeat-schedule
128 | celerybeat.pid
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 | .dmypy.json
155 | dmypy.json
156 |
157 | # Pyre type checker
158 | .pyre/
159 |
160 | # pytype static type analyzer
161 | .pytype/
162 |
163 | # Cython debug symbols
164 | cython_debug/
165 |
166 | # PyCharm
167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
169 | # and can be added to the global gitignore or merged into this file. For a more nuclear
170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
171 | #.idea/
172 |
173 | # Ruff stuff:
174 | .ruff_cache/
175 |
176 | # PyPI configuration file
177 | .pypirc
178 |
179 |
180 | .vscode
181 |
--------------------------------------------------------------------------------
/src/node0/utils/get_parameters.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import signal
17 |
18 | from typing import Any
19 |
20 | from hivemind.dht import DHT
21 | from hivemind.utils import get_logger
22 | from hivemind.utils.timed_storage import ValueWithExpiration
23 |
24 | from node0.security.validation import RunParameters
25 |
26 |
27 | logger = get_logger(__name__)
28 |
29 |
30 | def get_parameter_store(dht: DHT, prefix: str) -> tuple[Any, ...]:
31 | """
32 | Retrieve the most recent run parameters from the distributed hash table (DHT).
33 |
34 | This function fetches run parameters stored by peers in the DHT, validates them
35 | using RSA signature verification, and returns the averaging target batch size
36 | from the most recently updated entry.
37 |
38 | Args:
39 | dht (DHT): The distributed hash table instance to retrieve parameters from.
40 | prefix (str): Prefix string used to construct the DHT storage key for lookup.
41 |
42 | Returns:
43 | int: The averaging target batch size from the most recent valid peer entry.
44 | """
45 |
46 | param_store_key = f"{prefix}_paramstore"
47 | param_store_result = dht.get(param_store_key, latest=True)
48 |
49 | if not isinstance(param_store_result, ValueWithExpiration):
50 | logger.error("Could not retrieve run parameters from peers. Exiting run.")
51 | os.killpg(os.getpgrp(), signal.SIGTERM)
52 | raise RuntimeError("Could not retrieve run parameters from peers")
53 |
54 | metadata = param_store_result.value
55 |
56 | valid_peer_entries = [
57 | RunParameters.parse_obj(peer_value.value) for peer_value in metadata.values() if peer_value.value is not None
58 | ]
59 |
60 | last_time = -float("inf")
61 | averaging_target_batch_size = 0
62 | scheduler = ""
63 | num_warmup_steps = 0
64 | num_training_steps = 0
65 | averaging_timeout = 0
66 | matchmaking_time = 0
67 | request_timeout = 0
68 | load_state_timeout = 0
69 |
70 | for val in valid_peer_entries:
71 | if val.time > last_time:
72 | averaging_target_batch_size = val.averaging_target_batch_size
73 | scheduler = val.scheduler
74 | num_warmup_steps = val.num_warmup_steps
75 | num_training_steps = val.num_training_steps
76 | averaging_timeout = val.averaging_timeout
77 | matchmaking_time = val.matchmaking_time
78 | request_timeout = val.request_timeout
79 | load_state_timeout = val.load_state_timeout
80 | last_time = val.time
81 |
82 | if (
83 | averaging_target_batch_size <= 0
84 | or scheduler not in ["linear", "cosine"]
85 | or num_warmup_steps <= 0
86 | or num_training_steps <= 0
87 | or averaging_timeout <= 0
88 | or matchmaking_time <= 0
89 | or request_timeout <= 0
90 | or load_state_timeout <= 0
91 | ):
92 | logger.error("Could not retrieve run parameters from peers. Exiting run.")
93 | os.killpg(os.getpgrp(), signal.SIGTERM)
94 | return
95 |
96 | logger.info(
97 | f"Got runtime training parameters: "
98 | f"averaging_target_batch_size = {averaging_target_batch_size}, "
99 | f"scheduler = {scheduler}, "
100 | f"num_warmup_steps = {num_warmup_steps}, "
101 | f"num_training_steps = {num_training_steps}, "
102 | f"averaging_timeout = {averaging_timeout}, "
103 | f"matchmaking_time = {matchmaking_time}, "
104 | f"request_timeout = {request_timeout}, "
105 | f"load_state_timeout = {load_state_timeout}"
106 | )
107 | return (
108 | averaging_target_batch_size,
109 | scheduler,
110 | num_warmup_steps,
111 | num_training_steps,
112 | averaging_timeout,
113 | matchmaking_time,
114 | request_timeout,
115 | load_state_timeout,
116 | )
117 |
--------------------------------------------------------------------------------
/src/node0/utils/connection_test_server.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import socket
16 | import threading
17 | import time
18 |
19 | from hivemind.utils.logging import get_logger
20 |
21 | from node0.utils.node_info import TestServerError
22 |
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | class TestServer:
28 | def __init__(self, host="0.0.0.0", port=49200):
29 | self.host = host
30 | self.port = port
31 | self.server_socket = None
32 | self.running = False
33 | self.thread = None
34 | self.received_message = None # Store single message
35 | self.message_received = threading.Event() # Signal when message arrives
36 |
37 | def __enter__(self):
38 | self.start()
39 | return self
40 |
41 | def __exit__(self, exc_type, exc_val, exc_tb):
42 | self.close()
43 |
44 | def start(self):
45 | """Start the server"""
46 | try:
47 | self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
48 | self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
49 | self.server_socket.bind((self.host, self.port))
50 | self.server_socket.listen(1)
51 | self.server_socket.settimeout(1)
52 | self.running = True
53 | except OSError as e:
54 | logger.error(f"Failed to bind to {self.host}:{self.port}: {e}")
55 | if self.server_socket:
56 | self.server_socket.close()
57 | self.server_socket = None
58 | raise TestServerError(f"Failed to start server on {self.host}:{self.port}: {e}") from None
59 |
60 | def server_loop():
61 | logger.info(f"Test server listening on {self.host}:{self.port}")
62 | while self.running and not self.message_received.is_set():
63 | try:
64 | client_socket, addr = self.server_socket.accept()
65 | logger.info("Test server connection received")
66 |
67 | # Receive the message
68 | try:
69 | client_socket.settimeout(2)
70 | data = client_socket.recv(1024)
71 | if data:
72 | message = data.decode("utf-8", errors="ignore")
73 | self.received_message = {"message": message, "from": "", "timestamp": time.time()}
74 | self.message_received.set() # Signal message received
75 | logger.info(f"Received message: {message}")
76 | else:
77 | logger.info("No data received")
78 | except TimeoutError:
79 | logger.info("Timeout waiting for data")
80 | except Exception as e:
81 | logger.error(f"Error receiving data: {e}")
82 | finally:
83 | client_socket.close()
84 |
85 | except TimeoutError:
86 | continue
87 | except OSError:
88 | break
89 |
90 | self.thread = threading.Thread(target=server_loop)
91 | self.thread.daemon = True
92 | self.thread.start()
93 | time.sleep(0.1) # Let server start
94 |
95 | def close(self):
96 | """Close the server"""
97 | if self.running:
98 | self.running = False
99 | if self.server_socket:
100 | self.server_socket.close()
101 | if self.thread:
102 | self.thread.join(timeout=1)
103 | logger.info("Test server closed")
104 |
105 | def get_message(self):
106 | """Get the received message (or None if not received yet)"""
107 | return self.received_message
108 |
109 | def wait_for_message(self, timeout=10):
110 | """Wait for the message to be received. Returns True if received, False if timeout"""
111 | return self.message_received.wait(timeout)
112 |
113 | def verify_message(self):
114 | """Verify the received message contains expected content"""
115 | if self.received_message and "auth" in self.received_message["message"]:
116 | return True
117 | return False
118 |
--------------------------------------------------------------------------------
/docs/network.md:
--------------------------------------------------------------------------------
1 | # Network Configuration Guide
2 |
3 | This guide provides instructions for configuring port 49200 to be accessible for external connections across various cloud compute providers and personal networks.
4 |
5 | ## AWS
6 | 1. Navigate to your EC2 instance and click on the Security Group attached to it (found under the Security tab)
7 | 2. Click "Inbound rules" → "Edit inbound rules"
8 | 3. Click "Add rule" and configure:
9 | - **Type:** Custom TCP
10 | - **Port range:** 49200
11 | - **Source:** 0.0.0.0/0 (to allow traffic from any source)
12 |
13 |
14 |
15 | ## Google Cloud Platform (GCP)
16 | 1. Go to "VPC network" → "Firewall" → "Create firewall rule"
17 | 2. Configure the rule:
18 | - **Target:** Choose "All instances in the network" or "Specified target tags" and specify your instance
19 | - **Source IPv4 ranges:** 0.0.0.0/0
20 | - **Protocols and ports:** Select "Specified protocols and ports"
21 | - **TCP:** 49200
22 |
23 |
24 |
25 | ## RunPod
26 | RunPod assigns random external port mappings, so you need to configure port forwarding:
27 |
28 | 1. After deploying a Pod, click the three horizontal lines (Pod Settings) → "Edit Pod"
29 |
30 |
31 |
32 | 2. Under "Expose TCP Ports" add `49200` and save (this will restart the Pod)
33 |
34 |
35 |
36 | 3. Once restarted, click "Connect" to see the external port mapping for internal port 49200. In the example below, the external port is 17969.
37 |
38 |
39 |
40 | 4. Use these flags when running generate_script.py (see **Changing exposed port** section in README for details):
41 |
42 | ```bash
43 | --host_port 49200 # The internal port the library will listen on
44 | --announce_port 17969 # The external port other peers will connect to
45 | ```
46 |
47 | ## Tensordock
48 | ### Distributed Compute
49 | When provisioning an instance using the Distributed Compute option, Tensordock allows you to request a specific internal port and then assigns a nrandom external port mapping to that port, so you need to configure port forwarding:
50 |
51 | 1. During provisioning setup, under the Port Forwarding section → "Request Port" and choose `49200`:
52 |
53 |
54 |
55 | 2. Once deployed, note the randomly assigned external port:
56 |
57 |
58 |
59 | 3. Use these flags when running generate_script.py (see **Changing exposed port** section in README for details). `--announce_port` should specify your randomly assigned external port:
60 |
61 | ```bash
62 | --host_port 49200 # The internal port the library will listen on
63 | --announce_port 10009 # The external port other peers will connect to
64 | ```
65 |
66 | ## Lambda Labs
67 | 1. Navigate to the Firewall page in your Lambda Cloud dashboard
68 | 2. Click "Edit" in the Inbound Rules section
69 | 3. Configure the new rule:
70 | - **Rule type:** Custom TCP
71 | - **Port range:** 49200
72 | - **Source:** 0.0.0.0/0 (to allow traffic from any source)
73 |
74 |
75 |
76 | ## Personal Computer
77 |
78 | Set up your firewall to allow traffic from the outside world to the port 49200/tcp.
79 |
80 | If you have a router, set it up to allow connections from the outside world (port 49200/tcp) to your computer.
81 |
82 | ## WSL
83 |
84 | If you're using Windows with WSL 2, you'll need to configure port forwarding:
85 |
86 | 1. **Create WSL configuration file:**
87 | Create a `.wslconfig` file in your Windows user home directory (e.g., `C:\Users\YourUsername\.wslconfig`) with the following content:
88 | ```
89 | [wsl2]
90 | localhostforwarding=true
91 | ```
92 |
93 | 2. **Configure port proxy (run PowerShell as Administrator):**
94 |
95 | Find your WSL container IP:
96 | ```powershell
97 | ((wsl hostname -I) -split " ")[0]
98 | ```
99 |
100 | Add the port proxy (replace `` with the IP from the previous command):
101 | ```powershell
102 | netsh interface portproxy add v4tov4 listenport=49200 listenaddress=0.0.0.0 connectport=49200 connectaddress=
103 | ```
104 |
105 | Open the firewall for the port:
106 | ```powershell
107 | netsh advfirewall firewall add rule name="node0" dir=in action=allow protocol=TCP localport=49200
108 | ```
109 |
110 | 3. **If you have a router**, set it up to allow connections from the outside world (port 49200/tcp) to your computer.
111 | 4. **Restart your computer** to apply the changes.
112 |
--------------------------------------------------------------------------------
/src/node0/utils/network_throughput.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import threading
17 | import time
18 |
19 | import psutil
20 |
21 | from hivemind.utils.logging import get_logger
22 |
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | class NetworkMonitor(threading.Thread):
28 | """
29 | Monitor network throughput for a specific network interface or all interfaces.
30 |
31 | :param interface: Specific network interface to monitor (e.g., 'eth0')
32 | :param interval: Time between measurements in seconds
33 | :param duration: Total monitoring duration in seconds (None for continuous monitoring)
34 | """
35 |
36 | def __init__(
37 | self,
38 | interface: str | None = None,
39 | interval: int = 60,
40 | duration: int | None = None,
41 | ):
42 | super().__init__()
43 |
44 | self.interface = interface
45 | self.interval = interval
46 | self.duration = duration
47 |
48 | self.start()
49 |
50 | def get_tcp_connection_count(self) -> int:
51 | """
52 | Get the count of TCP connections from host system.
53 | """
54 | try:
55 | tcp_file = "/proc/1/root/proc/net/tcp"
56 | if os.path.exists(tcp_file):
57 | with open(tcp_file, "r") as f:
58 | return len(f.readlines()) - 1 # Subtract 1 for header line
59 | except (OSError, PermissionError):
60 | pass
61 |
62 | # method fails, return -1 to indicate unavailable
63 | return -1
64 |
65 | def run(self):
66 | # Store initial network counters
67 | logger.info("Running network bandwidth monitor")
68 | initial_counters = psutil.net_io_counters(pernic=True)
69 | start_time = time.time()
70 |
71 | try:
72 | # Determine interfaces to monitor
73 | if self.interface:
74 | interfaces = [self.interface]
75 | else:
76 | interfaces = [iface for iface in initial_counters.keys() if iface != "lo"]
77 |
78 | # Monitoring loop
79 | while True:
80 | # Wait for the interval
81 | time.sleep(self.interval)
82 |
83 | # Get current network counters
84 | current_counters = psutil.net_io_counters(pernic=True)
85 | current_time = time.time()
86 |
87 | # Calculate time elapsed
88 | elapsed = current_time - start_time
89 |
90 | # Process each interface
91 | bytes_sent = 0
92 | bytes_recv = 0
93 | for iface in interfaces:
94 | if iface not in current_counters:
95 | continue
96 |
97 | # Calculate throughput
98 | initial = initial_counters.get(iface, None)
99 | if not initial:
100 | continue
101 |
102 | bytes_sent += current_counters[iface].bytes_sent - initial.bytes_sent
103 | bytes_recv += current_counters[iface].bytes_recv - initial.bytes_recv
104 |
105 | # Calculate megabytes sent and received per second
106 | bytes_sent = (bytes_sent) / elapsed / (1024 * 1024)
107 | bytes_recv = (bytes_recv) / elapsed / (1024 * 1024)
108 |
109 | bits_sent = bytes_sent * 8
110 | bits_recv = bytes_recv * 8
111 |
112 | # Get current TCP connection counts
113 | tcp_count = self.get_tcp_connection_count()
114 |
115 | # Log results
116 | logger.info(
117 | f"Time {time.strftime('%Y-%m-%d %H:%M:%S'):<20} "
118 | f"Interface Agg "
119 | f"Sent {bits_sent:>10.2f} Mbps "
120 | f"Rcv {bits_recv:>10.2f} Mbps "
121 | )
122 |
123 | logger.info(f"Time {time.strftime('%Y-%m-%d %H:%M:%S'):<20} Number open connections {tcp_count}")
124 |
125 | # Update initial counters and start time
126 | initial_counters = current_counters
127 | start_time = current_time
128 |
129 | # Check if monitoring duration is specified
130 | if self.duration and current_time - start_time >= self.duration:
131 | break
132 |
133 | except KeyboardInterrupt:
134 | print("\nMonitoring stopped by user.")
135 |
--------------------------------------------------------------------------------
/src/node0/utils/common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib
16 | import io
17 |
18 | from functools import partial
19 | from pathlib import Path
20 | from typing import Any
21 |
22 | import requests
23 | import torch
24 |
25 | from hivemind.utils.logging import get_logger
26 |
27 |
28 | logger = get_logger(__name__)
29 |
30 |
31 | def build_cls(class_path: str, init_args: dict, partial_init: bool = False) -> Any:
32 | """Instantiate class.
33 |
34 | Args:
35 | class_path (str): full path to class
36 | init_args (dict): class init arguments
37 | partial_init (bool, optional): if True, return partial function. Defaults to False.
38 |
39 | Raises:
40 | Exception: wrong class path/init arguments
41 |
42 | Returns:
43 | Any: class instance or partial function
44 | """
45 | try:
46 | # Split module and class names
47 | module_name, class_name = class_path.rsplit(".", maxsplit=1)
48 |
49 | # Import module and get class
50 | module = importlib.import_module(module_name)
51 | class_ = getattr(module, class_name)
52 |
53 | # Instantiate class
54 | if partial_init:
55 | instance = partial(class_, **init_args)
56 | else:
57 | instance = class_(**init_args)
58 | return instance
59 | except Exception as e:
60 | logger.error(f"Can't initialize class {class_path}: {e}", exc_info=logger.getEffectiveLevel() <= 15)
61 | raise
62 |
63 |
64 | def infer_expert_params(
65 | pipeline_stage: str,
66 | max_experts: int = 1024,
67 | ) -> dict:
68 | """Infer required expert and stage parameters from pipeline_stage.
69 |
70 | Args:
71 | pipeline_stage str: stage type in the format: head-X, body-X, tail-X (X is int)
72 | max_experts (int, optional): max number of experts per stage. Defaults to 1024.
73 |
74 | Raises:
75 | ValueError: wrong stage type
76 |
77 | Returns:
78 | dict: expert_pattern, expert_class, stage
79 | """
80 | try:
81 | stage, stage_idx = pipeline_stage.split("-")
82 | stage_idx = int(stage_idx)
83 | assert stage in ["head", "body", "tail"]
84 | except Exception as e:
85 | raise ValueError("Wrong stage type. It should be one of: head-X, body-X, tail-X (X is int).") from e
86 |
87 | stage_idx_for_pattern = "" if stage in ["head", "tail"] else stage_idx
88 | expert_pattern = f"{stage}{stage_idx_for_pattern}.0.[32:{max_experts}]"
89 | stage_name = f"{stage}{stage_idx_for_pattern}.0."
90 | expert_class = f"lm_{stage}"
91 |
92 | params = {
93 | "expert_pattern": expert_pattern,
94 | "expert_cls": expert_class,
95 | "stage": stage_idx,
96 | "stage_name": stage_name,
97 | }
98 |
99 | return params
100 |
101 |
102 | def load_ss_components(ss_url: str):
103 | """
104 | Loads rcv and fixed_embeddings from URL location
105 |
106 | Args:
107 | ss_url (str): URL to the subspace compression file
108 |
109 | Returns:
110 | Dict of: rcv, fixed_tok_weight
111 | """
112 | try:
113 | response = requests.get(ss_url)
114 | response.raise_for_status()
115 |
116 | buffer = io.BytesIO(response.content)
117 | ss_comp_dict = torch.load(buffer, map_location="cpu")
118 |
119 | except requests.RequestException as e:
120 | raise RuntimeError(f"Failed to download from URL: {e}") from e
121 | except RuntimeError as e:
122 | raise RuntimeError(f"Remote loading of subspace compression components failed: {e}") from e
123 |
124 | return ss_comp_dict
125 |
126 |
127 | def clean_tmp():
128 | """Remove tmp hivemind socket files."""
129 | tmp_dir = Path("/tmp")
130 |
131 | if not tmp_dir.exists():
132 | return
133 |
134 | try:
135 | matching_files = list(tmp_dir.glob("hivemind*"))
136 |
137 | if len(matching_files) > 0:
138 | for fpath in matching_files:
139 | try:
140 | fpath.unlink()
141 | except Exception as e:
142 | logger.error(
143 | f"Can't remove file {fpath}: {e}. Please remove all hivemind* files from /tmp folder (see README for instructions to stop the server). Exiting run."
144 | )
145 | exit(1)
146 |
147 | matching_files = list(tmp_dir.glob("hivemind*"))
148 | if len(matching_files) > 0:
149 | logger.error(
150 | "Old hivemind* files are found in /tmp folder. Please clean the folder (see README for instructions to stop the server). Exiting run."
151 | )
152 | exit(1)
153 | logger.info("/tmp folder was cleaned")
154 | except Exception as e:
155 | logger.warning(f"Can't check tmp folder: {e}")
156 | return
157 |
--------------------------------------------------------------------------------
/docs/cloud.md:
--------------------------------------------------------------------------------
1 | # Cloud Options
2 | We list various cloud options and how to set them up. The cheapest option is RunPod.
3 |
4 | ## AWS (Amazon Web Services)
5 |
6 | ### How to set up
7 |
8 | **Step 1: Launch a GPU Instance**
9 |
10 | 1. **Log into AWS**: Go to the [AWS Management Console](https://aws.amazon.com/console/). Make an account or log in if you have one.
11 | 2. **Create a new EC2 instance**:
12 | * Go to **EC2** > **Launch Instance**.
13 | * Choose a name for the instance.
14 | * Select an **AMI (Amazon Machine Image)** that is Unix based, supports GPU, and has CUDA and PyTorch installed. For example,
15 | * `Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.7 (Ubuntu 22.04)`.
16 | * **Choose an Instance Type**:
17 | * Select an instance type with a **GPU** that has at least 16GB VRAM and 32GB RAM, e.g. `g4dn.2xlarge` which has an NVIDIA T4
18 | * Choose a key pair
19 | * If you do not have a key pair, click **Create new key pair** and download the **.pem key**
20 | 3. **Configure the Instance**:
21 | * Set up **security groups** to allow SSH (port 22) and a specific port for protocol communication (port 49200, though you can change and specify this)
22 | * We recommend minimum of 80GB of storage
23 | 1. **Review and Launch** the instance.
24 |
25 | **Step 2: Edit Security Group**
26 | 1. Follow the AWS section in the [network guide](network.md) to configuring port 49200 to be accessible for external connections.
27 |
28 | **Step 3: Connect to the Instance**
29 |
30 | 1. **SSH into your instance**:
31 | * Find your public ip and connect via SSH:
32 | ```bash
33 | ssh -i your-key.pem ubuntu@your-ec2-public-ip
34 | ```
35 |
36 | ## GCP (Google Cloud Platform)
37 |
38 | The cheapest option is a `NVIDIA T4` (16GB VRAM) and `n1-standard-8` (8 vCPU, 4 core, 30 GB memory) for $0.51/hour.
39 |
40 | ### How to set up
41 |
42 | **Step 1: Create a GPU-enabled VM**
43 |
44 | 1. **Log into Google Cloud** Go to the [Google Cloud Console](https://console.cloud.google.com/). Make an account or log in if you have one.
45 | 2. **Create a new VM instance**:
46 |
47 | * Go to **Compute Engine** > **VM instances** > **Create Instance**.
48 | * Choose a name and region for the instance.
49 | * Change from General Purpose to **GPUs** and select a **GPU** and **Machine Type**
50 | * E.g. Choose 1 `NVIDIA T4` and `n1-standard-8` (8 vCPU, 4 core, 30 GB memory)
51 | * In the **OS and storage** tab change the image to one that is Unix based, supports GPU, and has CUDA and PyTorch installed
52 | * E.g. OS `Deep Learning on Linux` and image `Deep Learning VM for PyTorch 2.4 with CUDA 12.4 M129`
53 | * In the **Security** tab click **Manage Access**. Under **Add manually generated SSH keys** click **Add item**, enter your SSH public key, and click **Save**.
54 | * Click **Create**
55 |
56 | **Step 2: Edit Firewall settings**
57 | 1. Follow the GCP section in the [network guide](network.md) to configuring port 49200 to be accessible for external connections.
58 |
59 | **Step 2: Connect to the Instance**
60 |
61 | 1. **Set up SSH Keys**
62 | * Go into your instance and click **Edit**
63 | * Under **SSH Keys** click **Add item**, enter your SSH public key, and click **Save**.
64 | 2. **SSH into your VM**:
65 | * Find your external ip and username (this is linked with your SSH key) under the instance details and connect via SSH:
66 | ```bash
67 | ssh ubuntu@your-external-ip
68 | ```
69 |
70 |
71 | ## RunPod
72 |
73 | The cheapest option is a RTX 2000 Ada: 16GB VRAM, 31Gb RAM, 6 vCPUs for $0.23/hour.
74 |
75 | RunPod launches your workspace within a docker container, so it is difficult to launch docker within the docker container.
76 | We recommend using conda instead. See the [installing guide](installing.md) for how to install conda.
77 |
78 | RunPod also assigns random external port mappings, so we need to find and specify that external port. See the RunPod section in the [network guide](network.md)
79 |
80 | Finally, if you need to install anything else with RunPod, note that most standard packages are not installed, so run `apt update` first.
81 |
82 | ### How to set up
83 | **Step 1: Launch a GPU Pod**
84 |
85 | 1. **Log into RunPod**: Go to the [RunPod Console](https://www.runpod.io/console/home). Make an account or log in if you have one.
86 | 2. **Set SSH Keys**:
87 | * Go to **Settings** and under **SSH Public Keys** add your public SSH key. If you have not made a SSH key yet, follow [this guide from RunPod](https://docs.runpod.io/pods/configuration/use-ssh).
88 | 3. **Create a new Pod**:
89 | * Go to **Pods** to see available pods, and choose a Pod
90 | * E.g. RTX 2000 Ada: 16GB VRAM, 31Gb RAM, 6 vCPUs for $0.23/hour.
91 | * Choose a Pod name and **Pod Template**
92 | * Want one with CUDA and PyTorch installed, the default `RunPod Pytorch 2.1` works.
93 | * Ensure that SSH Terminal Access is enabled
94 | * Click **Deploy On-Demand**
95 |
96 | **Step 2: Edit the Pod**
97 | 1. Follow the RunPod section in the [network guide](network.md) to edit the Pod to expose a TCP port.
98 |
99 | **Step 3: Connect to the Pod**
100 |
101 | 1. **SSH into your Pod**:
102 | * Go to **Connect** and in the **SSH** tab look at the ssh command under **SSH over exposed TCP**.
103 |
104 |
105 | ## Tensordock
106 |
107 | Tensordock offers low-cost consumer GPUs as low as an RTX A4000 for $0.105/hr.
108 |
109 | The distributed compute option in Tensordock also assigns random external port mappings, so we need to find and specify that external port. See the Tensordock section in the [network guide](network.md)
110 |
111 | ### How to set up
112 | 1. **Log into Tensordock**: Go to the [Tensordock Deploy Dashboard](https://dashboard.tensordock.com/deploy). Make an account or log in if you have one.
113 | 2. **Set SSH Keys**:
114 | * Go to **Secrets** and click **Add Secret** to add your public SSH key. If you have not made a SSH key yet, follow [this guide for Windows](https://learn.microsoft.com/en-us/viva/glint/setup/sftp-ssh-key-gen) and [this guide for Linux](https://www.digitalocean.com/community/tutorials/how-to-configure-ssh-key-based-authentication-on-a-linux-server).
115 | * Choose a **Name** for your SSH Key, choose **Type** as `SSH Key` and enter your public key value under **Value**. The public key value will look something like this `ssh-rsa ...`
116 | 3. **Deploy a GPU**
117 | * Go to **Deploy GPU** to see available GPUs, and choose a GPU
118 | * E.g. RTX 4000: 16GB VRAM for $0.105/hour.
119 | * Choose a Instance Name, configure the resource with CPU Cores, RAM and Storage options and choose a location and select the OS. We recommend `Ubuntu 24.04 LTS`.
120 | * Click **Deploy Instance**
121 | 4. **Connect to your instance**
122 | * Click on **My Servers** and you should see the newly provisioned GPU instance. You can click the instance to get details about the instance
123 | * Instructions for connecting to the instance using SSH can be found under the **Access** section
124 |
125 | #### CUDA \& Docker setup
126 | You may need to setup your Tensordock instances with NVIDIA toolkit and Docker (if using Docker).
127 |
128 | To install NVIDIA toolkit, run the following commands in your instance CLI:
129 | ```
130 | sudo apt update
131 | sudo apt install -y nvidia-container-toolkit
132 | sudo systemctl restart docker
133 | ```
134 |
135 | To install Docker, run the following commands in your instance CLI:
136 | ```
137 | sudo apt install apt-transport-https ca-certificates curl software-properties-common -y
138 | sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
139 | echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
140 | sudo apt update
141 | sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin -y
142 | sudo groupadd docker
143 | sudo usermod -aG docker $USER
144 | ```
145 |
146 | ## Lambda Labs
147 | Lambda does not support 16GB GPUs, the cheapest options is a RTX 6000 (24 GB VRAM) with 14 vCPUs, 46 GiB RAM, 0.5 TiB SSD for $0.50 / hr.
148 |
149 | ### How to set up
150 | **Step 1: Launch a GPU Instance**
151 |
152 | 1. **Log into Lambda**: Go to the [Lambda instances](https://cloud.lambda.ai/instances). Make an account or log in if you have one.
153 | 2. **Set SSH Keys**:
154 | * Go to **SSH Keys** and add your public SSH key.
155 | 3. **Create a new Instance**:
156 | * Go to **Instances** and select **Launch an Instance** to see available instances, and choose an instance
157 | * E.g. 1x RTX 6000 (24 GB), for $0.50/hour.
158 | * Choose a **Region** and **FileSystem**. If you don't have a filesystem, select **Create a filesystem**
159 | * Click **Launch**
160 |
161 | **Step 2: Edit the Firewall**
162 | 1. Follow the Lambda Labs section in the [network guide](network.md) to edit the firewall to expose a TCP port.
163 |
164 | **Step 3: Connect to the Instance**
165 |
166 | 1. **SSH into your Instance**:
167 | * Once the instance has booted, look at the SSH command under **SSH Login**.
168 |
169 |
--------------------------------------------------------------------------------
/src/node0/utils/node_info.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | import time
17 |
18 | from collections.abc import Callable
19 | from typing import Any
20 |
21 | import psutil
22 | import speedtest
23 | import torch
24 |
25 | from hivemind.utils.logging import get_logger
26 | from pydantic import BaseModel
27 |
28 |
29 | logger = get_logger(__name__)
30 |
31 |
32 | SPEEDTEST_SERVERS = [
33 | {
34 | "url": "http://speedtest-wa.waveip.org:8080/speedtest/upload.php",
35 | "lat": "47.6062",
36 | "lon": "-122.3321",
37 | "name": "Seattle, WA",
38 | "country": "United States",
39 | "cc": "US",
40 | "sponsor": "Wave",
41 | "id": "60635",
42 | "host": "speedtest-wa.waveip.org:8080",
43 | },
44 | {
45 | "url": "http://speedtest.sea1.nitelusa.net:8080/speedtest/upload.php",
46 | "lat": "47.6062",
47 | "lon": "-122.3321",
48 | "name": "Seattle, WA",
49 | "country": "United States",
50 | "cc": "US",
51 | "sponsor": "Nitel",
52 | "id": "12192",
53 | "host": "speedtest.sea1.nitelusa.net:8080",
54 | },
55 | {
56 | "url": "http://us-sea02.speed.misaka.one:8080/speedtest/upload.php",
57 | "lat": "47.6062",
58 | "lon": "-122.3321",
59 | "name": "Seattle, WA",
60 | "country": "United States",
61 | "cc": "US",
62 | "sponsor": "Misaka Network, Inc.",
63 | "id": "50679",
64 | "host": "us-sea02.speed.misaka.one:8080",
65 | },
66 | {
67 | "url": "http://sea-speedtest.net.sangoma.net:8080/speedtest/upload.php",
68 | "lat": "47.6062",
69 | "lon": "-122.3321",
70 | "name": "Seattle, WA",
71 | "country": "United States",
72 | "cc": "US",
73 | "sponsor": "Sangoma",
74 | "id": "63939",
75 | "host": "sea-speedtest.net.sangoma.net:8080",
76 | },
77 | {
78 | "url": "http://wa01svp-speed11.svc.tds.net:8080/speedtest/upload.php",
79 | "lat": "47.6062",
80 | "lon": "-122.3321",
81 | "name": "Seattle, WA",
82 | "country": "United States",
83 | "cc": "US",
84 | "sponsor": "TDS Telecom",
85 | "id": "70208",
86 | "host": "wa01svp-speed11.svc.tds.net:8080",
87 | },
88 | {
89 | "url": "http://speedtest-jp.tgb-host.com:8080/speedtest/upload.php",
90 | "lat": "35.6833",
91 | "lon": "139.6833",
92 | "name": "Tokyo",
93 | "country": "Japan",
94 | "cc": "JP",
95 | "sponsor": "7 BULL",
96 | "id": "65101",
97 | "host": "speedtest-jp.tgb-host.com:8080",
98 | },
99 | {
100 | "url": "http://speedtest.jp230.hnd.jp.ctcsci.com:8080/speedtest/upload.php",
101 | "lat": "35.6833",
102 | "lon": "139.6833",
103 | "name": "Tokyo",
104 | "country": "Japan",
105 | "cc": "JP",
106 | "sponsor": "CTCSCI TECH LTD",
107 | "id": "62217",
108 | "host": "speedtest.jp230.hnd.jp.ctcsci.com:8080",
109 | },
110 | {
111 | "url": "http://speedtest.3s-labo.com:8080/speedtest/upload.php",
112 | "lat": "35.6074",
113 | "lon": "140.1065",
114 | "name": "Chiba",
115 | "country": "Japan",
116 | "cc": "JP",
117 | "sponsor": "3s-labo",
118 | "id": "70451",
119 | "host": "speedtest.3s-labo.com:8080",
120 | },
121 | {
122 | "url": "http://sto-ste-speedtest1.bahnhof.net:8080/speedtest/upload.php",
123 | "lat": "59.3294",
124 | "lon": "18.0686",
125 | "name": "Stockholm",
126 | "country": "Sweden",
127 | "cc": "SE",
128 | "sponsor": "Bahnhof AB",
129 | "id": "34024",
130 | "host": "sto-ste-speedtest1.bahnhof.net:8080",
131 | },
132 | {
133 | "url": "http://fd.sunet.se:8080/speedtest/upload.php",
134 | "lat": "59.3294",
135 | "lon": "18.0686",
136 | "name": "Stockholm",
137 | "country": "Sweden",
138 | "cc": "SE",
139 | "sponsor": "SUNET",
140 | "id": "26852",
141 | "host": "fd.sunet.se:8080",
142 | },
143 | {
144 | "url": "http://speedtest-sth.netatonce.net:8080/speedtest/upload.php",
145 | "lat": "59.3294",
146 | "lon": "18.0686",
147 | "name": "Stockholm",
148 | "country": "Sweden",
149 | "cc": "SE",
150 | "sponsor": "Net at Once Sweden AB",
151 | "id": "63781",
152 | "host": "speedtest-sth.netatonce.net:8080",
153 | },
154 | {
155 | "url": "http://se-speedt02.hy.nis.telia.net:8080/speedtest/upload.php",
156 | "lat": "59.3294",
157 | "lon": "18.0686",
158 | "name": "Stockholm",
159 | "country": "Sweden",
160 | "cc": "SE",
161 | "sponsor": "Telia Sweden AB",
162 | "id": "45936",
163 | "host": "se-speedt02.hy.nis.telia.net:8080",
164 | },
165 | {
166 | "url": "http://speedtest-sth.84grams.net:8080/speedtest/upload.php",
167 | "lat": "59.3294",
168 | "lon": "18.0686",
169 | "name": "Stockholm",
170 | "country": "Sweden",
171 | "cc": "SE",
172 | "sponsor": "84 Grams AB",
173 | "id": "53521",
174 | "host": "speedtest-sth.84grams.net:8080",
175 | },
176 | ]
177 |
178 |
179 | class NodeInfo(BaseModel):
180 | device_name: str
181 | gpu_memory: float # GB
182 | ram: float # GB
183 | download_speed: float | None
184 | upload_speed: float | None
185 | latency: float | None
186 |
187 |
188 | class NonRetriableError(Exception):
189 | pass
190 |
191 |
192 | class RetriableError(Exception):
193 | pass
194 |
195 |
196 | class NotInAllowlistError(NonRetriableError):
197 | pass
198 |
199 |
200 | class BadRequestError(NonRetriableError):
201 | pass
202 |
203 |
204 | class IntegrityError(NonRetriableError):
205 | pass
206 |
207 |
208 | class ServerUnavailableError(RetriableError):
209 | pass
210 |
211 |
212 | class TestServerError(NonRetriableError):
213 | pass
214 |
215 |
216 | def call_with_retries(func: Callable, n_retries: int = 10, initial_delay: float = 1.0) -> Any:
217 | """Call the function with retries.
218 |
219 | Args:
220 | func (Callable): function to call
221 | n_retries (int, optional): number of retries attempts. Defaults to 10.
222 | initial_delay (float, optional): delay in sec between attempts. Defaults to 1.0.
223 |
224 | Returns:
225 | Any: output of the function
226 | """
227 | i = 0
228 | while True:
229 | try:
230 | i += 1
231 | return func()
232 | except NonRetriableError:
233 | raise
234 | except ServerUnavailableError as e:
235 | error_msg = str(e)
236 | if "Our servers are currently at full capacity" in error_msg:
237 | match = re.search(r"Retry in (\d+) s", error_msg)
238 | if match:
239 | delay = int(match.group(1))
240 | logger.warning(
241 | f"Failed to call function with exception: Our servers are currently at full capacity. Retrying in {delay} sec"
242 | )
243 | time.sleep(delay)
244 | else:
245 | raise
246 | else:
247 | if i >= n_retries:
248 | raise
249 |
250 | delay = initial_delay * (2**i)
251 | logger.warning(f"Failed to call function with exception: {e}. Retrying in {delay:.1f} sec")
252 | time.sleep(delay)
253 | except Exception as e:
254 | if i >= n_retries:
255 | raise
256 |
257 | delay = initial_delay * (2**i)
258 | logger.warning(f"Failed to call function with exception: {e}. Retrying in {delay:.1f} sec")
259 | time.sleep(delay)
260 |
261 |
262 | def robust_internet_speed() -> tuple[float | None, float | None, float | None]:
263 | try:
264 | return call_with_retries(test_internet_speed)
265 | except Exception:
266 | logger.error("An error occurred during the speed test, skipping")
267 | return (None, None, None)
268 |
269 |
270 | def test_internet_speed() -> tuple[float, float, float]:
271 | """Measure download/upload internet speed."""
272 | logger.info("Testing internet speed...")
273 |
274 | st = speedtest.Speedtest(secure=True)
275 | try:
276 | st.get_best_server(SPEEDTEST_SERVERS)
277 | logger.info(f"Best speed test server: {st.best['country']}")
278 | except Exception:
279 | pass
280 |
281 | # Perform the download speed test
282 | download_speed = st.download() / 1000000 # Convert to Mbps
283 |
284 | # Perform the upload speed test
285 | upload_speed = st.upload() / 1000000 # Convert to Mbps
286 |
287 | # Latency
288 | latency = float(st.results.ping) # ms
289 |
290 | # Print the results
291 | logger.info(f"Download Speed: {download_speed:.2f} Mbps")
292 | logger.info(f"Upload Speed: {upload_speed:.2f} Mbps")
293 | logger.info(f"Latency: {latency:.2f} ms")
294 |
295 | return (download_speed, upload_speed, latency)
296 |
297 |
298 | def get_device_info() -> tuple[str, float, float]:
299 | """Get device name and memory"""
300 | ram = float(psutil.virtual_memory().total) / 1024**3
301 |
302 | if not torch.cuda.is_available():
303 | logger.error("CUDA is not available. Exiting run.")
304 | exit(1)
305 |
306 | device_info = torch.cuda.get_device_properties()
307 | device_name = device_info.name
308 | gpu_memory = device_info.total_memory / 1024**3 # GB
309 | return device_name, gpu_memory, ram
310 |
311 |
312 | def get_node_info() -> NodeInfo:
313 | """Collect information about the node."""
314 | device_name, gpu_memory, ram = get_device_info()
315 | download_speed, upload_speed, latency = robust_internet_speed()
316 |
317 | node_info = NodeInfo(
318 | device_name=device_name,
319 | gpu_memory=gpu_memory,
320 | ram=ram,
321 | download_speed=download_speed,
322 | upload_speed=upload_speed,
323 | latency=latency,
324 | )
325 | return node_info
326 |
--------------------------------------------------------------------------------
/src/node0/run_server.py:
--------------------------------------------------------------------------------
1 | # This file contains code originally from Hivemind under MIT License
2 | # Original: Copyright 2020 Learning@home authors and collaborators
3 | # Modified by: Pluralis Research 2025
4 | #
5 | # Original code: MIT License (see THIRD_PARTY_LICENSES)
6 | # Modifications: Apache 2.0 License (see LICENSE)
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only;
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | import argparse
13 | import json
14 | import multiprocessing as mp
15 | import os
16 | import platform
17 |
18 | from functools import partial
19 | from pathlib import Path
20 |
21 | import torch
22 | import yaml
23 |
24 | from hivemind.compression import NoCompression, ScaledFloat16Compression
25 | from hivemind.proto.runtime_pb2 import CompressionType
26 | from hivemind.utils.logging import get_logger
27 |
28 | from node0.security.authorization import authorize_with_pluralis
29 | from node0.security.validation import make_validators
30 | from node0.server.HM_gradient_averager import GradientAverager
31 | from node0.server.node0_server import Node0Server
32 | from node0.server.optim import AutoStepOptimizer
33 | from node0.utils import (
34 | MonitorWorker,
35 | Node0Logger,
36 | build_cls,
37 | clean_tmp,
38 | infer_expert_params,
39 | update_initial_peers,
40 | )
41 | from node0.utils.mem_monitor import MemoryTracker
42 | from node0.utils.network_throughput import NetworkMonitor
43 | from node0.utils.node_info import get_node_info
44 |
45 |
46 | logger = get_logger(__name__)
47 |
48 | if platform.system().lower() == "darwin":
49 | # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
50 | os.environ.setdefault("no_proxy", "*")
51 | os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES")
52 | mp.set_start_method("fork", force=True)
53 |
54 |
55 | def parse_args():
56 | parser = argparse.ArgumentParser()
57 | # Connection arguments
58 | parser.add_argument(
59 | "--host_maddrs",
60 | type=str,
61 | help="Multiaddrs to listen for external connections from other p2p instances",
62 | )
63 | parser.add_argument(
64 | "--announce_maddrs",
65 | type=str,
66 | help="Visible multiaddrs the host announces for external connections from other p2p instances",
67 | )
68 | parser.add_argument(
69 | "--initial_peers",
70 | type=str,
71 | nargs="*",
72 | required=True,
73 | default=[],
74 | help="Multiaddrs of one or more active DHT peers",
75 | )
76 |
77 | # Experiment arguments
78 | parser.add_argument("--run_config", type=str, required=True, help="Run configuration file.")
79 | parser.add_argument(
80 | "--custom_module_path",
81 | type=str,
82 | help="Path of a file with custom nn.modules, wrapped into special decorator",
83 | )
84 | parser.add_argument(
85 | "--num_handlers",
86 | type=int,
87 | default=5,
88 | help="Server will use this many processes to handle incoming requests",
89 | )
90 |
91 | # Authorization parameters
92 | parser.add_argument(
93 | "--identity_path",
94 | type=str,
95 | default="private.key",
96 | help="Path to identity file to be used in P2P",
97 | )
98 | parser.add_argument("--token", type=str, required=True, help="HuggingFace token")
99 | parser.add_argument("--email", type=str, default="", help="Email address")
100 | parser.add_argument("--auth_server", type=str, required=True, help="Authentication server URL")
101 |
102 | args = parser.parse_args()
103 | args = vars(args)
104 |
105 | # Read run config file
106 | with open(args.pop("run_config")) as f:
107 | run_config = yaml.safe_load(f)
108 |
109 | # Combine arguments
110 | run_config.update({k: v for k, v in args.items() if v is not None})
111 | return run_config
112 |
113 |
114 | def main():
115 | if not torch.backends.openmp.is_available():
116 | # Necessary to prevent the server from freezing after forks
117 | torch.set_num_threads(1)
118 |
119 | # Parse arguments
120 | args = parse_args()
121 |
122 | # Set up logging
123 | log_level = args.pop("log_level", "info").upper()
124 | node0_logger = Node0Logger(log_level=log_level)
125 |
126 | # Logging and monitoring
127 | save_clean_logs = args.pop("save_clean_logs", False)
128 | terminate_AR_fail = args.pop("terminate_AR_fail", True)
129 | experiment_prefix = args.pop("experiment_prefix", "pluralis")
130 | monitor = MonitorWorker(
131 | stats_report_interval=args["stats_report_interval"] if "stats_report_interval" in args else 60,
132 | experiment_prefix=experiment_prefix,
133 | log_file="logs/server.log",
134 | save_clean_logs=save_clean_logs,
135 | terminate_AR_fail=terminate_AR_fail,
136 | )
137 | node0_logger.add_monitor_handler(monitor.queue_handler)
138 | monitor.start()
139 |
140 | logger.info(f"Running with configuration: {json.dumps(args, indent=4)}")
141 |
142 | # Check pytorch version
143 | if not torch.__version__.startswith("2.7"):
144 | logger.error("Wrong pytorch version. Please install torch 2.7")
145 | exit(1)
146 |
147 | # Clean tmp folder
148 | clean_tmp()
149 |
150 | # Collect information about the node
151 | node_info = get_node_info()
152 |
153 | # Authorize
154 | check_integrity = args.pop("check_integrity", True)
155 | authorizer = authorize_with_pluralis(
156 | node_info=node_info,
157 | user_token=args.pop("token"),
158 | user_email=args.pop("email"),
159 | role="worker",
160 | auth_server=args.pop("auth_server"),
161 | identity_path=args["identity_path"],
162 | current_path=Path(__file__).resolve().parent,
163 | announce_maddrs=args["announce_maddrs"],
164 | host_port=int(args["host_maddrs"].split("/")[4]),
165 | check_integrity=check_integrity,
166 | )
167 | pipeline_stage = str(authorizer.pipeline_stage)
168 | args["host_maddrs"] = [args["host_maddrs"]]
169 | args["announce_maddrs"] = [args["announce_maddrs"]]
170 |
171 | expert_params = infer_expert_params(pipeline_stage)
172 | stage_name = expert_params.pop("stage_name")
173 | args.update(expert_params)
174 |
175 | # Add validators
176 | validators, public_key = make_validators(
177 | experiment_prefix=experiment_prefix,
178 | peer_id=authorizer.peer_id,
179 | stage=stage_name,
180 | )
181 |
182 | # Add expert parameters to args
183 | num_worker_dhts = args.pop("num_dht") - 1
184 | args["initial_peers"] = update_initial_peers(
185 | args["initial_peers"], pipeline_stage, args["num_stages"], num_dht=num_worker_dhts
186 | )
187 |
188 | # Add authorization details to log monitor
189 | monitor.add_auth_info(
190 | authorizer=authorizer,
191 | peer_id=authorizer.peer_id,
192 | stage=stage_name,
193 | local_public_key=public_key,
194 | )
195 |
196 | # Set BW for the peer
197 | default_bandwidth = args.pop("bandwidth", 20)
198 | if node_info.upload_speed:
199 | bandwidth = node_info.upload_speed
200 | else:
201 | bandwidth = default_bandwidth
202 |
203 | # Build model arguments
204 | model_config = args.pop("model_config")
205 | model_args = build_cls(model_config["class_name"], model_config["init_args"])
206 | model_args.stage = args.pop("stage")
207 |
208 | # Build optimizer
209 | optim_config = args.pop("optim_config")
210 | optim_cls = build_cls(optim_config["class_name"], optim_config["init_args"], partial_init=True)
211 |
212 | # Build gradient averager
213 | grad_avg_config = args.pop("grad_avg_config", None)
214 | is_tail = "tail" in expert_params["expert_pattern"]
215 | if "request_timeout" in args:
216 | request_timeout = args["request_timeout"]
217 | else:
218 | request_timeout = 3.0
219 | if (
220 | grad_avg_config
221 | and ("averager_rank" in grad_avg_config["init_args"])
222 | and grad_avg_config["init_args"]["averager_rank"] > 0
223 | and not is_tail
224 | ):
225 | # If detected to be mac, replace psgd
226 | if platform.system().lower() == "darwin":
227 | grad_avg_config["class_name"] = "node0.server.power_sgd_averager_mac.PowerSGDGradientAverager"
228 | grad_avg_config["init_args"]["request_timeout"] = request_timeout
229 | grad_avg_factory = build_cls(grad_avg_config["class_name"], grad_avg_config["init_args"], partial_init=True)
230 | elif (
231 | grad_avg_config
232 | and ("grad_compression_factor" in grad_avg_config["init_args"])
233 | and grad_avg_config["init_args"]["grad_compression_factor"] > 0
234 | and not is_tail
235 | ):
236 | grad_avg_config["init_args"]["request_timeout"] = request_timeout
237 | grad_avg_factory = build_cls(grad_avg_config["class_name"], grad_avg_config["init_args"], partial_init=True)
238 | else:
239 | grad_avg_factory = partial(GradientAverager, bandwidth=bandwidth, request_timeout=request_timeout)
240 |
241 | # Optimizer compressions
242 | grad_averaging_compression = args.pop("grad_averaging_compression", "NoCompression")
243 | if grad_averaging_compression == "NoCompression":
244 | grad_averaging_compression = NoCompression()
245 | elif grad_averaging_compression == "Float16Compression":
246 | grad_averaging_compression = ScaledFloat16Compression()
247 | else:
248 | raise ValueError("grad_averaging_compression must be NoCompression or Float16Compression")
249 |
250 | load_state_compression = args.pop("load_state_compression", "NoCompression")
251 | if load_state_compression == "NoCompression":
252 | load_state_compression = NoCompression()
253 | elif load_state_compression == "Float16Compression":
254 | load_state_compression = ScaledFloat16Compression()
255 | else:
256 | raise ValueError("load_state_compression must be NoCompression or Float16Compression")
257 |
258 | # Compression
259 | compression_type = args.pop("compression")
260 | compression = getattr(CompressionType, compression_type)
261 |
262 | # Select device
263 | if node_info.device_name == "cpu":
264 | device = "cpu"
265 | elif node_info.device_name.startswith("mps"):
266 | device = "mps"
267 | else:
268 | device = "cuda"
269 |
270 | logger.info(f"Using {device} device")
271 |
272 | # Log machine usage
273 | log_machine_usage = args.pop("log_machine_usage", False)
274 | if log_machine_usage:
275 | NetworkMonitor()
276 | MemoryTracker()
277 |
278 | # Start server
279 | server = Node0Server.create(
280 | model_conf=model_args,
281 | optim_cls=optim_cls,
282 | grad_avg_factory=grad_avg_factory,
283 | optim_collab_cls=AutoStepOptimizer,
284 | grad_averaging_compression=grad_averaging_compression,
285 | load_state_compression=load_state_compression,
286 | start=True,
287 | compression=compression,
288 | record_validators=validators,
289 | authorizer=authorizer,
290 | monitor=monitor,
291 | upload_bw=bandwidth,
292 | device=device,
293 | **args,
294 | )
295 |
296 | try:
297 | server.join()
298 | except KeyboardInterrupt:
299 | logger.info("Caught KeyboardInterrupt, shutting down")
300 |
301 |
302 | if __name__ == "__main__":
303 | main()
304 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/src/node0/server/matchmaking.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import asyncio
16 | import re
17 | import threading
18 | import time
19 |
20 | from typing import List, Optional, Tuple
21 |
22 | import numpy as np
23 |
24 | from hivemind.averaging.control import StepControl
25 | from hivemind.averaging.group_info import GroupInfo
26 | from hivemind.dht import DHT
27 | from hivemind.moe.expert_uid import ExpertUID
28 | from hivemind.p2p import PeerID
29 | from hivemind.utils import DHTExpiration, TimedStorage, ValueWithExpiration, get_dht_time, get_logger
30 |
31 |
32 | GroupKey = str
33 | Endpoint = str
34 | GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$") # e.g. bert_exp4_averaging.0b01001101
35 | logger = get_logger(__name__)
36 |
37 |
38 | def is_valid_group(maybe_group: str) -> bool:
39 | """A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata"""
40 | return bool(GROUP_PATTERN.fullmatch(maybe_group))
41 |
42 |
43 | class GroupKeyManager:
44 | """
45 | Simplified utility class that manages a single fixed group for all nodes
46 | """
47 |
48 | def __init__(
49 | self,
50 | dht: DHT,
51 | prefix: str,
52 | fixed_group_key: Optional[str] = None,
53 | ):
54 | self.dht = dht
55 | self.prefix = prefix
56 | self.peer_id = dht.peer_id
57 |
58 | # Use a fixed group key - either provided or default to "global"
59 | if fixed_group_key:
60 | if not is_valid_group(fixed_group_key):
61 | raise ValueError(f"Invalid fixed group key: {fixed_group_key}")
62 | self._fixed_key = fixed_group_key
63 | else:
64 | # Default fixed group key with empty bits (all nodes use same group)
65 | self._fixed_key = f"{self.prefix}.0b"
66 | self.group_bits = ""
67 |
68 | @property
69 | def current_key(self) -> GroupKey:
70 | """Return the fixed group key that all nodes use"""
71 | return self._fixed_key
72 |
73 | async def declare_averager(self, peer_id: PeerID, expiration_time: float, looking_for_group: bool = True) -> bool:
74 | """
75 | Add (or remove) the averager to the fixed group
76 |
77 | :param peer_id: averager public peer_id for incoming requests
78 | :param expiration_time: intent to run allreduce before this timestamp
79 | :param looking_for_group: by default (True), declare the averager as "looking for group";
80 | If False, mark that the averager is no longer looking for group
81 | :return: True if declared, False if declaration was rejected by DHT peers
82 | """
83 | expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
84 | return await self.dht.store(
85 | key=self._fixed_key,
86 | subkey=peer_id.to_bytes(),
87 | value=looking_for_group,
88 | expiration_time=expiration_time,
89 | return_future=True,
90 | )
91 |
92 | async def get_averagers(self, only_active: bool = True) -> List[Tuple[PeerID, DHTExpiration]]:
93 | """
94 | Find and return averagers in the fixed group
95 |
96 | :param only_active: if True, return only active averagers that are looking for group
97 | if False, return all averagers in the group regardless of status
98 | :return: peer_ids and expirations of every matching averager
99 | """
100 | result = await self.dht.get(self._fixed_key, latest=True, return_future=True)
101 | if result is None or not isinstance(result.value, dict):
102 | logger.debug(f"Fixed group not found: {self._fixed_key}, creating new group")
103 | return []
104 |
105 | averagers = []
106 | for key, looking_for_group in result.value.items():
107 | try:
108 | if only_active and not looking_for_group.value:
109 | continue
110 | averagers.append((PeerID(key), looking_for_group.expiration_time))
111 | except Exception as e:
112 | logger.warning(f"Could not parse peer key {key} ({looking_for_group}, exc={e})")
113 | return averagers
114 |
115 | async def join_group(self, expiration_time: float) -> bool:
116 | """
117 | Join the fixed group (convenience method)
118 |
119 | :param expiration_time: when this declaration expires
120 | :return: True if successfully joined
121 | """
122 | return await self.declare_averager(self.peer_id, expiration_time, looking_for_group=True)
123 |
124 | async def leave_group(self, expiration_time: float) -> bool:
125 | """
126 | Leave the fixed group (convenience method)
127 |
128 | :param expiration_time: original expiration time used when joining
129 | :return: True if successfully left
130 | """
131 | return await self.declare_averager(self.peer_id, expiration_time, looking_for_group=False)
132 |
133 |
134 | class Matchmaking:
135 | """
136 | Simplified matchmaking that works with a single fixed group
137 |
138 | :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions.
139 | :param prefix: Prefix of the stage. i.e head, body0, tail.
140 | :param request_timeout: This timeout is backward compatible with HM Matchmaking. It is used to cancel matchmaking in case of no responses.
141 | :param min_matchmaking_time: How long before matchmaking operation is cancelled
142 | :param check_interval: How often to poll the DHT to check for new peers in the group
143 | :param update_period: How often to update the peer table with all peers in the stage
144 | """
145 |
146 | def __init__(
147 | self,
148 | dht: DHT,
149 | prefix: str,
150 | request_timeout: float = 5.0,
151 | min_matchmaking_time: float = 10.0,
152 | check_interval: float = 1.0,
153 | update_period: float = 3.0,
154 | ):
155 | self.group_key_manager = GroupKeyManager(dht, prefix)
156 |
157 | self.dht = dht
158 | self.prefix = prefix
159 | self.peer_id = self.group_key_manager.peer_id
160 | self.request_timeout = request_timeout
161 | self.min_matchmaking_time = min_matchmaking_time
162 | self.check_interval = check_interval
163 |
164 | # Parameters for update peer table
165 | self.update_period = update_period
166 | self.peer_table = TimedStorage[ExpertUID, PeerID]()
167 | self.is_alive = threading.Event()
168 | self.is_alive.set()
169 | self.update_trigger, self.update_finished = threading.Event(), threading.Event()
170 | self.update_period, self.last_update = update_period, get_dht_time()
171 | self.update_thread = threading.Thread(target=self.update_peers_in_background, daemon=True)
172 | self.update_thread.start()
173 |
174 | @property
175 | def max_peers(self):
176 | return len(self.peer_table)
177 |
178 | @property
179 | def peer_set(self):
180 | res = set()
181 | for index, expert_info in self.peer_table.items():
182 | res.add(expert_info.value._bytes)
183 | return res
184 |
185 | def update_peers_in_background(self):
186 | while self.is_alive.is_set():
187 | time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time())
188 | try:
189 | self.update_trigger.wait(timeout=time_to_next_update)
190 | # update triggered by main thread
191 | except TimeoutError:
192 | pass # update triggered by refresh_period
193 |
194 | self.update_trigger.clear()
195 | response = self.dht.get(self.prefix.split("_")[0] + ".0.", latest=True)
196 | if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict):
197 | for index, expert_info in response.value.items():
198 | try:
199 | (uid, endpoint), expiration_time = expert_info
200 | self._add_peer(uid, endpoint, expiration_time)
201 | except Exception as e:
202 | logger.warning(f"Skipping malformed peer info {expert_info} (exc={e})")
203 | else:
204 | logger.warning(
205 | f"Could not refresh peer, dht info key contains {response}, will retry in {time_to_next_update}s"
206 | )
207 | self.last_update = get_dht_time()
208 | self.update_finished.set()
209 |
210 | def _add_peer(self, uid: ExpertUID, endpoint: Endpoint, expiration_time: DHTExpiration):
211 | self.peer_table.store(uid, PeerID(endpoint), expiration_time)
212 | logger.debug(f"Storing peer: {uid}, expiration time = {expiration_time:.3f}.")
213 |
214 | async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
215 | """
216 | Look for peers in the fixed group and form a group if enough peers are available
217 |
218 | :param step: To get the step schedule time
219 | :return: GroupInfo if group formed successfully, None otherwise
220 | """
221 | timeout = self.min_matchmaking_time
222 |
223 | new_expiration_time = float(get_dht_time() + timeout)
224 | await self.group_key_manager.join_group(new_expiration_time)
225 |
226 | # Wait and retry logic
227 | start_time = time.time()
228 |
229 | # Accumulate all peers that issue join_group. Wait to match peer_table
230 | all_peerIds = set()
231 | while time.time() - start_time < timeout:
232 | # Get all active averagers in the group
233 | averagers = await self.group_key_manager.get_averagers(only_active=True)
234 |
235 | _peerIds = {peer_id.to_bytes() for peer_id, _ in averagers}
236 | all_peerIds = all_peerIds.union(_peerIds)
237 |
238 | # We have enough peers, proceed with group formation
239 | if (len(all_peerIds) == self.max_peers) and (self.max_peers > 0):
240 | break
241 |
242 | # Wait for either the peer_table to populate or to find all peers in the table
243 | logger.debug(f"Not enough peers yet: {len(all_peerIds)} < {self.max_peers}, waiting...")
244 | await asyncio.sleep(self.check_interval)
245 |
246 | if len(all_peerIds) == 0:
247 | # Timeout reached without finding enough peers
248 | logger.info(f"Timeout: Not any peers in group")
249 | return None
250 |
251 | # Create group info with all available peers
252 | all_peerIds = sorted(list(all_peerIds))
253 | peer_ids = [PeerID(peer_id) for peer_id in all_peerIds]
254 |
255 | # Create a deterministic group ID based on sorted peer IDs
256 | sorted_peer_ids = sorted([str(pid) for pid in peer_ids])
257 | group_id = b"O[\x9aU\xcf%\xf0(\x90Nq\xdf!\x8b\x85)&\x0c\xe9r"
258 | gathered = tuple(step.data_for_gather for peer_id in sorted_peer_ids)
259 |
260 | group_info = GroupInfo(group_id=group_id, peer_ids=tuple(peer_ids), gathered=gathered)
261 |
262 | end_time = time.time()
263 | logger.extra(
264 | f"Formed group with {len(peer_ids)} peers out of {self.max_peers} in {end_time - start_time:.3f} secs"
265 | )
266 | return group_info
267 |
268 | async def leave_group(self, expiration_time: float) -> bool:
269 | """
270 | Leave the current group
271 |
272 | :param expiration_time: original expiration time
273 | :return: True if successfully left
274 | """
275 | return await self.group_key_manager.leave_group(expiration_time)
276 |
277 |
278 | class MatchmakingException(Exception):
279 | """An internal exception that marks undesired edge cases during averaging"""
280 |
--------------------------------------------------------------------------------
/src/node0/server/state_averager_wrap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import asyncio
16 | import math
17 | import os
18 | import random
19 | import signal
20 | import time
21 |
22 | from typing import Any, Optional
23 |
24 | import torch
25 |
26 | from hivemind.averaging.allreduce import AveragingMode
27 | from hivemind.averaging.group_info import GroupInfo
28 | from hivemind.averaging.load_balancing import load_balance_peers
29 | from hivemind.compression import CompressionInfo, deserialize_torch_tensor
30 | from hivemind.p2p import PeerID
31 | from hivemind.proto import averaging_pb2
32 | from hivemind.utils import MPFuture, get_logger
33 | from hivemind.utils.asyncio import aiter_with_timeout, enter_asynchronously
34 | from hivemind.utils.streaming import combine_from_streaming
35 | from hivemind.utils.tensor_descr import TensorDescriptor
36 | from hivemind.utils.timed_storage import ValueWithExpiration
37 |
38 | from node0.server.HM_state_averager import (
39 | TrainingStateAverager as HivemindTrainingStateAverager,
40 | )
41 | from node0.server.matchmaking import MatchmakingException
42 |
43 |
44 | GatheredData = Any
45 | logger = get_logger(__name__)
46 |
47 |
48 | class IndexSelector:
49 | def __init__(self, p):
50 | self.state = {}
51 | self.p = p
52 |
53 | def get_indices(self, param):
54 | return torch.ones(param.shape).bool()
55 |
56 |
57 | class PartitionedIndexSelector(IndexSelector):
58 | def __init__(self, p, param):
59 | super().__init__(p)
60 | self.state[param] = {}
61 | self._set_partition(param)
62 |
63 | def _set_partition(self, param):
64 | param_state = self.state[param]
65 | param_state["num_partitions"] = min(math.ceil(1 / self.p), param.numel())
66 | param_state["partitions"] = (
67 | torch.rand(param.numel(), device=param.device).argsort().view(param.shape) % param_state["num_partitions"]
68 | )
69 |
70 | def get_indices(self, param, curr_partition):
71 | curr_partition = curr_partition % self.state[param]["num_partitions"]
72 | indices = (self.state[param]["partitions"] == curr_partition).bool()
73 |
74 | return indices
75 |
76 |
77 | class TrainingStateAverager(HivemindTrainingStateAverager):
78 | """
79 | A class that extends Hivemind TrainingStateAverager and prevents too many
80 | consecutive calls to load_state_from_peers.
81 | """
82 |
83 | def __init__(
84 | self,
85 | sparse_avg=0.0,
86 | average_state_every=1,
87 | call_limit: int = 1,
88 | *args,
89 | **kwargs,
90 | ):
91 | kwargs["start"] = False
92 | super().__init__(*args, **kwargs)
93 |
94 | self.zclip_warmup = None
95 | if hasattr(self, "zclip"):
96 | self.zclip_warmup = self.zclip.warmup_steps
97 |
98 | self._call_limit = call_limit
99 | self._consecutive_fails = 0
100 | self._request_timeout = kwargs["request_timeout"]
101 | self.sparse_avg = sparse_avg
102 | self.average_state_every = average_state_every
103 | self.partition_selector = []
104 |
105 | def set_sparta_partitions(self):
106 | with self.get_tensors() as local_tensors:
107 | torch.manual_seed(1337)
108 | for i, p in enumerate(local_tensors):
109 | self.partition_selector.append(PartitionedIndexSelector(self.sparse_avg, p))
110 |
111 | async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
112 | # Adapted from /hivemind/averaging/averager.py
113 | if timeout is not None:
114 | timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
115 | try:
116 | key_manager = self._matchmaking.group_key_manager
117 | peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
118 | peer_priority = {
119 | PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker
120 | for peer_id, info in peer_priority.items()
121 | if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
122 | }
123 |
124 | if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
125 | logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
126 | future.set_result(None)
127 | return
128 |
129 | metadata = None
130 | for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True):
131 | if peer != self.peer_id:
132 | t0 = time.monotonic()
133 | logger.info(f"Downloading parameters from peer {peer}")
134 | try:
135 | stub = self.get_stub(self._p2p, peer, namespace=self.prefix)
136 | stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
137 | current_tensor_parts, tensors = [], []
138 |
139 | async for message in aiter_with_timeout(stream, timeout=timeout):
140 | if message.metadata:
141 | metadata = self.serializer.loads(message.metadata)
142 | if message.tensor_part.dtype and current_tensor_parts:
143 | # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
144 | tensor = deserialize_torch_tensor(combine_from_streaming(current_tensor_parts))
145 | tensors.append(tensor)
146 | current_tensor_parts = []
147 | current_tensor_parts.append(message.tensor_part)
148 |
149 | if current_tensor_parts:
150 | tensor = deserialize_torch_tensor(combine_from_streaming(current_tensor_parts))
151 | tensors.append(tensor)
152 |
153 | if not metadata:
154 | logger.debug(f"Peer {peer} did not send its state")
155 | continue
156 |
157 | t1 = time.monotonic()
158 | logger.info(f"Finished downloading state in {t1 - t0:.3f}s from {peer}")
159 | self._consecutive_fails = 0
160 | # Check if any gradient contains NaN or inf values
161 | has_nans = any(not torch.isfinite(t).all() for t in tensors)
162 | if has_nans:
163 | logger.error(f"Failed to load state from peer.")
164 | logger.error(f"Downloaded state contains invalid values. Exiting the run.")
165 | os.killpg(os.getpgrp(), signal.SIGTERM)
166 | future.set_result((metadata, tensors))
167 | return
168 | except Exception as e:
169 | self._consecutive_fails = self._consecutive_fails + 1
170 |
171 | if isinstance(e, TimeoutError) and self._consecutive_fails < self._call_limit:
172 | logger.info(
173 | f"{self._consecutive_fails}/{self._call_limit} load state timeout before ending session."
174 | )
175 |
176 | if isinstance(e, TimeoutError) and self._consecutive_fails >= self._call_limit:
177 | logger.error(
178 | f"Failed to load state from peers. "
179 | "Too many TimeoutErrors were caught when trying to _load_state_from_peers. "
180 | "This problem may occur due to slow internet connection, or temporary overload of the peer-to-peer network. Exiting run."
181 | )
182 | os.killpg(os.getpgrp(), signal.SIGTERM)
183 | else:
184 | logger.error(
185 | f"Failed to load state from {peer} - {repr(e)}. Exiting run.",
186 | exc_info=logger.getEffectiveLevel() <= 15,
187 | )
188 | os.killpg(os.getpgrp(), signal.SIGTERM)
189 |
190 | finally:
191 | if not future.done():
192 | future.set_result(None)
193 |
194 | async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
195 | """Run sparse aggregation in a given group and update tensors in place, return gathered metadata"""
196 | try:
197 | bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
198 | user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
199 | modes = tuple(map(AveragingMode, mode_ids))
200 |
201 | # compute optimal part sizes from peer bandwidths;
202 | download_bandwidths = [
203 | thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
204 | ]
205 | peer_fractions = await asyncio.get_event_loop().run_in_executor(
206 | None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
207 | )
208 |
209 | epoch = self.global_epoch
210 | snapped_epoch = int(round(epoch / self.average_state_every))
211 | async with enter_asynchronously(self.get_tensors()) as local_tensors:
212 | if self.sparse_avg > 0:
213 | tensor_idxs = [] # index of tensor in local_tensors
214 | sparse_idxs = [] # the sparse indices that index into the averaged tensor
215 | sparse_tensor = [] # the sparse tensor to be averaged
216 |
217 | tensor_infos = kwargs["tensor_infos"]
218 | for i, val in enumerate(zip(local_tensors, tensor_infos)):
219 | p, ti = val
220 | should_include = False
221 |
222 | # Determine if tensor should be included
223 | if i < len(self.main_parameters):
224 | if self.main_parameters[i].requires_grad:
225 | should_include = True
226 | else:
227 | should_include = True
228 | if self.zclip_warmup:
229 | # Start averaging zclip (mean,var) after 2*zclip_warmup to account for
230 | # peers that joined between step 0 and zclip warmup
231 | epoch_round = int(snapped_epoch * self.average_state_every)
232 | zclip_in_warmup = epoch_round <= 2 * self.zclip_warmup
233 | is_zclip = ti.key == "zclip_mean" or ti.key == "zclip_var"
234 | if zclip_in_warmup and is_zclip:
235 | should_include = False
236 |
237 | # Only process tensors that should be included
238 | if should_include:
239 | tensor_idxs.append(i)
240 | _idx = self.partition_selector[i].get_indices(p, snapped_epoch)
241 | sparse_tensor.append(p[_idx].contiguous())
242 | sparse_idxs.append(_idx)
243 |
244 | # Build tensor info using proper indexing
245 | tensor_infos_sparse = []
246 | for sparse_idx, tensor_idx in enumerate(tensor_idxs):
247 | desc = TensorDescriptor(
248 | size=sparse_tensor[sparse_idx].shape,
249 | dtype=sparse_tensor[sparse_idx].dtype,
250 | device=sparse_tensor[sparse_idx].device,
251 | requires_grad=local_tensors[tensor_idx].requires_grad,
252 | )
253 | tensor_infos_sparse.append(CompressionInfo(key=sparse_idx, descriptor=desc))
254 |
255 | tensor_infos_sparse = tuple(tensor_infos_sparse)
256 | kwargs["tensor_infos"] = tensor_infos_sparse
257 | await self._run_allreduce_inplace_(
258 | sparse_tensor, group_info, peer_fractions=peer_fractions, **kwargs
259 | )
260 |
261 | # Copy results back using proper indexing
262 | for sparse_idx, tensor_idx in enumerate(tensor_idxs):
263 | local_tensors[tensor_idx][sparse_idxs[sparse_idx]] = sparse_tensor[sparse_idx]
264 | else:
265 | await self._run_allreduce_inplace_(
266 | local_tensors, group_info, peer_fractions=peer_fractions, **kwargs
267 | )
268 | return user_gathered
269 | except BaseException as e:
270 | if isinstance(e, Exception):
271 | logger.error(e, exc_info=logger.getEffectiveLevel() <= 15)
272 | raise MatchmakingException(f"Unable to run All-Reduce: {e}")
273 |
--------------------------------------------------------------------------------
/src/node0/security/authorization.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import signal
17 |
18 | from datetime import datetime, timedelta, timezone
19 | from functools import partial
20 | from pathlib import Path
21 |
22 | import requests
23 |
24 | from cryptography.hazmat.primitives import serialization
25 | from hivemind import PeerID
26 | from hivemind.proto import crypto_pb2
27 | from hivemind.proto.auth_pb2 import AccessToken
28 | from hivemind.utils.auth import TokenAuthorizerBase
29 | from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey
30 | from hivemind.utils.logging import get_logger
31 |
32 | from node0.security.integrity_check import verify_integrity
33 | from node0.utils.connection_test_server import TestServer
34 | from node0.utils.node_info import (
35 | BadRequestError,
36 | IntegrityError,
37 | NodeInfo,
38 | NotInAllowlistError,
39 | ServerUnavailableError,
40 | TestServerError,
41 | call_with_retries,
42 | )
43 |
44 |
45 | logger = get_logger(__name__)
46 |
47 |
48 | class PluralisAuthorizer(TokenAuthorizerBase):
49 | def __init__(
50 | self,
51 | peer_id: str,
52 | user_token: str,
53 | user_email: str,
54 | role: str,
55 | auth_server: str,
56 | node_info: NodeInfo,
57 | current_path: Path,
58 | announce_maddrs: str,
59 | host_port: int,
60 | check_integrity: bool,
61 | ):
62 | super().__init__()
63 |
64 | self.peer_id = peer_id
65 | self._user_token = user_token
66 | self._user_email = user_email
67 | self._role = role
68 | self._auth_server = auth_server
69 | self._node_info = node_info
70 | self._current_path = current_path
71 | self._authority_public_key = None
72 | self._check_integrity = check_integrity
73 | self.pipeline_stage = None
74 | self.reachable = "unknown"
75 | self.monitor_public_key = None
76 |
77 | # Parse announce address
78 | address_parts = announce_maddrs.split("/")
79 | self._announce_ip_address = address_parts[2]
80 | self._announce_port = int(address_parts[4])
81 | self._host_port = host_port
82 |
83 | async def get_token(self) -> AccessToken:
84 | """Hivemind calls this method to refresh the access token when necessary."""
85 |
86 | try:
87 | self.join_experiment()
88 | return self._local_access_token
89 | except NotInAllowlistError as e:
90 | logger.error(f"Authorization failed: {e}. Exiting run.")
91 | os.killpg(os.getpgrp(), signal.SIGTERM)
92 | except BadRequestError as e:
93 | logger.error(f"Authorization failed: {e}. Exiting run.")
94 | os.killpg(os.getpgrp(), signal.SIGTERM)
95 | except IntegrityError:
96 | logger.error("Authorization failed: verification failed. Exiting run.")
97 | os.killpg(os.getpgrp(), signal.SIGTERM)
98 | except Exception as e:
99 | logger.error(f"Authorization failed: {e}. Exiting run.")
100 | os.killpg(os.getpgrp(), signal.SIGTERM)
101 |
102 | def join_experiment(
103 | self,
104 | reset_reachability: bool = False,
105 | initial_join: bool = False,
106 | n_retries: int = 10,
107 | ) -> None:
108 | """Join experiment with retries."""
109 | call_with_retries(
110 | partial(self._join_experiment, reset_reachability, initial_join), n_retries=n_retries, initial_delay=3
111 | )
112 |
113 | def _join_experiment(self, reset_reachability: bool = False, initial_join: bool = False) -> None:
114 | """Send authorization request to join the experiment and receive access token."""
115 | try:
116 | # Check integrity of files
117 | if self._check_integrity:
118 | try:
119 | integrity_hash = verify_integrity(self._current_path, self._local_private_key)
120 | except Exception:
121 | raise IntegrityError("Verification failed.") from None
122 | else:
123 | integrity_hash = b"hash"
124 |
125 | url = f"{self._auth_server}/api/join"
126 | headers = {
127 | "Authorization": f"Bearer {self._user_token}",
128 | "request-type": "initial" if initial_join else "update",
129 | }
130 | json_body = {
131 | "peer_id": self.peer_id,
132 | "role": self._role,
133 | "peer_public_key": self.local_public_key.to_bytes().decode(),
134 | "device": self._node_info.device_name,
135 | "gpu_memory": self._node_info.gpu_memory,
136 | "ram": self._node_info.ram,
137 | "download_speed": self._node_info.download_speed,
138 | "upload_speed": self._node_info.upload_speed,
139 | "latency": self._node_info.latency,
140 | "integrity_hash": integrity_hash.decode(),
141 | "reset_reachability": reset_reachability,
142 | "email": self._user_email,
143 | "announce_ip_address": self._announce_ip_address,
144 | "announce_port": self._announce_port,
145 | }
146 |
147 | if initial_join and self._check_integrity:
148 | with TestServer(port=self._host_port) as server:
149 | response = requests.put(
150 | url,
151 | headers=headers,
152 | json=json_body,
153 | )
154 |
155 | response.raise_for_status()
156 |
157 | # Receive server message
158 | if not server.get_message():
159 | if not server.wait_for_message(timeout=5):
160 | raise TestServerError(
161 | "Port test failed. Make sure your port forwarding is correct"
162 | ) from None
163 |
164 | # Verify message content
165 | if not server.verify_message():
166 | raise TestServerError(
167 | "Port test failed, wrong message received. Please wait for few minutes before trying to join again"
168 | ) from None
169 |
170 | else:
171 | response = requests.put(
172 | url,
173 | headers=headers,
174 | json=json_body,
175 | )
176 |
177 | response.raise_for_status()
178 |
179 | response = response.json()
180 |
181 | self._authority_public_key = RSAPublicKey.from_bytes(response["auth_server_public_key"].encode())
182 | self.monitor_public_key = str(response["monitor_public_key"])
183 | self.pipeline_stage = response["stage_type"]
184 | self.reachable = response["reachable"]
185 |
186 | access_token = AccessToken()
187 | access_token.username = response["username"]
188 | access_token.public_key = response["peer_public_key"].encode()
189 | access_token.expiration_time = str(datetime.fromisoformat(response["expiration_time"]))
190 | access_token.signature = response["signature"].encode()
191 | self._local_access_token = access_token
192 |
193 | logger.info(
194 | f"Access for user {access_token.username} has been granted until {access_token.expiration_time} UTC"
195 | )
196 | except requests.exceptions.HTTPError as e:
197 | if e.response.status_code in [401, 403, 429]: # Unauthorized, blacklisted, blocked or IP rate limited
198 | try:
199 | error_detail = e.response.json()["detail"]
200 | except Exception:
201 | error_detail = "Request is blocked"
202 | raise NotInAllowlistError(error_detail) from None
203 | if e.response.status_code in [400, 413, 418, 422, 424]: # wrong request information
204 | raise BadRequestError(e.response.json()["detail"]) from None
205 | if e.response.status_code == 503: # can't join due to rate limiting
206 | raise ServerUnavailableError(e.response.json()["detail"]) from None
207 | raise e
208 | except Exception as e:
209 | raise e
210 |
211 | def is_token_valid(self, access_token: AccessToken) -> bool:
212 | """Verify that token is valid."""
213 | data = self._token_to_bytes(access_token)
214 | if not self._authority_public_key or not self._authority_public_key.verify(data, access_token.signature):
215 | logger.error("Access token has invalid signature")
216 | return False
217 |
218 | try:
219 | expiration_time = datetime.fromisoformat(access_token.expiration_time)
220 | except ValueError:
221 | logger.error(f"datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}")
222 | return False
223 | if expiration_time < datetime.now(timezone.utc):
224 | logger.error("Access token has expired")
225 | return False
226 |
227 | return True
228 |
229 | _MAX_LATENCY = timedelta(minutes=1)
230 |
231 | def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
232 | """Check if token has expired."""
233 | expiration_time = datetime.fromisoformat(access_token.expiration_time)
234 | return expiration_time < datetime.now(timezone.utc) + self._MAX_LATENCY
235 |
236 | @staticmethod
237 | def _token_to_bytes(access_token: AccessToken) -> bytes:
238 | """Convert access token to bytes."""
239 | return f"{access_token.username} {access_token.public_key} {access_token.expiration_time}".encode()
240 |
241 |
242 | def save_identity(private_key: RSAPrivateKey, identity_path: str) -> None:
243 | """Save private key to file.
244 |
245 | Args:
246 | private_key (RSAPrivateKey): local private key
247 | identity_path (str): path to save the key
248 |
249 | Raises:
250 | FileNotFoundError: can't create file
251 | """
252 | protobuf = crypto_pb2.PrivateKey(key_type=crypto_pb2.KeyType.RSA, data=private_key.to_bytes())
253 |
254 | try:
255 | with open(identity_path, "wb") as f:
256 | f.write(protobuf.SerializeToString())
257 | except FileNotFoundError as e:
258 | raise FileNotFoundError(
259 | f"The directory `{os.path.dirname(identity_path)}` for saving the identity does not exist"
260 | ) from e
261 | os.chmod(identity_path, 0o400)
262 |
263 |
264 | def authorize_with_pluralis(
265 | node_info: NodeInfo,
266 | user_token: str,
267 | user_email: str,
268 | role: str,
269 | auth_server: str,
270 | identity_path: str,
271 | current_path: Path,
272 | announce_maddrs: str,
273 | host_port: int,
274 | check_integrity: bool = True,
275 | ) -> PluralisAuthorizer:
276 | """Generate local keys and send authorization request to join the run.
277 |
278 | Args:
279 | node_info (NodeInfo): information about the node
280 | user_token (str): authentication token
281 | user_email (str): email address
282 | role (str): role in the swarm
283 | auth_server (str): authorization server URL
284 | identity_path (str): path to save/load private key
285 | current_path (Path): path to src/node0
286 | announce_maddrs (str): announce address
287 | host_port: (int): host port
288 | check_integrity (bool): flag to check integrity
289 |
290 | Returns:
291 | PluralisAuthorizer: authorizer instance
292 | """
293 | logger.info("Authorization started...")
294 |
295 | # Generate private key or read from file
296 | local_private_key = RSAPrivateKey.process_wide()
297 |
298 | if os.path.exists(identity_path):
299 | with open(identity_path, "rb") as f:
300 | key_data = crypto_pb2.PrivateKey.FromString(f.read()).data
301 | private_key = serialization.load_der_private_key(key_data, password=None)
302 | if local_private_key._process_wide_key:
303 | local_private_key._process_wide_key._private_key = private_key
304 | else:
305 | logger.error("Failed to initialize process-wide private key")
306 | raise RuntimeError("Process-wide key is None")
307 | else:
308 | save_identity(local_private_key, identity_path)
309 |
310 | # Get static peer id
311 | with open(identity_path, "rb") as f:
312 | peer_id = str(PeerID.from_identity(f.read()))
313 |
314 | # Authorize
315 | authorizer = PluralisAuthorizer(
316 | peer_id,
317 | user_token,
318 | user_email,
319 | role,
320 | auth_server,
321 | node_info,
322 | current_path,
323 | announce_maddrs,
324 | host_port,
325 | check_integrity,
326 | )
327 |
328 | try:
329 | authorizer.join_experiment(reset_reachability=True, initial_join=True)
330 | logger.info("Authorization completed")
331 | return authorizer
332 | except NotInAllowlistError as e:
333 | logger.error(f"Authorization failed: {e}. Exiting run.")
334 | exit(1)
335 | except BadRequestError as e:
336 | logger.error(f"Authorization failed: {e}. Exiting run.")
337 | exit(1)
338 | except IntegrityError:
339 | logger.error("Authorization failed: verification failed. Exiting run.")
340 | exit(1)
341 | except Exception as e:
342 | logger.error(f"Authorization failed: {e}. Exiting run.")
343 | exit(1)
344 |
--------------------------------------------------------------------------------
/src/node0/server/power_sgd_averager.py:
--------------------------------------------------------------------------------
1 | # This file contains code originally from Hivemind under MIT License
2 | # Original: Copyright 2020 Learning@home authors and collaborators
3 | # Modified by: Pluralis Research 2025
4 | #
5 | # Original code: MIT License (see THIRD_PARTY_LICENSES)
6 | # Modifications: Apache 2.0 License (see LICENSE)
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only;
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | import asyncio
13 | import contextlib
14 |
15 | from enum import Enum
16 | from typing import Any, Iterable, Optional, Sequence
17 |
18 | import torch
19 |
20 | from hivemind.averaging.allreduce import AveragingMode
21 | from hivemind.averaging.group_info import GroupInfo
22 | from hivemind.averaging.load_balancing import load_balance_peers
23 | from hivemind.averaging.matchmaking import MatchmakingException
24 | from hivemind.compression import CompressionInfo, TensorRole
25 | from hivemind.dht import DHT
26 | from hivemind.utils import get_logger
27 | from hivemind.utils.asyncio import enter_asynchronously
28 | from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
29 |
30 | from node0.server.HM_gradient_averager import GradientAverager
31 |
32 |
33 | GatheredData = Any
34 | logger = get_logger(__name__)
35 |
36 | def qr_orthogonalize(matrix, iters=1):
37 | """QR orthogonalization in-place for 2D matrix."""
38 | for _ in range(iters):
39 | Q, _ = torch.linalg.qr(matrix)
40 | matrix.copy_(Q)
41 | return matrix
42 |
43 | class AllReducePhases(Enum):
44 | PHASE_P = 1
45 | PHASE_Q = 2
46 |
47 |
48 | class PowerSGDGradientAverager(GradientAverager):
49 | """
50 | A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
51 | For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
52 | Put simply, this method approximates large gradient tensors (m,n) with a product of two
53 | smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
54 |
55 | As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
56 | High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
57 | Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
58 |
59 | To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
60 | if some part of the gradient is "lost in compression", it will be added to the next iteration.
61 | This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
62 | and (b) if devices stay alive only for one step, training with small rank may converge slower.
63 | This is because error feedback takes multiple steps to kick in.
64 |
65 | Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
66 | If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
67 | normally, i.e. without powerSGD. See min_compression_ratio for details.
68 |
69 | :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
70 | matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
71 |
72 | :param parameters: pytorch parameters for which to aggregate gradients
73 | :param averager_rank: rank of compressed gradients
74 | :param dht: a DHT instance connected to the rest of the swarm. See hivemind.DHT docs
75 | :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
76 | :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
77 | This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
78 | :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
79 | device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
80 | the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
81 | :param client_mode: if False, this averager will accept incoming requests from other peers.
82 | if True, the averager will only join existing groups where at least one peer has client_mode=False.
83 | By default, this flag is copied from DHTNode inside the ``dht`` instance.
84 | :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
85 | :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor,
86 | otherwise aggregate tensors as is
87 | :param averaged_grads: if provided, it will be used as a set of averagable gradients
88 | """
89 |
90 | def __init__(
91 | self,
92 | parameters: Iterable[torch.nn.Parameter],
93 | averager_rank: int,
94 | *,
95 | dht: DHT,
96 | prefix: str,
97 | reuse_grad_buffers: bool = False,
98 | accumulate_grads_on: Optional[torch.device] = None,
99 | client_mode: bool = None,
100 | warn: bool = True,
101 | min_compression_ratio: float = 0.5,
102 | averaged_grads: Optional[Sequence[torch.Tensor]] = None,
103 | reset_buffers_every_k_steps: int = 10,
104 | **kwargs,
105 | ):
106 | self.rank = averager_rank
107 | self.parameters = tuple(parameters)
108 | self._uncompressed_gradients_indexes = set(
109 | i
110 | for i, grad in enumerate(self._grads_from_parameters())
111 | if grad.ndim <= 1
112 | or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
113 | # compute how much parameters are left after factorization
114 | )
115 | self._ms = [
116 | torch.zeros_like(grad, device="cpu").share_memory_()
117 | for idx, grad in enumerate(self._grads_from_parameters())
118 | if idx not in self._uncompressed_gradients_indexes
119 | ]
120 |
121 | self._ms_copy = [
122 | torch.zeros_like(grad, device="cpu").share_memory_()
123 | for idx, grad in enumerate(self._grads_from_parameters())
124 | if idx not in self._uncompressed_gradients_indexes
125 | ]
126 |
127 | self._qs = [
128 | torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
129 | for idx, grad in enumerate(self._grads_from_parameters())
130 | if idx not in self._uncompressed_gradients_indexes
131 | ]
132 |
133 | # Buffer reset tracking
134 | self.reset_buffers_every_k_steps = reset_buffers_every_k_steps
135 | self._step_count = 0
136 | self._last_successful_reset_step = 0
137 |
138 | super().__init__(
139 | self.parameters,
140 | dht=dht,
141 | prefix=prefix,
142 | reuse_grad_buffers=reuse_grad_buffers,
143 | accumulate_grads_on=accumulate_grads_on,
144 | client_mode=client_mode,
145 | warn=warn,
146 | averaged_grads=averaged_grads,
147 | **kwargs,
148 | )
149 |
150 | @contextlib.contextmanager
151 | def _register_allreduce_group(self, group_info: GroupInfo):
152 | """Register a given group for one or more all-reduce rounds"""
153 | try:
154 | for phase in list(AllReducePhases):
155 | self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
156 | self._pending_groups_registered.set()
157 | yield
158 | finally:
159 | for phase in list(AllReducePhases):
160 | maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
161 | if maybe_future and not maybe_future.done():
162 | logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
163 | self._pending_groups_registered.set()
164 |
165 | async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
166 | """Run aggregation in a given group and update tensors in place, return gathered metadata"""
167 | self._step_count += 1 # Increment buffer step count
168 | try:
169 | bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
170 | user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
171 | modes = tuple(map(AveragingMode, mode_ids))
172 |
173 | download_bandwidths = [
174 | thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
175 | ]
176 | peer_fractions = await asyncio.get_event_loop().run_in_executor(
177 | None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
178 | )
179 |
180 | async with enter_asynchronously(self.get_tensors()) as averaged_grads:
181 | averaged_grads_via_sgd = [
182 | grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
183 | ]
184 |
185 | err_norm = torch.nn.utils.get_total_norm(self._ms).item()
186 | logger.extra(f"Error norm val: {err_norm:.6f}")
187 |
188 | prepsgd_norm = torch.nn.utils.get_total_norm(averaged_grads_via_sgd).item()
189 | logger.extra(f"Prepsgd norm val: {prepsgd_norm:.6f}")
190 |
191 | # Adding noise to qs to prevent slow-down issues
192 | for q in self._qs:
193 | q.add_(torch.randn_like(q) * 1e-30)
194 |
195 | # Make a copy of _ms in case of fail
196 | for m, ms_copy in zip(self._ms, self._ms_copy):
197 | m.copy_(ms_copy)
198 |
199 | for grad, m in zip(averaged_grads_via_sgd, self._ms):
200 | m.add_(grad.to(m.device))
201 |
202 | ps = [
203 | torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
204 | for idx, grad in enumerate(averaged_grads_via_sgd)
205 | ]
206 | for p, q, m in zip(ps, self._qs, self._ms):
207 | # we use reshape for all matrixes because PowerSGD works only with 2d tensors
208 | torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
209 |
210 | p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
211 | q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
212 |
213 | await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
214 |
215 | for p in ps:
216 | p = qr_orthogonalize(p, iters=1)
217 |
218 | for p, q, m in zip(ps, self._qs, self._ms):
219 | torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
220 |
221 | # local error before allreduce on Q
222 | for p, q, m in zip(ps, self._qs, self._ms):
223 | new_m = torch.matmul(p, q.t()).reshape(m.size())
224 | m.sub_(new_m) # prev_err + grad - new_approx
225 |
226 | phase_q_tensors = self._qs + [
227 | grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
228 | ]
229 |
230 | await self._run_allreduce_inplace_(
231 | phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs
232 | )
233 |
234 | for p, q, ms_copy, grad, m in zip(ps, self._qs, self._ms_copy, averaged_grads_via_sgd, self._ms):
235 | new_m = torch.matmul(p, q.t()).reshape(ms_copy.size())
236 | grad.copy_(new_m)
237 | ms_copy.copy_(m)
238 |
239 | postpsgd_norm = torch.nn.utils.get_total_norm(averaged_grads_via_sgd).item()
240 | logger.extra(f"Postpsgd norm val: {postpsgd_norm:.6f}")
241 |
242 | return user_gathered
243 | except BaseException as e:
244 | logger.error(e, exc_info=logger.getEffectiveLevel() <= 15)
245 | raise MatchmakingException(f"Unable to run All-Reduce: {e}")
246 |
247 | def get_current_state(self):
248 | """
249 | Get current gradient averager state and when requested by a newbie peer.
250 | """
251 | with torch.no_grad(), self.lock_averaged_tensors:
252 | grad_averager_buffers = [q for q in self._qs]
253 | grad_averager_buffers_infos = [
254 | CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
255 | for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
256 | ]
257 |
258 | metadata = dict(group_bits=self.get_group_bits())
259 | return metadata, grad_averager_buffers, grad_averager_buffers_infos
260 |
261 | def load_state_from_peers(self, **kwargs):
262 | """
263 | Attempt to download the latest optimizer state from peers and update gradient averager buffers.
264 | :returns: whether or the averager succeeded in loading parameters
265 | """
266 | loaded_state = super().load_state_from_peers(**kwargs)
267 | if loaded_state is None:
268 | return
269 |
270 | metadata, flat_tensors = loaded_state
271 | logger.info("Starting loading gradient averager buffers from peers")
272 |
273 | if len(flat_tensors) != len(self._qs):
274 | logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
275 | return
276 |
277 | with torch.no_grad(), self.lock_averaged_tensors:
278 | for local_q, loaded_q in zip(self._qs, flat_tensors):
279 | local_q.copy_(loaded_q, non_blocking=True)
280 |
--------------------------------------------------------------------------------
/src/node0/server/HM_gradient_averager.py:
--------------------------------------------------------------------------------
1 | # This file contains code originally from Hivemind under MIT License
2 | # Original: Copyright 2020 Learning@home authors and collaborators
3 | # Modified by: Pluralis Research 2025
4 | #
5 | # Original code: MIT License (see THIRD_PARTY_LICENSES)
6 | # Modifications: Apache 2.0 License (see LICENSE)
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only;
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | import contextlib
13 |
14 | from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar
15 |
16 | import torch
17 |
18 | from hivemind.averaging.control import StepControl
19 | from hivemind.dht import DHT
20 | from hivemind.utils import DHTExpiration, get_logger
21 |
22 | from node0.server.HM_averager import DecentralizedAverager
23 |
24 |
25 | logger = get_logger(__name__)
26 |
27 |
28 | TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
29 | GradientAveragerFactory = Callable[..., TGradientAverager]
30 |
31 |
32 | class GradientAverager(DecentralizedAverager):
33 | """
34 | An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
35 | GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
36 |
37 | GradientAverager manages three sets of buffers:
38 | (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
39 | These tensors are typically stored on device and updated by torch autograd
40 | (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
41 | - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
42 | which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
43 | (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
44 |
45 | :param parameters: pytorch parameters for which to aggregate gradients
46 | :param dht: a DHT instance connected to the rest of the swarm. See hivemind.DHT docs
47 | :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
48 | :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
49 | This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
50 | :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
51 | device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
52 | the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
53 | :param client_mode: if False, this averager will accept incoming requests from other peers.
54 | if True, the averager will only join existing groups where at least one peer has client_mode=False.
55 | By default, this flag is copied from DHTNode inside the ``dht`` instance.
56 | :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
57 | :param averaged_grads: if provided, it will be used as a set of averagable gradients
58 | :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
59 |
60 |
61 | Example:
62 |
63 | >>> model = SuchModelMuchLayers()
64 | >>> opt = torch.optim.Adam(model.parameters())
65 | >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
66 | >>> next_step_time = hivemind.get_dht_time() + 60 # runs global steps every 60 seconds
67 | >>> next_step_control = None
68 | >>> while True:
69 | >>> # accumulate as many gradients as you can before next_step_time
70 | >>> loss = compute_loss(model, batch_size=32)
71 | >>> loss.backward()
72 | >>> grad_averager.accumulate_grads_(batch_size=32)
73 | >>> # [optional] next step in 5 seconds, start looking for peers in advance
74 | >>> if next_step_time - hivemind.get_dht_time() <= 5
75 | >>> next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
76 | >>> # aggregate gradients and perform optimizer step
77 | >>> if hivemind.get_dht_time() >= next_step_time:
78 | >>> grad_averager.step(control=next_step_control)
79 | >>> with grad_averager.use_averaged_gradients(): # this will fill param.grads with aggregated gradients
80 | >>> opt.step() # update model parameters using averaged gradients
81 | >>> grad_averager.reset_accumulated_grads_() # prepare for next step
82 | >>> next_step_time = hivemind.get_dht_time() + 60
83 | >>> next_step_control = None
84 |
85 | """
86 |
87 | def __init__(
88 | self,
89 | parameters: Iterable[torch.nn.Parameter],
90 | *,
91 | dht: DHT,
92 | prefix: str,
93 | reuse_grad_buffers: bool = False,
94 | accumulate_grads_on: Optional[torch.device] = None,
95 | client_mode: bool = None,
96 | warn: bool = True,
97 | averaged_grads: Sequence[torch.Tensor] = (),
98 | **kwargs,
99 | ):
100 | if reuse_grad_buffers and accumulate_grads_on is not None:
101 | logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
102 | client_mode = client_mode if client_mode is not None else dht.client_mode
103 | self.parameters = tuple(parameters)
104 | self.reuse_grad_buffers = reuse_grad_buffers
105 | self.warn = warn
106 | self.local_samples_accumulated = 0
107 | self.local_times_accumulated = 0
108 | self._anchor_batch_size = None
109 | self._local_accumulators = None
110 | self.processed_batches = torch.tensor(0.0, device="cpu").share_memory_()
111 | if not reuse_grad_buffers:
112 | self._local_accumulators = tuple(
113 | torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
114 | )
115 | self._accumulators_used_in_step = False
116 | self._new_averaged_grads = False
117 |
118 | with torch.no_grad():
119 | if not averaged_grads:
120 | averaged_grads = tuple(
121 | grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
122 | )
123 | else:
124 | if any(
125 | param_grad.size() != grad.size()
126 | for param_grad, grad in zip(self._grads_from_parameters(), averaged_grads)
127 | ):
128 | raise ValueError("Averaged gradients don't have same shape as gradients from parameters")
129 | super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
130 |
131 | def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
132 | """gradient buffers associated with parameters"""
133 | for param in self.parameters:
134 | if param.grad is None:
135 | param.grad = torch.zeros_like(param)
136 | yield param.grad
137 |
138 | @torch.no_grad()
139 | def _grad_accumulators(self) -> Iterator[torch.Tensor]:
140 | """averager-based gradient accumulators"""
141 | assert (self._local_accumulators is None) == self.reuse_grad_buffers
142 | yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
143 |
144 | @torch.no_grad()
145 | def accumulate_grads_(self, batch_size: int):
146 | """add current gradients to local grad accumulators (if used)"""
147 | if self._accumulators_used_in_step and self.warn:
148 | logger.warning(
149 | "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
150 | "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)"
151 | )
152 | self._accumulators_used_in_step = False # warn once per round
153 | if self._anchor_batch_size is None:
154 | # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
155 | self._anchor_batch_size = batch_size
156 | self.local_samples_accumulated += batch_size
157 | self.local_times_accumulated += 1
158 | if self.reuse_grad_buffers:
159 | pass # user is responsible for accumulating gradients in .grad buffers
160 | else:
161 | alpha = float(batch_size) / self._anchor_batch_size
162 | for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
163 | grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
164 |
165 | def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
166 | """
167 | Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
168 |
169 | :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
170 | :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
171 | :note: setting weight at this stage is not supported, please leave this parameter as None
172 | :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
173 | :note: in the current implementation, each step_control can only be used in one step.
174 | """
175 | assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
176 | return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
177 |
178 | def step(
179 | self,
180 | weight: Optional[float] = None,
181 | reset_accumulators: bool = True,
182 | control: Optional[StepControl] = None,
183 | timeout: Optional[float] = None,
184 | wait: bool = True,
185 | **kwargs,
186 | ):
187 | """
188 | Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
189 |
190 | :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
191 | :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
192 | :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
193 | :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
194 | :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
195 | """
196 | if control is None:
197 | control = self.schedule_step(timeout=timeout, **kwargs)
198 | elif len(kwargs) > 0:
199 | raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect")
200 | assert not control.triggered, f"This {type(control)} instance was already used"
201 | if self._new_averaged_grads and self.warn:
202 | logger.warning(
203 | "[warn=True] Starting new averaging round, but previous round results were not used. "
204 | "This may be a sign of incorrect optimizer behavior"
205 | )
206 |
207 | self.load_accumulators_into_averager_()
208 | self._accumulators_used_in_step = True
209 | self._new_averaged_grads = True
210 |
211 | control.weight = self.local_samples_accumulated if weight is None else weight
212 | if reset_accumulators:
213 | self.reset_accumulated_grads_()
214 | control.allow_allreduce()
215 |
216 | return control.result(timeout) if wait else control
217 |
218 | @torch.no_grad()
219 | def load_accumulators_into_averager_(self):
220 | """load locally accumulated gradients into the averager for aggregation"""
221 | # divide locally accumulated gradients by the number of times they were accumulated
222 | grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
223 | with self.get_tensors() as averaged_grads:
224 | for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
225 | averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
226 |
227 | @torch.no_grad()
228 | def reset_accumulated_grads_(self):
229 | """reset averager-internal gradient accumulators and the denominator"""
230 | self._accumulators_used_in_step = False
231 | self.local_samples_accumulated = self.local_times_accumulated = 0
232 | self._anchor_batch_size = None
233 | for grad_buf in self._grad_accumulators():
234 | grad_buf.zero_()
235 |
236 | @contextlib.contextmanager
237 | @torch.no_grad()
238 | def use_averaged_gradients(self):
239 | """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
240 | self._new_averaged_grads = False
241 | with self.get_tensors() as averaged_grads:
242 | assert len(averaged_grads) == len(self.parameters)
243 | try:
244 | old_grads = [param.grad for param in self.parameters]
245 | for param, new_grad in zip(self.parameters, averaged_grads):
246 | param.grad = new_grad
247 | yield averaged_grads
248 | finally:
249 | for param, old_grad in zip(self.parameters, old_grads):
250 | param.grad = old_grad
251 |
252 | def notify_used_averaged_gradients(self):
253 | """Notify averager that the results of a previous averaging round are accounted for"""
254 | self._new_averaged_grads = False
255 |
256 | def has_nan_grads(self) -> bool:
257 | """Check if any gradient contains NaN or inf values"""
258 | return any(not torch.isfinite(grad).all() for grad in self._grads_from_parameters())
259 |
--------------------------------------------------------------------------------
/src/node0/server/node0_server.py:
--------------------------------------------------------------------------------
1 | # This file contains code originally from Hivemind under MIT License
2 | # Original: Copyright 2020 Learning@home authors and collaborators
3 | # Modified by: Pluralis Research 2025
4 | #
5 | # Original code: MIT License (see THIRD_PARTY_LICENSES)
6 | # Modifications: Apache 2.0 License (see LICENSE)
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only;
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | import multiprocessing as mp
13 |
14 | from functools import partial
15 | from typing import Type
16 |
17 | import torch
18 |
19 | from hivemind.compression import CompressionBase, NoCompression
20 | from hivemind.dht import DHT
21 | from hivemind.moe.expert_uid import UID_DELIMITER
22 | from hivemind.moe.server import Server
23 | from hivemind.moe.server.layers import (
24 | add_custom_models_from_file,
25 | name_to_block,
26 | name_to_input,
27 | )
28 | from hivemind.moe.server.server import _generate_uids
29 | from hivemind.optim import Optimizer
30 | from hivemind.proto.runtime_pb2 import CompressionType
31 | from hivemind.utils.logging import get_logger
32 | from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor
33 |
34 | from node0.models.arguments import ModelArguments
35 | from node0.models.lr_schedule import schedule_name_to_scheduler
36 | from node0.server.HM_gradient_averager import GradientAveragerFactory
37 | from node0.server.module_collab import ModuleCollab
38 | from node0.utils import MonitorWorker
39 | from node0.utils.common import load_ss_components
40 | from node0.utils.dht_monitor import patch_dht_protocol_logging
41 | from node0.utils.get_parameters import get_parameter_store
42 |
43 |
44 | logger = get_logger(__name__)
45 |
46 |
47 | def get_ss_dim(expert_cls: str, model_conf: ModelArguments) -> int:
48 | """Utility function to get the input dimension to a stage."""
49 | if "head" in expert_cls:
50 | input_dim = model_conf.hidden_dim
51 | else:
52 | compression_length = int(model_conf.hidden_dim // model_conf.compression_rate)
53 | input_dim = compression_length + 1 # +1 for tokens
54 | return input_dim
55 |
56 |
57 | class Node0Server(Server):
58 | def __init__(self, optim, *args, **kwargs):
59 | super().__init__(*args, **kwargs)
60 |
61 | optim._start_monitor()
62 |
63 | @classmethod
64 | def create(
65 | cls,
66 | num_experts: int = 1,
67 | expert_uids: str = None,
68 | expert_pattern: str = None,
69 | expert_cls: str = "lm_body",
70 | model_conf: ModelArguments = None,
71 | optim_cls: Type[torch.optim.Optimizer] = torch.optim.AdamW,
72 | scheduler: str = "none",
73 | num_warmup_steps: int | None = None,
74 | num_training_steps: int | None = None,
75 | clip_grad_norm: float | None = None,
76 | weight_decay: float | None = None,
77 | num_stages: int | None = None,
78 | num_handlers: int | None = None,
79 | min_batch_size: int = 1,
80 | max_batch_size: int = 4096,
81 | averaging_target_batch_size: int = 256,
82 | reuse_grad_buffers: bool = False,
83 | use_local_updates: bool = False,
84 | use_offloading: bool = False,
85 | matchmaking_time: float = 5.0,
86 | averaging_timeout: float = 10.0,
87 | request_timeout: float = 3.0,
88 | next_chunk_timeout: float = 10.0,
89 | load_state_timeout: float = 600,
90 | average_state_every: int = 1,
91 | sparse_avg: float = 0.0,
92 | max_allowed_stale: int = 0,
93 | grad_avg_factory: GradientAveragerFactory | None = None,
94 | optim_collab_cls: Type[torch.optim.Optimizer] = Optimizer,
95 | grad_averaging_compression: CompressionBase = NoCompression(),
96 | load_state_compression: CompressionBase = NoCompression(),
97 | device: str | None = None,
98 | initial_peers: list = [],
99 | compression: CompressionType = CompressionType.NONE,
100 | stats_report_interval: int = 60,
101 | custom_module_path: str | None = None,
102 | update_period: float = 30,
103 | expiration: float | None = None,
104 | monitor: MonitorWorker | None = None,
105 | upload_bw: float | None = None,
106 | *,
107 | start: bool,
108 | **kwargs,
109 | ) -> Server:
110 | """Instantiate a server for collaborative optimization.
111 |
112 | Args:
113 | start (bool): if True, starts the server right away
114 | num_experts (int, optional): run this many identical experts. Defaults to 1.
115 | expert_uids (str, optional): spawn experts with these exact uids, overrides num_experts and expert_pattern. Defaults to None.
116 | expert_pattern (str, optional): a string pattern for experts uids. Defaults to None.
117 | expert_cls (str, optional): expert type. Defaults to "lm_body".
118 | model_conf (BaseModel, optional): model config class. Defaults to None.
119 | optim_cls (Type[torch.optim.Optimizer], optional): optimizer class. Defaults to torch.optim.AdamW.
120 | scheduler (str, optional): if not `none`, the name of the expert LR scheduler. Defaults to "none".
121 | num_warmup_steps (int | None, optional): the number of warmup steps for LR scheduler. Defaults to None.
122 | num_training_steps (int | None, optional): the total number of steps for LR scheduler. Defaults to None.
123 | clip_grad_norm (float | None, optional): maximum gradient norm used for clipping. Defaults to None.
124 | num_handlers (int | None, optional): server will use this many parallel processes to handle incoming requests. Defaults to None.
125 | min_batch_size (int, optional): total num examples in the same batch will be greater than this value. Defaults to 1.
126 | max_batch_size (int, optional): total num examples in the same batch will not exceed this value. Defaults to 4096.
127 | averaging_target_batch_size (int): number of examples to accumulate across all peers before averaging. Defaults to 256.
128 | reuse_grad_buffers (bool, optional): if True, use model's .grad buffers for gradient accumulation. Defaults to False.
129 | use_local_updates (bool, optional): whether each node performs local weights updates between the allreduce. Defaults to False.
130 | use_offloading (bool, optional): perform gradient offloading. Defaults to False.
131 | matchmaking_time (float, optional): time window for nodes to find other nodes for allreduce. Defaults to 5.0.
132 | averaging_timeout (float, optional): timeout for nodes to perform the allreduce. Defaults to 10.0.
133 | optim_collab_cls (Type[torch.optim.Optimizer], optional): collaborative optimizer class. Defaults to Optimizer.
134 | device (str | None, optional): cuda or cpu. Defaults to None.
135 | initial_peers (list, optional): multiaddrs of one or more active DHT peers. Defaults to [].
136 | compression (CompressionType, optional): compression type. Defaults to CompressionType.NONE.
137 | stats_report_interval (int | None, optional): interval between two reports of batch processing performance statistics. Defaults to None.
138 | custom_module_path (str | None, optional): path of a file with custom nn.modules. Defaults to None.
139 | update_period (float, optional): server will report experts to DHT once in this many seconds. Defaults to 30.
140 | expiration (float | None, optional): DHT entries will expire after this many seconds. Defaults to None.
141 | monitor (MonitorWorker | None, optional): monitor instance. Defaults to None.
142 | upload_bw (float | None, optional): upload bandwidth. Defaults to None.
143 | kwargs: any other params will be forwarded to DHT upon creation
144 |
145 | Returns:
146 | Server: collaborative training server
147 | """
148 | # Add custom layers
149 | if custom_module_path is not None:
150 | add_custom_models_from_file(custom_module_path)
151 |
152 | try:
153 | assert expert_cls in name_to_block
154 | except:
155 | logger.error(
156 | f"Expert class {expert_cls} is not supported. Make sure you provided correct custom_module_path: {custom_module_path}"
157 | )
158 | raise
159 |
160 | # Connect to DHT
161 | _ = patch_dht_protocol_logging()
162 | dht = DHT(initial_peers=initial_peers, start=True, startup_timeout=30, **kwargs)
163 | visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
164 | logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
165 |
166 | # Connect to monitor
167 | if monitor is not None:
168 | monitor.connect_dht(dht)
169 |
170 | # Generate uids
171 | try:
172 | assert (expert_pattern is None and num_experts is None and expert_uids is not None) or (
173 | num_experts is not None and expert_uids is None
174 | )
175 | except:
176 | logger.error(
177 | "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both)"
178 | )
179 | raise
180 |
181 | if expert_uids is None:
182 | expert_uids = []
183 |
184 | uids_to_generate = num_experts - len(expert_uids)
185 | if uids_to_generate > 0:
186 | logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}")
187 | expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht))
188 |
189 | # Get parameter store
190 | stage = expert_uids[0].split(".")[0]
191 | (
192 | averaging_target_batch_size,
193 | scheduler,
194 | num_warmup_steps,
195 | num_training_steps,
196 | averaging_timeout,
197 | matchmaking_time,
198 | request_timeout,
199 | load_state_timeout,
200 | ) = get_parameter_store(dht, stage)
201 |
202 | num_experts = len(expert_uids)
203 | num_handlers = num_handlers if num_handlers is not None else num_experts * 8
204 | device = device or (
205 | "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
206 | )
207 |
208 | # Scheduler
209 | scheduler_cls = schedule_name_to_scheduler[scheduler]
210 | if scheduler_cls is not None:
211 | scheduler_cls = partial(
212 | scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
213 | )
214 |
215 | # Initialize experts
216 | input_dim = model_conf.hidden_dim if not model_conf.use_compression else get_ss_dim(expert_cls, model_conf)
217 | sequence_length = model_conf.max_seq_len
218 | sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, sequence_length, input_dim)
219 | if isinstance(sample_input, tuple):
220 | args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input)
221 | else:
222 | args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),)
223 |
224 | experts = {}
225 | for expert_uid in expert_uids:
226 | expert = name_to_block[expert_cls](model_conf)
227 | if monitor:
228 | active_period_timeout = (
229 | max(load_state_timeout, averaging_timeout + matchmaking_time + request_timeout) + 10
230 | )
231 | monitor.monitor_callback(expert, model_conf, active_period_timeout)
232 |
233 | assert averaging_target_batch_size is not None
234 |
235 | if model_conf.use_compression:
236 | exclude_params = ["rcv", "fixed_tok_embeddings"]
237 | else:
238 | exclude_params = []
239 | if weight_decay:
240 | no_decay = ["tok_embeddings.weight"]
241 | params = [
242 | {
243 | "params": [
244 | p
245 | for n, p in expert.named_parameters()
246 | if not any(nd in n for nd in no_decay) and not any(ex in n for ex in exclude_params)
247 | ],
248 | "weight_decay": weight_decay,
249 | },
250 | {
251 | "params": [
252 | p
253 | for n, p in expert.named_parameters()
254 | if any(nd in n for nd in no_decay) and not any(ex in n for ex in exclude_params)
255 | ],
256 | "weight_decay": 0.0,
257 | },
258 | ]
259 | else:
260 | params = [p for n, p in expert.named_parameters() if not any(ex in n for ex in exclude_params)]
261 |
262 | if model_conf.use_compression:
263 | ss_comps = load_ss_components(model_conf.ss_component)
264 | expert.load_comp(ss_comps)
265 | logger.info("Succeded loading remote subspace components")
266 | expert.ss_regularize()
267 |
268 | optimizer_lock = mp.Lock()
269 |
270 | backend = ModuleCollab(
271 | optimizer_lock=optimizer_lock,
272 | name=expert_uid,
273 | module=expert,
274 | args_schema=args_schema,
275 | min_batch_size=min_batch_size,
276 | max_batch_size=max_batch_size,
277 | )
278 |
279 | backend.module.to(device)
280 |
281 | optim_collab = optim_collab_cls(
282 | model=expert,
283 | optimizer_lock=optimizer_lock,
284 | sparse_avg=sparse_avg,
285 | max_allowed_stale=max_allowed_stale,
286 | optimizer=optim_cls,
287 | params=params,
288 | dht=dht,
289 | run_id=expert_uid.split(UID_DELIMITER)[0],
290 | scheduler=scheduler_cls,
291 | target_batch_size=averaging_target_batch_size,
292 | matchmaking_time=matchmaking_time,
293 | averaging_timeout=averaging_timeout,
294 | load_state_timeout=load_state_timeout,
295 | average_state_every=average_state_every,
296 | grad_averager_factory=grad_avg_factory,
297 | grad_compression=grad_averaging_compression,
298 | reuse_grad_buffers=reuse_grad_buffers,
299 | use_local_updates=use_local_updates,
300 | offload_optimizer=use_offloading,
301 | delay_state_averaging=False,
302 | next_chunk_timeout=next_chunk_timeout,
303 | verbose=True,
304 | averager_opts={"bandwidth": upload_bw, "request_timeout": request_timeout},
305 | )
306 | optim_collab.load_state_from_peers(wait_for_end_round=True)
307 | backend.optimizer = optim_collab
308 | experts[expert_uid] = backend
309 |
310 | return cls(
311 | optim_collab,
312 | dht,
313 | experts,
314 | num_connection_handlers=num_handlers,
315 | device=device,
316 | checkpoint_dir=None,
317 | stats_report_interval=stats_report_interval,
318 | update_period=update_period,
319 | expiration=expiration,
320 | start=start,
321 | )
322 |
--------------------------------------------------------------------------------
/src/node0/server/optim.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Pluralis Research
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import multiprocessing as mp
16 | import os
17 | import signal
18 | import threading
19 | import time
20 |
21 | from typing import Callable, Optional
22 |
23 | import torch
24 |
25 | from hivemind import Optimizer
26 | from hivemind.optim.grad_scaler import GradScaler
27 | from hivemind.utils import get_dht_time, get_logger
28 |
29 | from node0.server.HM_gradient_averager import GradientAverager, GradientAveragerFactory
30 | from node0.server.state_averager_wrap import TrainingStateAverager
31 |
32 |
33 | logger = get_logger(__name__)
34 |
35 |
36 | class AutoStepOptimizer(Optimizer):
37 | """
38 | A class that extends Hivemind Optimizer and ensures step() is called at least once per auto_step_time.
39 | If step() hasn't been called externally within the auto_step_time window, it will be called automatically.
40 | """
41 |
42 | def __init__(
43 | self,
44 | model,
45 | optimizer_lock,
46 | sparse_avg: float = 0.0,
47 | auto_step_time: float = 3.0,
48 | max_allowed_stale: int = 0,
49 | grad_schedule_buffer: float = 5.0,
50 | *args,
51 | **kwargs,
52 | ):
53 | super().__init__(*args, **kwargs)
54 | self._auto_step_time = auto_step_time
55 | self.model = model
56 | self.max_allowed_stale = max_allowed_stale
57 | self.grad_schedule_buffer = grad_schedule_buffer
58 | self.optimizer_lock = optimizer_lock
59 | self._last_step_time: float = time.time()
60 | self._step_lock = mp.Lock()
61 | self._monitor_thread: Optional[threading.Thread] = None
62 | self._should_stop = threading.Event()
63 | self.in_update = False
64 |
65 | # Set state avg parameters
66 | self.state_averager.average_state_every = self.average_state_every
67 | self.state_averager.sparse_avg = sparse_avg
68 | if sparse_avg:
69 | self.state_averager.set_sparta_partitions()
70 | self.state_averager.run_in_background(await_ready=True)
71 |
72 | def _resync_state(self):
73 | if self._should_load_state_from_peers():
74 | logger.log(self.status_loglevel, "Peer is out of sync")
75 | self.load_state_from_peers()
76 | return True # local gradients were computed with out-of-sync parameters, must start over
77 | elif self._catchup_epoch():
78 | with self.tracker.pause_updates():
79 | logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
80 | self.state_averager.local_epoch = self.tracker.global_epoch
81 | self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
82 | return False
83 |
84 | def _should_load_state_from_peers(self) -> bool:
85 | return self.local_epoch < (self.tracker.global_epoch - self.max_allowed_stale)
86 |
87 | def _catchup_epoch(self) -> bool:
88 | return (self.tracker.global_epoch - self.max_allowed_stale) <= self.local_epoch < self.tracker.global_epoch
89 |
90 | def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
91 | return TrainingStateAverager(
92 | dht=self.dht,
93 | prefix=f"{self.run_id}_state_averager",
94 | min_matchmaking_time=self.matchmaking_time,
95 | allreduce_timeout=self.allreduce_timeout,
96 | shutdown_timeout=self.shutdown_timeout,
97 | offload_optimizer=self.offload_optimizer,
98 | custom_gradients=self.offload_optimizer,
99 | status_loglevel=self.status_loglevel,
100 | next_chunk_timeout=self.next_chunk_timeout,
101 | client_mode=self.client_mode,
102 | auxiliary=self.auxiliary,
103 | allow_state_sharing=False,
104 | start=True,
105 | **kwargs,
106 | )
107 |
108 | def _make_gradient_averager(self, factory: Optional[GradientAveragerFactory], **kwargs) -> GradientAverager:
109 | assert hasattr(self, "state_averager"), "must initialize state averager first"
110 | factory = factory if factory is not None else GradientAverager
111 | grad_averager = factory(
112 | dht=self.dht,
113 | prefix=f"{self.run_id}_grad_averager",
114 | parameters=self.state_averager.main_parameters,
115 | min_matchmaking_time=self.matchmaking_time,
116 | allreduce_timeout=self.allreduce_timeout,
117 | shutdown_timeout=self.shutdown_timeout,
118 | next_chunk_timeout=self.next_chunk_timeout,
119 | client_mode=self.client_mode,
120 | auxiliary=self.auxiliary,
121 | allow_state_sharing=False,
122 | start=True,
123 | **kwargs,
124 | )
125 | if self.offload_optimizer:
126 | optimized_param_groups = self.state_averager.optimizer.param_groups
127 | optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
128 | with grad_averager.get_tensors() as averaged_gradients:
129 | assert len(averaged_gradients) == len(optimized_parameters)
130 | for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
131 | opt_param.grad = averaged_grad
132 | return grad_averager
133 |
134 | def _start_monitor(self) -> None:
135 | """Start the monitoring thread if it's not already running."""
136 | if self._monitor_thread is None or not self._monitor_thread.is_alive():
137 | self._should_stop.clear()
138 | self._monitor_thread = threading.Thread(target=self._monitor_step, daemon=True)
139 | self._monitor_thread.start()
140 |
141 | def _monitor_step(self) -> None:
142 | """Monitor thread that checks and calls step() if it wasn't called within auto_step_time window."""
143 | while not self._should_stop.is_set():
144 | time.sleep(1.0) # Check every 1s
145 |
146 | current_time = time.time()
147 | time_since_last_step = current_time - self._last_step_time
148 |
149 | if time_since_last_step >= self._auto_step_time:
150 | with self._step_lock:
151 | # Check again after acquiring lock in case step() was called
152 | if time.time() - self._last_step_time >= self._auto_step_time:
153 | self._auto_step()
154 |
155 | def _auto_step(self) -> None:
156 | """Internal method to call step() automatically."""
157 | try:
158 | # Call the parent class's step method with one batch size
159 | logger.debug(f"AutoStepOptimizer step at {time.strftime('%H:%M:%S')}")
160 | self._last_step_time = time.time()
161 | # self._check_update_version()
162 |
163 | if self._resync_state():
164 | return None
165 |
166 | self._maybe_schedule_gradient_averaging()
167 |
168 | if self.in_update and self.tracker.ready_to_update_epoch:
169 | batch_size = 1
170 | self._step(batch_size=batch_size)
171 |
172 | # self.state_averager.allow_state_sharing = True
173 | except Exception as e:
174 | logger.error(f"Error in auto step: {e}")
175 |
176 | def step(self, batch_size: Optional[int] = None) -> None:
177 | """
178 | Override of the step method that updates the last step time.
179 | This should be called by external code.
180 | """
181 | with self._step_lock:
182 | # self._check_update_version()
183 |
184 | if self._resync_state():
185 | return None
186 |
187 | self._step(batch_size=batch_size)
188 | # self.state_averager.allow_state_sharing = True
189 |
190 | self._last_step_time = time.time()
191 |
192 | def stop_monitoring(self) -> None:
193 | """Stop the monitoring thread."""
194 | self._should_stop.set()
195 | if self._monitor_thread is not None:
196 | self._monitor_thread.join(timeout=1)
197 | self._monitor_thread = None
198 |
199 | def __del__(self):
200 | """Ensure the monitoring thread is stopped when the object is destroyed."""
201 | self.stop_monitoring()
202 |
203 | @property
204 | def ready_to_update_epoch(self) -> bool:
205 | """Whether or not this peer can increment epoch right away."""
206 | return (
207 | self.tracker.global_epoch > self.tracker.local_progress.epoch
208 | or self.tracker.global_progress.samples_accumulated >= self.tracker.target_batch_size
209 | )
210 |
211 | def _check_and_accumulate_gradients(self, batch_size: int) -> bool:
212 | """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
213 | if self.grad_averager.has_nan_grads():
214 | self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
215 | logger.error("Encountered incorrect value in grads, exiting run")
216 | os.killpg(os.getpgrp(), signal.SIGTERM)
217 |
218 | self.grad_averager.accumulate_grads_(batch_size)
219 | self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
220 | return True
221 |
222 | def _maybe_schedule_gradient_averaging(self) -> None:
223 | """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
224 | assert self.use_gradient_averaging
225 | if not self.in_update and self.ready_to_update_epoch:
226 | if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
227 | self.in_update = True
228 | eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
229 | logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
230 | self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
231 |
232 | def _step(
233 | self,
234 | closure: Optional[Callable[[], torch.Tensor]] = None,
235 | batch_size: Optional[int] = None,
236 | grad_scaler: Optional[GradScaler] = None,
237 | ):
238 | """
239 | Update training progress after accumulating another local batch size. Depending on the configuration, this will
240 | report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
241 |
242 | :param closure: A closure that reevaluates the model and returns the loss.
243 | :param batch_size: optional override for batch_size_per_step from init.
244 | :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
245 | """
246 | if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
247 | raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
248 | if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
249 | raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
250 | if self.auxiliary and (closure is not None or batch_size is not None):
251 | raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
252 | batch_size = batch_size if batch_size is not None else self.batch_size_per_step
253 |
254 | # if delayed updates finished before step, apply these updates; otherwise do nothing
255 | self.state_averager.step(apply_delayed_updates=True)
256 |
257 | loss = None
258 | if closure is not None:
259 | with torch.enable_grad():
260 | loss = closure()
261 |
262 | # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
263 | self._check_and_accumulate_gradients(batch_size)
264 |
265 | self._maybe_schedule_gradient_averaging()
266 | # self._maybe_schedule_state_averaging()
267 |
268 | if self.in_update and self.tracker.ready_to_update_epoch:
269 | self.state_averager.allow_state_sharing = False # Prevent state sharing during AR.
270 | self.grad_averager.processed_batches.copy_(float(self.tracker.local_progress.samples_accumulated))
271 | self.state_averager.global_epoch = self.tracker.global_epoch
272 | self.grad_averager.global_epoch = self.tracker.global_epoch
273 | with self.optimizer_lock:
274 | self._update_global_epoch(grad_scaler)
275 |
276 | if self.model.model_args.use_compression:
277 | self.model.ss_regularize()
278 |
279 | self.in_update = False
280 | return loss
281 |
282 | def load_state_from_peers(self, wait_for_end_round=False, **kwargs):
283 | # Wait for a while grad accumulation round before requesting state from peers
284 | logger.info(f"Waiting for peers in stage to finish step before joining")
285 | while True:
286 | if (
287 | self.tracker.fetched_global_progress_this_epoch.is_set()
288 | and self.tracker.global_progress.samples_accumulated < self.target_batch_size * 0.1
289 | ):
290 | break
291 | else:
292 | logger.info(f"Waiting for peers in stage to finish step before joining")
293 | time.sleep(self.tracker.max_refresh_period)
294 |
295 | self._load_state_from_peers(**kwargs)
296 |
297 | if wait_for_end_round:
298 | self.tracker.fetched_global_progress_this_epoch.clear()
299 | while True:
300 | if (
301 | self.tracker.fetched_global_progress_this_epoch.is_set()
302 | and self.tracker.global_progress.samples_accumulated < self.target_batch_size * 0.2
303 | ):
304 | break
305 | else:
306 | logger.info(f"Downloaded state, waiting for start of new round")
307 | time.sleep(0.5)
308 | logger.info(f"Joining run")
309 |
310 | def _load_state_from_peers(self, **kwargs):
311 | """
312 | Attempt to load the newest collaboration state from other peers within the same run_id.
313 |
314 | If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
315 | """
316 | # note: we tag along for the next all-reduce because the run may have already started and cancelling it
317 | # will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds.
318 | if self.scheduled_grads is not None and not self.scheduled_grads.done():
319 | self._tag_along_with_zero_weight(self.scheduled_grads)
320 | self.scheduled_grads = None
321 |
322 | with self.tracker.pause_updates():
323 | while True:
324 | try:
325 | self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
326 | if self.grad_averager is not None:
327 | self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
328 | break
329 | except KeyboardInterrupt:
330 | raise
331 | except BaseException as e:
332 | logger.error(
333 | f"Failed to load state from peers: {e}, retrying ...",
334 | exc_info=logger.getEffectiveLevel() <= 15,
335 | )
336 | continue
337 |
338 | if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
339 | logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
340 | self.state_averager.local_epoch = self.tracker.global_epoch
341 |
342 | self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
343 |
344 | if not self.client_mode:
345 | self.state_averager.state_sharing_priority = self.local_epoch
346 |
347 | if self.use_gradient_averaging:
348 | self.grad_averager.reset_accumulated_grads_()
349 | if not self.client_mode:
350 | self.grad_averager.state_sharing_priority = self.local_epoch
351 |
352 | if hasattr(self.state_averager, "zclip") and self.state_averager.zclip.var.item() > 0.0:
353 | self.state_averager.zclip.initialized = True
354 |
--------------------------------------------------------------------------------