├── src └── node0 │ ├── __init__.py │ ├── security │ ├── integrity_check.cpython-311-aarch64-linux-gnu.so │ ├── integrity_check.cpython-311-x86_64-linux-gnu.so │ ├── validation.py │ └── authorization.py │ ├── utils │ ├── __init__.py │ ├── flops.py │ ├── dht_monitor.py │ ├── dht_partition.py │ ├── mem_monitor.py │ ├── logging.py │ ├── get_parameters.py │ ├── connection_test_server.py │ ├── network_throughput.py │ ├── common.py │ └── node_info.py │ ├── models │ ├── llama │ │ └── arguments.py │ ├── arguments.py │ └── lr_schedule.py │ ├── configs │ └── llama_8B_C.yaml │ ├── server │ ├── module_collab.py │ ├── matchmaking.py │ ├── state_averager_wrap.py │ ├── power_sgd_averager.py │ ├── HM_gradient_averager.py │ ├── node0_server.py │ └── optim.py │ └── run_server.py ├── .github ├── CODEOWNERS ├── assets │ └── dashboard-button.svg └── ISSUE_TEMPLATE │ └── bug_report.md ├── images ├── node0-logo-black.png ├── node0-logo-white.png ├── runpod_edit_pod.png ├── aws_inbound_rules.png ├── gcp_inbound_rules.png ├── lambda_inbound_rules.png ├── runpod_external_port.png ├── runpod_inbound_rules.png ├── tensordock_external_port.png └── tensordock_forwarded_ports.png ├── run.json ├── Dockerfile ├── pyproject.toml ├── THIRD_PARTY_LICENSES.md ├── docs ├── useful_links.md ├── network.md └── cloud.md ├── .gitignore └── LICENSE /src/node0/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @violettapluralis @gilpluralis @yanpluralis -------------------------------------------------------------------------------- /images/node0-logo-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/node0-logo-black.png -------------------------------------------------------------------------------- /images/node0-logo-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/node0-logo-white.png -------------------------------------------------------------------------------- /images/runpod_edit_pod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/runpod_edit_pod.png -------------------------------------------------------------------------------- /images/aws_inbound_rules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/aws_inbound_rules.png -------------------------------------------------------------------------------- /images/gcp_inbound_rules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/gcp_inbound_rules.png -------------------------------------------------------------------------------- /images/lambda_inbound_rules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/lambda_inbound_rules.png -------------------------------------------------------------------------------- /images/runpod_external_port.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/runpod_external_port.png -------------------------------------------------------------------------------- /images/runpod_inbound_rules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/runpod_inbound_rules.png -------------------------------------------------------------------------------- /images/tensordock_external_port.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/tensordock_external_port.png -------------------------------------------------------------------------------- /images/tensordock_forwarded_ports.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/images/tensordock_forwarded_ports.png -------------------------------------------------------------------------------- /src/node0/security/integrity_check.cpython-311-aarch64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/src/node0/security/integrity_check.cpython-311-aarch64-linux-gnu.so -------------------------------------------------------------------------------- /src/node0/security/integrity_check.cpython-311-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PluralisResearch/node0/HEAD/src/node0/security/integrity_check.cpython-311-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /run.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_config": "llama_8B_C", 3 | "auth_server": "https://auth.pluralis.ai", 4 | "seeds": ["/ip4/34.215.69.49/tcp/49200/p2p/QmQQc9r9pKF5hw3MbXHqRhGAHh4JZKiuv1hCenmGKV9e2Y", "/ip4/54.70.227.132/tcp/49200/p2p/QmSzC8yq5yqhiTjP2gRaTaUZbjik2ZrdbCc8rP7dXRdP21", "/ip4/34.208.159.142/tcp/49200/p2p/QmNY87dq6tmNXZJhWYddyA5WbVLY6FEktU1QTXdraYCSa3", "/ip4/35.89.245.43/tcp/49200/p2p/QmcRSpPVX3itcEedowDdLrSdSN7kMTC8jkkbTN5UYsoRH4"] 5 | } -------------------------------------------------------------------------------- /.github/assets/dashboard-button.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 🚀 Live Dashboard 10 | Click to view real-time data 11 | -------------------------------------------------------------------------------- /src/node0/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .common import build_cls, clean_tmp, infer_expert_params 16 | from .logging import Node0Logger 17 | from .monitor import MonitorWorker 18 | from .dht_partition import update_initial_peers -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu22.04 2 | 3 | WORKDIR /home 4 | # Set en_US.UTF-8 locale by default 5 | RUN echo "LC_ALL=en_US.UTF-8" >> /etc/environment 6 | 7 | # Install packages 8 | RUN apt-get update && apt-get install -y --no-install-recommends --force-yes \ 9 | build-essential \ 10 | rsync \ 11 | openssh-client \ 12 | curl \ 13 | wget \ 14 | git \ 15 | vim \ 16 | && apt-get clean autoclean && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} /tmp/* /var/tmp/* 17 | 18 | # Install conda 19 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh && \ 20 | bash install_miniconda.sh -b -p /opt/conda && rm install_miniconda.sh 21 | ENV PATH="/opt/conda/bin:${PATH}" 22 | 23 | # Accept conda TOS 24 | RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ 25 | conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r 26 | 27 | # Install torch 28 | RUN conda install python=3.11 pip 29 | 30 | # Install node0 lib 31 | COPY . node0/ 32 | RUN cd node0 && pip install . 33 | 34 | CMD ["bash"] 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "node0" 3 | version = "0.1.0" 4 | readme = "README.md" 5 | requires-python = ">=3.11,<3.12" 6 | license = "Apache-2.0" 7 | license-files = ["*LICEN[CS]E*"] 8 | dependencies = [ 9 | "PyYAML", 10 | "prometheus_client", 11 | "torch==2.7.0", 12 | "numpy>=1.17", 13 | "scipy>=1.2.1", 14 | "prefetch_generator>=1.0.1", 15 | "msgpack>=0.5.6", 16 | "sortedcontainers", 17 | "uvloop>=0.14.0", 18 | "grpcio-tools>=1.33.2", 19 | "protobuf>=5.29.0", 20 | "configargparse>=1.2.3", 21 | "py-multihash>=0.2.3", 22 | "cryptography>=3.4.6", 23 | "pydantic>=2.0.0", 24 | "packaging>=20.9", 25 | "varint>=1.0.2", 26 | "base58>=1.0.2", 27 | "netaddr>=1.3.0", 28 | "idna>=3.10", 29 | "py-cid>=0.3.0", 30 | "requests>=2.32.3", 31 | "speedtest-cli>=2.1.3", 32 | "psutil>=7.0.0", 33 | "hivemind @ git+https://github.com/learning-at-home/hivemind.git@4d5c41495be082490ea44cce4e9dd58f9926bb4e" 34 | ] 35 | 36 | [build-system] 37 | requires = ["hatchling"] 38 | build-backend = "hatchling.build" 39 | 40 | [tool.hatch.metadata] 41 | allow-direct-references = true 42 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Bug Description 11 | 12 | **Describe the bug** 13 | 14 | A clear and concise description of what the bug is and what you expected to happen. 15 | 16 | **Steps to Reproduce** 17 | 18 | Describe the steps to reproduce the issue, including the specific script/command you ran and any input parameters. 19 | 20 | ## Environment Information 21 | 22 | **Operating System:** 23 | - [ ] Windows + WSL (specify version: ) 24 | - [ ] Linux (specify distribution and version: ) 25 | 26 | **Python Version:** 27 | 28 | **Deployment Method:** 29 | - [ ] Docker 30 | - [ ] Local installation (pip/conda/venv) 31 | 32 | **Runtime Environment:** 33 | - [ ] Personal computer/laptop 34 | - [ ] Cloud service (specify which one): 35 | - [ ] AWS 36 | - [ ] Google Cloud Platform 37 | - [ ] Azure 38 | - [ ] Lambda Labs 39 | - [ ] Other: ___________ 40 | 41 | ## Error Details 42 | 43 | **Error Message:** 44 | ``` 45 | Paste the full error message/traceback here 46 | ``` 47 | 48 | **Log Output:** 49 | ``` 50 | Paste relevant log output here 51 | ``` 52 | 53 | **Any additional context or information that might be helpful:** 54 | -------------------------------------------------------------------------------- /src/node0/models/llama/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from node0.models.arguments import ModelArguments 16 | 17 | 18 | class LlamaArguments(ModelArguments): 19 | hidden_dim: int = 4096 20 | n_heads: int = 32 21 | n_kv_heads: int | None = None 22 | vocab_size: int = 50265 # Using AutoTokenizer.from_pretrained("facebook/opt-2.7b") 23 | # vocab_size: int = 50280 # Using AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf") 24 | multiple_of: int = 256 25 | ffn_dim_multiplier: float | None = None 26 | norm_eps: float = 1e-5 27 | rope_theta: float = 10000 28 | max_seq_len: int = 512 29 | depth_init: bool = True 30 | constant_init: bool = False 31 | norm_type: str = "rmsnorm" 32 | -------------------------------------------------------------------------------- /src/node0/configs/llama_8B_C.yaml: -------------------------------------------------------------------------------- 1 | model_config: 2 | class_name: node0.models.llama.arguments.LlamaArguments 3 | init_args: 4 | hidden_dim: 4096 5 | n_heads: 32 6 | n_kv_heads: 8 7 | ffn_dim_multiplier: 1.3 8 | multiple_of: 1024 9 | rope_theta: 500000 10 | num_hidden_layers: 1 11 | qk_norm: True 12 | norm_reorder: True 13 | trainable_rmsnorm: False 14 | compression_rate: 100 15 | use_compression: True 16 | max_seq_len: 4096 17 | ss_component: 'https://d2exiwjpgw0bxb.cloudfront.net/subspace_compression/8B_C/subspace_comp.pt' 18 | 19 | optim_config: 20 | class_name: torch.optim.AdamW 21 | init_args: 22 | lr: 0.0003 23 | weight_decay: 0.1 24 | 25 | grad_avg_config: 26 | class_name: node0.server.power_sgd_averager.PowerSGDGradientAverager 27 | init_args: 28 | averager_rank: 64 29 | 30 | # Scheduler 31 | scheduler: linear 32 | num_warmup_steps: 4000 33 | num_training_steps: 100000 34 | 35 | # Training 36 | num_stages: 32 37 | clip_grad_norm: 1.0 38 | weight_decay: 0.0 39 | compression: NONE 40 | min_batch_size: 1 41 | max_batch_size: 1 42 | averaging_target_batch_size: 1024 43 | matchmaking_time: 45 44 | averaging_timeout: 120 45 | request_timeout: 3 46 | sparse_avg: 0.05 47 | average_state_every: 5 48 | load_state_timeout: 150 49 | max_allowed_stale: 5 50 | num_dht: 5 -------------------------------------------------------------------------------- /src/node0/models/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Self 16 | 17 | from pydantic import BaseModel, model_validator 18 | 19 | 20 | class ModelArguments(BaseModel): 21 | # Model parameters 22 | hidden_dim: int 23 | n_heads: int 24 | num_hidden_layers: int 25 | n_layers: int = 0 26 | stage: int | None = None 27 | 28 | # Attention projection parameters 29 | attn_proj: bool = False 30 | 31 | # QK norm 32 | qk_norm: bool = False 33 | norm_reorder: bool = False 34 | trainable_rmsnorm: bool = True 35 | 36 | # Compression parameters 37 | compression_rate: int | None = None 38 | use_compression: bool = False 39 | ss_component: str | None 40 | 41 | @model_validator(mode="after") 42 | def set_n_layers(self) -> Self: 43 | self.n_layers = self.num_hidden_layers 44 | return self 45 | -------------------------------------------------------------------------------- /src/node0/utils/flops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from node0.models.arguments import ModelArguments 18 | 19 | 20 | def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: 21 | """Compute number of model parameters.""" 22 | num_params = sum(p.numel() for p in model.parameters()) 23 | if exclude_embedding and hasattr(model, "tok_embeddings"): 24 | num_params -= model.tok_embeddings.weight.numel() 25 | 26 | if exclude_embedding and hasattr(model, "fixed_tok_embeddings"): 27 | num_params -= model.fixed_tok_embeddings.weight.numel() 28 | 29 | return num_params 30 | 31 | 32 | def get_num_flop_per_token_fwd(num_params: int, model_config: ModelArguments, seq_len: int) -> int: 33 | """Compute number of flop per token in forward pass.""" 34 | layers, h, q, t = ( 35 | model_config.n_layers, 36 | model_config.n_heads, 37 | model_config.hidden_dim // model_config.n_heads, 38 | seq_len, 39 | ) 40 | 41 | flop_per_token = 2 * num_params + 4 * layers * h * q * t 42 | 43 | return flop_per_token 44 | 45 | 46 | def get_num_flop_per_token_bwd(num_params: int, model_config: ModelArguments, seq_len: int) -> int: 47 | """Compute number of flop per token in backward pass.""" 48 | layers, h, q, t = ( 49 | model_config.n_layers, 50 | model_config.n_heads, 51 | model_config.hidden_dim // model_config.n_heads, 52 | seq_len, 53 | ) 54 | 55 | flop_per_token = 4 * num_params + 8 * layers * h * q * t 56 | 57 | return flop_per_token 58 | -------------------------------------------------------------------------------- /THIRD_PARTY_LICENSES.md: -------------------------------------------------------------------------------- 1 | # Third-Party Dependencies 2 | 3 | This project uses the following open-source libraries: 4 | 5 | | Name | Version | License | 6 | |--------------------------|------------|--------------------------------------| 7 | | ConfigArgParse | 1.7.1 | MIT License | 8 | | PyYAML | 6.0.2 | MIT License | 9 | | base58 | 1.0.3 | MIT License | 10 | | cryptography | 45.0.7 | Apache-2.0 OR BSD-3-Clause | 11 | | grpcio-tools | 1.74.0 | Apache Software License | 12 | | hivemind | 1.2.0.dev0 | MIT License | 13 | | idna | 3.10 | BSD License | 14 | | msgpack | 1.1.1 | Apache 2.0 | 15 | | netaddr | 1.3.0 | BSD License | 16 | | numpy | 2.3.3 | BSD License | 17 | | packaging | 25.0 | Apache Software License; BSD License | 18 | | prefetch_generator | 1.0.3 | The Unlicense (Unlicense) | 19 | | prometheus_client | 0.22.1 | Apache-2.0 AND BSD-2-Clause | 20 | | protobuf | 6.32.1 | 3-Clause BSD License | 21 | | psutil | 7.0.0 | BSD License | 22 | | py-cid | 0.3.1 | MIT | 23 | | py-multihash | 2.0.1 | MIT License | 24 | | pydantic | 2.11.9 | MIT License | 25 | | requests | 2.32.5 | Apache Software License | 26 | | scipy | 1.16.2 | BSD License | 27 | | sortedcontainers | 2.4.0 | Apache Software License | 28 | | speedtest-cli | 2.1.3 | Apache Software License | 29 | | torch | 2.7.0 | BSD License | 30 | | uvloop | 0.21.0 | Apache Software License; MIT License | 31 | | varint | 1.0.2 | MIT License | 32 | -------------------------------------------------------------------------------- /src/node0/security/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from hivemind.dht.crypto import RSASignatureValidator 18 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator 19 | from hivemind.dht.validation import RecordValidatorBase 20 | from pydantic.v1 import BaseModel, StrictFloat, StrictInt, StrictStr 21 | 22 | 23 | if TYPE_CHECKING: 24 | # For type checkers, treat it as bytes 25 | BytesWithPublicKeyType = bytes 26 | else: 27 | # At runtime, use the actual validator 28 | BytesWithPublicKeyType = BytesWithPublicKey 29 | 30 | 31 | class WorkerMetricsV1(BaseModel): 32 | """Schema for worker metrics.""" 33 | 34 | peer_id: str 35 | num_flop: StrictFloat 36 | active_time: StrictFloat 37 | 38 | 39 | class WorkerPortV1(BaseModel): 40 | """Schema for worker port reachability.""" 41 | 42 | peer_id: str 43 | is_open: bool 44 | 45 | 46 | class RunParameters(BaseModel): 47 | peer_id: bytes 48 | averaging_target_batch_size: StrictInt 49 | scheduler: StrictStr 50 | num_warmup_steps: StrictInt 51 | num_training_steps: StrictInt 52 | averaging_timeout: StrictFloat 53 | matchmaking_time: StrictFloat 54 | request_timeout: StrictFloat 55 | load_state_timeout: StrictFloat 56 | time: StrictFloat 57 | 58 | 59 | class MetricSchema(BaseModel): 60 | """Force metrics keys to have signed subkeys.""" 61 | 62 | worker_metrics: dict[BytesWithPublicKeyType, WorkerMetricsV1] 63 | 64 | 65 | class PortSchema(BaseModel): 66 | """Force port keys to have signed subkeys.""" 67 | 68 | worker_ports: dict[BytesWithPublicKeyType, WorkerPortV1] 69 | 70 | 71 | class RunParametersSchema(BaseModel): 72 | paramstore: dict[BytesWithPublicKeyType, RunParameters | None] 73 | 74 | 75 | def make_validators(experiment_prefix: str, peer_id: str, stage: str) -> tuple[list[RecordValidatorBase], bytes]: 76 | """Create all validators""" 77 | metric_validator = SchemaValidator(MetricSchema, prefix=f"{experiment_prefix}_{stage}") 78 | port_validator = SchemaValidator(PortSchema, prefix=f"{experiment_prefix}_{peer_id}") 79 | param_validator = SchemaValidator(RunParametersSchema, prefix=stage.split(".")[0]) 80 | signature_validator = RSASignatureValidator() 81 | 82 | validators = [metric_validator, port_validator, param_validator, signature_validator] 83 | return validators, signature_validator.local_public_key 84 | -------------------------------------------------------------------------------- /src/node0/utils/dht_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import threading 16 | import time 17 | 18 | from hivemind.dht.protocol import DHTProtocol 19 | from hivemind.utils.logging import get_logger 20 | 21 | 22 | logger = get_logger(__name__) 23 | 24 | _rpc_store_count = 0 25 | _rpc_find_count = 0 26 | _last_log_time = time.time() 27 | _counter_lock = threading.Lock() 28 | 29 | 30 | def patch_dht_protocol_logging(): 31 | """Monkey patch DHTProtocol to add RPC call counting""" 32 | # Import here to avoid issues with module loading order 33 | from hivemind.p2p import P2PContext 34 | from hivemind.proto import dht_pb2 35 | 36 | global _rpc_store_count, _rpc_find_count, _last_log_time 37 | 38 | # Store original methods 39 | original_rpc_store = DHTProtocol.rpc_store 40 | original_rpc_find = DHTProtocol.rpc_find 41 | 42 | logger.extra("[DHT Monitor] Patching DHTProtocol to monitor RPC calls") 43 | 44 | # Create wrapped methods 45 | async def counted_rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse: 46 | global _rpc_store_count, _rpc_find_count, _last_log_time 47 | 48 | with _counter_lock: 49 | _rpc_store_count += 1 50 | current_time = time.time() 51 | 52 | # Log every 60 seconds 53 | if current_time - _last_log_time >= 60: 54 | logger.extra(f"[DHT RPC Stats] Last 60s - rpc_store: {_rpc_store_count}, rpc_find: {_rpc_find_count}") 55 | _rpc_store_count = 0 56 | _rpc_find_count = 0 57 | _last_log_time = current_time 58 | 59 | return await original_rpc_store(self, request, context) 60 | 61 | async def counted_rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse: 62 | global _rpc_store_count, _rpc_find_count, _last_log_time 63 | 64 | with _counter_lock: 65 | _rpc_find_count += 1 66 | current_time = time.time() 67 | 68 | # Log every 60 seconds 69 | if current_time - _last_log_time >= 60: 70 | logger.extra(f"[DHT RPC Stats] Last 60s - rpc_store: {_rpc_store_count}, rpc_find: {_rpc_find_count}") 71 | _rpc_store_count = 0 72 | _rpc_find_count = 0 73 | _last_log_time = current_time 74 | 75 | return await original_rpc_find(self, request, context) 76 | 77 | # Copy over important attributes from original methods 78 | counted_rpc_store.__name__ = "rpc_store" 79 | counted_rpc_store.__qualname__ = "DHTProtocol.rpc_store" 80 | counted_rpc_find.__name__ = "rpc_find" 81 | counted_rpc_find.__qualname__ = "DHTProtocol.rpc_find" 82 | 83 | # Patch with the new methods 84 | DHTProtocol.rpc_store = counted_rpc_store 85 | DHTProtocol.rpc_find = counted_rpc_find 86 | 87 | logger.extra("[DHT Monitor] RPC monitoring active - will log stats every 60 seconds") 88 | 89 | return None 90 | -------------------------------------------------------------------------------- /src/node0/utils/dht_partition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def partition_array(x: int, y: int) -> list[list[int]]: 17 | """ 18 | Partition array [0, 1, 2, ..., x-1] into y partitions as equally as possible. 19 | 20 | Args: 21 | x (int): Size of array (values 0 to x-1) 22 | y (int): Number of partitions 23 | 24 | Returns: 25 | list[list[int]]: list of lists representing the partitions 26 | """ 27 | if x == 0: 28 | return [] 29 | 30 | if x <= y: 31 | # Each value gets its own partition 32 | return [[i] for i in range(y) if i < x] 33 | 34 | base_size = x // y 35 | remainder = x % y 36 | 37 | partitions = [] 38 | start = 0 39 | 40 | for i in range(y): 41 | # First 'remainder' partitions get an extra element 42 | size = base_size + (1 if i < remainder else 0) 43 | partitions.append(list(range(start, start + size))) 44 | start += size 45 | 46 | return partitions 47 | 48 | 49 | def stage_to_dht_map(dht_partition: list[list[int]]) -> list[int]: 50 | """Map stage index to dht index 51 | 52 | Args: 53 | dht_partition (list[list[int]]): list of lists representing the partitions 54 | 55 | Returns: 56 | list[int]: stage to dht mapping 57 | """ 58 | stage_to_dht = [part for part, sublist in enumerate(dht_partition) for _ in sublist] 59 | return stage_to_dht 60 | 61 | 62 | def update_initial_peers( 63 | initial_peers: list[str], 64 | pipeline_stage: str, 65 | num_stages: int, 66 | num_dht: int, 67 | ) -> list[str]: 68 | """Update the list of initial peers with correct ports that match the given stage 69 | 70 | Args: 71 | initial_peers (list[str]): list of multiaddress 72 | pipeline_stage (str): stage type in the format: head-X, body-X, tail-X (X is int) 73 | num_stages (int): total number of stages 74 | num_dht (int): number of worker DHTs 75 | 76 | Raises: 77 | ValueError: wrong stage type 78 | 79 | Returns: 80 | list[str]: initial_peers 81 | """ 82 | 83 | # Calculate port offset according to stage 84 | stage_idx = int(pipeline_stage.split("-")[1]) 85 | dht_worker_partitions = partition_array(num_stages, num_dht) 86 | stage_to_dht = stage_to_dht_map(dht_worker_partitions) 87 | port_offset = stage_to_dht[stage_idx] 88 | 89 | # Update initial peers ports 90 | for i, peeri in enumerate(initial_peers): 91 | try: 92 | # Extract baseline port 93 | parts = peeri.split("/") 94 | port_index = parts.index("tcp") + 1 95 | base_port = int(parts[port_index]) 96 | parts[port_index] = str(base_port + port_offset) 97 | initial_peers[i] = "/".join(parts) 98 | except (ValueError, IndexError) as e: 99 | raise ValueError(f"Invalid multiaddress format in peer {i}: {peeri}. Error: {e}") from e 100 | 101 | return initial_peers 102 | -------------------------------------------------------------------------------- /src/node0/models/lr_schedule.py: -------------------------------------------------------------------------------- 1 | # This file contains code originally from Hivemind under MIT License 2 | # Original: Copyright 2020 Learning@home authors and collaborators 3 | # Modified by: Pluralis Research 2025 4 | # 5 | # Original code: MIT License (see THIRD_PARTY_LICENSES) 6 | # Modifications: Apache 2.0 License (see LICENSE) 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only; 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | import math 13 | 14 | from torch.optim.lr_scheduler import LambdaLR 15 | 16 | 17 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1): 18 | """ 19 | Create a schedule with a learning rate that decreases following a cosine curve from the initial lr set in the 20 | optimizer to min_lr_ratio * initial_lr, after a warmup period during which it increases linearly from 0 to 21 | the initial lr set in the optimizer. 22 | 23 | Args: 24 | optimizer (:class:`~torch.optim.Optimizer`): 25 | The optimizer for which to schedule the learning rate. 26 | num_warmup_steps (:obj:`int`): 27 | The number of steps for the warmup phase. 28 | num_training_steps (:obj:`int`): 29 | The total number of training steps. 30 | min_lr_ratio (:obj:`float`, optional): 31 | The minimum learning rate as a ratio of the initial learning rate. Default: 0.1 32 | Return: 33 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 34 | """ 35 | 36 | def lr_lambda(current_step: int): 37 | if current_step < num_warmup_steps: 38 | # Linear warmup: 0 to 1.0 39 | return float(current_step) / float(max(1, num_warmup_steps)) 40 | 41 | # Cosine decay: 1.0 to min_lr_ratio 42 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 43 | progress = min(progress, 1.0) # Clamp to [0, 1] 44 | 45 | # Cosine annealing formula 46 | cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) 47 | return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_factor 48 | 49 | return LambdaLR(optimizer, lr_lambda) 50 | 51 | 52 | # https://github.com/huggingface/transformers/blob/master/src/transformers/optimization.py 53 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1): 54 | """ 55 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0.1x, after 56 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 57 | Args: 58 | optimizer (:class:`~torch.optim.Optimizer`): 59 | The optimizer for which to schedule the learning rate. 60 | num_warmup_steps (:obj:`int`): 61 | The number of steps for the warmup phase. 62 | num_training_steps (:obj:`int`): 63 | The total number of training steps. 64 | min_lr_ratio (:obj:`float`, optional): 65 | The minimum learning rate as a ratio of the initial learning rate. Default: 0.1 66 | Return: 67 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 68 | """ 69 | 70 | def lr_lambda(current_step: int): 71 | if current_step < num_warmup_steps: 72 | return float(current_step) / float(max(1, num_warmup_steps)) 73 | return max( 74 | min_lr_ratio, 75 | float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)), 76 | ) 77 | 78 | return LambdaLR(optimizer, lr_lambda) 79 | 80 | 81 | schedule_name_to_scheduler = { 82 | "linear": get_linear_schedule_with_warmup, 83 | "cosine": get_cosine_schedule_with_warmup, 84 | "none": None, 85 | } 86 | -------------------------------------------------------------------------------- /src/node0/utils/mem_monitor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import threading 16 | import time 17 | 18 | import numpy as np 19 | import psutil 20 | import torch 21 | 22 | from hivemind.utils.logging import get_logger 23 | 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | class MemoryTracker(threading.Thread): 29 | """ 30 | Monitor memory 31 | 32 | """ 33 | 34 | def __init__( 35 | self, 36 | interval: int = 60, 37 | ): 38 | super().__init__() 39 | 40 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 41 | 42 | self.log_gpumem = True 43 | if self.device == "cpu": 44 | logger.info("GPU memory monitor skipped. Device is cpu.") 45 | self.log_gpumem = False 46 | 47 | self.interval = interval 48 | if self.log_gpumem: 49 | self.total_gpu_vram = torch.cuda.get_device_properties(self.device).total_memory / 1024**3 # Gb 50 | 51 | self.total_ram = psutil.virtual_memory().total / 1024**3 # Gb 52 | 53 | self.start() 54 | 55 | def get_gpu_memory_usage(self, device=0): 56 | """ 57 | Returns the current GPU memory usage 58 | """ 59 | 60 | allocated = torch.cuda.memory_allocated(device) / 1024**3 # Gb 61 | reserved = torch.cuda.memory_reserved(device) / 1024**3 # Gb 62 | 63 | return { 64 | "allocated": allocated, 65 | "reserved": reserved, 66 | } 67 | 68 | def get_memory_usage(self): 69 | """ 70 | Returns the current RAM usage 71 | """ 72 | 73 | ram = psutil.virtual_memory() 74 | used_ram_gb = ram.used / 1024**3 # Gb 75 | 76 | return { 77 | "used": used_ram_gb, 78 | } 79 | 80 | def run(self): 81 | # Store initial network counters 82 | logger.info("Running memory monitor") 83 | start_time = time.time() 84 | alloc = [] 85 | reserv = [] 86 | ram_used = [] 87 | try: 88 | while True: 89 | if time.time() - start_time > self.interval: 90 | if self.log_gpumem: 91 | logger.info(f"GPU mem size is {round(self.total_gpu_vram)}Gb") 92 | logger.info(f"Allocated GPU mem is {np.mean(alloc) / self.total_gpu_vram:.2f}") 93 | logger.info(f"Reserved GPU mem is {np.mean(reserv) / self.total_gpu_vram:.2f}") 94 | logger.info(f"Total RAM mem is {round(self.total_ram)}Gb") 95 | logger.info(f"Used RAM mem is {np.mean(ram_used) / self.total_ram:.2f}") 96 | start_time = time.time() 97 | alloc = [] 98 | reserv = [] 99 | ram_used = [] 100 | 101 | if self.log_gpumem: 102 | vram_usage = self.get_gpu_memory_usage() 103 | alloc.append(vram_usage["allocated"]) 104 | reserv.append(vram_usage["reserved"]) 105 | ram_usage = self.get_memory_usage() 106 | ram_used.append(ram_usage["used"]) 107 | time.sleep(0.5) 108 | 109 | except KeyboardInterrupt: 110 | print("\n GPU monitoring stopped by user.") 111 | -------------------------------------------------------------------------------- /src/node0/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | from logging.handlers import RotatingFileHandler 18 | 19 | from hivemind.utils import get_logger, use_hivemind_log_handler 20 | from hivemind.utils.logging import TextStyle, always_log_caller 21 | 22 | 23 | class CustomFormatter(logging.Formatter): 24 | """ 25 | A formatter that allows a log time and caller info to be overridden via 26 | ``logger.log(level, message, extra={"origin_created": ..., "caller": ...})``. 27 | """ 28 | 29 | _LEVEL_TO_COLOR = { 30 | logging.DEBUG: TextStyle.PURPLE, 31 | logging.INFO: TextStyle.BLUE, 32 | logging.WARNING: TextStyle.ORANGE, 33 | logging.ERROR: TextStyle.RED, 34 | logging.CRITICAL: TextStyle.RED, 35 | } 36 | 37 | def format(self, record: logging.LogRecord) -> str: 38 | if hasattr(record, "origin_created"): 39 | record.created = record.origin_created 40 | record.msecs = (record.created - int(record.created)) * 1000 41 | 42 | if record.levelno > logging.INFO or always_log_caller: 43 | if not hasattr(record, "caller"): 44 | record.caller = f"{record.name}.{record.funcName}:{record.lineno}" 45 | record.caller_block = f" [{TextStyle.BOLD}{record.caller}{TextStyle.RESET}]" 46 | else: 47 | record.caller_block = "" 48 | 49 | # Aliases for the format argument 50 | record.levelcolor = ( 51 | self._LEVEL_TO_COLOR[record.levelno] if record.levelno in self._LEVEL_TO_COLOR else TextStyle.BLUE 52 | ) 53 | record.bold = TextStyle.BOLD 54 | record.reset = TextStyle.RESET 55 | 56 | return super().format(record) 57 | 58 | 59 | class Node0Logger: 60 | def __init__(self, log_level: str = "INFO"): 61 | """Instantiate logger. 62 | 63 | Args: 64 | log_level (str, optional): logging level. Defaults to "INFO". 65 | log_file (str | None, optional): file to save logs. Defaults to None. 66 | """ 67 | # Add extra log level 68 | EXTRA_LEVEL = 15 # between INFO (20) and DEBUG (10) 69 | logging.addLevelName(EXTRA_LEVEL, "EXTRA") 70 | 71 | # Create a custom method for the logger 72 | def extra(self, message, *args, **kwargs): 73 | if self.isEnabledFor(EXTRA_LEVEL): 74 | self._log(EXTRA_LEVEL, message, args, **kwargs) 75 | 76 | logging.Logger.extra = extra 77 | use_hivemind_log_handler("in_root_logger") 78 | 79 | # Convert log level string to logging constant 80 | numeric_level = getattr(logging, log_level.upper(), log_level) 81 | 82 | # Configure root logger 83 | self.root_logger = get_logger() 84 | self.root_logger.setLevel(numeric_level) 85 | 86 | formatter = CustomFormatter( 87 | fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}]{caller_block} {message}", 88 | style="{", 89 | datefmt="%b %d %H:%M:%S", 90 | ) 91 | 92 | for handler in list(self.root_logger.handlers): 93 | if isinstance(handler, logging.StreamHandler) and not isinstance(handler, RotatingFileHandler): 94 | handler.setFormatter(formatter) 95 | 96 | def add_monitor_handler(self, monitor_handler: logging.Handler): 97 | """Attach monitor handler to report logs.""" 98 | self.root_logger.addHandler(monitor_handler) 99 | -------------------------------------------------------------------------------- /docs/useful_links.md: -------------------------------------------------------------------------------- 1 | # Useful Links 2 | 3 | ### Hugging Face Access Tokens 4 | 5 | * **Creating and managing a HF "Access Token"** 6 | [https://huggingface.co/docs/hub/security-tokens](https://huggingface.co/docs/hub/security-tokens) 7 | 8 | 9 | ### SSH Keys & Remote Login 10 | 11 | * **How to generate an SSH key pair** 12 | [https://www.digitalocean.com/community/tutorials/how-to-create-ssh-keys-with-openssh-on-macos-or-linux](https://www.digitalocean.com/community/tutorials/how-to-create-ssh-keys-with-openssh-on-macos-or-linux) 13 | * **How to connect to a server using SSH** 14 | [https://www.digitalocean.com/community/tutorials/how-to-use-ssh-to-connect-to-a-remote-server](https://www.digitalocean.com/community/tutorials/how-to-use-ssh-to-connect-to-a-remote-server) 15 | 16 | 17 | ### File Permissions & Ownership 18 | 19 | * **`chmod`, `chown`, and UNIX file permission basics** 20 | [https://www.digitalocean.com/community/tutorials/how-to-set-permissions-linux](https://www.digitalocean.com/community/tutorials/how-to-set-permissions-linux) 21 | * **Why “UNPROTECTED PRIVATE KEY FILE!” and how to fix it** 22 | [https://www.cyberciti.biz/faq/warning-unprotected-private-key-file-ssh-linux-unix-error/](https://www.cyberciti.biz/faq/warning-unprotected-private-key-file-ssh-linux-unix-error/) 23 | 24 | 25 | ### Basic Linux Package Management & Tools 26 | 27 | * **`apt` (Debian/Ubuntu) package manager intro** 28 | [https://help.ubuntu.com/community/AptGet/Howto](https://help.ubuntu.com/community/AptGet/Howto) 29 | * **Installing and using `lsof` to see open ports / GPU processes** 30 | [https://linux.die.net/man/8/lsof](https://linux.die.net/man/8/lsof) 31 | 32 | 33 | ### Installing Docker 34 | 35 | * **Official Docker Engine install guide (Ubuntu)** 36 | [https://docs.docker.com/engine/install/ubuntu/](https://docs.docker.com/engine/install/ubuntu/) 37 | * **Post-install steps (add your user to the `docker` group)** 38 | [https://docs.docker.com/engine/install/linux-postinstall/](https://docs.docker.com/engine/install/linux-postinstall/) 39 | 40 | 41 | ### Installing Conda (Miniconda) 42 | 43 | * **Miniconda install instructions (Linux)** 44 | [https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html](https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html) 45 | * **Conda cheat sheet (creating/activating envs, installing packages)** 46 | [https://docs.conda.io/projects/conda/en/latest/user-guide/cheatsheet.html](https://docs.conda.io/projects/conda/en/latest/user-guide/cheatsheet.html) 47 | 48 | 49 | ### Checking NVIDIA GPU & Drivers 50 | 51 | * **Installing NVIDIA drivers on Ubuntu** 52 | [https://docs.nvidia.com/datacenter/tesla/tesla-installation-notes/index.html](https://docs.nvidia.com/datacenter/tesla/tesla-installation-notes/index.html) 53 | * **Using `nvidia-smi` to verify GPU availability** 54 | [https://www.gpu-mart.com/blog/monitor-gpu-utilization-with-nvidia-smi](https://www.gpu-mart.com/blog/monitor-gpu-utilization-with-nvidia-smi) 55 | 56 | 57 | ### Git Basics 58 | 59 | * **Cloning a repository** 60 | [https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) 61 | * **Generating an SSH key for GitHub (if needed)** 62 | [https://docs.github.com/en/authentication/connecting-to-github-with-ssh](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) 63 | 64 | 65 | ### Firewall / Opening Ports 66 | 67 | * **Opening a Port on Linux to Allow TCP Connections** 68 | [https://www.digitalocean.com/community/tutorials/opening-a-port-on-linux](https://www.digitalocean.com/community/tutorials/opening-a-port-on-linux) 69 | * **General Linux firewall basics (iptables / firewalld)** 70 | [https://www.digitalocean.com/community/tutorials/iptables-essentials-common-firewall-rules-and-commands](https://www.digitalocean.com/community/tutorials/iptables-essentials-common-firewall-rules-and-commands) 71 | 72 | 73 | ### Opening a port on Cloud Providers 74 | 75 | * Check out the [network guide](network.md). 76 | 77 | -------------------------------------------------------------------------------- /src/node0/server/module_collab.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Tuple 16 | 17 | import torch 18 | 19 | from hivemind.moe.server import ModuleBackend 20 | from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack 21 | 22 | 23 | class ModuleCollab(ModuleBackend): 24 | def __init__(self, optimizer_lock, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | 27 | self.optimizer_lock = optimizer_lock 28 | 29 | def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: 30 | """ 31 | Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually 32 | To submit a request for asynchronous processing, please use ``ModuleBackend.backward_pool.submit_task``. 33 | 34 | Subclassing: 35 | This method receives a sequence of torch tensors following ``nested_flatten(self.backward_schema)``; 36 | 37 | It should return gradients w.r.t. inputs that follow ``nested_flatten(self.forward_schema)``; 38 | 39 | Runtime doesn't guarantee that backward will be performed in the same order and for the same data 40 | as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward. 41 | 42 | Please make sure to call ``ModuleBackend.on_backward`` after each call to backward 43 | """ 44 | (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) 45 | 46 | with torch.enable_grad(): 47 | with self.optimizer_lock: 48 | args = [ 49 | tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach() 50 | for tensor in args 51 | ] 52 | kwargs = { 53 | input_key: ( 54 | tensor.detach().requires_grad_(True) if tensor.is_floating_point() else tensor.detach() 55 | ) 56 | for input_key, tensor in kwargs.items() 57 | } 58 | 59 | batch_size = args[0].size(0) 60 | 61 | outputs = self.module(*args, **kwargs) 62 | assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure" 63 | 64 | outputs_flat = tuple(nested_flatten(outputs)) 65 | 66 | grad_outputs_flat = tuple( 67 | map( 68 | lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True), 69 | nested_flatten(grad_outputs), 70 | outputs_flat, 71 | ) 72 | ) 73 | torch.autograd.backward( 74 | outputs_flat, grad_tensors=grad_outputs_flat, create_graph=False, retain_graph=False 75 | ) 76 | self.on_backward(batch_size) 77 | 78 | return tuple( 79 | x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) for x in nested_flatten((args, kwargs)) 80 | ) 81 | 82 | def on_backward(self, batch_size: int) -> None: 83 | """ 84 | Train the expert for one step. This method is called by ``ModuleBackend.backward`` after computing gradients. 85 | """ 86 | if self.optimizer is not None: 87 | self.optimizer.step(batch_size=batch_size) 88 | self.optimizer.zero_grad() 89 | 90 | if self.scheduler is not None: 91 | self.scheduler.step() 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | logs*/ 7 | 8 | # C extensions 9 | *.so 10 | !integrity_check.*.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # Ruff stuff: 174 | .ruff_cache/ 175 | 176 | # PyPI configuration file 177 | .pypirc 178 | 179 | 180 | .vscode 181 | -------------------------------------------------------------------------------- /src/node0/utils/get_parameters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import signal 17 | 18 | from typing import Any 19 | 20 | from hivemind.dht import DHT 21 | from hivemind.utils import get_logger 22 | from hivemind.utils.timed_storage import ValueWithExpiration 23 | 24 | from node0.security.validation import RunParameters 25 | 26 | 27 | logger = get_logger(__name__) 28 | 29 | 30 | def get_parameter_store(dht: DHT, prefix: str) -> tuple[Any, ...]: 31 | """ 32 | Retrieve the most recent run parameters from the distributed hash table (DHT). 33 | 34 | This function fetches run parameters stored by peers in the DHT, validates them 35 | using RSA signature verification, and returns the averaging target batch size 36 | from the most recently updated entry. 37 | 38 | Args: 39 | dht (DHT): The distributed hash table instance to retrieve parameters from. 40 | prefix (str): Prefix string used to construct the DHT storage key for lookup. 41 | 42 | Returns: 43 | int: The averaging target batch size from the most recent valid peer entry. 44 | """ 45 | 46 | param_store_key = f"{prefix}_paramstore" 47 | param_store_result = dht.get(param_store_key, latest=True) 48 | 49 | if not isinstance(param_store_result, ValueWithExpiration): 50 | logger.error("Could not retrieve run parameters from peers. Exiting run.") 51 | os.killpg(os.getpgrp(), signal.SIGTERM) 52 | raise RuntimeError("Could not retrieve run parameters from peers") 53 | 54 | metadata = param_store_result.value 55 | 56 | valid_peer_entries = [ 57 | RunParameters.parse_obj(peer_value.value) for peer_value in metadata.values() if peer_value.value is not None 58 | ] 59 | 60 | last_time = -float("inf") 61 | averaging_target_batch_size = 0 62 | scheduler = "" 63 | num_warmup_steps = 0 64 | num_training_steps = 0 65 | averaging_timeout = 0 66 | matchmaking_time = 0 67 | request_timeout = 0 68 | load_state_timeout = 0 69 | 70 | for val in valid_peer_entries: 71 | if val.time > last_time: 72 | averaging_target_batch_size = val.averaging_target_batch_size 73 | scheduler = val.scheduler 74 | num_warmup_steps = val.num_warmup_steps 75 | num_training_steps = val.num_training_steps 76 | averaging_timeout = val.averaging_timeout 77 | matchmaking_time = val.matchmaking_time 78 | request_timeout = val.request_timeout 79 | load_state_timeout = val.load_state_timeout 80 | last_time = val.time 81 | 82 | if ( 83 | averaging_target_batch_size <= 0 84 | or scheduler not in ["linear", "cosine"] 85 | or num_warmup_steps <= 0 86 | or num_training_steps <= 0 87 | or averaging_timeout <= 0 88 | or matchmaking_time <= 0 89 | or request_timeout <= 0 90 | or load_state_timeout <= 0 91 | ): 92 | logger.error("Could not retrieve run parameters from peers. Exiting run.") 93 | os.killpg(os.getpgrp(), signal.SIGTERM) 94 | return 95 | 96 | logger.info( 97 | f"Got runtime training parameters: " 98 | f"averaging_target_batch_size = {averaging_target_batch_size}, " 99 | f"scheduler = {scheduler}, " 100 | f"num_warmup_steps = {num_warmup_steps}, " 101 | f"num_training_steps = {num_training_steps}, " 102 | f"averaging_timeout = {averaging_timeout}, " 103 | f"matchmaking_time = {matchmaking_time}, " 104 | f"request_timeout = {request_timeout}, " 105 | f"load_state_timeout = {load_state_timeout}" 106 | ) 107 | return ( 108 | averaging_target_batch_size, 109 | scheduler, 110 | num_warmup_steps, 111 | num_training_steps, 112 | averaging_timeout, 113 | matchmaking_time, 114 | request_timeout, 115 | load_state_timeout, 116 | ) 117 | -------------------------------------------------------------------------------- /src/node0/utils/connection_test_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import socket 16 | import threading 17 | import time 18 | 19 | from hivemind.utils.logging import get_logger 20 | 21 | from node0.utils.node_info import TestServerError 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | class TestServer: 28 | def __init__(self, host="0.0.0.0", port=49200): 29 | self.host = host 30 | self.port = port 31 | self.server_socket = None 32 | self.running = False 33 | self.thread = None 34 | self.received_message = None # Store single message 35 | self.message_received = threading.Event() # Signal when message arrives 36 | 37 | def __enter__(self): 38 | self.start() 39 | return self 40 | 41 | def __exit__(self, exc_type, exc_val, exc_tb): 42 | self.close() 43 | 44 | def start(self): 45 | """Start the server""" 46 | try: 47 | self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 48 | self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 49 | self.server_socket.bind((self.host, self.port)) 50 | self.server_socket.listen(1) 51 | self.server_socket.settimeout(1) 52 | self.running = True 53 | except OSError as e: 54 | logger.error(f"Failed to bind to {self.host}:{self.port}: {e}") 55 | if self.server_socket: 56 | self.server_socket.close() 57 | self.server_socket = None 58 | raise TestServerError(f"Failed to start server on {self.host}:{self.port}: {e}") from None 59 | 60 | def server_loop(): 61 | logger.info(f"Test server listening on {self.host}:{self.port}") 62 | while self.running and not self.message_received.is_set(): 63 | try: 64 | client_socket, addr = self.server_socket.accept() 65 | logger.info("Test server connection received") 66 | 67 | # Receive the message 68 | try: 69 | client_socket.settimeout(2) 70 | data = client_socket.recv(1024) 71 | if data: 72 | message = data.decode("utf-8", errors="ignore") 73 | self.received_message = {"message": message, "from": "", "timestamp": time.time()} 74 | self.message_received.set() # Signal message received 75 | logger.info(f"Received message: {message}") 76 | else: 77 | logger.info("No data received") 78 | except TimeoutError: 79 | logger.info("Timeout waiting for data") 80 | except Exception as e: 81 | logger.error(f"Error receiving data: {e}") 82 | finally: 83 | client_socket.close() 84 | 85 | except TimeoutError: 86 | continue 87 | except OSError: 88 | break 89 | 90 | self.thread = threading.Thread(target=server_loop) 91 | self.thread.daemon = True 92 | self.thread.start() 93 | time.sleep(0.1) # Let server start 94 | 95 | def close(self): 96 | """Close the server""" 97 | if self.running: 98 | self.running = False 99 | if self.server_socket: 100 | self.server_socket.close() 101 | if self.thread: 102 | self.thread.join(timeout=1) 103 | logger.info("Test server closed") 104 | 105 | def get_message(self): 106 | """Get the received message (or None if not received yet)""" 107 | return self.received_message 108 | 109 | def wait_for_message(self, timeout=10): 110 | """Wait for the message to be received. Returns True if received, False if timeout""" 111 | return self.message_received.wait(timeout) 112 | 113 | def verify_message(self): 114 | """Verify the received message contains expected content""" 115 | if self.received_message and "auth" in self.received_message["message"]: 116 | return True 117 | return False 118 | -------------------------------------------------------------------------------- /docs/network.md: -------------------------------------------------------------------------------- 1 | # Network Configuration Guide 2 | 3 | This guide provides instructions for configuring port 49200 to be accessible for external connections across various cloud compute providers and personal networks. 4 | 5 | ## AWS 6 | 1. Navigate to your EC2 instance and click on the Security Group attached to it (found under the Security tab) 7 | 2. Click "Inbound rules" → "Edit inbound rules" 8 | 3. Click "Add rule" and configure: 9 | - **Type:** Custom TCP 10 | - **Port range:** 49200 11 | - **Source:** 0.0.0.0/0 (to allow traffic from any source) 12 | 13 | AWS Inbound Rules 14 | 15 | ## Google Cloud Platform (GCP) 16 | 1. Go to "VPC network" → "Firewall" → "Create firewall rule" 17 | 2. Configure the rule: 18 | - **Target:** Choose "All instances in the network" or "Specified target tags" and specify your instance 19 | - **Source IPv4 ranges:** 0.0.0.0/0 20 | - **Protocols and ports:** Select "Specified protocols and ports" 21 | - **TCP:** 49200 22 | 23 | GCP Inbound Rules 24 | 25 | ## RunPod 26 | RunPod assigns random external port mappings, so you need to configure port forwarding: 27 | 28 | 1. After deploying a Pod, click the three horizontal lines (Pod Settings) → "Edit Pod" 29 | 30 | GCP Inbound Rules 31 | 32 | 2. Under "Expose TCP Ports" add `49200` and save (this will restart the Pod) 33 | 34 | GCP Inbound Rules 35 | 36 | 3. Once restarted, click "Connect" to see the external port mapping for internal port 49200. In the example below, the external port is 17969. 37 | 38 | GCP Inbound Rules 39 | 40 | 4. Use these flags when running generate_script.py (see **Changing exposed port** section in README for details): 41 | 42 | ```bash 43 | --host_port 49200 # The internal port the library will listen on 44 | --announce_port 17969 # The external port other peers will connect to 45 | ``` 46 | 47 | ## Tensordock 48 | ### Distributed Compute 49 | When provisioning an instance using the Distributed Compute option, Tensordock allows you to request a specific internal port and then assigns a nrandom external port mapping to that port, so you need to configure port forwarding: 50 | 51 | 1. During provisioning setup, under the Port Forwarding section → "Request Port" and choose `49200`: 52 | 53 | Tensordock Inbound Rules 54 | 55 | 2. Once deployed, note the randomly assigned external port: 56 | 57 | Tensordock Forwarded Ports 58 | 59 | 3. Use these flags when running generate_script.py (see **Changing exposed port** section in README for details). `--announce_port` should specify your randomly assigned external port: 60 | 61 | ```bash 62 | --host_port 49200 # The internal port the library will listen on 63 | --announce_port 10009 # The external port other peers will connect to 64 | ``` 65 | 66 | ## Lambda Labs 67 | 1. Navigate to the Firewall page in your Lambda Cloud dashboard 68 | 2. Click "Edit" in the Inbound Rules section 69 | 3. Configure the new rule: 70 | - **Rule type:** Custom TCP 71 | - **Port range:** 49200 72 | - **Source:** 0.0.0.0/0 (to allow traffic from any source) 73 | 74 | Lambda Inbound Rules 75 | 76 | ## Personal Computer 77 | 78 | Set up your firewall to allow traffic from the outside world to the port 49200/tcp. 79 | 80 | If you have a router, set it up to allow connections from the outside world (port 49200/tcp) to your computer. 81 | 82 | ## WSL 83 | 84 | If you're using Windows with WSL 2, you'll need to configure port forwarding: 85 | 86 | 1. **Create WSL configuration file:** 87 | Create a `.wslconfig` file in your Windows user home directory (e.g., `C:\Users\YourUsername\.wslconfig`) with the following content: 88 | ``` 89 | [wsl2] 90 | localhostforwarding=true 91 | ``` 92 | 93 | 2. **Configure port proxy (run PowerShell as Administrator):** 94 | 95 | Find your WSL container IP: 96 | ```powershell 97 | ((wsl hostname -I) -split " ")[0] 98 | ``` 99 | 100 | Add the port proxy (replace `` with the IP from the previous command): 101 | ```powershell 102 | netsh interface portproxy add v4tov4 listenport=49200 listenaddress=0.0.0.0 connectport=49200 connectaddress= 103 | ``` 104 | 105 | Open the firewall for the port: 106 | ```powershell 107 | netsh advfirewall firewall add rule name="node0" dir=in action=allow protocol=TCP localport=49200 108 | ``` 109 | 110 | 3. **If you have a router**, set it up to allow connections from the outside world (port 49200/tcp) to your computer. 111 | 4. **Restart your computer** to apply the changes. 112 | -------------------------------------------------------------------------------- /src/node0/utils/network_throughput.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import threading 17 | import time 18 | 19 | import psutil 20 | 21 | from hivemind.utils.logging import get_logger 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | class NetworkMonitor(threading.Thread): 28 | """ 29 | Monitor network throughput for a specific network interface or all interfaces. 30 | 31 | :param interface: Specific network interface to monitor (e.g., 'eth0') 32 | :param interval: Time between measurements in seconds 33 | :param duration: Total monitoring duration in seconds (None for continuous monitoring) 34 | """ 35 | 36 | def __init__( 37 | self, 38 | interface: str | None = None, 39 | interval: int = 60, 40 | duration: int | None = None, 41 | ): 42 | super().__init__() 43 | 44 | self.interface = interface 45 | self.interval = interval 46 | self.duration = duration 47 | 48 | self.start() 49 | 50 | def get_tcp_connection_count(self) -> int: 51 | """ 52 | Get the count of TCP connections from host system. 53 | """ 54 | try: 55 | tcp_file = "/proc/1/root/proc/net/tcp" 56 | if os.path.exists(tcp_file): 57 | with open(tcp_file, "r") as f: 58 | return len(f.readlines()) - 1 # Subtract 1 for header line 59 | except (OSError, PermissionError): 60 | pass 61 | 62 | # method fails, return -1 to indicate unavailable 63 | return -1 64 | 65 | def run(self): 66 | # Store initial network counters 67 | logger.info("Running network bandwidth monitor") 68 | initial_counters = psutil.net_io_counters(pernic=True) 69 | start_time = time.time() 70 | 71 | try: 72 | # Determine interfaces to monitor 73 | if self.interface: 74 | interfaces = [self.interface] 75 | else: 76 | interfaces = [iface for iface in initial_counters.keys() if iface != "lo"] 77 | 78 | # Monitoring loop 79 | while True: 80 | # Wait for the interval 81 | time.sleep(self.interval) 82 | 83 | # Get current network counters 84 | current_counters = psutil.net_io_counters(pernic=True) 85 | current_time = time.time() 86 | 87 | # Calculate time elapsed 88 | elapsed = current_time - start_time 89 | 90 | # Process each interface 91 | bytes_sent = 0 92 | bytes_recv = 0 93 | for iface in interfaces: 94 | if iface not in current_counters: 95 | continue 96 | 97 | # Calculate throughput 98 | initial = initial_counters.get(iface, None) 99 | if not initial: 100 | continue 101 | 102 | bytes_sent += current_counters[iface].bytes_sent - initial.bytes_sent 103 | bytes_recv += current_counters[iface].bytes_recv - initial.bytes_recv 104 | 105 | # Calculate megabytes sent and received per second 106 | bytes_sent = (bytes_sent) / elapsed / (1024 * 1024) 107 | bytes_recv = (bytes_recv) / elapsed / (1024 * 1024) 108 | 109 | bits_sent = bytes_sent * 8 110 | bits_recv = bytes_recv * 8 111 | 112 | # Get current TCP connection counts 113 | tcp_count = self.get_tcp_connection_count() 114 | 115 | # Log results 116 | logger.info( 117 | f"Time {time.strftime('%Y-%m-%d %H:%M:%S'):<20} " 118 | f"Interface Agg " 119 | f"Sent {bits_sent:>10.2f} Mbps " 120 | f"Rcv {bits_recv:>10.2f} Mbps " 121 | ) 122 | 123 | logger.info(f"Time {time.strftime('%Y-%m-%d %H:%M:%S'):<20} Number open connections {tcp_count}") 124 | 125 | # Update initial counters and start time 126 | initial_counters = current_counters 127 | start_time = current_time 128 | 129 | # Check if monitoring duration is specified 130 | if self.duration and current_time - start_time >= self.duration: 131 | break 132 | 133 | except KeyboardInterrupt: 134 | print("\nMonitoring stopped by user.") 135 | -------------------------------------------------------------------------------- /src/node0/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | import io 17 | 18 | from functools import partial 19 | from pathlib import Path 20 | from typing import Any 21 | 22 | import requests 23 | import torch 24 | 25 | from hivemind.utils.logging import get_logger 26 | 27 | 28 | logger = get_logger(__name__) 29 | 30 | 31 | def build_cls(class_path: str, init_args: dict, partial_init: bool = False) -> Any: 32 | """Instantiate class. 33 | 34 | Args: 35 | class_path (str): full path to class 36 | init_args (dict): class init arguments 37 | partial_init (bool, optional): if True, return partial function. Defaults to False. 38 | 39 | Raises: 40 | Exception: wrong class path/init arguments 41 | 42 | Returns: 43 | Any: class instance or partial function 44 | """ 45 | try: 46 | # Split module and class names 47 | module_name, class_name = class_path.rsplit(".", maxsplit=1) 48 | 49 | # Import module and get class 50 | module = importlib.import_module(module_name) 51 | class_ = getattr(module, class_name) 52 | 53 | # Instantiate class 54 | if partial_init: 55 | instance = partial(class_, **init_args) 56 | else: 57 | instance = class_(**init_args) 58 | return instance 59 | except Exception as e: 60 | logger.error(f"Can't initialize class {class_path}: {e}", exc_info=logger.getEffectiveLevel() <= 15) 61 | raise 62 | 63 | 64 | def infer_expert_params( 65 | pipeline_stage: str, 66 | max_experts: int = 1024, 67 | ) -> dict: 68 | """Infer required expert and stage parameters from pipeline_stage. 69 | 70 | Args: 71 | pipeline_stage str: stage type in the format: head-X, body-X, tail-X (X is int) 72 | max_experts (int, optional): max number of experts per stage. Defaults to 1024. 73 | 74 | Raises: 75 | ValueError: wrong stage type 76 | 77 | Returns: 78 | dict: expert_pattern, expert_class, stage 79 | """ 80 | try: 81 | stage, stage_idx = pipeline_stage.split("-") 82 | stage_idx = int(stage_idx) 83 | assert stage in ["head", "body", "tail"] 84 | except Exception as e: 85 | raise ValueError("Wrong stage type. It should be one of: head-X, body-X, tail-X (X is int).") from e 86 | 87 | stage_idx_for_pattern = "" if stage in ["head", "tail"] else stage_idx 88 | expert_pattern = f"{stage}{stage_idx_for_pattern}.0.[32:{max_experts}]" 89 | stage_name = f"{stage}{stage_idx_for_pattern}.0." 90 | expert_class = f"lm_{stage}" 91 | 92 | params = { 93 | "expert_pattern": expert_pattern, 94 | "expert_cls": expert_class, 95 | "stage": stage_idx, 96 | "stage_name": stage_name, 97 | } 98 | 99 | return params 100 | 101 | 102 | def load_ss_components(ss_url: str): 103 | """ 104 | Loads rcv and fixed_embeddings from URL location 105 | 106 | Args: 107 | ss_url (str): URL to the subspace compression file 108 | 109 | Returns: 110 | Dict of: rcv, fixed_tok_weight 111 | """ 112 | try: 113 | response = requests.get(ss_url) 114 | response.raise_for_status() 115 | 116 | buffer = io.BytesIO(response.content) 117 | ss_comp_dict = torch.load(buffer, map_location="cpu") 118 | 119 | except requests.RequestException as e: 120 | raise RuntimeError(f"Failed to download from URL: {e}") from e 121 | except RuntimeError as e: 122 | raise RuntimeError(f"Remote loading of subspace compression components failed: {e}") from e 123 | 124 | return ss_comp_dict 125 | 126 | 127 | def clean_tmp(): 128 | """Remove tmp hivemind socket files.""" 129 | tmp_dir = Path("/tmp") 130 | 131 | if not tmp_dir.exists(): 132 | return 133 | 134 | try: 135 | matching_files = list(tmp_dir.glob("hivemind*")) 136 | 137 | if len(matching_files) > 0: 138 | for fpath in matching_files: 139 | try: 140 | fpath.unlink() 141 | except Exception as e: 142 | logger.error( 143 | f"Can't remove file {fpath}: {e}. Please remove all hivemind* files from /tmp folder (see README for instructions to stop the server). Exiting run." 144 | ) 145 | exit(1) 146 | 147 | matching_files = list(tmp_dir.glob("hivemind*")) 148 | if len(matching_files) > 0: 149 | logger.error( 150 | "Old hivemind* files are found in /tmp folder. Please clean the folder (see README for instructions to stop the server). Exiting run." 151 | ) 152 | exit(1) 153 | logger.info("/tmp folder was cleaned") 154 | except Exception as e: 155 | logger.warning(f"Can't check tmp folder: {e}") 156 | return 157 | -------------------------------------------------------------------------------- /docs/cloud.md: -------------------------------------------------------------------------------- 1 | # Cloud Options 2 | We list various cloud options and how to set them up. The cheapest option is RunPod. 3 | 4 | ## AWS (Amazon Web Services) 5 | 6 | ### How to set up 7 | 8 | **Step 1: Launch a GPU Instance** 9 | 10 | 1. **Log into AWS**: Go to the [AWS Management Console](https://aws.amazon.com/console/). Make an account or log in if you have one. 11 | 2. **Create a new EC2 instance**: 12 | * Go to **EC2** > **Launch Instance**. 13 | * Choose a name for the instance. 14 | * Select an **AMI (Amazon Machine Image)** that is Unix based, supports GPU, and has CUDA and PyTorch installed. For example, 15 | * `Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.7 (Ubuntu 22.04)`. 16 | * **Choose an Instance Type**: 17 | * Select an instance type with a **GPU** that has at least 16GB VRAM and 32GB RAM, e.g. `g4dn.2xlarge` which has an NVIDIA T4 18 | * Choose a key pair 19 | * If you do not have a key pair, click **Create new key pair** and download the **.pem key** 20 | 3. **Configure the Instance**: 21 | * Set up **security groups** to allow SSH (port 22) and a specific port for protocol communication (port 49200, though you can change and specify this) 22 | * We recommend minimum of 80GB of storage 23 | 1. **Review and Launch** the instance. 24 | 25 | **Step 2: Edit Security Group** 26 | 1. Follow the AWS section in the [network guide](network.md) to configuring port 49200 to be accessible for external connections. 27 | 28 | **Step 3: Connect to the Instance** 29 | 30 | 1. **SSH into your instance**: 31 | * Find your public ip and connect via SSH: 32 | ```bash 33 | ssh -i your-key.pem ubuntu@your-ec2-public-ip 34 | ``` 35 | 36 | ## GCP (Google Cloud Platform) 37 | 38 | The cheapest option is a `NVIDIA T4` (16GB VRAM) and `n1-standard-8` (8 vCPU, 4 core, 30 GB memory) for $0.51/hour. 39 | 40 | ### How to set up 41 | 42 | **Step 1: Create a GPU-enabled VM** 43 | 44 | 1. **Log into Google Cloud** Go to the [Google Cloud Console](https://console.cloud.google.com/). Make an account or log in if you have one. 45 | 2. **Create a new VM instance**: 46 | 47 | * Go to **Compute Engine** > **VM instances** > **Create Instance**. 48 | * Choose a name and region for the instance. 49 | * Change from General Purpose to **GPUs** and select a **GPU** and **Machine Type** 50 | * E.g. Choose 1 `NVIDIA T4` and `n1-standard-8` (8 vCPU, 4 core, 30 GB memory) 51 | * In the **OS and storage** tab change the image to one that is Unix based, supports GPU, and has CUDA and PyTorch installed 52 | * E.g. OS `Deep Learning on Linux` and image `Deep Learning VM for PyTorch 2.4 with CUDA 12.4 M129` 53 | * In the **Security** tab click **Manage Access**. Under **Add manually generated SSH keys** click **Add item**, enter your SSH public key, and click **Save**. 54 | * Click **Create** 55 | 56 | **Step 2: Edit Firewall settings** 57 | 1. Follow the GCP section in the [network guide](network.md) to configuring port 49200 to be accessible for external connections. 58 | 59 | **Step 2: Connect to the Instance** 60 | 61 | 1. **Set up SSH Keys** 62 | * Go into your instance and click **Edit** 63 | * Under **SSH Keys** click **Add item**, enter your SSH public key, and click **Save**. 64 | 2. **SSH into your VM**: 65 | * Find your external ip and username (this is linked with your SSH key) under the instance details and connect via SSH: 66 | ```bash 67 | ssh ubuntu@your-external-ip 68 | ``` 69 | 70 | 71 | ## RunPod 72 | 73 | The cheapest option is a RTX 2000 Ada: 16GB VRAM, 31Gb RAM, 6 vCPUs for $0.23/hour. 74 | 75 | RunPod launches your workspace within a docker container, so it is difficult to launch docker within the docker container. 76 | We recommend using conda instead. See the [installing guide](installing.md) for how to install conda. 77 | 78 | RunPod also assigns random external port mappings, so we need to find and specify that external port. See the RunPod section in the [network guide](network.md) 79 | 80 | Finally, if you need to install anything else with RunPod, note that most standard packages are not installed, so run `apt update` first. 81 | 82 | ### How to set up 83 | **Step 1: Launch a GPU Pod** 84 | 85 | 1. **Log into RunPod**: Go to the [RunPod Console](https://www.runpod.io/console/home). Make an account or log in if you have one. 86 | 2. **Set SSH Keys**: 87 | * Go to **Settings** and under **SSH Public Keys** add your public SSH key. If you have not made a SSH key yet, follow [this guide from RunPod](https://docs.runpod.io/pods/configuration/use-ssh). 88 | 3. **Create a new Pod**: 89 | * Go to **Pods** to see available pods, and choose a Pod 90 | * E.g. RTX 2000 Ada: 16GB VRAM, 31Gb RAM, 6 vCPUs for $0.23/hour. 91 | * Choose a Pod name and **Pod Template** 92 | * Want one with CUDA and PyTorch installed, the default `RunPod Pytorch 2.1` works. 93 | * Ensure that SSH Terminal Access is enabled 94 | * Click **Deploy On-Demand** 95 | 96 | **Step 2: Edit the Pod** 97 | 1. Follow the RunPod section in the [network guide](network.md) to edit the Pod to expose a TCP port. 98 | 99 | **Step 3: Connect to the Pod** 100 | 101 | 1. **SSH into your Pod**: 102 | * Go to **Connect** and in the **SSH** tab look at the ssh command under **SSH over exposed TCP**. 103 | 104 | 105 | ## Tensordock 106 | 107 | Tensordock offers low-cost consumer GPUs as low as an RTX A4000 for $0.105/hr. 108 | 109 | The distributed compute option in Tensordock also assigns random external port mappings, so we need to find and specify that external port. See the Tensordock section in the [network guide](network.md) 110 | 111 | ### How to set up 112 | 1. **Log into Tensordock**: Go to the [Tensordock Deploy Dashboard](https://dashboard.tensordock.com/deploy). Make an account or log in if you have one. 113 | 2. **Set SSH Keys**: 114 | * Go to **Secrets** and click **Add Secret** to add your public SSH key. If you have not made a SSH key yet, follow [this guide for Windows](https://learn.microsoft.com/en-us/viva/glint/setup/sftp-ssh-key-gen) and [this guide for Linux](https://www.digitalocean.com/community/tutorials/how-to-configure-ssh-key-based-authentication-on-a-linux-server). 115 | * Choose a **Name** for your SSH Key, choose **Type** as `SSH Key` and enter your public key value under **Value**. The public key value will look something like this `ssh-rsa ...` 116 | 3. **Deploy a GPU** 117 | * Go to **Deploy GPU** to see available GPUs, and choose a GPU 118 | * E.g. RTX 4000: 16GB VRAM for $0.105/hour. 119 | * Choose a Instance Name, configure the resource with CPU Cores, RAM and Storage options and choose a location and select the OS. We recommend `Ubuntu 24.04 LTS`. 120 | * Click **Deploy Instance** 121 | 4. **Connect to your instance** 122 | * Click on **My Servers** and you should see the newly provisioned GPU instance. You can click the instance to get details about the instance 123 | * Instructions for connecting to the instance using SSH can be found under the **Access** section 124 | 125 | #### CUDA \& Docker setup 126 | You may need to setup your Tensordock instances with NVIDIA toolkit and Docker (if using Docker). 127 | 128 | To install NVIDIA toolkit, run the following commands in your instance CLI: 129 | ``` 130 | sudo apt update 131 | sudo apt install -y nvidia-container-toolkit 132 | sudo systemctl restart docker 133 | ``` 134 | 135 | To install Docker, run the following commands in your instance CLI: 136 | ``` 137 | sudo apt install apt-transport-https ca-certificates curl software-properties-common -y 138 | sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc 139 | echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null 140 | sudo apt update 141 | sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin -y 142 | sudo groupadd docker 143 | sudo usermod -aG docker $USER 144 | ``` 145 | 146 | ## Lambda Labs 147 | Lambda does not support 16GB GPUs, the cheapest options is a RTX 6000 (24 GB VRAM) with 14 vCPUs, 46 GiB RAM, 0.5 TiB SSD for $0.50 / hr. 148 | 149 | ### How to set up 150 | **Step 1: Launch a GPU Instance** 151 | 152 | 1. **Log into Lambda**: Go to the [Lambda instances](https://cloud.lambda.ai/instances). Make an account or log in if you have one. 153 | 2. **Set SSH Keys**: 154 | * Go to **SSH Keys** and add your public SSH key. 155 | 3. **Create a new Instance**: 156 | * Go to **Instances** and select **Launch an Instance** to see available instances, and choose an instance 157 | * E.g. 1x RTX 6000 (24 GB), for $0.50/hour. 158 | * Choose a **Region** and **FileSystem**. If you don't have a filesystem, select **Create a filesystem** 159 | * Click **Launch** 160 | 161 | **Step 2: Edit the Firewall** 162 | 1. Follow the Lambda Labs section in the [network guide](network.md) to edit the firewall to expose a TCP port. 163 | 164 | **Step 3: Connect to the Instance** 165 | 166 | 1. **SSH into your Instance**: 167 | * Once the instance has booted, look at the SSH command under **SSH Login**. 168 | 169 | -------------------------------------------------------------------------------- /src/node0/utils/node_info.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import time 17 | 18 | from collections.abc import Callable 19 | from typing import Any 20 | 21 | import psutil 22 | import speedtest 23 | import torch 24 | 25 | from hivemind.utils.logging import get_logger 26 | from pydantic import BaseModel 27 | 28 | 29 | logger = get_logger(__name__) 30 | 31 | 32 | SPEEDTEST_SERVERS = [ 33 | { 34 | "url": "http://speedtest-wa.waveip.org:8080/speedtest/upload.php", 35 | "lat": "47.6062", 36 | "lon": "-122.3321", 37 | "name": "Seattle, WA", 38 | "country": "United States", 39 | "cc": "US", 40 | "sponsor": "Wave", 41 | "id": "60635", 42 | "host": "speedtest-wa.waveip.org:8080", 43 | }, 44 | { 45 | "url": "http://speedtest.sea1.nitelusa.net:8080/speedtest/upload.php", 46 | "lat": "47.6062", 47 | "lon": "-122.3321", 48 | "name": "Seattle, WA", 49 | "country": "United States", 50 | "cc": "US", 51 | "sponsor": "Nitel", 52 | "id": "12192", 53 | "host": "speedtest.sea1.nitelusa.net:8080", 54 | }, 55 | { 56 | "url": "http://us-sea02.speed.misaka.one:8080/speedtest/upload.php", 57 | "lat": "47.6062", 58 | "lon": "-122.3321", 59 | "name": "Seattle, WA", 60 | "country": "United States", 61 | "cc": "US", 62 | "sponsor": "Misaka Network, Inc.", 63 | "id": "50679", 64 | "host": "us-sea02.speed.misaka.one:8080", 65 | }, 66 | { 67 | "url": "http://sea-speedtest.net.sangoma.net:8080/speedtest/upload.php", 68 | "lat": "47.6062", 69 | "lon": "-122.3321", 70 | "name": "Seattle, WA", 71 | "country": "United States", 72 | "cc": "US", 73 | "sponsor": "Sangoma", 74 | "id": "63939", 75 | "host": "sea-speedtest.net.sangoma.net:8080", 76 | }, 77 | { 78 | "url": "http://wa01svp-speed11.svc.tds.net:8080/speedtest/upload.php", 79 | "lat": "47.6062", 80 | "lon": "-122.3321", 81 | "name": "Seattle, WA", 82 | "country": "United States", 83 | "cc": "US", 84 | "sponsor": "TDS Telecom", 85 | "id": "70208", 86 | "host": "wa01svp-speed11.svc.tds.net:8080", 87 | }, 88 | { 89 | "url": "http://speedtest-jp.tgb-host.com:8080/speedtest/upload.php", 90 | "lat": "35.6833", 91 | "lon": "139.6833", 92 | "name": "Tokyo", 93 | "country": "Japan", 94 | "cc": "JP", 95 | "sponsor": "7 BULL", 96 | "id": "65101", 97 | "host": "speedtest-jp.tgb-host.com:8080", 98 | }, 99 | { 100 | "url": "http://speedtest.jp230.hnd.jp.ctcsci.com:8080/speedtest/upload.php", 101 | "lat": "35.6833", 102 | "lon": "139.6833", 103 | "name": "Tokyo", 104 | "country": "Japan", 105 | "cc": "JP", 106 | "sponsor": "CTCSCI TECH LTD", 107 | "id": "62217", 108 | "host": "speedtest.jp230.hnd.jp.ctcsci.com:8080", 109 | }, 110 | { 111 | "url": "http://speedtest.3s-labo.com:8080/speedtest/upload.php", 112 | "lat": "35.6074", 113 | "lon": "140.1065", 114 | "name": "Chiba", 115 | "country": "Japan", 116 | "cc": "JP", 117 | "sponsor": "3s-labo", 118 | "id": "70451", 119 | "host": "speedtest.3s-labo.com:8080", 120 | }, 121 | { 122 | "url": "http://sto-ste-speedtest1.bahnhof.net:8080/speedtest/upload.php", 123 | "lat": "59.3294", 124 | "lon": "18.0686", 125 | "name": "Stockholm", 126 | "country": "Sweden", 127 | "cc": "SE", 128 | "sponsor": "Bahnhof AB", 129 | "id": "34024", 130 | "host": "sto-ste-speedtest1.bahnhof.net:8080", 131 | }, 132 | { 133 | "url": "http://fd.sunet.se:8080/speedtest/upload.php", 134 | "lat": "59.3294", 135 | "lon": "18.0686", 136 | "name": "Stockholm", 137 | "country": "Sweden", 138 | "cc": "SE", 139 | "sponsor": "SUNET", 140 | "id": "26852", 141 | "host": "fd.sunet.se:8080", 142 | }, 143 | { 144 | "url": "http://speedtest-sth.netatonce.net:8080/speedtest/upload.php", 145 | "lat": "59.3294", 146 | "lon": "18.0686", 147 | "name": "Stockholm", 148 | "country": "Sweden", 149 | "cc": "SE", 150 | "sponsor": "Net at Once Sweden AB", 151 | "id": "63781", 152 | "host": "speedtest-sth.netatonce.net:8080", 153 | }, 154 | { 155 | "url": "http://se-speedt02.hy.nis.telia.net:8080/speedtest/upload.php", 156 | "lat": "59.3294", 157 | "lon": "18.0686", 158 | "name": "Stockholm", 159 | "country": "Sweden", 160 | "cc": "SE", 161 | "sponsor": "Telia Sweden AB", 162 | "id": "45936", 163 | "host": "se-speedt02.hy.nis.telia.net:8080", 164 | }, 165 | { 166 | "url": "http://speedtest-sth.84grams.net:8080/speedtest/upload.php", 167 | "lat": "59.3294", 168 | "lon": "18.0686", 169 | "name": "Stockholm", 170 | "country": "Sweden", 171 | "cc": "SE", 172 | "sponsor": "84 Grams AB", 173 | "id": "53521", 174 | "host": "speedtest-sth.84grams.net:8080", 175 | }, 176 | ] 177 | 178 | 179 | class NodeInfo(BaseModel): 180 | device_name: str 181 | gpu_memory: float # GB 182 | ram: float # GB 183 | download_speed: float | None 184 | upload_speed: float | None 185 | latency: float | None 186 | 187 | 188 | class NonRetriableError(Exception): 189 | pass 190 | 191 | 192 | class RetriableError(Exception): 193 | pass 194 | 195 | 196 | class NotInAllowlistError(NonRetriableError): 197 | pass 198 | 199 | 200 | class BadRequestError(NonRetriableError): 201 | pass 202 | 203 | 204 | class IntegrityError(NonRetriableError): 205 | pass 206 | 207 | 208 | class ServerUnavailableError(RetriableError): 209 | pass 210 | 211 | 212 | class TestServerError(NonRetriableError): 213 | pass 214 | 215 | 216 | def call_with_retries(func: Callable, n_retries: int = 10, initial_delay: float = 1.0) -> Any: 217 | """Call the function with retries. 218 | 219 | Args: 220 | func (Callable): function to call 221 | n_retries (int, optional): number of retries attempts. Defaults to 10. 222 | initial_delay (float, optional): delay in sec between attempts. Defaults to 1.0. 223 | 224 | Returns: 225 | Any: output of the function 226 | """ 227 | i = 0 228 | while True: 229 | try: 230 | i += 1 231 | return func() 232 | except NonRetriableError: 233 | raise 234 | except ServerUnavailableError as e: 235 | error_msg = str(e) 236 | if "Our servers are currently at full capacity" in error_msg: 237 | match = re.search(r"Retry in (\d+) s", error_msg) 238 | if match: 239 | delay = int(match.group(1)) 240 | logger.warning( 241 | f"Failed to call function with exception: Our servers are currently at full capacity. Retrying in {delay} sec" 242 | ) 243 | time.sleep(delay) 244 | else: 245 | raise 246 | else: 247 | if i >= n_retries: 248 | raise 249 | 250 | delay = initial_delay * (2**i) 251 | logger.warning(f"Failed to call function with exception: {e}. Retrying in {delay:.1f} sec") 252 | time.sleep(delay) 253 | except Exception as e: 254 | if i >= n_retries: 255 | raise 256 | 257 | delay = initial_delay * (2**i) 258 | logger.warning(f"Failed to call function with exception: {e}. Retrying in {delay:.1f} sec") 259 | time.sleep(delay) 260 | 261 | 262 | def robust_internet_speed() -> tuple[float | None, float | None, float | None]: 263 | try: 264 | return call_with_retries(test_internet_speed) 265 | except Exception: 266 | logger.error("An error occurred during the speed test, skipping") 267 | return (None, None, None) 268 | 269 | 270 | def test_internet_speed() -> tuple[float, float, float]: 271 | """Measure download/upload internet speed.""" 272 | logger.info("Testing internet speed...") 273 | 274 | st = speedtest.Speedtest(secure=True) 275 | try: 276 | st.get_best_server(SPEEDTEST_SERVERS) 277 | logger.info(f"Best speed test server: {st.best['country']}") 278 | except Exception: 279 | pass 280 | 281 | # Perform the download speed test 282 | download_speed = st.download() / 1000000 # Convert to Mbps 283 | 284 | # Perform the upload speed test 285 | upload_speed = st.upload() / 1000000 # Convert to Mbps 286 | 287 | # Latency 288 | latency = float(st.results.ping) # ms 289 | 290 | # Print the results 291 | logger.info(f"Download Speed: {download_speed:.2f} Mbps") 292 | logger.info(f"Upload Speed: {upload_speed:.2f} Mbps") 293 | logger.info(f"Latency: {latency:.2f} ms") 294 | 295 | return (download_speed, upload_speed, latency) 296 | 297 | 298 | def get_device_info() -> tuple[str, float, float]: 299 | """Get device name and memory""" 300 | ram = float(psutil.virtual_memory().total) / 1024**3 301 | 302 | if not torch.cuda.is_available(): 303 | logger.error("CUDA is not available. Exiting run.") 304 | exit(1) 305 | 306 | device_info = torch.cuda.get_device_properties() 307 | device_name = device_info.name 308 | gpu_memory = device_info.total_memory / 1024**3 # GB 309 | return device_name, gpu_memory, ram 310 | 311 | 312 | def get_node_info() -> NodeInfo: 313 | """Collect information about the node.""" 314 | device_name, gpu_memory, ram = get_device_info() 315 | download_speed, upload_speed, latency = robust_internet_speed() 316 | 317 | node_info = NodeInfo( 318 | device_name=device_name, 319 | gpu_memory=gpu_memory, 320 | ram=ram, 321 | download_speed=download_speed, 322 | upload_speed=upload_speed, 323 | latency=latency, 324 | ) 325 | return node_info 326 | -------------------------------------------------------------------------------- /src/node0/run_server.py: -------------------------------------------------------------------------------- 1 | # This file contains code originally from Hivemind under MIT License 2 | # Original: Copyright 2020 Learning@home authors and collaborators 3 | # Modified by: Pluralis Research 2025 4 | # 5 | # Original code: MIT License (see THIRD_PARTY_LICENSES) 6 | # Modifications: Apache 2.0 License (see LICENSE) 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only; 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | import argparse 13 | import json 14 | import multiprocessing as mp 15 | import os 16 | import platform 17 | 18 | from functools import partial 19 | from pathlib import Path 20 | 21 | import torch 22 | import yaml 23 | 24 | from hivemind.compression import NoCompression, ScaledFloat16Compression 25 | from hivemind.proto.runtime_pb2 import CompressionType 26 | from hivemind.utils.logging import get_logger 27 | 28 | from node0.security.authorization import authorize_with_pluralis 29 | from node0.security.validation import make_validators 30 | from node0.server.HM_gradient_averager import GradientAverager 31 | from node0.server.node0_server import Node0Server 32 | from node0.server.optim import AutoStepOptimizer 33 | from node0.utils import ( 34 | MonitorWorker, 35 | Node0Logger, 36 | build_cls, 37 | clean_tmp, 38 | infer_expert_params, 39 | update_initial_peers, 40 | ) 41 | from node0.utils.mem_monitor import MemoryTracker 42 | from node0.utils.network_throughput import NetworkMonitor 43 | from node0.utils.node_info import get_node_info 44 | 45 | 46 | logger = get_logger(__name__) 47 | 48 | if platform.system().lower() == "darwin": 49 | # Necessary for forks to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93 50 | os.environ.setdefault("no_proxy", "*") 51 | os.environ.setdefault("OBJC_DISABLE_INITIALIZE_FORK_SAFETY", "YES") 52 | mp.set_start_method("fork", force=True) 53 | 54 | 55 | def parse_args(): 56 | parser = argparse.ArgumentParser() 57 | # Connection arguments 58 | parser.add_argument( 59 | "--host_maddrs", 60 | type=str, 61 | help="Multiaddrs to listen for external connections from other p2p instances", 62 | ) 63 | parser.add_argument( 64 | "--announce_maddrs", 65 | type=str, 66 | help="Visible multiaddrs the host announces for external connections from other p2p instances", 67 | ) 68 | parser.add_argument( 69 | "--initial_peers", 70 | type=str, 71 | nargs="*", 72 | required=True, 73 | default=[], 74 | help="Multiaddrs of one or more active DHT peers", 75 | ) 76 | 77 | # Experiment arguments 78 | parser.add_argument("--run_config", type=str, required=True, help="Run configuration file.") 79 | parser.add_argument( 80 | "--custom_module_path", 81 | type=str, 82 | help="Path of a file with custom nn.modules, wrapped into special decorator", 83 | ) 84 | parser.add_argument( 85 | "--num_handlers", 86 | type=int, 87 | default=5, 88 | help="Server will use this many processes to handle incoming requests", 89 | ) 90 | 91 | # Authorization parameters 92 | parser.add_argument( 93 | "--identity_path", 94 | type=str, 95 | default="private.key", 96 | help="Path to identity file to be used in P2P", 97 | ) 98 | parser.add_argument("--token", type=str, required=True, help="HuggingFace token") 99 | parser.add_argument("--email", type=str, default="", help="Email address") 100 | parser.add_argument("--auth_server", type=str, required=True, help="Authentication server URL") 101 | 102 | args = parser.parse_args() 103 | args = vars(args) 104 | 105 | # Read run config file 106 | with open(args.pop("run_config")) as f: 107 | run_config = yaml.safe_load(f) 108 | 109 | # Combine arguments 110 | run_config.update({k: v for k, v in args.items() if v is not None}) 111 | return run_config 112 | 113 | 114 | def main(): 115 | if not torch.backends.openmp.is_available(): 116 | # Necessary to prevent the server from freezing after forks 117 | torch.set_num_threads(1) 118 | 119 | # Parse arguments 120 | args = parse_args() 121 | 122 | # Set up logging 123 | log_level = args.pop("log_level", "info").upper() 124 | node0_logger = Node0Logger(log_level=log_level) 125 | 126 | # Logging and monitoring 127 | save_clean_logs = args.pop("save_clean_logs", False) 128 | terminate_AR_fail = args.pop("terminate_AR_fail", True) 129 | experiment_prefix = args.pop("experiment_prefix", "pluralis") 130 | monitor = MonitorWorker( 131 | stats_report_interval=args["stats_report_interval"] if "stats_report_interval" in args else 60, 132 | experiment_prefix=experiment_prefix, 133 | log_file="logs/server.log", 134 | save_clean_logs=save_clean_logs, 135 | terminate_AR_fail=terminate_AR_fail, 136 | ) 137 | node0_logger.add_monitor_handler(monitor.queue_handler) 138 | monitor.start() 139 | 140 | logger.info(f"Running with configuration: {json.dumps(args, indent=4)}") 141 | 142 | # Check pytorch version 143 | if not torch.__version__.startswith("2.7"): 144 | logger.error("Wrong pytorch version. Please install torch 2.7") 145 | exit(1) 146 | 147 | # Clean tmp folder 148 | clean_tmp() 149 | 150 | # Collect information about the node 151 | node_info = get_node_info() 152 | 153 | # Authorize 154 | check_integrity = args.pop("check_integrity", True) 155 | authorizer = authorize_with_pluralis( 156 | node_info=node_info, 157 | user_token=args.pop("token"), 158 | user_email=args.pop("email"), 159 | role="worker", 160 | auth_server=args.pop("auth_server"), 161 | identity_path=args["identity_path"], 162 | current_path=Path(__file__).resolve().parent, 163 | announce_maddrs=args["announce_maddrs"], 164 | host_port=int(args["host_maddrs"].split("/")[4]), 165 | check_integrity=check_integrity, 166 | ) 167 | pipeline_stage = str(authorizer.pipeline_stage) 168 | args["host_maddrs"] = [args["host_maddrs"]] 169 | args["announce_maddrs"] = [args["announce_maddrs"]] 170 | 171 | expert_params = infer_expert_params(pipeline_stage) 172 | stage_name = expert_params.pop("stage_name") 173 | args.update(expert_params) 174 | 175 | # Add validators 176 | validators, public_key = make_validators( 177 | experiment_prefix=experiment_prefix, 178 | peer_id=authorizer.peer_id, 179 | stage=stage_name, 180 | ) 181 | 182 | # Add expert parameters to args 183 | num_worker_dhts = args.pop("num_dht") - 1 184 | args["initial_peers"] = update_initial_peers( 185 | args["initial_peers"], pipeline_stage, args["num_stages"], num_dht=num_worker_dhts 186 | ) 187 | 188 | # Add authorization details to log monitor 189 | monitor.add_auth_info( 190 | authorizer=authorizer, 191 | peer_id=authorizer.peer_id, 192 | stage=stage_name, 193 | local_public_key=public_key, 194 | ) 195 | 196 | # Set BW for the peer 197 | default_bandwidth = args.pop("bandwidth", 20) 198 | if node_info.upload_speed: 199 | bandwidth = node_info.upload_speed 200 | else: 201 | bandwidth = default_bandwidth 202 | 203 | # Build model arguments 204 | model_config = args.pop("model_config") 205 | model_args = build_cls(model_config["class_name"], model_config["init_args"]) 206 | model_args.stage = args.pop("stage") 207 | 208 | # Build optimizer 209 | optim_config = args.pop("optim_config") 210 | optim_cls = build_cls(optim_config["class_name"], optim_config["init_args"], partial_init=True) 211 | 212 | # Build gradient averager 213 | grad_avg_config = args.pop("grad_avg_config", None) 214 | is_tail = "tail" in expert_params["expert_pattern"] 215 | if "request_timeout" in args: 216 | request_timeout = args["request_timeout"] 217 | else: 218 | request_timeout = 3.0 219 | if ( 220 | grad_avg_config 221 | and ("averager_rank" in grad_avg_config["init_args"]) 222 | and grad_avg_config["init_args"]["averager_rank"] > 0 223 | and not is_tail 224 | ): 225 | # If detected to be mac, replace psgd 226 | if platform.system().lower() == "darwin": 227 | grad_avg_config["class_name"] = "node0.server.power_sgd_averager_mac.PowerSGDGradientAverager" 228 | grad_avg_config["init_args"]["request_timeout"] = request_timeout 229 | grad_avg_factory = build_cls(grad_avg_config["class_name"], grad_avg_config["init_args"], partial_init=True) 230 | elif ( 231 | grad_avg_config 232 | and ("grad_compression_factor" in grad_avg_config["init_args"]) 233 | and grad_avg_config["init_args"]["grad_compression_factor"] > 0 234 | and not is_tail 235 | ): 236 | grad_avg_config["init_args"]["request_timeout"] = request_timeout 237 | grad_avg_factory = build_cls(grad_avg_config["class_name"], grad_avg_config["init_args"], partial_init=True) 238 | else: 239 | grad_avg_factory = partial(GradientAverager, bandwidth=bandwidth, request_timeout=request_timeout) 240 | 241 | # Optimizer compressions 242 | grad_averaging_compression = args.pop("grad_averaging_compression", "NoCompression") 243 | if grad_averaging_compression == "NoCompression": 244 | grad_averaging_compression = NoCompression() 245 | elif grad_averaging_compression == "Float16Compression": 246 | grad_averaging_compression = ScaledFloat16Compression() 247 | else: 248 | raise ValueError("grad_averaging_compression must be NoCompression or Float16Compression") 249 | 250 | load_state_compression = args.pop("load_state_compression", "NoCompression") 251 | if load_state_compression == "NoCompression": 252 | load_state_compression = NoCompression() 253 | elif load_state_compression == "Float16Compression": 254 | load_state_compression = ScaledFloat16Compression() 255 | else: 256 | raise ValueError("load_state_compression must be NoCompression or Float16Compression") 257 | 258 | # Compression 259 | compression_type = args.pop("compression") 260 | compression = getattr(CompressionType, compression_type) 261 | 262 | # Select device 263 | if node_info.device_name == "cpu": 264 | device = "cpu" 265 | elif node_info.device_name.startswith("mps"): 266 | device = "mps" 267 | else: 268 | device = "cuda" 269 | 270 | logger.info(f"Using {device} device") 271 | 272 | # Log machine usage 273 | log_machine_usage = args.pop("log_machine_usage", False) 274 | if log_machine_usage: 275 | NetworkMonitor() 276 | MemoryTracker() 277 | 278 | # Start server 279 | server = Node0Server.create( 280 | model_conf=model_args, 281 | optim_cls=optim_cls, 282 | grad_avg_factory=grad_avg_factory, 283 | optim_collab_cls=AutoStepOptimizer, 284 | grad_averaging_compression=grad_averaging_compression, 285 | load_state_compression=load_state_compression, 286 | start=True, 287 | compression=compression, 288 | record_validators=validators, 289 | authorizer=authorizer, 290 | monitor=monitor, 291 | upload_bw=bandwidth, 292 | device=device, 293 | **args, 294 | ) 295 | 296 | try: 297 | server.join() 298 | except KeyboardInterrupt: 299 | logger.info("Caught KeyboardInterrupt, shutting down") 300 | 301 | 302 | if __name__ == "__main__": 303 | main() 304 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /src/node0/server/matchmaking.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import asyncio 16 | import re 17 | import threading 18 | import time 19 | 20 | from typing import List, Optional, Tuple 21 | 22 | import numpy as np 23 | 24 | from hivemind.averaging.control import StepControl 25 | from hivemind.averaging.group_info import GroupInfo 26 | from hivemind.dht import DHT 27 | from hivemind.moe.expert_uid import ExpertUID 28 | from hivemind.p2p import PeerID 29 | from hivemind.utils import DHTExpiration, TimedStorage, ValueWithExpiration, get_dht_time, get_logger 30 | 31 | 32 | GroupKey = str 33 | Endpoint = str 34 | GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$") # e.g. bert_exp4_averaging.0b01001101 35 | logger = get_logger(__name__) 36 | 37 | 38 | def is_valid_group(maybe_group: str) -> bool: 39 | """A group identifier must contain group type, followed by one or more .-separated indices, and any ?metadata""" 40 | return bool(GROUP_PATTERN.fullmatch(maybe_group)) 41 | 42 | 43 | class GroupKeyManager: 44 | """ 45 | Simplified utility class that manages a single fixed group for all nodes 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dht: DHT, 51 | prefix: str, 52 | fixed_group_key: Optional[str] = None, 53 | ): 54 | self.dht = dht 55 | self.prefix = prefix 56 | self.peer_id = dht.peer_id 57 | 58 | # Use a fixed group key - either provided or default to "global" 59 | if fixed_group_key: 60 | if not is_valid_group(fixed_group_key): 61 | raise ValueError(f"Invalid fixed group key: {fixed_group_key}") 62 | self._fixed_key = fixed_group_key 63 | else: 64 | # Default fixed group key with empty bits (all nodes use same group) 65 | self._fixed_key = f"{self.prefix}.0b" 66 | self.group_bits = "" 67 | 68 | @property 69 | def current_key(self) -> GroupKey: 70 | """Return the fixed group key that all nodes use""" 71 | return self._fixed_key 72 | 73 | async def declare_averager(self, peer_id: PeerID, expiration_time: float, looking_for_group: bool = True) -> bool: 74 | """ 75 | Add (or remove) the averager to the fixed group 76 | 77 | :param peer_id: averager public peer_id for incoming requests 78 | :param expiration_time: intent to run allreduce before this timestamp 79 | :param looking_for_group: by default (True), declare the averager as "looking for group"; 80 | If False, mark that the averager is no longer looking for group 81 | :return: True if declared, False if declaration was rejected by DHT peers 82 | """ 83 | expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf"))) 84 | return await self.dht.store( 85 | key=self._fixed_key, 86 | subkey=peer_id.to_bytes(), 87 | value=looking_for_group, 88 | expiration_time=expiration_time, 89 | return_future=True, 90 | ) 91 | 92 | async def get_averagers(self, only_active: bool = True) -> List[Tuple[PeerID, DHTExpiration]]: 93 | """ 94 | Find and return averagers in the fixed group 95 | 96 | :param only_active: if True, return only active averagers that are looking for group 97 | if False, return all averagers in the group regardless of status 98 | :return: peer_ids and expirations of every matching averager 99 | """ 100 | result = await self.dht.get(self._fixed_key, latest=True, return_future=True) 101 | if result is None or not isinstance(result.value, dict): 102 | logger.debug(f"Fixed group not found: {self._fixed_key}, creating new group") 103 | return [] 104 | 105 | averagers = [] 106 | for key, looking_for_group in result.value.items(): 107 | try: 108 | if only_active and not looking_for_group.value: 109 | continue 110 | averagers.append((PeerID(key), looking_for_group.expiration_time)) 111 | except Exception as e: 112 | logger.warning(f"Could not parse peer key {key} ({looking_for_group}, exc={e})") 113 | return averagers 114 | 115 | async def join_group(self, expiration_time: float) -> bool: 116 | """ 117 | Join the fixed group (convenience method) 118 | 119 | :param expiration_time: when this declaration expires 120 | :return: True if successfully joined 121 | """ 122 | return await self.declare_averager(self.peer_id, expiration_time, looking_for_group=True) 123 | 124 | async def leave_group(self, expiration_time: float) -> bool: 125 | """ 126 | Leave the fixed group (convenience method) 127 | 128 | :param expiration_time: original expiration time used when joining 129 | :return: True if successfully left 130 | """ 131 | return await self.declare_averager(self.peer_id, expiration_time, looking_for_group=False) 132 | 133 | 134 | class Matchmaking: 135 | """ 136 | Simplified matchmaking that works with a single fixed group 137 | 138 | :type dht: an instance of hivemind.DHT. Server will use DHT for all network interactions. 139 | :param prefix: Prefix of the stage. i.e head, body0, tail. 140 | :param request_timeout: This timeout is backward compatible with HM Matchmaking. It is used to cancel matchmaking in case of no responses. 141 | :param min_matchmaking_time: How long before matchmaking operation is cancelled 142 | :param check_interval: How often to poll the DHT to check for new peers in the group 143 | :param update_period: How often to update the peer table with all peers in the stage 144 | """ 145 | 146 | def __init__( 147 | self, 148 | dht: DHT, 149 | prefix: str, 150 | request_timeout: float = 5.0, 151 | min_matchmaking_time: float = 10.0, 152 | check_interval: float = 1.0, 153 | update_period: float = 3.0, 154 | ): 155 | self.group_key_manager = GroupKeyManager(dht, prefix) 156 | 157 | self.dht = dht 158 | self.prefix = prefix 159 | self.peer_id = self.group_key_manager.peer_id 160 | self.request_timeout = request_timeout 161 | self.min_matchmaking_time = min_matchmaking_time 162 | self.check_interval = check_interval 163 | 164 | # Parameters for update peer table 165 | self.update_period = update_period 166 | self.peer_table = TimedStorage[ExpertUID, PeerID]() 167 | self.is_alive = threading.Event() 168 | self.is_alive.set() 169 | self.update_trigger, self.update_finished = threading.Event(), threading.Event() 170 | self.update_period, self.last_update = update_period, get_dht_time() 171 | self.update_thread = threading.Thread(target=self.update_peers_in_background, daemon=True) 172 | self.update_thread.start() 173 | 174 | @property 175 | def max_peers(self): 176 | return len(self.peer_table) 177 | 178 | @property 179 | def peer_set(self): 180 | res = set() 181 | for index, expert_info in self.peer_table.items(): 182 | res.add(expert_info.value._bytes) 183 | return res 184 | 185 | def update_peers_in_background(self): 186 | while self.is_alive.is_set(): 187 | time_to_next_update = max(0.0, self.last_update + self.update_period - get_dht_time()) 188 | try: 189 | self.update_trigger.wait(timeout=time_to_next_update) 190 | # update triggered by main thread 191 | except TimeoutError: 192 | pass # update triggered by refresh_period 193 | 194 | self.update_trigger.clear() 195 | response = self.dht.get(self.prefix.split("_")[0] + ".0.", latest=True) 196 | if isinstance(response, ValueWithExpiration) and isinstance(response.value, dict): 197 | for index, expert_info in response.value.items(): 198 | try: 199 | (uid, endpoint), expiration_time = expert_info 200 | self._add_peer(uid, endpoint, expiration_time) 201 | except Exception as e: 202 | logger.warning(f"Skipping malformed peer info {expert_info} (exc={e})") 203 | else: 204 | logger.warning( 205 | f"Could not refresh peer, dht info key contains {response}, will retry in {time_to_next_update}s" 206 | ) 207 | self.last_update = get_dht_time() 208 | self.update_finished.set() 209 | 210 | def _add_peer(self, uid: ExpertUID, endpoint: Endpoint, expiration_time: DHTExpiration): 211 | self.peer_table.store(uid, PeerID(endpoint), expiration_time) 212 | logger.debug(f"Storing peer: {uid}, expiration time = {expiration_time:.3f}.") 213 | 214 | async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]: 215 | """ 216 | Look for peers in the fixed group and form a group if enough peers are available 217 | 218 | :param step: To get the step schedule time 219 | :return: GroupInfo if group formed successfully, None otherwise 220 | """ 221 | timeout = self.min_matchmaking_time 222 | 223 | new_expiration_time = float(get_dht_time() + timeout) 224 | await self.group_key_manager.join_group(new_expiration_time) 225 | 226 | # Wait and retry logic 227 | start_time = time.time() 228 | 229 | # Accumulate all peers that issue join_group. Wait to match peer_table 230 | all_peerIds = set() 231 | while time.time() - start_time < timeout: 232 | # Get all active averagers in the group 233 | averagers = await self.group_key_manager.get_averagers(only_active=True) 234 | 235 | _peerIds = {peer_id.to_bytes() for peer_id, _ in averagers} 236 | all_peerIds = all_peerIds.union(_peerIds) 237 | 238 | # We have enough peers, proceed with group formation 239 | if (len(all_peerIds) == self.max_peers) and (self.max_peers > 0): 240 | break 241 | 242 | # Wait for either the peer_table to populate or to find all peers in the table 243 | logger.debug(f"Not enough peers yet: {len(all_peerIds)} < {self.max_peers}, waiting...") 244 | await asyncio.sleep(self.check_interval) 245 | 246 | if len(all_peerIds) == 0: 247 | # Timeout reached without finding enough peers 248 | logger.info(f"Timeout: Not any peers in group") 249 | return None 250 | 251 | # Create group info with all available peers 252 | all_peerIds = sorted(list(all_peerIds)) 253 | peer_ids = [PeerID(peer_id) for peer_id in all_peerIds] 254 | 255 | # Create a deterministic group ID based on sorted peer IDs 256 | sorted_peer_ids = sorted([str(pid) for pid in peer_ids]) 257 | group_id = b"O[\x9aU\xcf%\xf0(\x90Nq\xdf!\x8b\x85)&\x0c\xe9r" 258 | gathered = tuple(step.data_for_gather for peer_id in sorted_peer_ids) 259 | 260 | group_info = GroupInfo(group_id=group_id, peer_ids=tuple(peer_ids), gathered=gathered) 261 | 262 | end_time = time.time() 263 | logger.extra( 264 | f"Formed group with {len(peer_ids)} peers out of {self.max_peers} in {end_time - start_time:.3f} secs" 265 | ) 266 | return group_info 267 | 268 | async def leave_group(self, expiration_time: float) -> bool: 269 | """ 270 | Leave the current group 271 | 272 | :param expiration_time: original expiration time 273 | :return: True if successfully left 274 | """ 275 | return await self.group_key_manager.leave_group(expiration_time) 276 | 277 | 278 | class MatchmakingException(Exception): 279 | """An internal exception that marks undesired edge cases during averaging""" 280 | -------------------------------------------------------------------------------- /src/node0/server/state_averager_wrap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import asyncio 16 | import math 17 | import os 18 | import random 19 | import signal 20 | import time 21 | 22 | from typing import Any, Optional 23 | 24 | import torch 25 | 26 | from hivemind.averaging.allreduce import AveragingMode 27 | from hivemind.averaging.group_info import GroupInfo 28 | from hivemind.averaging.load_balancing import load_balance_peers 29 | from hivemind.compression import CompressionInfo, deserialize_torch_tensor 30 | from hivemind.p2p import PeerID 31 | from hivemind.proto import averaging_pb2 32 | from hivemind.utils import MPFuture, get_logger 33 | from hivemind.utils.asyncio import aiter_with_timeout, enter_asynchronously 34 | from hivemind.utils.streaming import combine_from_streaming 35 | from hivemind.utils.tensor_descr import TensorDescriptor 36 | from hivemind.utils.timed_storage import ValueWithExpiration 37 | 38 | from node0.server.HM_state_averager import ( 39 | TrainingStateAverager as HivemindTrainingStateAverager, 40 | ) 41 | from node0.server.matchmaking import MatchmakingException 42 | 43 | 44 | GatheredData = Any 45 | logger = get_logger(__name__) 46 | 47 | 48 | class IndexSelector: 49 | def __init__(self, p): 50 | self.state = {} 51 | self.p = p 52 | 53 | def get_indices(self, param): 54 | return torch.ones(param.shape).bool() 55 | 56 | 57 | class PartitionedIndexSelector(IndexSelector): 58 | def __init__(self, p, param): 59 | super().__init__(p) 60 | self.state[param] = {} 61 | self._set_partition(param) 62 | 63 | def _set_partition(self, param): 64 | param_state = self.state[param] 65 | param_state["num_partitions"] = min(math.ceil(1 / self.p), param.numel()) 66 | param_state["partitions"] = ( 67 | torch.rand(param.numel(), device=param.device).argsort().view(param.shape) % param_state["num_partitions"] 68 | ) 69 | 70 | def get_indices(self, param, curr_partition): 71 | curr_partition = curr_partition % self.state[param]["num_partitions"] 72 | indices = (self.state[param]["partitions"] == curr_partition).bool() 73 | 74 | return indices 75 | 76 | 77 | class TrainingStateAverager(HivemindTrainingStateAverager): 78 | """ 79 | A class that extends Hivemind TrainingStateAverager and prevents too many 80 | consecutive calls to load_state_from_peers. 81 | """ 82 | 83 | def __init__( 84 | self, 85 | sparse_avg=0.0, 86 | average_state_every=1, 87 | call_limit: int = 1, 88 | *args, 89 | **kwargs, 90 | ): 91 | kwargs["start"] = False 92 | super().__init__(*args, **kwargs) 93 | 94 | self.zclip_warmup = None 95 | if hasattr(self, "zclip"): 96 | self.zclip_warmup = self.zclip.warmup_steps 97 | 98 | self._call_limit = call_limit 99 | self._consecutive_fails = 0 100 | self._request_timeout = kwargs["request_timeout"] 101 | self.sparse_avg = sparse_avg 102 | self.average_state_every = average_state_every 103 | self.partition_selector = [] 104 | 105 | def set_sparta_partitions(self): 106 | with self.get_tensors() as local_tensors: 107 | torch.manual_seed(1337) 108 | for i, p in enumerate(local_tensors): 109 | self.partition_selector.append(PartitionedIndexSelector(self.sparse_avg, p)) 110 | 111 | async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None): 112 | # Adapted from /hivemind/averaging/averager.py 113 | if timeout is not None: 114 | timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout 115 | try: 116 | key_manager = self._matchmaking.group_key_manager 117 | peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None) 118 | peer_priority = { 119 | PeerID(peer_id): (float(info.value), random.random()) # using randomness as a tie breaker 120 | for peer_id, info in peer_priority.items() 121 | if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int)) 122 | } 123 | 124 | if not isinstance(peer_priority, dict) or len(peer_priority) == 0: 125 | logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}") 126 | future.set_result(None) 127 | return 128 | 129 | metadata = None 130 | for peer in sorted(peer_priority.keys(), key=peer_priority.get, reverse=True): 131 | if peer != self.peer_id: 132 | t0 = time.monotonic() 133 | logger.info(f"Downloading parameters from peer {peer}") 134 | try: 135 | stub = self.get_stub(self._p2p, peer, namespace=self.prefix) 136 | stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest()) 137 | current_tensor_parts, tensors = [], [] 138 | 139 | async for message in aiter_with_timeout(stream, timeout=timeout): 140 | if message.metadata: 141 | metadata = self.serializer.loads(message.metadata) 142 | if message.tensor_part.dtype and current_tensor_parts: 143 | # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one 144 | tensor = deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)) 145 | tensors.append(tensor) 146 | current_tensor_parts = [] 147 | current_tensor_parts.append(message.tensor_part) 148 | 149 | if current_tensor_parts: 150 | tensor = deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)) 151 | tensors.append(tensor) 152 | 153 | if not metadata: 154 | logger.debug(f"Peer {peer} did not send its state") 155 | continue 156 | 157 | t1 = time.monotonic() 158 | logger.info(f"Finished downloading state in {t1 - t0:.3f}s from {peer}") 159 | self._consecutive_fails = 0 160 | # Check if any gradient contains NaN or inf values 161 | has_nans = any(not torch.isfinite(t).all() for t in tensors) 162 | if has_nans: 163 | logger.error(f"Failed to load state from peer.") 164 | logger.error(f"Downloaded state contains invalid values. Exiting the run.") 165 | os.killpg(os.getpgrp(), signal.SIGTERM) 166 | future.set_result((metadata, tensors)) 167 | return 168 | except Exception as e: 169 | self._consecutive_fails = self._consecutive_fails + 1 170 | 171 | if isinstance(e, TimeoutError) and self._consecutive_fails < self._call_limit: 172 | logger.info( 173 | f"{self._consecutive_fails}/{self._call_limit} load state timeout before ending session." 174 | ) 175 | 176 | if isinstance(e, TimeoutError) and self._consecutive_fails >= self._call_limit: 177 | logger.error( 178 | f"Failed to load state from peers. " 179 | "Too many TimeoutErrors were caught when trying to _load_state_from_peers. " 180 | "This problem may occur due to slow internet connection, or temporary overload of the peer-to-peer network. Exiting run." 181 | ) 182 | os.killpg(os.getpgrp(), signal.SIGTERM) 183 | else: 184 | logger.error( 185 | f"Failed to load state from {peer} - {repr(e)}. Exiting run.", 186 | exc_info=logger.getEffectiveLevel() <= 15, 187 | ) 188 | os.killpg(os.getpgrp(), signal.SIGTERM) 189 | 190 | finally: 191 | if not future.done(): 192 | future.set_result(None) 193 | 194 | async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData: 195 | """Run sparse aggregation in a given group and update tensors in place, return gathered metadata""" 196 | try: 197 | bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered)) 198 | user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes))) 199 | modes = tuple(map(AveragingMode, mode_ids)) 200 | 201 | # compute optimal part sizes from peer bandwidths; 202 | download_bandwidths = [ 203 | thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes) 204 | ] 205 | peer_fractions = await asyncio.get_event_loop().run_in_executor( 206 | None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size 207 | ) 208 | 209 | epoch = self.global_epoch 210 | snapped_epoch = int(round(epoch / self.average_state_every)) 211 | async with enter_asynchronously(self.get_tensors()) as local_tensors: 212 | if self.sparse_avg > 0: 213 | tensor_idxs = [] # index of tensor in local_tensors 214 | sparse_idxs = [] # the sparse indices that index into the averaged tensor 215 | sparse_tensor = [] # the sparse tensor to be averaged 216 | 217 | tensor_infos = kwargs["tensor_infos"] 218 | for i, val in enumerate(zip(local_tensors, tensor_infos)): 219 | p, ti = val 220 | should_include = False 221 | 222 | # Determine if tensor should be included 223 | if i < len(self.main_parameters): 224 | if self.main_parameters[i].requires_grad: 225 | should_include = True 226 | else: 227 | should_include = True 228 | if self.zclip_warmup: 229 | # Start averaging zclip (mean,var) after 2*zclip_warmup to account for 230 | # peers that joined between step 0 and zclip warmup 231 | epoch_round = int(snapped_epoch * self.average_state_every) 232 | zclip_in_warmup = epoch_round <= 2 * self.zclip_warmup 233 | is_zclip = ti.key == "zclip_mean" or ti.key == "zclip_var" 234 | if zclip_in_warmup and is_zclip: 235 | should_include = False 236 | 237 | # Only process tensors that should be included 238 | if should_include: 239 | tensor_idxs.append(i) 240 | _idx = self.partition_selector[i].get_indices(p, snapped_epoch) 241 | sparse_tensor.append(p[_idx].contiguous()) 242 | sparse_idxs.append(_idx) 243 | 244 | # Build tensor info using proper indexing 245 | tensor_infos_sparse = [] 246 | for sparse_idx, tensor_idx in enumerate(tensor_idxs): 247 | desc = TensorDescriptor( 248 | size=sparse_tensor[sparse_idx].shape, 249 | dtype=sparse_tensor[sparse_idx].dtype, 250 | device=sparse_tensor[sparse_idx].device, 251 | requires_grad=local_tensors[tensor_idx].requires_grad, 252 | ) 253 | tensor_infos_sparse.append(CompressionInfo(key=sparse_idx, descriptor=desc)) 254 | 255 | tensor_infos_sparse = tuple(tensor_infos_sparse) 256 | kwargs["tensor_infos"] = tensor_infos_sparse 257 | await self._run_allreduce_inplace_( 258 | sparse_tensor, group_info, peer_fractions=peer_fractions, **kwargs 259 | ) 260 | 261 | # Copy results back using proper indexing 262 | for sparse_idx, tensor_idx in enumerate(tensor_idxs): 263 | local_tensors[tensor_idx][sparse_idxs[sparse_idx]] = sparse_tensor[sparse_idx] 264 | else: 265 | await self._run_allreduce_inplace_( 266 | local_tensors, group_info, peer_fractions=peer_fractions, **kwargs 267 | ) 268 | return user_gathered 269 | except BaseException as e: 270 | if isinstance(e, Exception): 271 | logger.error(e, exc_info=logger.getEffectiveLevel() <= 15) 272 | raise MatchmakingException(f"Unable to run All-Reduce: {e}") 273 | -------------------------------------------------------------------------------- /src/node0/security/authorization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import signal 17 | 18 | from datetime import datetime, timedelta, timezone 19 | from functools import partial 20 | from pathlib import Path 21 | 22 | import requests 23 | 24 | from cryptography.hazmat.primitives import serialization 25 | from hivemind import PeerID 26 | from hivemind.proto import crypto_pb2 27 | from hivemind.proto.auth_pb2 import AccessToken 28 | from hivemind.utils.auth import TokenAuthorizerBase 29 | from hivemind.utils.crypto import RSAPrivateKey, RSAPublicKey 30 | from hivemind.utils.logging import get_logger 31 | 32 | from node0.security.integrity_check import verify_integrity 33 | from node0.utils.connection_test_server import TestServer 34 | from node0.utils.node_info import ( 35 | BadRequestError, 36 | IntegrityError, 37 | NodeInfo, 38 | NotInAllowlistError, 39 | ServerUnavailableError, 40 | TestServerError, 41 | call_with_retries, 42 | ) 43 | 44 | 45 | logger = get_logger(__name__) 46 | 47 | 48 | class PluralisAuthorizer(TokenAuthorizerBase): 49 | def __init__( 50 | self, 51 | peer_id: str, 52 | user_token: str, 53 | user_email: str, 54 | role: str, 55 | auth_server: str, 56 | node_info: NodeInfo, 57 | current_path: Path, 58 | announce_maddrs: str, 59 | host_port: int, 60 | check_integrity: bool, 61 | ): 62 | super().__init__() 63 | 64 | self.peer_id = peer_id 65 | self._user_token = user_token 66 | self._user_email = user_email 67 | self._role = role 68 | self._auth_server = auth_server 69 | self._node_info = node_info 70 | self._current_path = current_path 71 | self._authority_public_key = None 72 | self._check_integrity = check_integrity 73 | self.pipeline_stage = None 74 | self.reachable = "unknown" 75 | self.monitor_public_key = None 76 | 77 | # Parse announce address 78 | address_parts = announce_maddrs.split("/") 79 | self._announce_ip_address = address_parts[2] 80 | self._announce_port = int(address_parts[4]) 81 | self._host_port = host_port 82 | 83 | async def get_token(self) -> AccessToken: 84 | """Hivemind calls this method to refresh the access token when necessary.""" 85 | 86 | try: 87 | self.join_experiment() 88 | return self._local_access_token 89 | except NotInAllowlistError as e: 90 | logger.error(f"Authorization failed: {e}. Exiting run.") 91 | os.killpg(os.getpgrp(), signal.SIGTERM) 92 | except BadRequestError as e: 93 | logger.error(f"Authorization failed: {e}. Exiting run.") 94 | os.killpg(os.getpgrp(), signal.SIGTERM) 95 | except IntegrityError: 96 | logger.error("Authorization failed: verification failed. Exiting run.") 97 | os.killpg(os.getpgrp(), signal.SIGTERM) 98 | except Exception as e: 99 | logger.error(f"Authorization failed: {e}. Exiting run.") 100 | os.killpg(os.getpgrp(), signal.SIGTERM) 101 | 102 | def join_experiment( 103 | self, 104 | reset_reachability: bool = False, 105 | initial_join: bool = False, 106 | n_retries: int = 10, 107 | ) -> None: 108 | """Join experiment with retries.""" 109 | call_with_retries( 110 | partial(self._join_experiment, reset_reachability, initial_join), n_retries=n_retries, initial_delay=3 111 | ) 112 | 113 | def _join_experiment(self, reset_reachability: bool = False, initial_join: bool = False) -> None: 114 | """Send authorization request to join the experiment and receive access token.""" 115 | try: 116 | # Check integrity of files 117 | if self._check_integrity: 118 | try: 119 | integrity_hash = verify_integrity(self._current_path, self._local_private_key) 120 | except Exception: 121 | raise IntegrityError("Verification failed.") from None 122 | else: 123 | integrity_hash = b"hash" 124 | 125 | url = f"{self._auth_server}/api/join" 126 | headers = { 127 | "Authorization": f"Bearer {self._user_token}", 128 | "request-type": "initial" if initial_join else "update", 129 | } 130 | json_body = { 131 | "peer_id": self.peer_id, 132 | "role": self._role, 133 | "peer_public_key": self.local_public_key.to_bytes().decode(), 134 | "device": self._node_info.device_name, 135 | "gpu_memory": self._node_info.gpu_memory, 136 | "ram": self._node_info.ram, 137 | "download_speed": self._node_info.download_speed, 138 | "upload_speed": self._node_info.upload_speed, 139 | "latency": self._node_info.latency, 140 | "integrity_hash": integrity_hash.decode(), 141 | "reset_reachability": reset_reachability, 142 | "email": self._user_email, 143 | "announce_ip_address": self._announce_ip_address, 144 | "announce_port": self._announce_port, 145 | } 146 | 147 | if initial_join and self._check_integrity: 148 | with TestServer(port=self._host_port) as server: 149 | response = requests.put( 150 | url, 151 | headers=headers, 152 | json=json_body, 153 | ) 154 | 155 | response.raise_for_status() 156 | 157 | # Receive server message 158 | if not server.get_message(): 159 | if not server.wait_for_message(timeout=5): 160 | raise TestServerError( 161 | "Port test failed. Make sure your port forwarding is correct" 162 | ) from None 163 | 164 | # Verify message content 165 | if not server.verify_message(): 166 | raise TestServerError( 167 | "Port test failed, wrong message received. Please wait for few minutes before trying to join again" 168 | ) from None 169 | 170 | else: 171 | response = requests.put( 172 | url, 173 | headers=headers, 174 | json=json_body, 175 | ) 176 | 177 | response.raise_for_status() 178 | 179 | response = response.json() 180 | 181 | self._authority_public_key = RSAPublicKey.from_bytes(response["auth_server_public_key"].encode()) 182 | self.monitor_public_key = str(response["monitor_public_key"]) 183 | self.pipeline_stage = response["stage_type"] 184 | self.reachable = response["reachable"] 185 | 186 | access_token = AccessToken() 187 | access_token.username = response["username"] 188 | access_token.public_key = response["peer_public_key"].encode() 189 | access_token.expiration_time = str(datetime.fromisoformat(response["expiration_time"])) 190 | access_token.signature = response["signature"].encode() 191 | self._local_access_token = access_token 192 | 193 | logger.info( 194 | f"Access for user {access_token.username} has been granted until {access_token.expiration_time} UTC" 195 | ) 196 | except requests.exceptions.HTTPError as e: 197 | if e.response.status_code in [401, 403, 429]: # Unauthorized, blacklisted, blocked or IP rate limited 198 | try: 199 | error_detail = e.response.json()["detail"] 200 | except Exception: 201 | error_detail = "Request is blocked" 202 | raise NotInAllowlistError(error_detail) from None 203 | if e.response.status_code in [400, 413, 418, 422, 424]: # wrong request information 204 | raise BadRequestError(e.response.json()["detail"]) from None 205 | if e.response.status_code == 503: # can't join due to rate limiting 206 | raise ServerUnavailableError(e.response.json()["detail"]) from None 207 | raise e 208 | except Exception as e: 209 | raise e 210 | 211 | def is_token_valid(self, access_token: AccessToken) -> bool: 212 | """Verify that token is valid.""" 213 | data = self._token_to_bytes(access_token) 214 | if not self._authority_public_key or not self._authority_public_key.verify(data, access_token.signature): 215 | logger.error("Access token has invalid signature") 216 | return False 217 | 218 | try: 219 | expiration_time = datetime.fromisoformat(access_token.expiration_time) 220 | except ValueError: 221 | logger.error(f"datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}") 222 | return False 223 | if expiration_time < datetime.now(timezone.utc): 224 | logger.error("Access token has expired") 225 | return False 226 | 227 | return True 228 | 229 | _MAX_LATENCY = timedelta(minutes=1) 230 | 231 | def does_token_need_refreshing(self, access_token: AccessToken) -> bool: 232 | """Check if token has expired.""" 233 | expiration_time = datetime.fromisoformat(access_token.expiration_time) 234 | return expiration_time < datetime.now(timezone.utc) + self._MAX_LATENCY 235 | 236 | @staticmethod 237 | def _token_to_bytes(access_token: AccessToken) -> bytes: 238 | """Convert access token to bytes.""" 239 | return f"{access_token.username} {access_token.public_key} {access_token.expiration_time}".encode() 240 | 241 | 242 | def save_identity(private_key: RSAPrivateKey, identity_path: str) -> None: 243 | """Save private key to file. 244 | 245 | Args: 246 | private_key (RSAPrivateKey): local private key 247 | identity_path (str): path to save the key 248 | 249 | Raises: 250 | FileNotFoundError: can't create file 251 | """ 252 | protobuf = crypto_pb2.PrivateKey(key_type=crypto_pb2.KeyType.RSA, data=private_key.to_bytes()) 253 | 254 | try: 255 | with open(identity_path, "wb") as f: 256 | f.write(protobuf.SerializeToString()) 257 | except FileNotFoundError as e: 258 | raise FileNotFoundError( 259 | f"The directory `{os.path.dirname(identity_path)}` for saving the identity does not exist" 260 | ) from e 261 | os.chmod(identity_path, 0o400) 262 | 263 | 264 | def authorize_with_pluralis( 265 | node_info: NodeInfo, 266 | user_token: str, 267 | user_email: str, 268 | role: str, 269 | auth_server: str, 270 | identity_path: str, 271 | current_path: Path, 272 | announce_maddrs: str, 273 | host_port: int, 274 | check_integrity: bool = True, 275 | ) -> PluralisAuthorizer: 276 | """Generate local keys and send authorization request to join the run. 277 | 278 | Args: 279 | node_info (NodeInfo): information about the node 280 | user_token (str): authentication token 281 | user_email (str): email address 282 | role (str): role in the swarm 283 | auth_server (str): authorization server URL 284 | identity_path (str): path to save/load private key 285 | current_path (Path): path to src/node0 286 | announce_maddrs (str): announce address 287 | host_port: (int): host port 288 | check_integrity (bool): flag to check integrity 289 | 290 | Returns: 291 | PluralisAuthorizer: authorizer instance 292 | """ 293 | logger.info("Authorization started...") 294 | 295 | # Generate private key or read from file 296 | local_private_key = RSAPrivateKey.process_wide() 297 | 298 | if os.path.exists(identity_path): 299 | with open(identity_path, "rb") as f: 300 | key_data = crypto_pb2.PrivateKey.FromString(f.read()).data 301 | private_key = serialization.load_der_private_key(key_data, password=None) 302 | if local_private_key._process_wide_key: 303 | local_private_key._process_wide_key._private_key = private_key 304 | else: 305 | logger.error("Failed to initialize process-wide private key") 306 | raise RuntimeError("Process-wide key is None") 307 | else: 308 | save_identity(local_private_key, identity_path) 309 | 310 | # Get static peer id 311 | with open(identity_path, "rb") as f: 312 | peer_id = str(PeerID.from_identity(f.read())) 313 | 314 | # Authorize 315 | authorizer = PluralisAuthorizer( 316 | peer_id, 317 | user_token, 318 | user_email, 319 | role, 320 | auth_server, 321 | node_info, 322 | current_path, 323 | announce_maddrs, 324 | host_port, 325 | check_integrity, 326 | ) 327 | 328 | try: 329 | authorizer.join_experiment(reset_reachability=True, initial_join=True) 330 | logger.info("Authorization completed") 331 | return authorizer 332 | except NotInAllowlistError as e: 333 | logger.error(f"Authorization failed: {e}. Exiting run.") 334 | exit(1) 335 | except BadRequestError as e: 336 | logger.error(f"Authorization failed: {e}. Exiting run.") 337 | exit(1) 338 | except IntegrityError: 339 | logger.error("Authorization failed: verification failed. Exiting run.") 340 | exit(1) 341 | except Exception as e: 342 | logger.error(f"Authorization failed: {e}. Exiting run.") 343 | exit(1) 344 | -------------------------------------------------------------------------------- /src/node0/server/power_sgd_averager.py: -------------------------------------------------------------------------------- 1 | # This file contains code originally from Hivemind under MIT License 2 | # Original: Copyright 2020 Learning@home authors and collaborators 3 | # Modified by: Pluralis Research 2025 4 | # 5 | # Original code: MIT License (see THIRD_PARTY_LICENSES) 6 | # Modifications: Apache 2.0 License (see LICENSE) 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only; 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | import asyncio 13 | import contextlib 14 | 15 | from enum import Enum 16 | from typing import Any, Iterable, Optional, Sequence 17 | 18 | import torch 19 | 20 | from hivemind.averaging.allreduce import AveragingMode 21 | from hivemind.averaging.group_info import GroupInfo 22 | from hivemind.averaging.load_balancing import load_balance_peers 23 | from hivemind.averaging.matchmaking import MatchmakingException 24 | from hivemind.compression import CompressionInfo, TensorRole 25 | from hivemind.dht import DHT 26 | from hivemind.utils import get_logger 27 | from hivemind.utils.asyncio import enter_asynchronously 28 | from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_ 29 | 30 | from node0.server.HM_gradient_averager import GradientAverager 31 | 32 | 33 | GatheredData = Any 34 | logger = get_logger(__name__) 35 | 36 | def qr_orthogonalize(matrix, iters=1): 37 | """QR orthogonalization in-place for 2D matrix.""" 38 | for _ in range(iters): 39 | Q, _ = torch.linalg.qr(matrix) 40 | matrix.copy_(Q) 41 | return matrix 42 | 43 | class AllReducePhases(Enum): 44 | PHASE_P = 1 45 | PHASE_Q = 2 46 | 47 | 48 | class PowerSGDGradientAverager(GradientAverager): 49 | """ 50 | A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727 51 | For basic properties and guaranties of gradient averagers, please refer to the base class docstring. 52 | Put simply, this method approximates large gradient tensors (m,n) with a product of two 53 | smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank). 54 | 55 | As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n). 56 | High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence. 57 | Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task. 58 | 59 | To maintain convergence with low r, this averager uses the error feedback strategy. Put simply, 60 | if some part of the gradient is "lost in compression", it will be added to the next iteration. 61 | This has two implications: (a) it needs more RAM in order to store the "feedback buffers" 62 | and (b) if devices stay alive only for one step, training with small rank may converge slower. 63 | This is because error feedback takes multiple steps to kick in. 64 | 65 | Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1). 66 | If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated 67 | normally, i.e. without powerSGD. See min_compression_ratio for details. 68 | 69 | :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a 70 | matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32). 71 | 72 | :param parameters: pytorch parameters for which to aggregate gradients 73 | :param averager_rank: rank of compressed gradients 74 | :param dht: a DHT instance connected to the rest of the swarm. See hivemind.DHT docs 75 | :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes 76 | :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps. 77 | This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all 78 | :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same 79 | device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at 80 | the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect. 81 | :param client_mode: if False, this averager will accept incoming requests from other peers. 82 | if True, the averager will only join existing groups where at least one peer has client_mode=False. 83 | By default, this flag is copied from DHTNode inside the ``dht`` instance. 84 | :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results 85 | :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor, 86 | otherwise aggregate tensors as is 87 | :param averaged_grads: if provided, it will be used as a set of averagable gradients 88 | """ 89 | 90 | def __init__( 91 | self, 92 | parameters: Iterable[torch.nn.Parameter], 93 | averager_rank: int, 94 | *, 95 | dht: DHT, 96 | prefix: str, 97 | reuse_grad_buffers: bool = False, 98 | accumulate_grads_on: Optional[torch.device] = None, 99 | client_mode: bool = None, 100 | warn: bool = True, 101 | min_compression_ratio: float = 0.5, 102 | averaged_grads: Optional[Sequence[torch.Tensor]] = None, 103 | reset_buffers_every_k_steps: int = 10, 104 | **kwargs, 105 | ): 106 | self.rank = averager_rank 107 | self.parameters = tuple(parameters) 108 | self._uncompressed_gradients_indexes = set( 109 | i 110 | for i, grad in enumerate(self._grads_from_parameters()) 111 | if grad.ndim <= 1 112 | or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio 113 | # compute how much parameters are left after factorization 114 | ) 115 | self._ms = [ 116 | torch.zeros_like(grad, device="cpu").share_memory_() 117 | for idx, grad in enumerate(self._grads_from_parameters()) 118 | if idx not in self._uncompressed_gradients_indexes 119 | ] 120 | 121 | self._ms_copy = [ 122 | torch.zeros_like(grad, device="cpu").share_memory_() 123 | for idx, grad in enumerate(self._grads_from_parameters()) 124 | if idx not in self._uncompressed_gradients_indexes 125 | ] 126 | 127 | self._qs = [ 128 | torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_() 129 | for idx, grad in enumerate(self._grads_from_parameters()) 130 | if idx not in self._uncompressed_gradients_indexes 131 | ] 132 | 133 | # Buffer reset tracking 134 | self.reset_buffers_every_k_steps = reset_buffers_every_k_steps 135 | self._step_count = 0 136 | self._last_successful_reset_step = 0 137 | 138 | super().__init__( 139 | self.parameters, 140 | dht=dht, 141 | prefix=prefix, 142 | reuse_grad_buffers=reuse_grad_buffers, 143 | accumulate_grads_on=accumulate_grads_on, 144 | client_mode=client_mode, 145 | warn=warn, 146 | averaged_grads=averaged_grads, 147 | **kwargs, 148 | ) 149 | 150 | @contextlib.contextmanager 151 | def _register_allreduce_group(self, group_info: GroupInfo): 152 | """Register a given group for one or more all-reduce rounds""" 153 | try: 154 | for phase in list(AllReducePhases): 155 | self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future() 156 | self._pending_groups_registered.set() 157 | yield 158 | finally: 159 | for phase in list(AllReducePhases): 160 | maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None) 161 | if maybe_future and not maybe_future.done(): 162 | logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.") 163 | self._pending_groups_registered.set() 164 | 165 | async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData: 166 | """Run aggregation in a given group and update tensors in place, return gathered metadata""" 167 | self._step_count += 1 # Increment buffer step count 168 | try: 169 | bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered)) 170 | user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes))) 171 | modes = tuple(map(AveragingMode, mode_ids)) 172 | 173 | download_bandwidths = [ 174 | thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes) 175 | ] 176 | peer_fractions = await asyncio.get_event_loop().run_in_executor( 177 | None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size 178 | ) 179 | 180 | async with enter_asynchronously(self.get_tensors()) as averaged_grads: 181 | averaged_grads_via_sgd = [ 182 | grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes 183 | ] 184 | 185 | err_norm = torch.nn.utils.get_total_norm(self._ms).item() 186 | logger.extra(f"Error norm val: {err_norm:.6f}") 187 | 188 | prepsgd_norm = torch.nn.utils.get_total_norm(averaged_grads_via_sgd).item() 189 | logger.extra(f"Prepsgd norm val: {prepsgd_norm:.6f}") 190 | 191 | # Adding noise to qs to prevent slow-down issues 192 | for q in self._qs: 193 | q.add_(torch.randn_like(q) * 1e-30) 194 | 195 | # Make a copy of _ms in case of fail 196 | for m, ms_copy in zip(self._ms, self._ms_copy): 197 | m.copy_(ms_copy) 198 | 199 | for grad, m in zip(averaged_grads_via_sgd, self._ms): 200 | m.add_(grad.to(m.device)) 201 | 202 | ps = [ 203 | torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu") 204 | for idx, grad in enumerate(averaged_grads_via_sgd) 205 | ] 206 | for p, q, m in zip(ps, self._qs, self._ms): 207 | # we use reshape for all matrixes because PowerSGD works only with 2d tensors 208 | torch.matmul(m.reshape(-1, q.size(0)), q, out=p) 209 | 210 | p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode() 211 | q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode() 212 | 213 | await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs) 214 | 215 | for p in ps: 216 | p = qr_orthogonalize(p, iters=1) 217 | 218 | for p, q, m in zip(ps, self._qs, self._ms): 219 | torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q) 220 | 221 | # local error before allreduce on Q 222 | for p, q, m in zip(ps, self._qs, self._ms): 223 | new_m = torch.matmul(p, q.t()).reshape(m.size()) 224 | m.sub_(new_m) # prev_err + grad - new_approx 225 | 226 | phase_q_tensors = self._qs + [ 227 | grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes 228 | ] 229 | 230 | await self._run_allreduce_inplace_( 231 | phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs 232 | ) 233 | 234 | for p, q, ms_copy, grad, m in zip(ps, self._qs, self._ms_copy, averaged_grads_via_sgd, self._ms): 235 | new_m = torch.matmul(p, q.t()).reshape(ms_copy.size()) 236 | grad.copy_(new_m) 237 | ms_copy.copy_(m) 238 | 239 | postpsgd_norm = torch.nn.utils.get_total_norm(averaged_grads_via_sgd).item() 240 | logger.extra(f"Postpsgd norm val: {postpsgd_norm:.6f}") 241 | 242 | return user_gathered 243 | except BaseException as e: 244 | logger.error(e, exc_info=logger.getEffectiveLevel() <= 15) 245 | raise MatchmakingException(f"Unable to run All-Reduce: {e}") 246 | 247 | def get_current_state(self): 248 | """ 249 | Get current gradient averager state and when requested by a newbie peer. 250 | """ 251 | with torch.no_grad(), self.lock_averaged_tensors: 252 | grad_averager_buffers = [q for q in self._qs] 253 | grad_averager_buffers_infos = [ 254 | CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT) 255 | for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers)) 256 | ] 257 | 258 | metadata = dict(group_bits=self.get_group_bits()) 259 | return metadata, grad_averager_buffers, grad_averager_buffers_infos 260 | 261 | def load_state_from_peers(self, **kwargs): 262 | """ 263 | Attempt to download the latest optimizer state from peers and update gradient averager buffers. 264 | :returns: whether or the averager succeeded in loading parameters 265 | """ 266 | loaded_state = super().load_state_from_peers(**kwargs) 267 | if loaded_state is None: 268 | return 269 | 270 | metadata, flat_tensors = loaded_state 271 | logger.info("Starting loading gradient averager buffers from peers") 272 | 273 | if len(flat_tensors) != len(self._qs): 274 | logger.error("Failed to load state from peer, received invalid parameters, extras or metadata") 275 | return 276 | 277 | with torch.no_grad(), self.lock_averaged_tensors: 278 | for local_q, loaded_q in zip(self._qs, flat_tensors): 279 | local_q.copy_(loaded_q, non_blocking=True) 280 | -------------------------------------------------------------------------------- /src/node0/server/HM_gradient_averager.py: -------------------------------------------------------------------------------- 1 | # This file contains code originally from Hivemind under MIT License 2 | # Original: Copyright 2020 Learning@home authors and collaborators 3 | # Modified by: Pluralis Research 2025 4 | # 5 | # Original code: MIT License (see THIRD_PARTY_LICENSES) 6 | # Modifications: Apache 2.0 License (see LICENSE) 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only; 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | import contextlib 13 | 14 | from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar 15 | 16 | import torch 17 | 18 | from hivemind.averaging.control import StepControl 19 | from hivemind.dht import DHT 20 | from hivemind.utils import DHTExpiration, get_logger 21 | 22 | from node0.server.HM_averager import DecentralizedAverager 23 | 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager") 29 | GradientAveragerFactory = Callable[..., TGradientAverager] 30 | 31 | 32 | class GradientAverager(DecentralizedAverager): 33 | """ 34 | An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers. 35 | GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below). 36 | 37 | GradientAverager manages three sets of buffers: 38 | (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad). 39 | These tensors are typically stored on device and updated by torch autograd 40 | (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated. 41 | - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators, 42 | which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually 43 | (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory 44 | 45 | :param parameters: pytorch parameters for which to aggregate gradients 46 | :param dht: a DHT instance connected to the rest of the swarm. See hivemind.DHT docs 47 | :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes 48 | :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps. 49 | This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all 50 | :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same 51 | device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at 52 | the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect. 53 | :param client_mode: if False, this averager will accept incoming requests from other peers. 54 | if True, the averager will only join existing groups where at least one peer has client_mode=False. 55 | By default, this flag is copied from DHTNode inside the ``dht`` instance. 56 | :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results 57 | :param averaged_grads: if provided, it will be used as a set of averagable gradients 58 | :param kwargs: see DecentralizedAverager keyword arguments for additional parameters 59 | 60 | 61 | Example: 62 | 63 | >>> model = SuchModelMuchLayers() 64 | >>> opt = torch.optim.Adam(model.parameters()) 65 | >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...)) 66 | >>> next_step_time = hivemind.get_dht_time() + 60 # runs global steps every 60 seconds 67 | >>> next_step_control = None 68 | >>> while True: 69 | >>> # accumulate as many gradients as you can before next_step_time 70 | >>> loss = compute_loss(model, batch_size=32) 71 | >>> loss.backward() 72 | >>> grad_averager.accumulate_grads_(batch_size=32) 73 | >>> # [optional] next step in 5 seconds, start looking for peers in advance 74 | >>> if next_step_time - hivemind.get_dht_time() <= 5 75 | >>> next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time) 76 | >>> # aggregate gradients and perform optimizer step 77 | >>> if hivemind.get_dht_time() >= next_step_time: 78 | >>> grad_averager.step(control=next_step_control) 79 | >>> with grad_averager.use_averaged_gradients(): # this will fill param.grads with aggregated gradients 80 | >>> opt.step() # update model parameters using averaged gradients 81 | >>> grad_averager.reset_accumulated_grads_() # prepare for next step 82 | >>> next_step_time = hivemind.get_dht_time() + 60 83 | >>> next_step_control = None 84 | 85 | """ 86 | 87 | def __init__( 88 | self, 89 | parameters: Iterable[torch.nn.Parameter], 90 | *, 91 | dht: DHT, 92 | prefix: str, 93 | reuse_grad_buffers: bool = False, 94 | accumulate_grads_on: Optional[torch.device] = None, 95 | client_mode: bool = None, 96 | warn: bool = True, 97 | averaged_grads: Sequence[torch.Tensor] = (), 98 | **kwargs, 99 | ): 100 | if reuse_grad_buffers and accumulate_grads_on is not None: 101 | logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True") 102 | client_mode = client_mode if client_mode is not None else dht.client_mode 103 | self.parameters = tuple(parameters) 104 | self.reuse_grad_buffers = reuse_grad_buffers 105 | self.warn = warn 106 | self.local_samples_accumulated = 0 107 | self.local_times_accumulated = 0 108 | self._anchor_batch_size = None 109 | self._local_accumulators = None 110 | self.processed_batches = torch.tensor(0.0, device="cpu").share_memory_() 111 | if not reuse_grad_buffers: 112 | self._local_accumulators = tuple( 113 | torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters() 114 | ) 115 | self._accumulators_used_in_step = False 116 | self._new_averaged_grads = False 117 | 118 | with torch.no_grad(): 119 | if not averaged_grads: 120 | averaged_grads = tuple( 121 | grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters() 122 | ) 123 | else: 124 | if any( 125 | param_grad.size() != grad.size() 126 | for param_grad, grad in zip(self._grads_from_parameters(), averaged_grads) 127 | ): 128 | raise ValueError("Averaged gradients don't have same shape as gradients from parameters") 129 | super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs) 130 | 131 | def _grads_from_parameters(self) -> Iterator[torch.Tensor]: 132 | """gradient buffers associated with parameters""" 133 | for param in self.parameters: 134 | if param.grad is None: 135 | param.grad = torch.zeros_like(param) 136 | yield param.grad 137 | 138 | @torch.no_grad() 139 | def _grad_accumulators(self) -> Iterator[torch.Tensor]: 140 | """averager-based gradient accumulators""" 141 | assert (self._local_accumulators is None) == self.reuse_grad_buffers 142 | yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators 143 | 144 | @torch.no_grad() 145 | def accumulate_grads_(self, batch_size: int): 146 | """add current gradients to local grad accumulators (if used)""" 147 | if self._accumulators_used_in_step and self.warn: 148 | logger.warning( 149 | "[warn=True] Gradient accumulators were not reset since the last averaging round. Please " 150 | "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)" 151 | ) 152 | self._accumulators_used_in_step = False # warn once per round 153 | if self._anchor_batch_size is None: 154 | # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size 155 | self._anchor_batch_size = batch_size 156 | self.local_samples_accumulated += batch_size 157 | self.local_times_accumulated += 1 158 | if self.reuse_grad_buffers: 159 | pass # user is responsible for accumulating gradients in .grad buffers 160 | else: 161 | alpha = float(batch_size) / self._anchor_batch_size 162 | for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()): 163 | grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha) 164 | 165 | def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl: 166 | """ 167 | Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time. 168 | 169 | :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time 170 | :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc 171 | :note: setting weight at this stage is not supported, please leave this parameter as None 172 | :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group 173 | :note: in the current implementation, each step_control can only be used in one step. 174 | """ 175 | assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported" 176 | return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs) 177 | 178 | def step( 179 | self, 180 | weight: Optional[float] = None, 181 | reset_accumulators: bool = True, 182 | control: Optional[StepControl] = None, 183 | timeout: Optional[float] = None, 184 | wait: bool = True, 185 | **kwargs, 186 | ): 187 | """ 188 | Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators 189 | 190 | :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples 191 | :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds 192 | :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step 193 | :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True) 194 | :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background 195 | """ 196 | if control is None: 197 | control = self.schedule_step(timeout=timeout, **kwargs) 198 | elif len(kwargs) > 0: 199 | raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect") 200 | assert not control.triggered, f"This {type(control)} instance was already used" 201 | if self._new_averaged_grads and self.warn: 202 | logger.warning( 203 | "[warn=True] Starting new averaging round, but previous round results were not used. " 204 | "This may be a sign of incorrect optimizer behavior" 205 | ) 206 | 207 | self.load_accumulators_into_averager_() 208 | self._accumulators_used_in_step = True 209 | self._new_averaged_grads = True 210 | 211 | control.weight = self.local_samples_accumulated if weight is None else weight 212 | if reset_accumulators: 213 | self.reset_accumulated_grads_() 214 | control.allow_allreduce() 215 | 216 | return control.result(timeout) if wait else control 217 | 218 | @torch.no_grad() 219 | def load_accumulators_into_averager_(self): 220 | """load locally accumulated gradients into the averager for aggregation""" 221 | # divide locally accumulated gradients by the number of times they were accumulated 222 | grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0 223 | with self.get_tensors() as averaged_grads: 224 | for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads): 225 | averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale) 226 | 227 | @torch.no_grad() 228 | def reset_accumulated_grads_(self): 229 | """reset averager-internal gradient accumulators and the denominator""" 230 | self._accumulators_used_in_step = False 231 | self.local_samples_accumulated = self.local_times_accumulated = 0 232 | self._anchor_batch_size = None 233 | for grad_buf in self._grad_accumulators(): 234 | grad_buf.zero_() 235 | 236 | @contextlib.contextmanager 237 | @torch.no_grad() 238 | def use_averaged_gradients(self): 239 | """Substitute model's main gradients with averaged gradients (does not respect device placement)""" 240 | self._new_averaged_grads = False 241 | with self.get_tensors() as averaged_grads: 242 | assert len(averaged_grads) == len(self.parameters) 243 | try: 244 | old_grads = [param.grad for param in self.parameters] 245 | for param, new_grad in zip(self.parameters, averaged_grads): 246 | param.grad = new_grad 247 | yield averaged_grads 248 | finally: 249 | for param, old_grad in zip(self.parameters, old_grads): 250 | param.grad = old_grad 251 | 252 | def notify_used_averaged_gradients(self): 253 | """Notify averager that the results of a previous averaging round are accounted for""" 254 | self._new_averaged_grads = False 255 | 256 | def has_nan_grads(self) -> bool: 257 | """Check if any gradient contains NaN or inf values""" 258 | return any(not torch.isfinite(grad).all() for grad in self._grads_from_parameters()) 259 | -------------------------------------------------------------------------------- /src/node0/server/node0_server.py: -------------------------------------------------------------------------------- 1 | # This file contains code originally from Hivemind under MIT License 2 | # Original: Copyright 2020 Learning@home authors and collaborators 3 | # Modified by: Pluralis Research 2025 4 | # 5 | # Original code: MIT License (see THIRD_PARTY_LICENSES) 6 | # Modifications: Apache 2.0 License (see LICENSE) 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License") for modifications only; 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 11 | 12 | import multiprocessing as mp 13 | 14 | from functools import partial 15 | from typing import Type 16 | 17 | import torch 18 | 19 | from hivemind.compression import CompressionBase, NoCompression 20 | from hivemind.dht import DHT 21 | from hivemind.moe.expert_uid import UID_DELIMITER 22 | from hivemind.moe.server import Server 23 | from hivemind.moe.server.layers import ( 24 | add_custom_models_from_file, 25 | name_to_block, 26 | name_to_input, 27 | ) 28 | from hivemind.moe.server.server import _generate_uids 29 | from hivemind.optim import Optimizer 30 | from hivemind.proto.runtime_pb2 import CompressionType 31 | from hivemind.utils.logging import get_logger 32 | from hivemind.utils.tensor_descr import DUMMY_BATCH_SIZE, BatchTensorDescriptor 33 | 34 | from node0.models.arguments import ModelArguments 35 | from node0.models.lr_schedule import schedule_name_to_scheduler 36 | from node0.server.HM_gradient_averager import GradientAveragerFactory 37 | from node0.server.module_collab import ModuleCollab 38 | from node0.utils import MonitorWorker 39 | from node0.utils.common import load_ss_components 40 | from node0.utils.dht_monitor import patch_dht_protocol_logging 41 | from node0.utils.get_parameters import get_parameter_store 42 | 43 | 44 | logger = get_logger(__name__) 45 | 46 | 47 | def get_ss_dim(expert_cls: str, model_conf: ModelArguments) -> int: 48 | """Utility function to get the input dimension to a stage.""" 49 | if "head" in expert_cls: 50 | input_dim = model_conf.hidden_dim 51 | else: 52 | compression_length = int(model_conf.hidden_dim // model_conf.compression_rate) 53 | input_dim = compression_length + 1 # +1 for tokens 54 | return input_dim 55 | 56 | 57 | class Node0Server(Server): 58 | def __init__(self, optim, *args, **kwargs): 59 | super().__init__(*args, **kwargs) 60 | 61 | optim._start_monitor() 62 | 63 | @classmethod 64 | def create( 65 | cls, 66 | num_experts: int = 1, 67 | expert_uids: str = None, 68 | expert_pattern: str = None, 69 | expert_cls: str = "lm_body", 70 | model_conf: ModelArguments = None, 71 | optim_cls: Type[torch.optim.Optimizer] = torch.optim.AdamW, 72 | scheduler: str = "none", 73 | num_warmup_steps: int | None = None, 74 | num_training_steps: int | None = None, 75 | clip_grad_norm: float | None = None, 76 | weight_decay: float | None = None, 77 | num_stages: int | None = None, 78 | num_handlers: int | None = None, 79 | min_batch_size: int = 1, 80 | max_batch_size: int = 4096, 81 | averaging_target_batch_size: int = 256, 82 | reuse_grad_buffers: bool = False, 83 | use_local_updates: bool = False, 84 | use_offloading: bool = False, 85 | matchmaking_time: float = 5.0, 86 | averaging_timeout: float = 10.0, 87 | request_timeout: float = 3.0, 88 | next_chunk_timeout: float = 10.0, 89 | load_state_timeout: float = 600, 90 | average_state_every: int = 1, 91 | sparse_avg: float = 0.0, 92 | max_allowed_stale: int = 0, 93 | grad_avg_factory: GradientAveragerFactory | None = None, 94 | optim_collab_cls: Type[torch.optim.Optimizer] = Optimizer, 95 | grad_averaging_compression: CompressionBase = NoCompression(), 96 | load_state_compression: CompressionBase = NoCompression(), 97 | device: str | None = None, 98 | initial_peers: list = [], 99 | compression: CompressionType = CompressionType.NONE, 100 | stats_report_interval: int = 60, 101 | custom_module_path: str | None = None, 102 | update_period: float = 30, 103 | expiration: float | None = None, 104 | monitor: MonitorWorker | None = None, 105 | upload_bw: float | None = None, 106 | *, 107 | start: bool, 108 | **kwargs, 109 | ) -> Server: 110 | """Instantiate a server for collaborative optimization. 111 | 112 | Args: 113 | start (bool): if True, starts the server right away 114 | num_experts (int, optional): run this many identical experts. Defaults to 1. 115 | expert_uids (str, optional): spawn experts with these exact uids, overrides num_experts and expert_pattern. Defaults to None. 116 | expert_pattern (str, optional): a string pattern for experts uids. Defaults to None. 117 | expert_cls (str, optional): expert type. Defaults to "lm_body". 118 | model_conf (BaseModel, optional): model config class. Defaults to None. 119 | optim_cls (Type[torch.optim.Optimizer], optional): optimizer class. Defaults to torch.optim.AdamW. 120 | scheduler (str, optional): if not `none`, the name of the expert LR scheduler. Defaults to "none". 121 | num_warmup_steps (int | None, optional): the number of warmup steps for LR scheduler. Defaults to None. 122 | num_training_steps (int | None, optional): the total number of steps for LR scheduler. Defaults to None. 123 | clip_grad_norm (float | None, optional): maximum gradient norm used for clipping. Defaults to None. 124 | num_handlers (int | None, optional): server will use this many parallel processes to handle incoming requests. Defaults to None. 125 | min_batch_size (int, optional): total num examples in the same batch will be greater than this value. Defaults to 1. 126 | max_batch_size (int, optional): total num examples in the same batch will not exceed this value. Defaults to 4096. 127 | averaging_target_batch_size (int): number of examples to accumulate across all peers before averaging. Defaults to 256. 128 | reuse_grad_buffers (bool, optional): if True, use model's .grad buffers for gradient accumulation. Defaults to False. 129 | use_local_updates (bool, optional): whether each node performs local weights updates between the allreduce. Defaults to False. 130 | use_offloading (bool, optional): perform gradient offloading. Defaults to False. 131 | matchmaking_time (float, optional): time window for nodes to find other nodes for allreduce. Defaults to 5.0. 132 | averaging_timeout (float, optional): timeout for nodes to perform the allreduce. Defaults to 10.0. 133 | optim_collab_cls (Type[torch.optim.Optimizer], optional): collaborative optimizer class. Defaults to Optimizer. 134 | device (str | None, optional): cuda or cpu. Defaults to None. 135 | initial_peers (list, optional): multiaddrs of one or more active DHT peers. Defaults to []. 136 | compression (CompressionType, optional): compression type. Defaults to CompressionType.NONE. 137 | stats_report_interval (int | None, optional): interval between two reports of batch processing performance statistics. Defaults to None. 138 | custom_module_path (str | None, optional): path of a file with custom nn.modules. Defaults to None. 139 | update_period (float, optional): server will report experts to DHT once in this many seconds. Defaults to 30. 140 | expiration (float | None, optional): DHT entries will expire after this many seconds. Defaults to None. 141 | monitor (MonitorWorker | None, optional): monitor instance. Defaults to None. 142 | upload_bw (float | None, optional): upload bandwidth. Defaults to None. 143 | kwargs: any other params will be forwarded to DHT upon creation 144 | 145 | Returns: 146 | Server: collaborative training server 147 | """ 148 | # Add custom layers 149 | if custom_module_path is not None: 150 | add_custom_models_from_file(custom_module_path) 151 | 152 | try: 153 | assert expert_cls in name_to_block 154 | except: 155 | logger.error( 156 | f"Expert class {expert_cls} is not supported. Make sure you provided correct custom_module_path: {custom_module_path}" 157 | ) 158 | raise 159 | 160 | # Connect to DHT 161 | _ = patch_dht_protocol_logging() 162 | dht = DHT(initial_peers=initial_peers, start=True, startup_timeout=30, **kwargs) 163 | visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()] 164 | logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") 165 | 166 | # Connect to monitor 167 | if monitor is not None: 168 | monitor.connect_dht(dht) 169 | 170 | # Generate uids 171 | try: 172 | assert (expert_pattern is None and num_experts is None and expert_uids is not None) or ( 173 | num_experts is not None and expert_uids is None 174 | ) 175 | except: 176 | logger.error( 177 | "Please provide either expert_uids *or* num_experts (possibly with expert_pattern), but not both)" 178 | ) 179 | raise 180 | 181 | if expert_uids is None: 182 | expert_uids = [] 183 | 184 | uids_to_generate = num_experts - len(expert_uids) 185 | if uids_to_generate > 0: 186 | logger.info(f"Generating {uids_to_generate} expert uids from pattern {expert_pattern}") 187 | expert_uids.extend(_generate_uids(uids_to_generate, expert_pattern, dht)) 188 | 189 | # Get parameter store 190 | stage = expert_uids[0].split(".")[0] 191 | ( 192 | averaging_target_batch_size, 193 | scheduler, 194 | num_warmup_steps, 195 | num_training_steps, 196 | averaging_timeout, 197 | matchmaking_time, 198 | request_timeout, 199 | load_state_timeout, 200 | ) = get_parameter_store(dht, stage) 201 | 202 | num_experts = len(expert_uids) 203 | num_handlers = num_handlers if num_handlers is not None else num_experts * 8 204 | device = device or ( 205 | "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" 206 | ) 207 | 208 | # Scheduler 209 | scheduler_cls = schedule_name_to_scheduler[scheduler] 210 | if scheduler_cls is not None: 211 | scheduler_cls = partial( 212 | scheduler_cls, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps 213 | ) 214 | 215 | # Initialize experts 216 | input_dim = model_conf.hidden_dim if not model_conf.use_compression else get_ss_dim(expert_cls, model_conf) 217 | sequence_length = model_conf.max_seq_len 218 | sample_input = name_to_input[expert_cls](DUMMY_BATCH_SIZE, sequence_length, input_dim) 219 | if isinstance(sample_input, tuple): 220 | args_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in sample_input) 221 | else: 222 | args_schema = (BatchTensorDescriptor.from_tensor(sample_input, compression),) 223 | 224 | experts = {} 225 | for expert_uid in expert_uids: 226 | expert = name_to_block[expert_cls](model_conf) 227 | if monitor: 228 | active_period_timeout = ( 229 | max(load_state_timeout, averaging_timeout + matchmaking_time + request_timeout) + 10 230 | ) 231 | monitor.monitor_callback(expert, model_conf, active_period_timeout) 232 | 233 | assert averaging_target_batch_size is not None 234 | 235 | if model_conf.use_compression: 236 | exclude_params = ["rcv", "fixed_tok_embeddings"] 237 | else: 238 | exclude_params = [] 239 | if weight_decay: 240 | no_decay = ["tok_embeddings.weight"] 241 | params = [ 242 | { 243 | "params": [ 244 | p 245 | for n, p in expert.named_parameters() 246 | if not any(nd in n for nd in no_decay) and not any(ex in n for ex in exclude_params) 247 | ], 248 | "weight_decay": weight_decay, 249 | }, 250 | { 251 | "params": [ 252 | p 253 | for n, p in expert.named_parameters() 254 | if any(nd in n for nd in no_decay) and not any(ex in n for ex in exclude_params) 255 | ], 256 | "weight_decay": 0.0, 257 | }, 258 | ] 259 | else: 260 | params = [p for n, p in expert.named_parameters() if not any(ex in n for ex in exclude_params)] 261 | 262 | if model_conf.use_compression: 263 | ss_comps = load_ss_components(model_conf.ss_component) 264 | expert.load_comp(ss_comps) 265 | logger.info("Succeded loading remote subspace components") 266 | expert.ss_regularize() 267 | 268 | optimizer_lock = mp.Lock() 269 | 270 | backend = ModuleCollab( 271 | optimizer_lock=optimizer_lock, 272 | name=expert_uid, 273 | module=expert, 274 | args_schema=args_schema, 275 | min_batch_size=min_batch_size, 276 | max_batch_size=max_batch_size, 277 | ) 278 | 279 | backend.module.to(device) 280 | 281 | optim_collab = optim_collab_cls( 282 | model=expert, 283 | optimizer_lock=optimizer_lock, 284 | sparse_avg=sparse_avg, 285 | max_allowed_stale=max_allowed_stale, 286 | optimizer=optim_cls, 287 | params=params, 288 | dht=dht, 289 | run_id=expert_uid.split(UID_DELIMITER)[0], 290 | scheduler=scheduler_cls, 291 | target_batch_size=averaging_target_batch_size, 292 | matchmaking_time=matchmaking_time, 293 | averaging_timeout=averaging_timeout, 294 | load_state_timeout=load_state_timeout, 295 | average_state_every=average_state_every, 296 | grad_averager_factory=grad_avg_factory, 297 | grad_compression=grad_averaging_compression, 298 | reuse_grad_buffers=reuse_grad_buffers, 299 | use_local_updates=use_local_updates, 300 | offload_optimizer=use_offloading, 301 | delay_state_averaging=False, 302 | next_chunk_timeout=next_chunk_timeout, 303 | verbose=True, 304 | averager_opts={"bandwidth": upload_bw, "request_timeout": request_timeout}, 305 | ) 306 | optim_collab.load_state_from_peers(wait_for_end_round=True) 307 | backend.optimizer = optim_collab 308 | experts[expert_uid] = backend 309 | 310 | return cls( 311 | optim_collab, 312 | dht, 313 | experts, 314 | num_connection_handlers=num_handlers, 315 | device=device, 316 | checkpoint_dir=None, 317 | stats_report_interval=stats_report_interval, 318 | update_period=update_period, 319 | expiration=expiration, 320 | start=start, 321 | ) 322 | -------------------------------------------------------------------------------- /src/node0/server/optim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Pluralis Research 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing as mp 16 | import os 17 | import signal 18 | import threading 19 | import time 20 | 21 | from typing import Callable, Optional 22 | 23 | import torch 24 | 25 | from hivemind import Optimizer 26 | from hivemind.optim.grad_scaler import GradScaler 27 | from hivemind.utils import get_dht_time, get_logger 28 | 29 | from node0.server.HM_gradient_averager import GradientAverager, GradientAveragerFactory 30 | from node0.server.state_averager_wrap import TrainingStateAverager 31 | 32 | 33 | logger = get_logger(__name__) 34 | 35 | 36 | class AutoStepOptimizer(Optimizer): 37 | """ 38 | A class that extends Hivemind Optimizer and ensures step() is called at least once per auto_step_time. 39 | If step() hasn't been called externally within the auto_step_time window, it will be called automatically. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | model, 45 | optimizer_lock, 46 | sparse_avg: float = 0.0, 47 | auto_step_time: float = 3.0, 48 | max_allowed_stale: int = 0, 49 | grad_schedule_buffer: float = 5.0, 50 | *args, 51 | **kwargs, 52 | ): 53 | super().__init__(*args, **kwargs) 54 | self._auto_step_time = auto_step_time 55 | self.model = model 56 | self.max_allowed_stale = max_allowed_stale 57 | self.grad_schedule_buffer = grad_schedule_buffer 58 | self.optimizer_lock = optimizer_lock 59 | self._last_step_time: float = time.time() 60 | self._step_lock = mp.Lock() 61 | self._monitor_thread: Optional[threading.Thread] = None 62 | self._should_stop = threading.Event() 63 | self.in_update = False 64 | 65 | # Set state avg parameters 66 | self.state_averager.average_state_every = self.average_state_every 67 | self.state_averager.sparse_avg = sparse_avg 68 | if sparse_avg: 69 | self.state_averager.set_sparta_partitions() 70 | self.state_averager.run_in_background(await_ready=True) 71 | 72 | def _resync_state(self): 73 | if self._should_load_state_from_peers(): 74 | logger.log(self.status_loglevel, "Peer is out of sync") 75 | self.load_state_from_peers() 76 | return True # local gradients were computed with out-of-sync parameters, must start over 77 | elif self._catchup_epoch(): 78 | with self.tracker.pause_updates(): 79 | logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}") 80 | self.state_averager.local_epoch = self.tracker.global_epoch 81 | self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0) 82 | return False 83 | 84 | def _should_load_state_from_peers(self) -> bool: 85 | return self.local_epoch < (self.tracker.global_epoch - self.max_allowed_stale) 86 | 87 | def _catchup_epoch(self) -> bool: 88 | return (self.tracker.global_epoch - self.max_allowed_stale) <= self.local_epoch < self.tracker.global_epoch 89 | 90 | def _make_state_averager(self, **kwargs) -> TrainingStateAverager: 91 | return TrainingStateAverager( 92 | dht=self.dht, 93 | prefix=f"{self.run_id}_state_averager", 94 | min_matchmaking_time=self.matchmaking_time, 95 | allreduce_timeout=self.allreduce_timeout, 96 | shutdown_timeout=self.shutdown_timeout, 97 | offload_optimizer=self.offload_optimizer, 98 | custom_gradients=self.offload_optimizer, 99 | status_loglevel=self.status_loglevel, 100 | next_chunk_timeout=self.next_chunk_timeout, 101 | client_mode=self.client_mode, 102 | auxiliary=self.auxiliary, 103 | allow_state_sharing=False, 104 | start=True, 105 | **kwargs, 106 | ) 107 | 108 | def _make_gradient_averager(self, factory: Optional[GradientAveragerFactory], **kwargs) -> GradientAverager: 109 | assert hasattr(self, "state_averager"), "must initialize state averager first" 110 | factory = factory if factory is not None else GradientAverager 111 | grad_averager = factory( 112 | dht=self.dht, 113 | prefix=f"{self.run_id}_grad_averager", 114 | parameters=self.state_averager.main_parameters, 115 | min_matchmaking_time=self.matchmaking_time, 116 | allreduce_timeout=self.allreduce_timeout, 117 | shutdown_timeout=self.shutdown_timeout, 118 | next_chunk_timeout=self.next_chunk_timeout, 119 | client_mode=self.client_mode, 120 | auxiliary=self.auxiliary, 121 | allow_state_sharing=False, 122 | start=True, 123 | **kwargs, 124 | ) 125 | if self.offload_optimizer: 126 | optimized_param_groups = self.state_averager.optimizer.param_groups 127 | optimized_parameters = [param for group in optimized_param_groups for param in group["params"]] 128 | with grad_averager.get_tensors() as averaged_gradients: 129 | assert len(averaged_gradients) == len(optimized_parameters) 130 | for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients): 131 | opt_param.grad = averaged_grad 132 | return grad_averager 133 | 134 | def _start_monitor(self) -> None: 135 | """Start the monitoring thread if it's not already running.""" 136 | if self._monitor_thread is None or not self._monitor_thread.is_alive(): 137 | self._should_stop.clear() 138 | self._monitor_thread = threading.Thread(target=self._monitor_step, daemon=True) 139 | self._monitor_thread.start() 140 | 141 | def _monitor_step(self) -> None: 142 | """Monitor thread that checks and calls step() if it wasn't called within auto_step_time window.""" 143 | while not self._should_stop.is_set(): 144 | time.sleep(1.0) # Check every 1s 145 | 146 | current_time = time.time() 147 | time_since_last_step = current_time - self._last_step_time 148 | 149 | if time_since_last_step >= self._auto_step_time: 150 | with self._step_lock: 151 | # Check again after acquiring lock in case step() was called 152 | if time.time() - self._last_step_time >= self._auto_step_time: 153 | self._auto_step() 154 | 155 | def _auto_step(self) -> None: 156 | """Internal method to call step() automatically.""" 157 | try: 158 | # Call the parent class's step method with one batch size 159 | logger.debug(f"AutoStepOptimizer step at {time.strftime('%H:%M:%S')}") 160 | self._last_step_time = time.time() 161 | # self._check_update_version() 162 | 163 | if self._resync_state(): 164 | return None 165 | 166 | self._maybe_schedule_gradient_averaging() 167 | 168 | if self.in_update and self.tracker.ready_to_update_epoch: 169 | batch_size = 1 170 | self._step(batch_size=batch_size) 171 | 172 | # self.state_averager.allow_state_sharing = True 173 | except Exception as e: 174 | logger.error(f"Error in auto step: {e}") 175 | 176 | def step(self, batch_size: Optional[int] = None) -> None: 177 | """ 178 | Override of the step method that updates the last step time. 179 | This should be called by external code. 180 | """ 181 | with self._step_lock: 182 | # self._check_update_version() 183 | 184 | if self._resync_state(): 185 | return None 186 | 187 | self._step(batch_size=batch_size) 188 | # self.state_averager.allow_state_sharing = True 189 | 190 | self._last_step_time = time.time() 191 | 192 | def stop_monitoring(self) -> None: 193 | """Stop the monitoring thread.""" 194 | self._should_stop.set() 195 | if self._monitor_thread is not None: 196 | self._monitor_thread.join(timeout=1) 197 | self._monitor_thread = None 198 | 199 | def __del__(self): 200 | """Ensure the monitoring thread is stopped when the object is destroyed.""" 201 | self.stop_monitoring() 202 | 203 | @property 204 | def ready_to_update_epoch(self) -> bool: 205 | """Whether or not this peer can increment epoch right away.""" 206 | return ( 207 | self.tracker.global_epoch > self.tracker.local_progress.epoch 208 | or self.tracker.global_progress.samples_accumulated >= self.tracker.target_batch_size 209 | ) 210 | 211 | def _check_and_accumulate_gradients(self, batch_size: int) -> bool: 212 | """Check if gradients are valid, accumulate and return True; otherwise, reset and return False""" 213 | if self.grad_averager.has_nan_grads(): 214 | self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0) 215 | logger.error("Encountered incorrect value in grads, exiting run") 216 | os.killpg(os.getpgrp(), signal.SIGTERM) 217 | 218 | self.grad_averager.accumulate_grads_(batch_size) 219 | self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated) 220 | return True 221 | 222 | def _maybe_schedule_gradient_averaging(self) -> None: 223 | """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch""" 224 | assert self.use_gradient_averaging 225 | if not self.in_update and self.ready_to_update_epoch: 226 | if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done(): 227 | self.in_update = True 228 | eta_seconds = self.tracker.estimated_next_update_time - get_dht_time() 229 | logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec") 230 | self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout) 231 | 232 | def _step( 233 | self, 234 | closure: Optional[Callable[[], torch.Tensor]] = None, 235 | batch_size: Optional[int] = None, 236 | grad_scaler: Optional[GradScaler] = None, 237 | ): 238 | """ 239 | Update training progress after accumulating another local batch size. Depending on the configuration, this will 240 | report progress to peers, run global or local optimizer step, average parameters or schedule background tasks. 241 | 242 | :param closure: A closure that reevaluates the model and returns the loss. 243 | :param batch_size: optional override for batch_size_per_step from init. 244 | :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details. 245 | """ 246 | if grad_scaler is not None and not isinstance(grad_scaler, GradScaler): 247 | raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)") 248 | if self.batch_size_per_step is None and batch_size is None and not self.auxiliary: 249 | raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step") 250 | if self.auxiliary and (closure is not None or batch_size is not None): 251 | raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler") 252 | batch_size = batch_size if batch_size is not None else self.batch_size_per_step 253 | 254 | # if delayed updates finished before step, apply these updates; otherwise do nothing 255 | self.state_averager.step(apply_delayed_updates=True) 256 | 257 | loss = None 258 | if closure is not None: 259 | with torch.enable_grad(): 260 | loss = closure() 261 | 262 | # accumulate gradients toward target batch size, then aggregate with peers and run optimizer 263 | self._check_and_accumulate_gradients(batch_size) 264 | 265 | self._maybe_schedule_gradient_averaging() 266 | # self._maybe_schedule_state_averaging() 267 | 268 | if self.in_update and self.tracker.ready_to_update_epoch: 269 | self.state_averager.allow_state_sharing = False # Prevent state sharing during AR. 270 | self.grad_averager.processed_batches.copy_(float(self.tracker.local_progress.samples_accumulated)) 271 | self.state_averager.global_epoch = self.tracker.global_epoch 272 | self.grad_averager.global_epoch = self.tracker.global_epoch 273 | with self.optimizer_lock: 274 | self._update_global_epoch(grad_scaler) 275 | 276 | if self.model.model_args.use_compression: 277 | self.model.ss_regularize() 278 | 279 | self.in_update = False 280 | return loss 281 | 282 | def load_state_from_peers(self, wait_for_end_round=False, **kwargs): 283 | # Wait for a while grad accumulation round before requesting state from peers 284 | logger.info(f"Waiting for peers in stage to finish step before joining") 285 | while True: 286 | if ( 287 | self.tracker.fetched_global_progress_this_epoch.is_set() 288 | and self.tracker.global_progress.samples_accumulated < self.target_batch_size * 0.1 289 | ): 290 | break 291 | else: 292 | logger.info(f"Waiting for peers in stage to finish step before joining") 293 | time.sleep(self.tracker.max_refresh_period) 294 | 295 | self._load_state_from_peers(**kwargs) 296 | 297 | if wait_for_end_round: 298 | self.tracker.fetched_global_progress_this_epoch.clear() 299 | while True: 300 | if ( 301 | self.tracker.fetched_global_progress_this_epoch.is_set() 302 | and self.tracker.global_progress.samples_accumulated < self.target_batch_size * 0.2 303 | ): 304 | break 305 | else: 306 | logger.info(f"Downloaded state, waiting for start of new round") 307 | time.sleep(0.5) 308 | logger.info(f"Joining run") 309 | 310 | def _load_state_from_peers(self, **kwargs): 311 | """ 312 | Attempt to load the newest collaboration state from other peers within the same run_id. 313 | 314 | If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place. 315 | """ 316 | # note: we tag along for the next all-reduce because the run may have already started and cancelling it 317 | # will cause peers to restart matchmaking and may stall the entire collaboration for a few seconds. 318 | if self.scheduled_grads is not None and not self.scheduled_grads.done(): 319 | self._tag_along_with_zero_weight(self.scheduled_grads) 320 | self.scheduled_grads = None 321 | 322 | with self.tracker.pause_updates(): 323 | while True: 324 | try: 325 | self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs) 326 | if self.grad_averager is not None: 327 | self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs) 328 | break 329 | except KeyboardInterrupt: 330 | raise 331 | except BaseException as e: 332 | logger.error( 333 | f"Failed to load state from peers: {e}, retrying ...", 334 | exc_info=logger.getEffectiveLevel() <= 15, 335 | ) 336 | continue 337 | 338 | if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch: 339 | logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}") 340 | self.state_averager.local_epoch = self.tracker.global_epoch 341 | 342 | self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0) 343 | 344 | if not self.client_mode: 345 | self.state_averager.state_sharing_priority = self.local_epoch 346 | 347 | if self.use_gradient_averaging: 348 | self.grad_averager.reset_accumulated_grads_() 349 | if not self.client_mode: 350 | self.grad_averager.state_sharing_priority = self.local_epoch 351 | 352 | if hasattr(self.state_averager, "zclip") and self.state_averager.zclip.var.item() > 0.0: 353 | self.state_averager.zclip.initialized = True 354 | --------------------------------------------------------------------------------