├── neurons └── __init__.py ├── distributed_training ├── base │ ├── __init__.py │ ├── miner.py │ └── neuron.py ├── validator │ └── __init__.py ├── utils │ ├── __init__.py │ ├── dist.py │ ├── chain.py │ ├── upload_worker.py │ ├── dendrite.py │ ├── progress_tracker.py │ ├── r2.py │ ├── weight_utils.py │ ├── logger.py │ ├── misc.py │ ├── uids.py │ ├── config.py │ └── compression.py ├── averaging │ ├── exceptions.py │ └── avg_handler.py ├── __init__.py ├── scripts │ └── cleanup_r2_bucket.py └── protocol.py ├── pyproject.toml ├── assets ├── error_asyncio_timeout.png ├── error_download_state_from_peers.png ├── error_failed_to_connect_to_DHT.png └── error_could_not_find_a_group_error.png ├── .env.example ├── Makefile ├── requirements.txt ├── min.compute.yml ├── .gitignore ├── setup.py ├── README.md ├── run_miner.sh ├── run_validator.sh └── eval └── eval_loss.py /neurons/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /distributed_training/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 -------------------------------------------------------------------------------- /distributed_training/validator/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward import forward 2 | -------------------------------------------------------------------------------- /distributed_training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config 2 | from . import misc 3 | from . import uids 4 | -------------------------------------------------------------------------------- /assets/error_asyncio_timeout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMFODA/DistributedTraining/HEAD/assets/error_asyncio_timeout.png -------------------------------------------------------------------------------- /assets/error_download_state_from_peers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMFODA/DistributedTraining/HEAD/assets/error_download_state_from_peers.png -------------------------------------------------------------------------------- /assets/error_failed_to_connect_to_DHT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMFODA/DistributedTraining/HEAD/assets/error_failed_to_connect_to_DHT.png -------------------------------------------------------------------------------- /assets/error_could_not_find_a_group_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMFODA/DistributedTraining/HEAD/assets/error_could_not_find_a_group_error.png -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | export R2_ACCOUNT_ID= 2 | export R2_BUCKET_NAME= 3 | export R2_READ_ACCESS_KEY_ID= 4 | export R2_READ_SECRET_ACCESS_KEY= 5 | export R2_WRITE_ACCESS_KEY_ID= 6 | export R2_WRITE_SECRET_ACCESS_KEY= 7 | export R2_ADMIN_ACCESS_KEY_ID= 8 | export R2_ADMIN_SECRET_ACCESS_KEY= 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PHONY: black isort flake8 2 | 3 | virtual: 4 | 5 | install: 6 | pip install -e . && env/bin/python post_install.py 7 | 8 | black: # Formats code with black 9 | black --config pyproject.toml ./ 10 | 11 | isort: isort # Sorts imports using isort 12 | isort *.py 13 | 14 | flake8: flake8 # Lints code using flake8 15 | flake8 *.py 16 | -------------------------------------------------------------------------------- /distributed_training/utils/dist.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import torch 3 | 4 | 5 | def gloabl_dist_checkpoint(check: bool, process_group): 6 | """ 7 | Returns True iff *all ranks* reported ok=True. 8 | Performs a distributed MIN reduction. 9 | """ 10 | t = torch.tensor([1 if check else 0], dtype=torch.int, device="cpu") 11 | dist.all_reduce(t, op=dist.ReduceOp.MIN, group=process_group) 12 | return t.item() == 1 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes==0.44.1 2 | bitarray==3.0.0 3 | datasets==3.0.2 4 | einops==0.8.1 5 | memory-profiler==0.61.0 6 | transformers==4.39.3 7 | wandb==0.19.11 8 | python-dotenv==1.0.1 9 | python-logging-loki==0.3.1 10 | speedtest-cli==2.1.3 11 | loguru==0.7.2 12 | flake8==7.0.0 13 | black==23.7.0 14 | isort==5.13.2 15 | expecttest==0.2.1 16 | torch==2.5.1 17 | influxdb-client==1.48.0 18 | sentencepiece==0.2.0 19 | openskill==6.1.3 20 | muon-optimizer @ git+https://github.com/KellerJordan/Muon@f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd 21 | rich==14.1.0 22 | bittensor-cli==9.11.2 23 | boto3==1.40.45 24 | -------------------------------------------------------------------------------- /distributed_training/averaging/exceptions.py: -------------------------------------------------------------------------------- 1 | class AllReduceError(Exception): 2 | """Base exception for AllReduce-related errors.""" 3 | 4 | pass 5 | 6 | 7 | class GradientAveragingTimeoutError(AllReduceError): 8 | """Raised when gradient averaging step times out.""" 9 | 10 | pass 11 | 12 | 13 | class GradientAveragingError(AllReduceError): 14 | """Raised when gradient averaging fails for non-timeout reasons.""" 15 | 16 | pass 17 | 18 | 19 | class StateAveragingError(AllReduceError): 20 | """Raised when state averaging fails.""" 21 | 22 | pass 23 | 24 | 25 | class ModelStateError(AllReduceError): 26 | """Raised when model weights are corrupted after an all reduce.""" 27 | 28 | pass 29 | -------------------------------------------------------------------------------- /distributed_training/utils/chain.py: -------------------------------------------------------------------------------- 1 | from distributed_training import __run__ 2 | 3 | 4 | def log_r2_to_chain(self): 5 | if self.master: 6 | try: 7 | metadata = ( 8 | self.config.r2.account_id 9 | + self.config.r2.read.access_key_id 10 | + self.config.r2.read.secret_access_key 11 | ) 12 | self.subtensor.commit(self.wallet, self.config.netuid, str(metadata)) 13 | self.r2_credentials_logged_to_chain = True 14 | self.logger.info(f"Metadata Dict Succesfully Logged To Chain.") 15 | except Exception as e: 16 | self.peer_id_logged_to_chain = False 17 | self.logger.info( 18 | f"Unable To Log Bucket Data To Chain Due To Error {e}. Retrying On The Next Step." 19 | ) 20 | -------------------------------------------------------------------------------- /min.compute.yml: -------------------------------------------------------------------------------- 1 | # PLEASE READ THIS 2 | # THIS SUBNET INCENTIVIZES COMPUTE AND BANDWIDTH. THE MORE DATA YOU CAN TRAIN ON AND 3 | # THE FASTER YOU CAN UPLOAD & DOWNLOAD GRADIENTS THE MORE INCENTIVE YOU'LL GAIN. 4 | 5 | # RUNPOD'S RTX 6xA40'S WITH 3 GBPS UPLOAD & DOWNLOAD SPEEDS PERFORM WELL. 6 | 7 | version: '1.0' 8 | 9 | compute_spec: 10 | 11 | miner: 12 | 13 | cpu: 14 | min_vcpu: 48 15 | min_memory_gb: 300 16 | min_disk_gb: 1000 17 | 18 | gpu: 19 | required: True 20 | min_vram_gb: 48 21 | recommended_gpu: "RTX A40" 22 | min_gpus: 4 23 | 24 | os: 25 | name: "Ubuntu" 26 | version: 20.04 27 | 28 | validator: 29 | 30 | cpu: 31 | min_vcpu: 48 32 | min_memory_gb: 300 33 | min_disk_gb: 1000 34 | 35 | gpu: 36 | required: True 37 | min_vram_gb: 48 38 | recommended_gpu: "RTX A6000" 39 | min_gpus: 4 40 | 41 | os: 42 | name: "Ubuntu" 43 | version: 20.04 44 | 45 | network_spec: 46 | bandwidth: # Gbps. THIS WILL ALSO IMPACT YOUR SCORE. THE QUICKER YOU DOWNLOAD & UPLOAD THE BETTER. 47 | min_download : 15 48 | recommended_download: 20 49 | min_upload: 15 50 | recommended_upload: 20 -------------------------------------------------------------------------------- /distributed_training/__init__.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2023 Karim Foda 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | # the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | # DEALINGS IN THE SOFTWARE. 18 | 19 | # version nomenclature = __training_type__.__model__.__other_changes__ 20 | __version__ = "1.2.6" 21 | __run__ = "4" 22 | version_split = __version__.split(".") 23 | __spec_version__ = ( 24 | (1000 * int(version_split[0])) 25 | + (10 * int(version_split[1])) 26 | + (1 * int(version_split[2])) 27 | ) 28 | 29 | # Import all submodules. 30 | from . import protocol 31 | -------------------------------------------------------------------------------- /distributed_training/utils/upload_worker.py: -------------------------------------------------------------------------------- 1 | # upload_worker.py 2 | import sys 3 | import pathlib 4 | import boto3 5 | from botocore.config import Config 6 | 7 | from distributed_training.utils.r2 import ( 8 | upload_folder_to_r2, 9 | archive_root_bucket, 10 | restore_from_epoch, 11 | ) 12 | 13 | if __name__ == "__main__": 14 | bucket = sys.argv[1] 15 | r2_account_id = sys.argv[2] 16 | r2_write_access_access_key_id = sys.argv[3] 17 | r2_write_access_secret_access_key = sys.argv[4] 18 | tag = sys.argv[5] 19 | archive = sys.argv[6] 20 | epoch = tag.split(".")[1] 21 | prefix = f"epoch-{epoch}/" 22 | restore = "True" 23 | 24 | r2_write = boto3.client( 25 | "s3", 26 | endpoint_url=f"https://{r2_account_id}.r2.cloudflarestorage.com", 27 | aws_access_key_id=r2_write_access_access_key_id, 28 | aws_secret_access_key=r2_write_access_secret_access_key, 29 | region_name="auto", 30 | config=Config( 31 | retries={"max_attempts": 10, "mode": "adaptive"}, # or "standard" 32 | connect_timeout=30, 33 | read_timeout=120, 34 | max_pool_connections=50, 35 | ), 36 | ) 37 | 38 | upload_folder_to_r2(r2_write, bucket, prefix) 39 | # Only archive on the miner side after an AllReduce 40 | # Variable has to be fed as a string in subprocess 41 | if archive == "True": 42 | archive_root_bucket(r2_write, bucket, epoch) 43 | 44 | # if restore == "True": 45 | # restore_from_epoch(r2_write, bucket, epoch) 46 | 47 | # r2_write.upload_file( 48 | # str("/root/llama-4b-ws-4/metadata.json"), bucket, f"metadata.json" 49 | # ) 50 | -------------------------------------------------------------------------------- /distributed_training/scripts/cleanup_r2_bucket.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import boto3 5 | from dotenv import load_dotenv 6 | 7 | load_dotenv() 8 | 9 | bucket = "llama-1b-ws-4-000" 10 | 11 | r2 = boto3.client( 12 | "s3", 13 | endpoint_url=f"https://{os.environ['R2_ACCOUNT_ID']}.r2.cloudflarestorage.com", 14 | aws_access_key_id=os.environ["R2_ADMIN_ACCESS_KEY_ID"], 15 | aws_secret_access_key=os.environ["R2_ADMIN_SECRET_ACCESS_KEY"], 16 | region_name="auto", 17 | ) 18 | 19 | 20 | # 1️⃣ Delete all objects 21 | def delete_all_objects(bucket): 22 | print("Deleting objects...") 23 | paginator = r2.get_paginator("list_objects_v2") 24 | for page in paginator.paginate(Bucket=bucket): 25 | objs = page.get("Contents", []) 26 | if not objs: 27 | continue 28 | to_delete = [{"Key": o["Key"]} for o in objs] 29 | r2.delete_objects(Bucket=bucket, Delete={"Objects": to_delete}) 30 | print(f"Deleted {len(to_delete)} objects") 31 | print("✅ All objects deleted") 32 | 33 | 34 | # 2️⃣ Abort any ongoing multipart uploads 35 | def abort_all_multipart(bucket): 36 | print("Aborting multipart uploads...") 37 | while True: 38 | resp = r2.list_multipart_uploads(Bucket=bucket) 39 | uploads = resp.get("Uploads", []) 40 | if not uploads: 41 | break 42 | for u in uploads: 43 | r2.abort_multipart_upload( 44 | Bucket=bucket, Key=u["Key"], UploadId=u["UploadId"] 45 | ) 46 | print(f"Aborted {len(uploads)} uploads") 47 | # small delay to let R2 finalize the aborts 48 | time.sleep(0.5) 49 | print("✅ All multipart uploads aborted") 50 | 51 | 52 | # 3️⃣ Now delete the bucket itself 53 | def delete_bucket(bucket): 54 | print("Deleting bucket...") 55 | r2.delete_bucket(Bucket=bucket) 56 | print("✅ Bucket deleted") 57 | 58 | 59 | # Execute 60 | delete_all_objects(bucket) 61 | abort_all_multipart(bucket) 62 | delete_bucket(bucket) 63 | -------------------------------------------------------------------------------- /distributed_training/protocol.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2023 Karim Foda 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | # the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | # DEALINGS IN THE SOFTWARE. 18 | 19 | from typing import Optional 20 | 21 | import bittensor as bt 22 | import pydantic 23 | 24 | 25 | class IsAlive(bt.Synapse): 26 | answer: Optional[str] = None 27 | completion: str = pydantic.Field( 28 | "", 29 | title="Completion", 30 | description="Completion status of the current StreamPrompting object. " 31 | "This attribute is mutable and can be updated.", 32 | ) 33 | epoch: Optional[int] = None 34 | 35 | 36 | class AllReduce(bt.Synapse): 37 | answer: Optional[str] = None 38 | completion: bool = pydantic.Field( 39 | None, 40 | title="Completion", 41 | description="Completion status of the current StreamPrompting object. " 42 | "This attribute is mutable and can be updated.", 43 | ) 44 | min_group_size: Optional[int] = None 45 | request_timeout: Optional[float] = None 46 | allreduce_timeout: Optional[float] = None 47 | next_chunk_timeout: Optional[float] = None 48 | min_matchmaking_time: Optional[float] = None 49 | -------------------------------------------------------------------------------- /distributed_training/utils/dendrite.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | 3 | from bittensor.core.axon import Axon 4 | from bittensor.core.dendrite import DendriteMixin 5 | from bittensor.utils.registration import torch, use_torch 6 | from typing import Optional, Union, List 7 | from bittensor_wallet import Keypair, Wallet 8 | 9 | 10 | class DTDendriteMixin(DendriteMixin): 11 | def __init__(self, wallet, connection_limit=100): 12 | self._connection_limit = connection_limit 13 | super().__init__(wallet) 14 | 15 | @property 16 | async def session(self) -> aiohttp.ClientSession: 17 | """ 18 | An asynchronous property that provides access to the internal `aiohttp `_ 19 | client session. 20 | 21 | This property ensures the management of HTTP connections in an efficient way. It lazily 22 | initializes the `aiohttp.ClientSession `_ 23 | on its first use. The session is then reused for subsequent HTTP requests, offering performance benefits by 24 | reusing underlying connections. 25 | 26 | This is used internally by the dendrite when querying axons, and should not be used directly 27 | unless absolutely necessary for your application. 28 | 29 | Returns: 30 | aiohttp.ClientSession: The active `aiohttp `_ client session instance. 31 | If no session exists, a new one is created and returned. This session is used for asynchronous HTTP requests 32 | within the dendrite, adhering to the async nature of the network interactions in the Bittensor framework. 33 | 34 | Example usage:: 35 | 36 | import bittensor # Import bittensor 37 | wallet = bittensor.Wallet( ... ) # Initialize a wallet 38 | dendrite = bittensor.Dendrite(wallet=wallet) # Initialize a dendrite instance with the wallet 39 | 40 | async with (await dendrite.session).post( # Use the session to make an HTTP POST request 41 | url, # URL to send the request to 42 | headers={...}, # Headers dict to be sent with the request 43 | json={...}, # JSON body data to be sent with the request 44 | timeout=10, # Timeout duration in seconds 45 | ) as response: 46 | json_response = await response.json() # Extract the JSON response from the server 47 | 48 | """ 49 | if self._session is None: 50 | self._session = aiohttp.ClientSession( 51 | connector=aiohttp.TCPConnector(limit=self._connection_limit) 52 | ) 53 | return self._session 54 | 55 | 56 | # For back-compatibility with torch 57 | BaseModel: Union["torch.nn.Module", object] = torch.nn.Module if use_torch() else object 58 | 59 | 60 | class DTDendrite(DTDendriteMixin, BaseModel): # type: ignore 61 | def __init__( 62 | self, 63 | wallet: Optional[Union["Wallet", "Keypair"]] = None, 64 | connection_limit: int = 100, 65 | ): 66 | if use_torch(): 67 | torch.nn.Module.__init__(self) 68 | DTDendriteMixin.__init__(self, wallet, connection_limit) 69 | 70 | 71 | if not use_torch(): 72 | 73 | async def call(self, *args, **kwargs): 74 | return await self.forward(*args, **kwargs) 75 | 76 | DTDendrite.__call__ = call 77 | 78 | 79 | async def async_dendrite_forward( 80 | wallet: "Wallet" = None, 81 | axons: List["Axon"] = [], 82 | synapse=None, 83 | connection_limit: int = 100, 84 | timeout: float = 30.0, 85 | ): 86 | async with DTDendrite(wallet, connection_limit=connection_limit) as d: 87 | await d(axons, synapse=synapse, timeout=timeout) 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | bittensor-subnet-template/ 162 | wandb/ 163 | .vscode/ 164 | logs* 165 | bittensor/ 166 | hivemind/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # KMFODA 4 | # Copyright © 2023 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 12 | # the Software. 13 | 14 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 15 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import codecs 21 | import os 22 | import re 23 | import subprocess 24 | from io import open 25 | from os import path 26 | 27 | from pkg_resources import parse_requirements 28 | from setuptools import find_packages, setup 29 | from setuptools.command.develop import develop 30 | from setuptools.command.egg_info import egg_info 31 | from setuptools.command.install import install 32 | 33 | 34 | def custom_command(): 35 | import pip 36 | 37 | pip.main( 38 | [ 39 | "install", 40 | "git+https://github.com/learning-at-home/hivemind.git@3a4cc15e29ce51b20c5d415a4c579abbae435718", 41 | ] 42 | ) 43 | pip.main(["install", "bittensor==9.12.2"]) 44 | pip.main(["install", "py-multihash==2.0.1"]) 45 | 46 | # Install Go and HFDownloader 47 | try: 48 | subprocess.run(["apt-get", "update"], check=True) 49 | subprocess.run(["apt-get", "install", "-y", "golang"], check=True) 50 | subprocess.run( 51 | ["go", "install", "github.com/lxe/hfdownloader@latest"], check=True 52 | ) 53 | 54 | # Add Go bin to PATH in venv activate script 55 | if "VIRTUAL_ENV" in os.environ: 56 | activate_script = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "activate") 57 | if os.path.exists(activate_script): 58 | with open(activate_script, "a") as f: 59 | f.write('\nexport PATH="$HOME/go/bin:$PATH"\n') 60 | # Also update current session's PATH 61 | go_bin_path = os.path.expanduser("~/go/bin") 62 | if go_bin_path not in os.environ["PATH"]: 63 | os.environ["PATH"] = f"{go_bin_path}:{os.environ['PATH']}" 64 | 65 | except Exception as e: 66 | raise RuntimeError(f"Failed to install Go and HFDownloader: {str(e)}") 67 | 68 | 69 | class CustomInstallCommand(install): 70 | def run(self): 71 | custom_command() 72 | install.run(self) 73 | 74 | 75 | class CustomDevelopCommand(develop): 76 | def run(self): 77 | custom_command() 78 | develop.run(self) 79 | 80 | 81 | class CustomEggInfoCommand(egg_info): 82 | def run(self): 83 | custom_command() 84 | egg_info.run(self) 85 | 86 | 87 | with open("requirements.txt") as requirements_file: 88 | requirements = list(map(str, parse_requirements(requirements_file))) 89 | 90 | here = path.abspath(path.dirname(__file__)) 91 | 92 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 93 | long_description = f.read() 94 | 95 | # loading version from setup.py 96 | with codecs.open( 97 | os.path.join(here, "distributed_training/__init__.py"), encoding="utf-8" 98 | ) as init_file: 99 | version_match = re.search( 100 | r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M 101 | ) 102 | version_string = version_match.group(1) 103 | 104 | setup( 105 | name="distributed_training", 106 | version=version_string, 107 | description="distributed_training", 108 | long_description=long_description, 109 | long_description_content_type="text/markdown", 110 | url="https://github.com/dstrbtd/DistributedTraining", 111 | author="KMFODA", 112 | packages=find_packages(), 113 | include_package_data=True, 114 | author_email="", 115 | license="MIT", 116 | python_requires=">=3.8", 117 | cmdclass={ 118 | "install": CustomInstallCommand, 119 | "develop": CustomDevelopCommand, 120 | "egg_info": CustomEggInfoCommand, 121 | }, 122 | setup_requires=["pip"], 123 | install_requires=requirements, 124 | classifiers=[ 125 | "Development Status :: 3 - Alpha", 126 | "Intended Audience :: Developers", 127 | "Topic :: Software Development :: Build Tools", 128 | # Pick your license as you wish 129 | "License :: OSI Approved :: MIT License", 130 | "Programming Language :: Python :: 3 :: Only", 131 | "Programming Language :: Python :: 3.8", 132 | "Programming Language :: Python :: 3.9", 133 | "Programming Language :: Python :: 3.10", 134 | "Topic :: Scientific/Engineering", 135 | "Topic :: Scientific/Engineering :: Mathematics", 136 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 137 | "Topic :: Software Development", 138 | "Topic :: Software Development :: Libraries", 139 | "Topic :: Software Development :: Libraries :: Python Modules", 140 | ], 141 | ) 142 | -------------------------------------------------------------------------------- /distributed_training/utils/progress_tracker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch.distributed as dist 3 | 4 | from dataclasses import dataclass 5 | from distributed_training import __run__ 6 | from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint 7 | 8 | 9 | @dataclass(frozen=False) 10 | class GlobalTrainingProgress: 11 | epoch: int 12 | samples_accumulated: int 13 | 14 | 15 | class LocalTrainingProgress(BaseModel): 16 | peer_id: bytes 17 | epoch: conint(ge=0, strict=True) 18 | samples_accumulated: conint(ge=0, strict=True) 19 | samples_per_second: confloat(ge=0.0, strict=True) 20 | time: StrictFloat 21 | client_mode: StrictBool 22 | inner_step: conint(ge=0, strict=True) 23 | loss: confloat(ge=0.0, strict=True) 24 | 25 | 26 | from pydantic import BaseModel, Field, field_validator 27 | from typing import Optional 28 | 29 | 30 | class SkillRating(BaseModel): 31 | mu: float 32 | sigma: float 33 | 34 | 35 | class LossProfile(BaseModel): 36 | before: float = 0.0 37 | after: float = 0.0 38 | absolute: float = 0.0 39 | relative: float = 0.0 40 | score: float = 0.0 41 | 42 | 43 | class Chaindata(BaseModel): 44 | last_updated_block: int = 0 45 | 46 | 47 | class ScoreAllReduce(BaseModel): 48 | peer_id: Optional[str] = None 49 | score: float = 0.0 50 | count: int = 0 51 | 52 | 53 | class ScoreTrain(BaseModel): 54 | r2_hash: Optional[str] = None 55 | model_id: Optional[str] = None 56 | account_id: Optional[str] = "x" * 32 57 | access_key_id: Optional[str] = "x" * 32 58 | secret_access_key: Optional[str] = "x" * 64 59 | is_valid: bool = True 60 | random: LossProfile = Field(default_factory=LossProfile) 61 | assigned: LossProfile = Field(default_factory=LossProfile) 62 | openskill_rating: float = 0.0 63 | score: float = 0.0 64 | updated_time: float = 0 65 | revision: str = "0.0.0" 66 | openskill_rating: SkillRating = Field( 67 | default_factory=lambda: SkillRating(mu=25.0, sigma=8.333) 68 | ) 69 | 70 | 71 | class ScoreTotal(BaseModel): 72 | score: float = 0.0 73 | 74 | 75 | class UidTracker(BaseModel): 76 | uid: int 77 | all_reduce: ScoreAllReduce = Field(default_factory=ScoreAllReduce) 78 | train: ScoreTrain = Field(default_factory=ScoreTrain) 79 | total: ScoreTotal = Field(default_factory=ScoreTotal) 80 | chaindata: Chaindata = Field(default_factory=Chaindata) 81 | 82 | 83 | def get_r2_client(self, uid: int, donwload_on_all_ranks: bool): 84 | if uid == self.uid: 85 | account_id = self.config.r2.account_id 86 | access_key_id = self.config.r2.read.access_key_id 87 | secret_access_key = self.config.r2.read.secret_access_key 88 | elif uid == self.master_uid: 89 | return self.r2["global"] 90 | elif donwload_on_all_ranks: 91 | if self.master: 92 | account_id = self.uid_tracker[uid].train.account_id 93 | access_key_id = self.uid_tracker[uid].train.access_key_id 94 | secret_access_key = self.uid_tracker[uid].train.secret_access_key 95 | else: 96 | account_id = self.config.r2.account_id 97 | access_key_id = self.config.r2.read.access_key_id 98 | secret_access_key = self.config.r2.read.secret_access_key 99 | commitment = [account_id + access_key_id + secret_access_key] 100 | dist.broadcast_object_list(commitment, src=0, group=self.gloo_group) 101 | self.logger.debug(f"UID {uid:03d}: Commitment - {commitment}") 102 | account_id = commitment[0][:32] 103 | access_key_id = commitment[0][32:64] 104 | secret_access_key = commitment[0][64:] 105 | else: 106 | account_id = self.uid_tracker[uid].train.account_id 107 | access_key_id = self.uid_tracker[uid].train.access_key_id 108 | secret_access_key = self.uid_tracker[uid].train.secret_access_key 109 | 110 | self.logger.debug(account_id, access_key_id, secret_access_key) 111 | 112 | if ( 113 | (account_id == "x" * 32) 114 | or (access_key_id == "x" * 32) 115 | or (secret_access_key == "x" * 64) 116 | ): 117 | raise ValueError(f"UID {uid:03d} has no R2 credentials.") 118 | 119 | return self.session.client( 120 | "s3", 121 | endpoint_url=f"https://{account_id}.r2.cloudflarestorage.com", 122 | aws_access_key_id=access_key_id, 123 | aws_secret_access_key=secret_access_key, 124 | region_name="auto", 125 | ) 126 | 127 | 128 | def get_progress( 129 | self, 130 | local_or_global: str, 131 | bucket_name: str = None, 132 | uid: int = None, 133 | epoch: int = None, 134 | donwload_on_all_ranks=True, 135 | ): 136 | # local_or_global is used for miners 137 | # uid is used for validators to cycle through progress of different uids 138 | if (local_or_global != "global") and (bucket_name is None) and (uid is None): 139 | bucket_name = self.config.r2.bucket_name 140 | elif (local_or_global != "global") and (uid == self.master_uid): 141 | bucket_name = f"{self.config.neuron.global_model_name}-{uid:03d}" 142 | elif (uid is not None) and (uid != self.master_uid): 143 | bucket_name = f"{self.config.neuron.global_model_name}-{uid:03d}" 144 | elif (local_or_global == "global") or (uid == self.master_uid): 145 | bucket_name = self.config.neuron.global_model_name 146 | 147 | try: 148 | if uid is not None: 149 | r2 = get_r2_client(self, uid, donwload_on_all_ranks) 150 | else: 151 | r2 = self.r2[local_or_global] 152 | 153 | obj = r2.get_object(Bucket=bucket_name, Key="metadata.json") 154 | data = obj["Body"].read() 155 | metadata = json.loads(data) 156 | local_epoch = metadata["outer_step"] 157 | local_inner_step = metadata["inner_step"] 158 | local_peer_id = metadata["peer_id"] 159 | return local_epoch, local_inner_step, local_peer_id 160 | except Exception as e: 161 | self.logger.debug(f"Error in get_progress: {str(e)}") 162 | return None, 0, None 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # **Distributed Training Subnet** 5 | [![Discord Chat](https://img.shields.io/discord/308323056592486420.svg)](https://discord.gg/bittensor) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 7 | 8 |
9 | 10 | --- 11 | 12 | # Overview 13 | [Blog post](https://distributed-training.notion.site/Decentralised-Distributed-Training-fd21bdfa72294dfeab8fb092770212b9) 14 | 15 | # Minimum Requirements 16 | 17 | [min.compute.yml](min.compute.yml) 18 | 19 | # Installation 20 | This repository requires python3.10 or higher. To install, simply clone this repository and install the requirements. 21 | 22 | 1. Install this repository 23 | ```bash 24 | git clone https://github.com/KMFODA/DistributedTraining 25 | cd DistributedTraining 26 | pip install -e . 27 | ``` 28 | 29 | 2. Log in to wandb: 30 | ```bash 31 | wandb login 32 | ``` 33 | 34 | 3. Log in to huggingface 35 | ```bash 36 | huggingface-cli login 37 | ``` 38 | 39 | 4. Register your hotkey 40 | ```bash 41 | btcli subnets register --subtensor.network finney --netuid $NETUID --wallet.name $WALLET_NAME --wallet.hotkey $HOTKEY_NAME 42 | ``` 43 | 44 | 5. Create a R2 bucket which follows the below format. 45 | ```bash 46 | f"{self.config.neuron.global_model_name}-{uid:03d}" 47 | ``` 48 | For example uid `1` under a global_model_name `llama-4b-ws-4` would have the name `llama-4b-ws-4-001`. 49 | 50 | 6. Create READ and WRITE that cover all your applciable R2 buckets. Add them to `.env.example` and rename the file `.env`. 51 | 52 | 7. Install [PM2](https://pm2.io/docs/runtime/guide/installation/) and the [`jq` package](https://jqlang.github.io/jq/) on your system. 53 | 54 | **On Linux**: 55 | ```bash 56 | sudo apt update && sudo apt install jq && sudo apt install npm && sudo npm install pm2 -g && pm2 update 57 | ``` 58 | **On Mac OS** 59 | ```bash 60 | brew update && brew install jq && brew install npm && sudo npm install pm2 -g && pm2 update 61 | ``` 62 | 63 | --- 64 | # Running a Miner 65 | Once you have installed this repo you can run a miner with **auto updates enabled** using the following commands: 66 | ```bash 67 | chmod +x run_miner.sh 68 | pm2 start run_miner.sh --name distributed_training_miner_auto_update -- 69 | --nproc_per_node # Must be algined to the maximum number of gpus on your VM 70 | --netuid # Must be attained by following the instructions in the docs/running_on_*.md files 71 | --subtensor.chain_endpoint # Must be attained by following the instructions in the docs/running_on_*.md files 72 | --wallet.name # Must be created using the bittensor-cli 73 | --wallet.hotkey # Must be created using the bittensor-cli 74 | --axon.port 75 | --dht.port 76 | --dht.ip 77 | --show_all_rank_logs # Only enable for debugging 78 | ``` 79 | --- 80 | 81 | # Running a Validator 82 | Once you have installed this repo you can then run a validator **auto updates enabled** using the following command: 83 | ```bash 84 | chmod +x run_validator.sh 85 | pm2 start run_validator.sh --name distributed_training_auto_update -- 86 | --nproc_per_node # Must be algined to the maximum number of gpus on your VM 87 | --netuid # Must be attained by following the instructions in the docs/running_on_*.md files 88 | --subtensor.chain_endpoint # Must be attained by following the instructions in the docs/running_on_*.md files 89 | --wallet.name # Must be created using the bittensor-cli 90 | --wallet.hotkey # Must be created using the bittensor-cli 91 | --logging.debug # Run in debug mode, alternatively --logging.trace for trace mode 92 | --axon.port 93 | --dht.port 94 | --dht.ip 95 | --show_all_rank_logs # Only enable for debugging 96 | ``` 97 | 98 | 99 | 100 | --- 101 | 102 | ## Known Errors 103 | Currently this subnet still relies on the awesome [hivemind](https://github.com/learning-at-home/hivemind) library to facilitate the all-reduce part of distributed training. This library runs multiple asynchronous porcesses in the background and sometimes these fail. It is desinged in a way such that if some of these failures occur training still progresses. Here are some of the most common errors. 104 | 105 | **Asyncio Timeout Error**: 106 | ![Image](assets/error_asyncio_timeout.png) 107 | 108 | This happens when one of the various async processes times out. If your logs continue after this error and you still receive validator calls your miner will still gain incentive. 109 | 110 | **Load State From Peer Error**: 111 | ![Image](assets/error_download_state_from_peers.png) 112 | 113 | This happens when a validator tries to pull the latest model state frorm another peer and fails to do so in the timeout period. This is most likely due to low bandwidth on either your or your peers side. So long as your bandwidth on WandB is above the minimum requirements this won't impact your incentive. 114 | 115 | **Averaging step failed: could not find a group**: 116 | ![Image](assets/error_could_not_find_a_group_error.png) 117 | 118 | This occurs when your miner hasn't been able to find a group to join to perform the all-reduce round. This might be due to low bandwidth or issues with your DHT connecting with other DHTs. Make sure your bandwidth is above the minimum requirements and that you aren't running any other background processes or miners on the same machine. Getting this error once shouldn't have a huge impact on incentive but if it keeps repeating incentives will drop. 119 | 120 | **Failed to connect to DHT address**: 121 | ![Image](assets/error_failed_to_connect_to_DHT.png) 122 | 123 | This error indicates that you are failing to connect to some of the DHT addresses in the initial_peers list. This isn't a breaking error if you just have 1 successful attempt at the end of these retries. Many retries are expected as nodes drop out of training and leave their DHT's idle in the background. 124 | 125 | ## License 126 | This repository is licensed under the MIT License. 127 | ```text 128 | # The MIT License (MIT) 129 | # Copyright © 2023 Yuma Rao 130 | 131 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 132 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 133 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 134 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 135 | 136 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 137 | # the Software. 138 | 139 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 140 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 141 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 142 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 143 | # DEALINGS IN THE SOFTWARE. 144 | ``` 145 | -------------------------------------------------------------------------------- /distributed_training/utils/r2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import os 4 | import tempfile 5 | import filelock 6 | import datetime 7 | import pathlib 8 | import json 9 | import torch.distributed as dist 10 | 11 | from boto3.s3.transfer import TransferConfig 12 | from concurrent.futures import ThreadPoolExecutor 13 | from botocore.client import BaseClient 14 | from s3transfer.manager import TransferManager 15 | from boto3.s3.transfer import TransferConfig 16 | from distributed_training import __run__ 17 | 18 | ACCEPTED_FILES = [ 19 | "config.json", 20 | "model.safetensors", 21 | "gradients.pt", 22 | "inner_optimizer.rank0001-of-4.pt", 23 | "inner_optimizer.rank0002-of-4.pt", 24 | "inner_optimizer.rank0003-of-4.pt", 25 | "inner_optimizer.rank0004-of-4.pt", 26 | "outer_optimizer.pt", 27 | ] 28 | 29 | 30 | def upload_folder_to_r2(r2, bucket, prefix="", max_workers=8): 31 | local_folder = pathlib.Path(bucket) 32 | 33 | files = [p for p in local_folder.rglob("*") if p.is_file()] 34 | # print(f"Uploading {len(files)} files with {max_workers} threads...") 35 | 36 | pbar = tqdm.tqdm(total=len(files)) 37 | 38 | def _upload(path): 39 | key = f"{prefix}{path.relative_to(local_folder)}" 40 | print(key) 41 | # key = str(path.relative_to(local_folder)) 42 | size = os.path.getsize(path) 43 | 44 | if key.split("/")[-1] not in ACCEPTED_FILES: 45 | return key 46 | 47 | if size > 512: 48 | threshold = 64 49 | workers = 12 50 | else: 51 | threshold = 32 52 | workers = 8 53 | 54 | if size < 512 * 1024**2: # < 512 MB 55 | threshold, workers = 16, 6 56 | elif size < 2 * 1024**3: # < 2 GB 57 | threshold, workers = 32, 10 58 | else: # > 2 GB 59 | threshold, workers = 64, 14 60 | 61 | cfg = TransferConfig( 62 | multipart_threshold=threshold * 1024 * 1024, # 8 MB 63 | multipart_chunksize=threshold * 1024 * 1024, 64 | max_concurrency=workers, 65 | use_threads=True, 66 | ) 67 | r2.upload_file(str(path), bucket, key, Config=cfg) 68 | pbar.update(1) 69 | return key 70 | 71 | with ThreadPoolExecutor(max_workers=max_workers) as ex: 72 | for i, key in enumerate(ex.map(_upload, files), 1): 73 | if i % 10 == 0 or i == len(files): 74 | pass 75 | # print(f"Uploaded {i}/{len(files)}") 76 | 77 | # print("✅ Upload complete") 78 | # print(datetime.datetime.now()) 79 | 80 | 81 | def archive_root_bucket(r2: BaseClient, bucket: str, epoch: int): 82 | print("⌛️ Archive start") 83 | print(datetime.datetime.now()) 84 | # multipart thresholds/chunks; tune as needed 85 | tcfg = TransferConfig( 86 | multipart_threshold=8 * 1024 * 1024, # 8MB 87 | multipart_chunksize=64 * 1024 * 1024, # 64MB 88 | max_concurrency=4, 89 | use_threads=True, 90 | ) 91 | 92 | archive_prefix = f"epoch-{epoch}/" 93 | paginator = r2.get_paginator("list_objects_v2") 94 | with TransferManager(r2, config=tcfg) as tm: 95 | futures = [] 96 | for page in paginator.paginate(Bucket=bucket): 97 | for obj in page.get("Contents", []): 98 | key = obj["Key"] 99 | 100 | # ✅ skip pseudo-folders or empty keys 101 | if ( 102 | (not key) 103 | or ("epoch-" in key) 104 | or (obj["Size"] == 0) 105 | or (key not in ACCEPTED_FILES) 106 | ): 107 | continue 108 | 109 | dest_key = f"{archive_prefix}{key}" 110 | 111 | futures.append( 112 | tm.copy( 113 | copy_source={"Bucket": bucket, "Key": key}, 114 | bucket=bucket, 115 | key=dest_key, 116 | extra_args={"MetadataDirective": "COPY"}, 117 | ) 118 | ) 119 | 120 | # wait for all copies to finish (raises on failure) 121 | for f in futures: 122 | f.result() 123 | r2.close() 124 | print("✅ Archive complete") 125 | print(datetime.datetime.now()) 126 | 127 | 128 | def restore_from_epoch(r2: BaseClient, bucket: str, epoch: int): 129 | """ 130 | Copies all objects from epoch-{epoch}/ back into the main bucket root. 131 | """ 132 | tcfg = TransferConfig( 133 | multipart_threshold=8 * 1024 * 1024, # 8MB 134 | multipart_chunksize=64 * 1024 * 1024, # 64MB 135 | max_concurrency=4, 136 | use_threads=True, 137 | ) 138 | 139 | source_prefix = f"epoch-{epoch}/" 140 | paginator = r2.get_paginator("list_objects_v2") 141 | 142 | with TransferManager(r2, config=tcfg) as tm: 143 | futures = [] 144 | for page in paginator.paginate(Bucket=bucket, Prefix=source_prefix): 145 | for obj in page.get("Contents", []): 146 | key = obj["Key"] 147 | 148 | # skip empty or malformed entries 149 | if not key or obj["Size"] == 0: 150 | continue 151 | 152 | # remove the epoch prefix so files go to root 153 | dest_key = key[len(source_prefix) :] 154 | 155 | # skip if key would become empty (folder marker) 156 | if not dest_key: 157 | continue 158 | 159 | futures.append( 160 | tm.copy( 161 | copy_source={"Bucket": bucket, "Key": key}, 162 | bucket=bucket, 163 | key=dest_key, 164 | extra_args={"MetadataDirective": "COPY"}, 165 | ) 166 | ) 167 | 168 | for f in futures: 169 | f.result() 170 | 171 | r2.close() 172 | 173 | 174 | def r2_download( 175 | self, 176 | r2, 177 | bucket, 178 | key, 179 | donwload_on_all_ranks=True, 180 | run_on_all_ranks=True, 181 | destination=None, 182 | ): 183 | if destination is None: 184 | fd, destination_path = tempfile.mkstemp() 185 | os.close(fd) 186 | else: 187 | destination_path = destination 188 | destination_path = os.path.join( 189 | destination_path, os.path.basename(key.split("/")[-1]) 190 | ) 191 | 192 | # Let only the master perform the actual download 193 | if (self.master) or (donwload_on_all_ranks): 194 | try: 195 | os.makedirs(os.path.dirname(destination_path), exist_ok=True) 196 | lock_path = destination_path + ".lock" 197 | with filelock.FileLock(lock_path): 198 | r2.download_file(bucket, key, destination_path) 199 | success = torch.tensor([1], dtype=torch.int, device="cuda") 200 | except Exception as e: 201 | self.logger.info(f"Download failed due to error: {e}") 202 | success = torch.tensor([0], dtype=torch.int, device="cuda") 203 | else: 204 | success = torch.tensor([0], dtype=torch.int, device="cuda") 205 | 206 | if donwload_on_all_ranks or run_on_all_ranks: 207 | # Broadcast success flag from master to everyone 208 | dist.broadcast(success, src=0) 209 | 210 | # If master failed, all ranks raise the same error 211 | if success.item() == 0: 212 | raise RuntimeError("Master rank failed during r2_download().") 213 | 214 | return destination_path 215 | 216 | 217 | def log_peerid_to_r2(self, prefix=""): 218 | if self.master: 219 | # Save metadata 220 | metadata = { 221 | "run": int(__run__), 222 | "outer_step": int(self.local_progress.epoch), 223 | "inner_step": int(self.local_progress.inner_step), 224 | "peer_id": str(self.dht.peer_id.to_base58()), 225 | } 226 | with open(os.path.join(self.output_dir, f"metadata.json"), "w") as f: 227 | json.dump(metadata, f, indent=4, sort_keys=True) 228 | # Upload Peer Metadata With Updated Peer ID 229 | self.r2["write"].upload_file( 230 | str(os.path.join(self.output_dir, "metadata.json")), 231 | self.config.r2.bucket_name, 232 | f"{prefix}metadata.json", 233 | ) 234 | -------------------------------------------------------------------------------- /distributed_training/utils/weight_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Tuple, List, Union, Any 3 | import bittensor 4 | from numpy import ndarray, dtype, floating, complexfloating 5 | 6 | U32_MAX = 4294967295 7 | U16_MAX = 65535 8 | 9 | 10 | def normalize_max_weight(x: np.ndarray, limit: float = 0.1) -> np.ndarray: 11 | r"""Normalizes the numpy array x so that sum(x) = 1 and the max value is not greater than the limit. 12 | Args: 13 | x (:obj:`np.ndarray`): 14 | Array to be max_value normalized. 15 | limit: float: 16 | Max value after normalization. 17 | Returns: 18 | y (:obj:`np.ndarray`): 19 | Normalized x array. 20 | """ 21 | epsilon = 1e-7 # For numerical stability after normalization 22 | 23 | weights = x.copy() 24 | values = np.sort(weights) 25 | 26 | if x.sum() == 0 or len(x) * limit <= 1: 27 | return np.ones_like(x) / x.size 28 | else: 29 | estimation = values / values.sum() 30 | 31 | if estimation.max() <= limit: 32 | return weights / weights.sum() 33 | 34 | # Find the cumulative sum and sorted array 35 | cumsum = np.cumsum(estimation, 0) 36 | 37 | # Determine the index of cutoff 38 | estimation_sum = np.array( 39 | [(len(values) - i - 1) * estimation[i] for i in range(len(values))] 40 | ) 41 | n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum() 42 | 43 | # Determine the cutoff based on the index 44 | cutoff_scale = (limit * cumsum[n_values - 1] - epsilon) / ( 45 | 1 - (limit * (len(estimation) - n_values)) 46 | ) 47 | cutoff = cutoff_scale * values.sum() 48 | 49 | # Applying the cutoff 50 | weights[weights > cutoff] = cutoff 51 | 52 | y = weights / weights.sum() 53 | 54 | return y 55 | 56 | 57 | def convert_weights_and_uids_for_emit( 58 | uids: np.ndarray, weights: np.ndarray 59 | ) -> Tuple[List[int], List[int]]: 60 | r"""Converts weights into integer u32 representation that sum to MAX_INT_WEIGHT. 61 | Args: 62 | uids (:obj:`np.ndarray,`): 63 | Array of uids as destinations for passed weights. 64 | weights (:obj:`np.ndarray,`): 65 | Array of weights. 66 | Returns: 67 | weight_uids (List[int]): 68 | Uids as a list. 69 | weight_vals (List[int]): 70 | Weights as a list. 71 | """ 72 | # Checks. 73 | uids = np.asarray(uids) 74 | weights = np.asarray(weights) 75 | 76 | # Get non-zero weights and corresponding uids 77 | non_zero_weights = weights[weights > 0] 78 | non_zero_weight_uids = uids[weights > 0] 79 | 80 | # Debugging information 81 | bittensor.logging.debug(f"weights: {weights}") 82 | bittensor.logging.debug(f"non_zero_weights: {non_zero_weights}") 83 | bittensor.logging.debug(f"uids: {uids}") 84 | bittensor.logging.debug(f"non_zero_weight_uids: {non_zero_weight_uids}") 85 | 86 | if np.min(weights) < 0: 87 | raise ValueError( 88 | "Passed weight is negative cannot exist on chain {}".format(weights) 89 | ) 90 | if np.min(uids) < 0: 91 | raise ValueError("Passed uid is negative cannot exist on chain {}".format(uids)) 92 | if len(uids) != len(weights): 93 | raise ValueError( 94 | "Passed weights and uids must have the same length, got {} and {}".format( 95 | len(uids), len(weights) 96 | ) 97 | ) 98 | if np.sum(weights) == 0: 99 | bittensor.logging.debug("nothing to set on chain") 100 | return [], [] # Nothing to set on chain. 101 | else: 102 | max_weight = float(np.max(weights)) 103 | weights = [ 104 | float(value) / max_weight for value in weights 105 | ] # max-upscale values (max_weight = 1). 106 | bittensor.logging.debug( 107 | f"setting on chain max: {max_weight} and weights: {weights}" 108 | ) 109 | 110 | weight_vals = [] 111 | weight_uids = [] 112 | for i, (weight_i, uid_i) in enumerate(list(zip(weights, uids))): 113 | uint16_val = round( 114 | float(weight_i) * int(U16_MAX) 115 | ) # convert to int representation. 116 | 117 | # Filter zeros 118 | if uint16_val != 0: # Filter zeros 119 | weight_vals.append(uint16_val) 120 | weight_uids.append(uid_i) 121 | bittensor.logging.debug(f"final params: {weight_uids} : {weight_vals}") 122 | return weight_uids, weight_vals 123 | 124 | 125 | def process_weights_for_netuid( 126 | uids, 127 | weights: np.ndarray, 128 | netuid: int, 129 | subtensor: "bittensor.subtensor", 130 | metagraph: "bittensor.metagraph" = None, 131 | exclude_quantile: int = 0, 132 | ) -> Union[ 133 | tuple[ 134 | ndarray[Any, dtype[Any]], 135 | Union[ 136 | Union[ 137 | ndarray[Any, dtype[floating[Any]]], 138 | ndarray[Any, dtype[complexfloating[Any, Any]]], 139 | ], 140 | Any, 141 | ], 142 | ], 143 | tuple[ndarray[Any, dtype[Any]], ndarray], 144 | tuple[Any, ndarray], 145 | ]: 146 | bittensor.logging.debug("process_weights_for_netuid()") 147 | bittensor.logging.debug(f"weights: {weights}") 148 | bittensor.logging.debug("netuid", netuid) 149 | bittensor.logging.debug("subtensor", subtensor) 150 | bittensor.logging.debug("metagraph", metagraph) 151 | 152 | # Get latest metagraph from chain if metagraph is None. 153 | if metagraph is None: 154 | metagraph = subtensor.metagraph(netuid) 155 | 156 | # Cast weights to floats. 157 | if not isinstance(weights, np.ndarray) or weights.dtype != np.float32: 158 | weights = weights.astype(np.float32) 159 | 160 | # Network configuration parameters from an subtensor. 161 | # These parameters determine the range of acceptable weights for each neuron. 162 | quantile = exclude_quantile / U16_MAX 163 | min_allowed_weights = subtensor.min_allowed_weights(netuid=netuid) 164 | max_weight_limit = subtensor.max_weight_limit(netuid=netuid) 165 | bittensor.logging.debug("quantile", quantile) 166 | bittensor.logging.debug("min_allowed_weights", min_allowed_weights) 167 | bittensor.logging.debug("max_weight_limit", max_weight_limit) 168 | 169 | # Find all non zero weights. 170 | non_zero_weight_idx = np.argwhere(weights > 0).squeeze() 171 | non_zero_weight_idx = np.atleast_1d(non_zero_weight_idx) 172 | non_zero_weight_uids = uids[non_zero_weight_idx] 173 | non_zero_weights = weights[non_zero_weight_idx] 174 | if non_zero_weights.size == 0 or metagraph.n < min_allowed_weights: 175 | bittensor.logging.warning("No non-zero weights returning all ones.") 176 | final_weights = np.ones(metagraph.n) / metagraph.n 177 | bittensor.logging.debug(f"final_weights: {final_weights}") 178 | return np.arange(len(final_weights)), final_weights 179 | 180 | elif non_zero_weights.size < min_allowed_weights: 181 | bittensor.logging.warning( 182 | "No non-zero weights less then min allowed weight, returning all ones." 183 | ) 184 | weights = np.ones(metagraph.n) * 1e-5 # creating minimum even non-zero weights 185 | weights[non_zero_weight_idx] += non_zero_weights 186 | bittensor.logging.debug(f"final_weights: {final_weights}") 187 | normalized_weights = normalize_max_weight(x=weights, limit=max_weight_limit) 188 | return np.arange(len(normalized_weights)), normalized_weights 189 | 190 | bittensor.logging.debug(f"non_zero_weights: {non_zero_weights}") 191 | 192 | # Compute the exclude quantile and find the weights in the lowest quantile 193 | max_exclude = max(0, len(non_zero_weights) - min_allowed_weights) / len( 194 | non_zero_weights 195 | ) 196 | exclude_quantile = min([quantile, max_exclude]) 197 | lowest_quantile = np.quantile(non_zero_weights, exclude_quantile) 198 | bittensor.logging.debug("max_exclude", max_exclude) 199 | bittensor.logging.debug("exclude_quantile", exclude_quantile) 200 | bittensor.logging.debug("lowest_quantile", lowest_quantile) 201 | 202 | # Exclude all weights below the allowed quantile. 203 | non_zero_weight_uids = non_zero_weight_uids[lowest_quantile <= non_zero_weights] 204 | non_zero_weights = non_zero_weights[lowest_quantile <= non_zero_weights] 205 | bittensor.logging.debug(f"non_zero_weight_uids {non_zero_weight_uids}") 206 | bittensor.logging.debug(f"non_zero_weights {non_zero_weights}") 207 | 208 | # Normalize weights and return. 209 | normalized_weights = normalize_max_weight( 210 | x=non_zero_weights, limit=max_weight_limit 211 | ) 212 | bittensor.logging.debug(f"final_weights: {normalized_weights}") 213 | return non_zero_weight_uids, normalized_weights 214 | -------------------------------------------------------------------------------- /distributed_training/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import logging 5 | import bittensor as bt 6 | import logging_loki 7 | import traceback 8 | 9 | from dotenv import load_dotenv 10 | from hivemind.utils.logging import use_hivemind_log_handler 11 | from logging.handlers import QueueHandler, QueueListener, RotatingFileHandler 12 | from multiprocessing import Queue 13 | 14 | from distributed_training import __version__, __spec_version__ 15 | from bittensor.utils.btlogging import format as bt_format 16 | 17 | load_dotenv() 18 | 19 | 20 | class LokiHandler(logging_loki.LokiHandler): 21 | """ 22 | Custom Loki logging handler that safely handles errors. 23 | 24 | Overrides `handleError` to log any exceptions that occur during logging 25 | to Loki, without shutting down the emitter. This ensures that retry logic 26 | remains active instead of terminating on the first failure. 27 | """ 28 | 29 | def handleError(self, record): 30 | logging.getLogger(__name__).error("Loki logging error", exc_info=True) 31 | # No emitter.close() here — keeps retry alive 32 | 33 | 34 | class JSONFormatter(logging.Formatter): 35 | """ 36 | Formats log records as JSON for Loki ingestion. 37 | 38 | Adds extra metadata about the miner, such as: 39 | - network and netuid 40 | - hotkey 41 | - neuron type 42 | - IP/port 43 | - Bittensor and spec version numbers 44 | - UID 45 | 46 | If exception info is present, it is included as a formatted string. 47 | """ 48 | 49 | def __init__(self, miner): 50 | self.network = miner.config.subtensor.network 51 | self.netuid = miner.config.netuid 52 | self.hotkey = miner.wallet.hotkey.ss58_address if miner.master else None 53 | self.version = __version__ 54 | self.spec_version = __spec_version__ 55 | self.run_id = None 56 | self.ip = ( 57 | miner.config.axon.ip 58 | if miner.config.axon.ip != "[::]" 59 | else bt.utils.networking.get_external_ip() 60 | ) 61 | self.port = miner.config.axon.port 62 | self.uid = miner.uid 63 | self.neuron_type = "validator" 64 | 65 | def format(self, record): 66 | msg = record.getMessage() 67 | 68 | log_record = { 69 | "level": record.levelname.lower(), 70 | "module": record.module, 71 | "func_name": record.funcName, 72 | "thread": record.threadName, 73 | "netuid": self.netuid, 74 | "network": self.network, 75 | "neuron_type": self.neuron_type, 76 | "hotkey": self.hotkey, 77 | "uid": self.uid, 78 | "ip": self.ip, 79 | "port": self.port, 80 | "message": msg, 81 | "filename": record.filename, 82 | "lineno": record.lineno, 83 | "version": self.version, 84 | "spec_version": self.spec_version, 85 | } 86 | 87 | if record.exc_info: 88 | log_record["exception"] = "".join( 89 | traceback.format_exception(*record.exc_info) 90 | ) 91 | 92 | return json.dumps(log_record) 93 | 94 | 95 | def hive_log_filter(record): 96 | """ 97 | Filters out noisy Hivemind loggers that are not relevant to application output. 98 | 99 | Returns: 100 | bool: True if the record should be logged, False otherwise. 101 | """ 102 | return record.name not in { 103 | "hivemind.dht.protocol", 104 | "hivemind.optim.progress_tracker", 105 | "hivemind.p2p.p2p_daemon_bindings.control", 106 | } 107 | 108 | 109 | class RankFilter(logging.Filter): 110 | def __init__(self, rank, show_all_rank_logs): 111 | super().__init__() 112 | self.rank = rank 113 | self.show_all_rank_logs = show_all_rank_logs 114 | 115 | def filter(self, record): 116 | # Add ANSI escape code for bold: \033[1m … \033[0m 117 | record.rank = f"\033[1mRank {self.rank}\033[0m" 118 | # return True 119 | return self.rank == 0 if (self.show_all_rank_logs is False) else True 120 | 121 | 122 | def setup_logging(self, local_logfile="logs_mylogfile.txt", config=None): 123 | """ 124 | Configure and start logging for the distributed training miner. 125 | 126 | This includes: 127 | - Bittensor terminal logging with custom emoji mapping 128 | - Loki logging via a background queue listener 129 | - File logging for Hivemind output (filtered) 130 | - Disabling noisy loggers and default Hivemind handlers 131 | 132 | Args: 133 | self: Miner instance containing config, wallet, and UID. 134 | local_logfile (str): Path to the local log file. 135 | config: Optional Bittensor config object for logging. 136 | """ 137 | 138 | # Configure Bittensor terminal output 139 | bt_format.emoji_map.update( 140 | { 141 | ":rocket:": "🚀", 142 | ":lock:": "🔒", 143 | ":unlock:": "🔓", 144 | ":lightning:": "⚡", 145 | ":error:": "❗", 146 | ":info:": "ℹ️", 147 | ":idle:": "😴", 148 | ":network:": "🌐", 149 | ":memory:": "💾", 150 | ":training:": "🏋️", 151 | ":progress:": "📈", 152 | ":wait:": "⏳", 153 | ":clock:": "⏱️", 154 | ":signal:": "📶", 155 | ":upload:": "🔼", 156 | ":broadcast:": "📡", 157 | ":sync:": "🔄", 158 | ":send:": "📤", 159 | ":receive:": "📥", 160 | ":pages:": "📑", 161 | } 162 | ) 163 | bt.logging(config=config or bt.config()) 164 | 165 | bt_logger = logging.getLogger("bittensor") 166 | 167 | # Default to INFO if no flags are set 168 | if not ( 169 | getattr(config.logging, "debug", False) 170 | or getattr(config.logging, "trace", False) 171 | or getattr(config.logging, "info", False) 172 | ): 173 | bt_logger.setLevel(logging.INFO) 174 | 175 | # Prepare root logger 176 | root_logger = logging.getLogger() 177 | root_logger.setLevel(logging.DEBUG) # Capture all levels 178 | 179 | # Attach rank filter to all loggers 180 | rank_filter = RankFilter(self.local_rank, self.config.neuron.show_all_rank_logs) 181 | root_logger.addFilter(rank_filter) 182 | bt_logger.addFilter(rank_filter) 183 | 184 | # Create rank-aware formatter 185 | terminal_formatter = logging.Formatter( 186 | " %(rank)s | %(message)s", 187 | ) 188 | for handler in bt_logger.handlers: 189 | handler.addFilter(rank_filter) 190 | handler.setFormatter(terminal_formatter) 191 | 192 | # Loki handler with extra labels 193 | loki_handler = LokiHandler( 194 | url="https://logs-prod-006.grafana.net/loki/api/v1/push", 195 | tags={ 196 | "application": "distributed_training", 197 | "level": "dynamic", # Will be overridden dynamically 198 | "hotkey": self.wallet.hotkey.ss58_address if self.master else None, 199 | "netuid": str(self.config.netuid), 200 | }, 201 | auth=("944477", os.getenv("LOKI_KEY")), 202 | version="1", 203 | ) 204 | loki_handler.setLevel(logging.DEBUG) 205 | loki_handler.setFormatter(JSONFormatter(self)) 206 | 207 | # Wrap emit so level label matches log level 208 | original_emit = loki_handler.emit 209 | 210 | def dynamic_label_emit(record): 211 | loki_handler.emitter.tags["level"] = record.levelname.lower() 212 | original_emit(record) 213 | 214 | loki_handler.emit = dynamic_label_emit 215 | 216 | # File handler for Hivemind 217 | if os.path.exists(local_logfile) and self.master: 218 | shutil.copyfile(local_logfile, local_logfile.replace(".txt", "_archive.txt")) 219 | os.remove(local_logfile) 220 | 221 | # Setup hivemind logger 222 | hivemind_logger = logging.getLogger("hivemind") 223 | hivemind_logger.handlers.clear() 224 | hivemind_logger.setLevel(logging.DEBUG) 225 | file_handler = logging.FileHandler(local_logfile) 226 | file_handler.setLevel(logging.DEBUG) 227 | file_handler.addFilter(hive_log_filter) 228 | file_handler.addFilter( 229 | RankFilter(self.local_rank, self.config.neuron.rank_0_only_log) 230 | ) 231 | file_handler.setFormatter( 232 | logging.Formatter( 233 | "%(asctime)s - rank %(rank)s - %(name)s - %(levelname)s - %(message)s" 234 | ) 235 | ) 236 | hivemind_logger.addHandler(file_handler) 237 | hivemind_logger.propagate = False 238 | 239 | # Setup queue logging 240 | log_queue = Queue(-1) 241 | queue_handler = QueueHandler(log_queue) 242 | root_logger.addHandler(queue_handler) 243 | 244 | listener = QueueListener(log_queue, loki_handler, file_handler) 245 | listener = QueueListener(log_queue, loki_handler) 246 | listener.start() 247 | 248 | # Disable noisy hivemind default logging 249 | use_hivemind_log_handler("nowhere") 250 | 251 | for name in logging.root.manager.loggerDict: 252 | if name.startswith("hivemind"): 253 | logger = logging.getLogger(name) 254 | logger.addHandler(file_handler) 255 | logger.propagate = False 256 | logger.setLevel(logging.DEBUG) 257 | 258 | # Disable propagation for other loggers 259 | for name, logger in logging.root.manager.loggerDict.items(): 260 | if isinstance(logger, logging.Logger): 261 | if name not in ["bittensor"]: 262 | logger.propagate = False 263 | if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers): 264 | logger.addHandler(file_handler) 265 | -------------------------------------------------------------------------------- /run_miner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Inspired by https://github.com/surcyf123/smart-scrape/blob/main/run.sh 3 | 4 | # Initialize variables 5 | script="neurons/miner.py" 6 | autoRunLoc=$(readlink -f "$0") 7 | proc_name="distributed_training_miner" 8 | args=() 9 | version_location="./distributed_training/__init__.py" 10 | version="__version__" 11 | NODES=1 12 | 13 | old_args=$@ 14 | 15 | # Check if pm2 is installed 16 | if ! command -v pm2 &> /dev/null 17 | then 18 | echo "pm2 could not be found. To install see: https://pm2.keymetrics.io/docs/usage/quick-start/" 19 | exit 1 20 | fi 21 | 22 | # Checks if $1 is smaller than $2 23 | # If $1 is smaller than or equal to $2, then true. 24 | # else false. 25 | version_less_than_or_equal() { 26 | [ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ] 27 | } 28 | 29 | # Checks if $1 is smaller than $2 30 | # If $1 is smaller than $2, then true. 31 | # else false. 32 | version_less_than() { 33 | [ "$1" = "$2" ] && return 1 || version_less_than_or_equal $1 $2 34 | } 35 | 36 | # Returns the difference between 37 | # two versions as a numerical value. 38 | get_version_difference() { 39 | local tag1="$1" 40 | local tag2="$2" 41 | 42 | # Extract the version numbers from the tags 43 | local version1=$(echo "$tag1" | sed 's/v//') 44 | local version2=$(echo "$tag2" | sed 's/v//') 45 | 46 | # Split the version numbers into an array 47 | IFS='.' read -ra version1_arr <<< "$version1" 48 | IFS='.' read -ra version2_arr <<< "$version2" 49 | 50 | # Calculate the numerical difference 51 | local diff=0 52 | for i in "${!version1_arr[@]}"; do 53 | local num1=${version1_arr[$i]} 54 | local num2=${version2_arr[$i]} 55 | 56 | # Compare the numbers and update the difference 57 | if (( num1 > num2 )); then 58 | diff=$((diff + num1 - num2)) 59 | elif (( num1 < num2 )); then 60 | diff=$((diff + num2 - num1)) 61 | fi 62 | done 63 | 64 | strip_quotes $diff 65 | } 66 | 67 | read_version_value() { 68 | # Read each line in the file 69 | while IFS= read -r line; do 70 | # Check if the line contains the variable name 71 | if [[ "$line" == *"$version"* ]]; then 72 | # Extract the value of the variable 73 | local value=$(echo "$line" | awk -F '=' '{print $2}' | tr -d ' ') 74 | strip_quotes $value 75 | return 0 76 | fi 77 | done < "$version_location" 78 | 79 | echo "" 80 | } 81 | 82 | check_package_installed() { 83 | local package_name="$1" 84 | os_name=$(uname -s) 85 | 86 | if [[ "$os_name" == "Linux" ]]; then 87 | # Use dpkg-query to check if the package is installed 88 | if dpkg-query -W -f='${Status}' "$package_name" 2>/dev/null | grep -q "installed"; then 89 | return 1 90 | else 91 | return 0 92 | fi 93 | elif [[ "$os_name" == "Darwin" ]]; then 94 | if brew list --formula | grep -q "^$package_name$"; then 95 | return 1 96 | else 97 | return 0 98 | fi 99 | else 100 | echo "Unknown operating system" 101 | return 0 102 | fi 103 | } 104 | 105 | check_variable_value_on_github() { 106 | local repo="$1" 107 | local file_path="$2" 108 | local variable_name="$3" 109 | 110 | local url="https://api.github.com/repos/$repo/contents/$file_path" 111 | local response=$(curl -s "$url") 112 | 113 | # Check if the response contains an error message 114 | if [[ $response =~ "message" ]]; then 115 | echo "Error: Failed to retrieve file contents from GitHub." 116 | return 1 117 | fi 118 | 119 | # Extract the content from the response 120 | local content=$(echo "$response" | tr -d '\n' | jq -r '.content') 121 | 122 | if [[ "$content" == "null" ]]; then 123 | echo "File '$file_path' not found in the repository." 124 | return 1 125 | fi 126 | 127 | # Decode the Base64-encoded content 128 | local decoded_content=$(echo "$content" | base64 --decode) 129 | 130 | # Extract the variable value from the content 131 | local variable_value=$(echo "$decoded_content" | grep "$variable_name" | awk -F '=' '{print $2}' | tr -d ' ') 132 | 133 | if [[ -z "$variable_value" ]]; then 134 | echo "Variable '$variable_name' not found in the file '$file_path'." 135 | return 1 136 | fi 137 | 138 | strip_quotes $variable_value 139 | } 140 | 141 | strip_quotes() { 142 | local input="$1" 143 | 144 | # Remove leading and trailing quotes using parameter expansion 145 | local stripped="${input#\"}" 146 | stripped="${stripped%\"}" 147 | 148 | echo "$stripped" 149 | } 150 | 151 | # Loop through all command line arguments 152 | while [[ $# -gt 0 ]]; do 153 | arg="$1" 154 | 155 | # Detect node count argument early 156 | if [[ "$arg" == "--nodes" || "$arg" == "--nproc_per_node" ]]; then 157 | if [[ $# -gt 1 && "$2" != -* ]]; then 158 | NODES="$2" 159 | shift 2 160 | continue 161 | fi 162 | fi 163 | 164 | # Check if the argument starts with a hyphen (flag) 165 | if [[ "$arg" == -* ]]; then 166 | # Check if the argument has a value 167 | if [[ $# -gt 1 && "$2" != -* ]]; then 168 | if [[ "$arg" == "--script" ]]; then 169 | script="$2"; 170 | shift 2 171 | else 172 | # Add '=' sign between flag and value 173 | args+=("'$arg'"); 174 | args+=("'$2'"); 175 | shift 2 176 | fi 177 | else 178 | # Add '=True' for flags with no value 179 | args+=("'$arg'"); 180 | shift 181 | fi 182 | else 183 | # Argument is not a flag, add it as it is 184 | args+=("'$arg '"); 185 | shift 186 | fi 187 | done 188 | 189 | # Check if script argument was provided 190 | if [[ -z "$script" ]]; then 191 | echo "The --script argument is required." 192 | exit 1 193 | fi 194 | 195 | branch=$(git branch --show-current) # get current branch. 196 | echo watching branch: $branch 197 | echo pm2 process name: $proc_name 198 | 199 | # Get the current version locally. 200 | current_version=$(read_version_value) 201 | 202 | # Check if script is already running with pm2 203 | if pm2 status | grep -q $proc_name; then 204 | echo "The script is already running with pm2. Stopping and restarting..." 205 | pkill -9 python 206 | pm2 delete $proc_name 207 | fi 208 | 209 | # Run the Python script with the arguments using pm2 210 | echo "Running $script with the following pm2 config:" 211 | 212 | # Join the arguments with commas using printf 213 | joined_args=$(printf "%s," "${args[@]}") 214 | 215 | # Remove the trailing comma 216 | joined_args=${joined_args%,} 217 | 218 | # Create the pm2 config file 219 | echo "module.exports = { 220 | apps : [{ 221 | name : '$proc_name', 222 | script : '$script', 223 | interpreter: 'torchrun', 224 | interpreter_args: '--nproc_per_node=' + '$NODES', 225 | min_uptime: '5m', 226 | max_restarts: '5', 227 | args: [$joined_args] 228 | }] 229 | }" > app.config.js 230 | 231 | # Print configuration to be used 232 | cat app.config.js 233 | 234 | pm2 start app.config.js 235 | 236 | # Check if packages are installed. 237 | check_package_installed "jq" 238 | if [ "$?" -eq 1 ]; then 239 | while true; do 240 | 241 | # First ensure that this is a git installation 242 | if [ -d "./.git" ]; then 243 | 244 | # check value on github remotely 245 | latest_version=$(check_variable_value_on_github "dstrbtd/DistributedTraining" "distributed_training/__init__.py" "__version__ ") 246 | 247 | # If the file has been updated 248 | if version_less_than $current_version $latest_version; then 249 | echo "latest version $latest_version" 250 | echo "current version $current_version" 251 | diff=$(get_version_difference $latest_version $current_version) 252 | if [ "$diff" -gt 0 ]; then 253 | echo "current validator version:" "$current_version" 254 | echo "latest validator version:" "$latest_version" 255 | 256 | # Pull latest changes 257 | # Failed git pull will return a non-zero output 258 | if git pull origin $branch; then 259 | # latest_version is newer than current_version, should download and reinstall. 260 | echo "New version published. Updating the local copy." 261 | 262 | # Install latest changes just in case. 263 | pip install -e . 264 | 265 | # # Run the Python script with the arguments using pm2 266 | # TODO (shib): Remove this pm2 del in the next spec version update. 267 | pm2 del auto_run_validator 268 | echo "Restarting PM2 process" 269 | pkill -9 python 270 | pm2 restart $proc_name 271 | 272 | # Update current version: 273 | current_version=$(read_version_value) 274 | echo "" 275 | 276 | # Restart autorun script 277 | echo "Restarting script..." 278 | ./$(basename $0) $old_args && exit 279 | else 280 | echo "**Will not update**" 281 | echo "It appears you have made changes on your local copy. Please stash your changes using git stash." 282 | fi 283 | else 284 | # current version is newer than the latest on git. This is likely a local copy, so do nothing. 285 | echo "**Will not update**" 286 | echo "The local version is $diff versions behind. Please manually update to the latest version and re-run this script." 287 | fi 288 | else 289 | echo "**Skipping update **" 290 | echo "$current_version is the same as or more than $latest_version. You are likely running locally." 291 | fi 292 | else 293 | echo "The installation does not appear to be done through Git. Please install from source at https://github.com/opentensor/validators and rerun this script." 294 | fi 295 | 296 | # Wait about 30 minutes 297 | # This should be plenty of time for validators to catch up 298 | # and should prevent any rate limitations by GitHub. 299 | sleep 1200 300 | done 301 | else 302 | echo "Missing package 'jq'. Please install it for your system first." 303 | fi -------------------------------------------------------------------------------- /run_validator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Inspired by https://github.com/surcyf123/smart-scrape/blob/main/run.sh 3 | 4 | # Initialize variables 5 | script="neurons/validator.py" 6 | autoRunLoc=$(readlink -f "$0") 7 | proc_name="distributed_training_validator" 8 | args=() 9 | version_location="./distributed_training/__init__.py" 10 | version="__version__" 11 | NODES=1 12 | 13 | old_args=$@ 14 | 15 | # Check if pm2 is installed 16 | if ! command -v pm2 &> /dev/null 17 | then 18 | echo "pm2 could not be found. To install see: https://pm2.keymetrics.io/docs/usage/quick-start/" 19 | exit 1 20 | fi 21 | 22 | # Checks if $1 is smaller than $2 23 | # If $1 is smaller than or equal to $2, then true. 24 | # else false. 25 | version_less_than_or_equal() { 26 | [ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ] 27 | } 28 | 29 | # Checks if $1 is smaller than $2 30 | # If $1 is smaller than $2, then true. 31 | # else false. 32 | version_less_than() { 33 | [ "$1" = "$2" ] && return 1 || version_less_than_or_equal $1 $2 34 | } 35 | 36 | # Returns the difference between 37 | # two versions as a numerical value. 38 | get_version_difference() { 39 | local tag1="$1" 40 | local tag2="$2" 41 | 42 | # Extract the version numbers from the tags 43 | local version1=$(echo "$tag1" | sed 's/v//') 44 | local version2=$(echo "$tag2" | sed 's/v//') 45 | 46 | # Split the version numbers into an array 47 | IFS='.' read -ra version1_arr <<< "$version1" 48 | IFS='.' read -ra version2_arr <<< "$version2" 49 | 50 | # Calculate the numerical difference 51 | local diff=0 52 | for i in "${!version1_arr[@]}"; do 53 | local num1=${version1_arr[$i]} 54 | local num2=${version2_arr[$i]} 55 | 56 | # Compare the numbers and update the difference 57 | if (( num1 > num2 )); then 58 | diff=$((diff + num1 - num2)) 59 | elif (( num1 < num2 )); then 60 | diff=$((diff + num2 - num1)) 61 | fi 62 | done 63 | 64 | strip_quotes $diff 65 | } 66 | 67 | read_version_value() { 68 | # Read each line in the file 69 | while IFS= read -r line; do 70 | # Check if the line contains the variable name 71 | if [[ "$line" == *"$version"* ]]; then 72 | # Extract the value of the variable 73 | local value=$(echo "$line" | awk -F '=' '{print $2}' | tr -d ' ') 74 | strip_quotes $value 75 | return 0 76 | fi 77 | done < "$version_location" 78 | 79 | echo "" 80 | } 81 | 82 | check_package_installed() { 83 | local package_name="$1" 84 | os_name=$(uname -s) 85 | 86 | if [[ "$os_name" == "Linux" ]]; then 87 | # Use dpkg-query to check if the package is installed 88 | if dpkg-query -W -f='${Status}' "$package_name" 2>/dev/null | grep -q "installed"; then 89 | return 1 90 | else 91 | return 0 92 | fi 93 | elif [[ "$os_name" == "Darwin" ]]; then 94 | if brew list --formula | grep -q "^$package_name$"; then 95 | return 1 96 | else 97 | return 0 98 | fi 99 | else 100 | echo "Unknown operating system" 101 | return 0 102 | fi 103 | } 104 | 105 | check_variable_value_on_github() { 106 | local repo="$1" 107 | local file_path="$2" 108 | local variable_name="$3" 109 | 110 | local url="https://api.github.com/repos/$repo/contents/$file_path" 111 | local response=$(curl -s "$url") 112 | 113 | # Check if the response contains an error message 114 | if [[ $response =~ "message" ]]; then 115 | echo "Error: Failed to retrieve file contents from GitHub." 116 | return 1 117 | fi 118 | 119 | # Extract the content from the response 120 | local content=$(echo "$response" | tr -d '\n' | jq -r '.content') 121 | 122 | if [[ "$content" == "null" ]]; then 123 | echo "File '$file_path' not found in the repository." 124 | return 1 125 | fi 126 | 127 | # Decode the Base64-encoded content 128 | local decoded_content=$(echo "$content" | base64 --decode) 129 | 130 | # Extract the variable value from the content 131 | local variable_value=$(echo "$decoded_content" | grep "$variable_name" | awk -F '=' '{print $2}' | tr -d ' ') 132 | 133 | if [[ -z "$variable_value" ]]; then 134 | echo "Variable '$variable_name' not found in the file '$file_path'." 135 | return 1 136 | fi 137 | 138 | strip_quotes $variable_value 139 | } 140 | 141 | strip_quotes() { 142 | local input="$1" 143 | 144 | # Remove leading and trailing quotes using parameter expansion 145 | local stripped="${input#\"}" 146 | stripped="${stripped%\"}" 147 | 148 | echo "$stripped" 149 | } 150 | 151 | # Loop through all command line arguments 152 | while [[ $# -gt 0 ]]; do 153 | arg="$1" 154 | 155 | # Detect node count argument early 156 | if [[ "$arg" == "--nodes" || "$arg" == "--nproc_per_node" ]]; then 157 | if [[ $# -gt 1 && "$2" != -* ]]; then 158 | NODES="$2" 159 | shift 2 160 | continue 161 | fi 162 | fi 163 | 164 | # Check if the argument starts with a hyphen (flag) 165 | if [[ "$arg" == -* ]]; then 166 | # Check if the argument has a value 167 | if [[ $# -gt 1 && "$2" != -* ]]; then 168 | if [[ "$arg" == "--script" ]]; then 169 | script="$2"; 170 | shift 2 171 | else 172 | # Add '=' sign between flag and value 173 | args+=("'$arg'"); 174 | args+=("'$2'"); 175 | shift 2 176 | fi 177 | else 178 | # Add '=True' for flags with no value 179 | args+=("'$arg'"); 180 | shift 181 | fi 182 | else 183 | # Argument is not a flag, add it as it is 184 | args+=("'$arg '"); 185 | shift 186 | fi 187 | done 188 | 189 | # Check if script argument was provided 190 | if [[ -z "$script" ]]; then 191 | echo "The --script argument is required." 192 | exit 1 193 | fi 194 | 195 | branch=$(git branch --show-current) # get current branch. 196 | echo watching branch: $branch 197 | echo pm2 process name: $proc_name 198 | 199 | # Get the current version locally. 200 | current_version=$(read_version_value) 201 | 202 | # Check if script is already running with pm2 203 | if pm2 status | grep -q $proc_name; then 204 | echo "The script is already running with pm2. Stopping and restarting..." 205 | pkill -9 python 206 | pm2 delete $proc_name 207 | fi 208 | 209 | # Run the Python script with the arguments using pm2 210 | echo "Running $script with the following pm2 config:" 211 | 212 | # Join the arguments with commas using printf 213 | joined_args=$(printf "%s," "${args[@]}") 214 | 215 | # Remove the trailing comma 216 | joined_args=${joined_args%,} 217 | 218 | # Create the pm2 config file 219 | echo "module.exports = { 220 | apps : [{ 221 | name : '$proc_name', 222 | script : '$script', 223 | interpreter: 'torchrun', 224 | interpreter_args: '--nproc_per_node=' + '$NODES', 225 | min_uptime: '5m', 226 | max_restarts: '5', 227 | args: [$joined_args] 228 | }] 229 | }" > app.config.js 230 | 231 | # Print configuration to be used 232 | cat app.config.js 233 | 234 | pm2 start app.config.js 235 | 236 | # Check if packages are installed. 237 | check_package_installed "jq" 238 | if [ "$?" -eq 1 ]; then 239 | while true; do 240 | 241 | # First ensure that this is a git installation 242 | if [ -d "./.git" ]; then 243 | 244 | # check value on github remotely 245 | latest_version=$(check_variable_value_on_github "dstrbtd/DistributedTraining" "distributed_training/__init__.py" "__version__ ") 246 | 247 | # If the file has been updated 248 | if version_less_than $current_version $latest_version; then 249 | echo "latest version $latest_version" 250 | echo "current version $current_version" 251 | diff=$(get_version_difference $latest_version $current_version) 252 | if [ "$diff" -gt 0 ]; then 253 | echo "current validator version:" "$current_version" 254 | echo "latest validator version:" "$latest_version" 255 | 256 | # Pull latest changes 257 | # Failed git pull will return a non-zero output 258 | if git pull origin $branch; then 259 | # latest_version is newer than current_version, should download and reinstall. 260 | echo "New version published. Updating the local copy." 261 | 262 | # Install latest changes just in case. 263 | # pip uninstall -y distributed_training && pip freeze --exclude-editable | cut -d "@" -f1 | xargs pip uninstall -y && pip install -e . 264 | pip install -e . 265 | 266 | # # Run the Python script with the arguments using pm2 267 | # TODO (shib): Remove this pm2 del in the next spec version update. 268 | pm2 del auto_run_validator 269 | echo "Restarting PM2 process" 270 | pkill -9 python 271 | pm2 restart $proc_name 272 | 273 | # Update current version: 274 | current_version=$(read_version_value) 275 | echo "" 276 | 277 | # Restart autorun script 278 | echo "Restarting script..." 279 | ./$(basename $0) $old_args && exit 280 | else 281 | echo "**Will not update**" 282 | echo "It appears you have made changes on your local copy. Please stash your changes using git stash." 283 | fi 284 | else 285 | # current version is newer than the latest on git. This is likely a local copy, so do nothing. 286 | echo "**Will not update**" 287 | echo "The local version is $diff versions behind. Please manually update to the latest version and re-run this script." 288 | fi 289 | else 290 | echo "**Skipping update **" 291 | echo "$current_version is the same as or more than $latest_version. You are likely running locally." 292 | fi 293 | else 294 | echo "The installation does not appear to be done through Git. Please install from source at https://github.com/opentensor/validators and rerun this script." 295 | fi 296 | 297 | # Wait about 30 minutes 298 | # This should be plenty of time for validators to catch up 299 | # and should prevent any rate limitations by GitHub. 300 | sleep 1200 301 | done 302 | else 303 | echo "Missing package 'jq'. Please install it for your system first." 304 | fi -------------------------------------------------------------------------------- /distributed_training/utils/misc.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2023 Opentensor Foundation 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | # the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | # DEALINGS IN THE SOFTWARE. 18 | 19 | import re 20 | import time 21 | from functools import lru_cache, update_wrapper 22 | from ipaddress import ip_address 23 | from math import floor 24 | from typing import Any, Callable 25 | 26 | import bittensor as bt 27 | import hivemind 28 | import speedtest 29 | from hivemind import utils 30 | 31 | import wandb 32 | from distributed_training import __run__, __version__ 33 | from copy import deepcopy 34 | from dataclasses import is_dataclass, asdict 35 | 36 | 37 | # LRU Cache with TTL 38 | def ttl_cache(maxsize: int = 128, typed: bool = False, ttl: int = -1): 39 | """ 40 | Decorator that creates a cache of the most recently used function calls with a time-to-live (TTL) feature. 41 | The cache evicts the least recently used entries if the cache exceeds the `maxsize` or if an entry has 42 | been in the cache longer than the `ttl` period. 43 | 44 | Args: 45 | maxsize (int): Maximum size of the cache. Once the cache grows to this size, subsequent entries 46 | replace the least recently used ones. Defaults to 128. 47 | typed (bool): If set to True, arguments of different types will be cached separately. For example, 48 | f(3) and f(3.0) will be treated as distinct calls with distinct results. Defaults to False. 49 | ttl (int): The time-to-live for each cache entry, measured in seconds. If set to a non-positive value, 50 | the TTL is set to a very large number, effectively making the cache entries permanent. Defaults to -1. 51 | 52 | Returns: 53 | Callable: A decorator that can be applied to functions to cache their return values. 54 | 55 | The decorator is useful for caching results of functions that are expensive to compute and are called 56 | with the same arguments frequently within short periods of time. The TTL feature helps in ensuring 57 | that the cached values are not stale. 58 | 59 | Example: 60 | @ttl_cache(ttl=10) 61 | def get_data(param): 62 | # Expensive data retrieval operation 63 | return data 64 | """ 65 | if ttl <= 0: 66 | ttl = 65536 67 | hash_gen = _ttl_hash_gen(ttl) 68 | 69 | def wrapper(func: Callable) -> Callable: 70 | @lru_cache(maxsize, typed) 71 | def ttl_func(ttl_hash, *args, **kwargs): 72 | return func(*args, **kwargs) 73 | 74 | def wrapped(*args, **kwargs) -> Any: 75 | th = next(hash_gen) 76 | return ttl_func(th, *args, **kwargs) 77 | 78 | return update_wrapper(wrapped, func) 79 | 80 | return wrapper 81 | 82 | 83 | def _ttl_hash_gen(seconds: int): 84 | """ 85 | Internal generator function used by the `ttl_cache` decorator to generate a new hash value at regular 86 | time intervals specified by `seconds`. 87 | 88 | Args: 89 | seconds (int): The number of seconds after which a new hash value will be generated. 90 | 91 | Yields: 92 | int: A hash value that represents the current time interval. 93 | 94 | This generator is used to create time-based hash values that enable the `ttl_cache` to determine 95 | whether cached entries are still valid or if they have expired and should be recalculated. 96 | """ 97 | start_time = time.time() 98 | while True: 99 | yield floor((time.time() - start_time) / seconds) 100 | 101 | 102 | # 12 seconds updating block. 103 | @ttl_cache(maxsize=1, ttl=12) 104 | def ttl_get_block(self) -> int: 105 | """ 106 | Retrieves the current block number from the blockchain. This method is cached with a time-to-live (TTL) 107 | of 12 seconds, meaning that it will only refresh the block number from the blockchain at most every 12 seconds, 108 | reducing the number of calls to the underlying blockchain interface. 109 | 110 | Returns: 111 | int: The current block number on the blockchain. 112 | 113 | This method is useful for applications that need to access the current block number frequently and can 114 | tolerate a delay of up to 12 seconds for the latest information. By using a cache with TTL, the method 115 | efficiently reduces the workload on the blockchain interface. 116 | 117 | Example: 118 | current_block = ttl_get_block(self) 119 | 120 | Note: self here is the miner or validator instance 121 | """ 122 | return self.subtensor.get_current_block() 123 | 124 | 125 | def to_plain_dict(obj): 126 | if isinstance(obj, dict): 127 | return deepcopy(obj) 128 | if is_dataclass(obj): 129 | return asdict(obj) 130 | if hasattr(obj, "model_dump"): # pydantic v2 131 | return obj.model_dump() 132 | if hasattr(obj, "dict"): # pydantic v1 133 | return obj.dict() 134 | if hasattr(obj, "to_dict"): 135 | return obj.to_dict() 136 | return deepcopy(getattr(obj, "__dict__", {})) 137 | 138 | 139 | def sanitize_wandb_config(cfg): 140 | cfg_dict = to_plain_dict(cfg) 141 | # remove the entire sensitive subtree 142 | cfg_dict.pop("r2", None) 143 | return cfg_dict 144 | 145 | 146 | def load_wandb(self, config, wallet, neuron_type, peer_id): 147 | run_name = f"{neuron_type[0].upper()}{'{:03}'.format(self.uid)}" 148 | 149 | tags = [peer_id, __version__, self.wallet.hotkey.ss58_address, f"run{__run__}"] 150 | 151 | run_id = "_".join([run_name] + tags[2:]).lower() 152 | 153 | wandb_run = wandb.init( 154 | id=run_id, 155 | name=run_name, 156 | anonymous="allow", 157 | resume="allow", 158 | tags=tags, 159 | project=config.neuron.wandb_project, 160 | entity=config.neuron.wandb_entity, 161 | config={}, 162 | allow_val_change=True, 163 | ) 164 | 165 | sanitized_config = sanitize_wandb_config(config) 166 | wandb_run.config.update(sanitized_config, allow_val_change=True) 167 | 168 | return wandb_run 169 | 170 | 171 | def get_bandwidth(): 172 | # Get speedtest results 173 | s = speedtest.Speedtest() 174 | s.get_servers() 175 | s.get_best_server() 176 | s.download() 177 | s.upload() 178 | results = s.results.dict() 179 | 180 | # Copy key metrics to a formatted badnwidth_dict 181 | bandwidth_dict = {} 182 | keys = ["download", "upload", "ping"] 183 | for key in keys: 184 | bandwidth_dict[f"all_reduce/{key}"] = float(f"{results[key] / 1e6:.2f}") 185 | 186 | return bandwidth_dict 187 | 188 | 189 | def init_dht(self): 190 | if self.master: 191 | # Init DHT and model 192 | if self.config.dht.ip: 193 | version = "4" 194 | address = self.config.dht.ip 195 | announce_maddrs = [f"/ip{version}/{address}/tcp/{self.config.dht.port}"] 196 | else: 197 | address = bt.utils.networking.get_external_ip() 198 | self.logger.info(f"Received public IP address of this machine: {address}") 199 | version = ip_address(address).version 200 | announce_maddrs = [f"/ip{version}/{address}/tcp/{self.config.dht.port}"] 201 | 202 | # Init list of available DHT addresses from wandb 203 | api = wandb.Api() 204 | initial_peers_list = self.config.neuron.initial_peers 205 | 206 | validator_runs = api.runs( 207 | f"{self.config.neuron.wandb_entity}/{self.config.neuron.wandb_project.replace('_validators','').replace('_miners','')}_validators" 208 | ) 209 | for ru in validator_runs: 210 | if ru.state == "running": 211 | if "dht_addresses" not in ru.config["neuron"].keys(): 212 | continue 213 | else: 214 | for peer in ru.config["neuron"]["dht_addresses"]: 215 | if peer not in initial_peers_list: 216 | initial_peers_list.append(peer) 217 | 218 | miner_runs = api.runs( 219 | f"{self.config.neuron.wandb_entity}/{self.config.neuron.wandb_project.replace('_validators','').replace('_miners','')}_miners" 220 | ) 221 | for ru in miner_runs: 222 | if ru.state == "running": 223 | if "dht_addresses" not in ru.config["neuron"].keys(): 224 | continue 225 | else: 226 | for peer in ru.config["neuron"]["dht_addresses"]: 227 | if peer not in initial_peers_list: 228 | initial_peers_list.append(peer) 229 | 230 | # Init DHT 231 | retries = 0 232 | buffer = 5 233 | max_retries = buffer * len(initial_peers_list) 234 | successful_connection = False 235 | while successful_connection is False: 236 | if (retries == max_retries) and (successful_connection is False): 237 | raise Exception("Max retries reached, operation failed.") 238 | for attempt in range(0, buffer): 239 | for initial_peer in initial_peers_list: 240 | try: 241 | # Init DHT 242 | self.dht = hivemind.DHT( 243 | host_maddrs=[ 244 | f"/ip4/0.0.0.0/tcp/{self.config.dht.port}", 245 | f"/ip4/0.0.0.0/udp/{self.config.dht.port}/quic", 246 | ], 247 | initial_peers=[initial_peer], 248 | announce_maddrs=announce_maddrs, 249 | start=True, 250 | ) 251 | self.logger.info( 252 | f"Successfully initialised dht using initial_peer as {initial_peer}" 253 | ) 254 | successful_connection = True 255 | utils.log_visible_maddrs( 256 | self.dht.get_visible_maddrs(), only_p2p=True 257 | ) 258 | # Add DHT address to wandb config 259 | self.config.neuron.dht_addresses = [ 260 | re.sub( 261 | "ip4/?(.*?)/", 262 | f"ip{version}/{address}/", 263 | str(addr), 264 | flags=re.DOTALL, 265 | ) 266 | for addr in self.dht.get_visible_maddrs() 267 | ] 268 | return 269 | except Exception as e: 270 | self.logger.error( 271 | f"Attempt {retries + 1} to init DHT using initial_peer as {initial_peer} failed with error: {e}" 272 | ) 273 | retries += 1 274 | time.sleep(5) 275 | self.logger.error("Retrying...") 276 | -------------------------------------------------------------------------------- /distributed_training/base/miner.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2025 dstrbtd.ai 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import asyncio 19 | import threading 20 | import time 21 | import traceback 22 | import torch.distributed as dist 23 | 24 | import bittensor as bt 25 | 26 | from enum import Enum 27 | from distributed_training.base.neuron import BaseNeuron 28 | from distributed_training.utils.chain import log_r2_to_chain 29 | from distributed_training.utils.misc import get_bandwidth 30 | from distributed_training.utils.state_loader import load_state_from_peer 31 | from distributed_training.utils.progress_tracker import get_progress 32 | 33 | 34 | class TrainingStatus(Enum): 35 | ERROR = "❗ | Error" 36 | RUNNING = "🏋️ | Training" 37 | STOPPED = "😴 | Stopped" 38 | PAUSED = "🔄 | Paused" 39 | 40 | 41 | class BaseMinerNeuron(BaseNeuron): 42 | """ 43 | Base class for Bittensor miners. 44 | """ 45 | 46 | neuron_type: str = "MinerNeuron" 47 | 48 | def __init__(self, config=None): 49 | super().__init__(config=config) 50 | 51 | # Warn if allowing incoming requests from anyone. 52 | if not self.config.blacklist.force_validator_permit: 53 | self.logger.warning( 54 | "You are allowing non-validators to send requests to your miner. This is a security risk." 55 | ) 56 | if self.config.blacklist.allow_non_registered: 57 | self.logger.warning( 58 | "You are allowing non-registered entities to send requests to your miner. This is a security risk." 59 | ) 60 | 61 | if self.master: 62 | # The axon handles request processing, allowing validators to send this miner requests. 63 | self.axon = bt.axon( 64 | wallet=self.wallet, 65 | config=self.config, 66 | port=self.config.axon.port, 67 | ip=self.config.axon.ip, 68 | external_ip=self.config.axon.external_ip, 69 | external_port=self.config.axon.external_port, 70 | ) 71 | 72 | # Attach determiners which functions are called when servicing a request. 73 | self.logger.info("Attaching forward function to miner axon.") 74 | self.axon.attach( 75 | forward_fn=self.is_alive, 76 | blacklist_fn=self.blacklist_is_alive, 77 | # priority_fn=self.priority, 78 | ).attach( 79 | forward_fn=self.all_reduce, 80 | blacklist_fn=self.blacklist_all_reduce, 81 | ) 82 | self.logger.info(f"Axon created: {self.axon}") 83 | 84 | # Instantiate runners 85 | self.should_exit: bool = False 86 | self.is_running: bool = False 87 | self.thread: threading.Thread = None 88 | self.lock = asyncio.Lock() 89 | 90 | # self.config.neuron.disable_set_weights = True 91 | 92 | # Log PeerID to chain flag 93 | self.r2_credentials_logged_to_chain = False 94 | 95 | # def run(rank, self, world_size): 96 | def run(self): 97 | """ 98 | Initiates and manages the main loop for the miner on the Bittensor network. The main loop handles graceful shutdown on keyboard interrupts and logs unforeseen errors. 99 | 100 | This function performs the following primary tasks: 101 | 1. Check for registration on the Bittensor network. 102 | 2. Starts the miner's axon, making it active on the network. 103 | 3. Periodically resynchronizes with the chain; updating the metagraph with the latest network state and setting weights. 104 | 105 | The miner continues its operations until `should_exit` is set to True or an external interruption occurs. 106 | During each epoch of its operation, the miner waits for new blocks on the Bittensor network, updates its 107 | knowledge of the network (metagraph), and sets its weights. This process ensures the miner remains active 108 | and up-to-date with the network's latest state. 109 | 110 | Note: 111 | - The function leverages the global configurations set during the initialization of the miner. 112 | - The miner's axon serves as its interface to the Bittensor network, handling incoming and outgoing requests. 113 | 114 | Raises: 115 | KeyboardInterrupt: If the miner is stopped by a manual interruption. 116 | Exception: For unforeseen errors during the miner's operation, which are logged for diagnosis. 117 | """ 118 | self.logger.info("Synced metagraph") 119 | 120 | # This loop maintains the miner's operations until intentionally stopped. 121 | try: 122 | dist.barrier() 123 | self.resume_training() 124 | 125 | while not self.should_exit: 126 | try: 127 | if self.master: 128 | if self.r2_credentials_logged_to_chain is False: 129 | log_r2_to_chain(self) 130 | 131 | if not self.config.neuron.dont_wandb_log: 132 | if self.event != {}: 133 | self.event.update(self.get_miner_info()) 134 | try: 135 | self.bandwidth = get_bandwidth() 136 | self.event.update(self.bandwidth) 137 | except Exception: 138 | self.logger.debug("Error getting bandwidth metrics") 139 | if self.master: 140 | self.wandb.log(self.event) 141 | self.event = {} 142 | 143 | self.logger.debug( 144 | "self.training_active.set()", 145 | self.training_active.is_set(), 146 | "pre dataset", 147 | ) 148 | # Wait if training is paused 149 | self.training_active.wait() 150 | 151 | self.logger.debug(":pages: Fetching fineweb-edu pages") 152 | dataset = self.training_loop.run_until_complete( 153 | self.fetch_training_data() 154 | ) 155 | 156 | # Wait if training is paused 157 | self.logger.debug( 158 | "self.training_active.wait()", 159 | self.training_active.is_set(), 160 | "post dataset", 161 | ) 162 | self.training_active.wait() 163 | 164 | if self.master: 165 | self.model.config.block_list.append(self.current_block) 166 | self._process_training_batch(dataset) 167 | 168 | except Exception as e: 169 | self.logger.warning(f"Training Loop Failed with error: {e}") 170 | self.training_status = TrainingStatus.ERROR 171 | self.training_error = str(e) 172 | break 173 | 174 | # Await the training task to ensure it completes before exiting 175 | self.training_status = TrainingStatus.STOPPED 176 | 177 | # If someone intentionally stops the miner, it'll safely terminate operations. 178 | except KeyboardInterrupt: 179 | self.should_exit = True 180 | if self.master: 181 | self.axon.stop() 182 | self.logger.success( 183 | ":white_heavy_check_mark: Miner killed by keyboard interrupt." 184 | ) 185 | exit() 186 | 187 | # In case of unforeseen errors, the miner will log the error and continue operations. 188 | except Exception as e: 189 | self.logger.error(traceback.format_exc()) 190 | 191 | def load_state(self, reset_last_allreduce_block=False): 192 | self.global_progress.epoch = get_progress(self, "global")[0] 193 | if self.local_progress.epoch != self.global_progress.epoch: 194 | self.logger.info( 195 | f"Local Epoch {self.local_progress.epoch} Behind Global Epoch {self.global_progress.epoch}. Loading Latest Model State." 196 | ) 197 | self.pause_training() 198 | # If there's an ongoing upload, check if it's done 199 | while self.current_upload_future and not self.current_upload_future.done(): 200 | self.logger.info( 201 | "Previous upload still in progress. Waiting until upload is complete." 202 | ) 203 | time.sleep(1) 204 | if self.global_progress.epoch == 0: 205 | load_state_from_peer(self, epoch=self.global_progress.epoch) 206 | else: 207 | load_state_from_peer( 208 | self, 209 | uid=self.uid, 210 | epoch=self.global_progress.epoch, 211 | ) 212 | self.model.config.block_list = [] 213 | self.resume_training() 214 | if reset_last_allreduce_block: 215 | self.last_allreduce_block = None 216 | 217 | def run_in_background_thread(self): 218 | """ 219 | Starts the miner's operations in a separate background thread. 220 | This is useful for non-blocking operations. 221 | """ 222 | if not self.is_running: 223 | self.logger.debug("Starting miner in background thread.") 224 | self.should_exit = False 225 | self.thread = threading.Thread(target=self.run, daemon=True) 226 | self.thread.start() 227 | self.is_running = True 228 | self.logger.debug("Started") 229 | 230 | def stop_run_thread(self): 231 | """ 232 | Stops the miner's operations that are running in the background thread. 233 | """ 234 | if self.is_running: 235 | self.logger.debug("Stopping miner in background thread.") 236 | self.should_exit = True 237 | self.thread.join(5) 238 | self.is_running = False 239 | self.logger.debug("Stopped") 240 | 241 | def __enter__(self): 242 | """ 243 | Starts the miner's operations in a background thread upon entering the context. 244 | This method facilitates the use of the miner in a 'with' statement. 245 | """ 246 | # self.run_in_background_thread() 247 | self.run() 248 | return self 249 | 250 | def __exit__(self, exc_type, exc_value, traceback): 251 | """ 252 | Stops the miner's background operations upon exiting the context. 253 | This method facilitates the use of the miner in a 'with' statement. 254 | 255 | Args: 256 | exc_type: The type of the exception that caused the context to be exited. 257 | None if the context was exited without an exception. 258 | exc_value: The instance of the exception that caused the context to be exited. 259 | None if the context was exited without an exception. 260 | traceback: A traceback object encoding the stack trace. 261 | None if the context was exited without an exception. 262 | """ 263 | self.stop_run_thread() 264 | 265 | def resync_metagraph(self): 266 | """Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph.""" 267 | self.logger.info("resync_metagraph()") 268 | 269 | # Sync the metagraph. 270 | self.metagraph.sync(subtensor=self.subtensor) 271 | -------------------------------------------------------------------------------- /distributed_training/utils/uids.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import hashlib 3 | import random 4 | import requests 5 | import traceback 6 | from typing import List 7 | 8 | import bittensor as bt 9 | import distributed_training 10 | import numpy as np 11 | from bittensor.core.chain_data import decode_account_id 12 | from hivemind.p2p import PeerID 13 | from hivemind.utils.timed_storage import ValueWithExpiration 14 | from distributed_training.utils.state_loader import get_progress 15 | from distributed_training import __run__ 16 | 17 | 18 | async def check_uid(self, dendrite, axon, uid, epoch=None): 19 | try: 20 | response = await dendrite( 21 | axon, 22 | distributed_training.protocol.IsAlive(), 23 | deserialize=False, 24 | timeout=2.3, 25 | ) 26 | if response.is_success: 27 | if (epoch is not None) and (response.epoch == epoch): 28 | self.logger.trace(f"UID {uid} is active and on epoch {epoch}") 29 | return True 30 | elif (epoch is not None) and (response.epoch != epoch): 31 | self.logger.trace(f"UID {uid} is active but not on epoch {epoch}") 32 | return False 33 | else: 34 | self.logger.trace(f"UID {uid} is active.") 35 | return True 36 | else: 37 | self.logger.trace(f"UID {uid} is not active.") 38 | return False 39 | except Exception as e: 40 | self.logger.error(f"Error checking UID {uid}: {e}\n{traceback.format_exc()}") 41 | # loop.close() 42 | return False 43 | 44 | 45 | async def check_uid_availability( 46 | self, 47 | dendrite, 48 | metagraph: "bt.metagraph.Metagraph", 49 | uid: int, 50 | vpermit_tao_limit: int, 51 | epoch: int = None, 52 | ) -> bool: 53 | """Check if uid is available. The UID should be available if it is serving and has less than vpermit_tao_limit stake 54 | Args: 55 | metagraph (:obj: bt.metagraph.Metagraph): Metagraph object 56 | uid (int): uid to be checked 57 | vpermit_tao_limit (int): Validator permit tao limit 58 | Returns: 59 | bool: True if uid is available, False otherwise 60 | """ 61 | # Filter non serving axons. 62 | if not metagraph.axons[uid].is_serving: 63 | return False 64 | 65 | # Filter validator permit > 1024 stake. 66 | if metagraph.validator_permit[uid]: 67 | if metagraph.S[uid] > vpermit_tao_limit: 68 | return False 69 | 70 | # Filter for miners that are processing other responses 71 | if not await check_uid(self, dendrite, metagraph.axons[uid], uid, epoch): 72 | return False 73 | # Available otherwise. 74 | return True 75 | 76 | 77 | async def get_random_uids( 78 | self, dendrite, k: int, exclude: List[int] = None, epoch: int = None 79 | ) -> np.ndarray: 80 | """Returns k available random uids from the metagraph. 81 | Args: 82 | k (int): Number of uids to return. 83 | exclude (List[int]): List of uids to exclude from the random sampling. 84 | Returns: 85 | uids (np.ndarray): Randomly sampled available uids. 86 | Notes: 87 | If `k` is larger than the number of available `uids`, set `k` to the number of available `uids`. 88 | """ 89 | candidate_uids = [] 90 | avail_uids = [] 91 | uids = [i for i in range(self.metagraph.n)] 92 | random.shuffle(uids) 93 | 94 | responses = [] 95 | attempt = 0 96 | limit = self.config.neuron.uid_isalive_limit 97 | while (sum(responses) < k) and ( 98 | (attempt < (int(self.metagraph.n / limit) - 1)) or (attempt == 0) 99 | ): 100 | tasks = [] 101 | if limit > int(self.metagraph.n): 102 | limit = int(self.metagraph.n) 103 | 104 | for i in range((attempt * limit), (attempt * limit) + limit): 105 | # The dendrite client queries the network. 106 | tasks.append( 107 | check_uid_availability( 108 | self, 109 | dendrite, 110 | self.metagraph, 111 | uids[i], 112 | self.config.neuron.vpermit_tao_limit, 113 | epoch, 114 | ) 115 | ) 116 | responses += await asyncio.gather(*tasks) 117 | attempt += 1 118 | 119 | for i, response in enumerate(responses): 120 | if response == False: 121 | self.failed_is_alive_counter[uids[i]] += 1 122 | else: 123 | self.failed_is_alive_counter[uids[i]] = 0 124 | 125 | for uid, uid_is_available in zip(uids, (responses)): 126 | uid_is_not_excluded = exclude is None or uid not in exclude 127 | if uid_is_available: 128 | avail_uids.append(uid) 129 | if uid_is_not_excluded: 130 | candidate_uids.append(uid) 131 | 132 | # Check if candidate_uids contain enough for querying, if not grab all avaliable uids 133 | available_uids = candidate_uids 134 | if len(candidate_uids) < k: 135 | uids = np.array(available_uids) 136 | else: 137 | uids = np.array(random.sample(available_uids, k)) 138 | return uids 139 | 140 | 141 | def get_next_uids_manual(self, epoch: int, k: int = 25) -> List[int]: 142 | try: 143 | for uid in self.uid_tracker.keys(): 144 | self.uid_tracker[ 145 | uid 146 | ].train.revision = f"{__run__}.{epoch}.{get_progress(self, 'local', uid=uid, donwload_on_all_ranks=False)[1]}" 147 | 148 | # Rank miners based off train_similarity_score_last_updated 149 | uids = list( 150 | dict( 151 | sorted( 152 | ( 153 | (uid, rec) 154 | for uid, rec in self.uid_tracker.items() 155 | if rec.train.revision.split(".")[-1] != "0" 156 | ), 157 | key=lambda item: ( 158 | not item[1].train.is_valid, 159 | item[1].train.updated_time, 160 | ), 161 | ) 162 | ).keys() 163 | ) 164 | uids = uids[:k] 165 | return uids 166 | 167 | except Exception as e: 168 | self.logger.info(f"Error getting UID manually: {e}") 169 | 170 | 171 | def get_next_uid_api(self, epoch: int = None) -> List[int]: 172 | try: 173 | # raise Exception("Forcing manual UID retrieval") 174 | response = requests.get( 175 | url=self.uid_api_url, headers={"Authorization": self.uid_api_get_token} 176 | ) 177 | uids = response.json()["uids"] 178 | 179 | assert uids != self.miner_uids 180 | assert type(uids) == list 181 | assert all(isinstance(uid, int) for uid in uids) 182 | return uids 183 | except Exception as e: 184 | self.logger.info( 185 | f"Error {e} getting UID from: {self.uid_api_url}. Attempting to get UID manually." 186 | ) 187 | uids = get_next_uids_manual(self, epoch, k=self.config.neuron.sample_size) 188 | return uids 189 | 190 | 191 | def post_next_uid_api(self, epoch: int = None): 192 | uids = get_next_uids_manual(self, epoch, k=self.config.neuron.sample_size) 193 | try: 194 | response = requests.post( 195 | url=self.uid_api_url, 196 | json={"uids": uids}, 197 | headers={"Authorization": self.uid_api_post_token}, 198 | ) 199 | if response.status_code != 200: 200 | raise Exception( 201 | f"UID post request failed with error: Resp {response.status_code}" 202 | ) 203 | except Exception as e: 204 | self.logger.info( 205 | f"Error {e} getting UID from: {self.uid_api_url}. Attempting to get UID manually." 206 | ) 207 | 208 | 209 | def update_run_peerid_list(self): 210 | prefix = self.grad_averager.matchmaking_kwargs["prefix"] 211 | metadata, _ = self.dht.get(f"{prefix}.all_averagers", latest=True) or ( 212 | {}, 213 | None, 214 | ) 215 | self.run_peer_id_list = [ 216 | str(PeerID(peer_id)) 217 | for peer_id, info in metadata.items() 218 | if isinstance(info, ValueWithExpiration) 219 | and isinstance(info.value, (float, int)) 220 | ] 221 | 222 | 223 | def decode_metadata(encoded_ss58: tuple, metadata: dict) -> tuple[str, str]: 224 | decoded_key = decode_account_id(encoded_ss58[0]) 225 | commitment = metadata["info"]["fields"][0][0] 226 | bytes_tuple = commitment[next(iter(commitment.keys()))][0] 227 | return decoded_key, bytes(bytes_tuple).decode() 228 | 229 | 230 | def hash_r2_creds(account_id, access_key_id, secret_key): 231 | concat = f"{account_id}:{access_key_id}:{secret_key}" 232 | return hashlib.sha256(concat.encode()).hexdigest() 233 | 234 | 235 | def map_uid_to_peerid(self): 236 | result = {} 237 | try: 238 | subtensor = bt.subtensor(config=self.config) 239 | result = subtensor.substrate.query_map( 240 | module="Commitments", 241 | storage_function="CommitmentOf", 242 | params=[self.config.netuid], 243 | block_hash=None, 244 | ) 245 | hotkey_to_uid = dict(zip(self.metagraph.hotkeys, self.metagraph.uids.tolist())) 246 | except Exception as e: 247 | self.logger.info(f"Error {e} when querying UID commitments") 248 | 249 | for key, value in result: 250 | try: 251 | hotkey, metadata = decode_metadata(key, value.value) 252 | if hotkey not in hotkey_to_uid: 253 | continue 254 | 255 | uid = hotkey_to_uid[hotkey] 256 | last_updated_block = value.value.get("block", 0) 257 | if last_updated_block is None: 258 | last_updated_block = 0 259 | 260 | concatenated = metadata 261 | 262 | if len(concatenated) != 128: 263 | raise ValueError( 264 | f"Commitment {concatenated} is of length {len(concatenated)} but should be of length 128." 265 | ) 266 | 267 | account_id = concatenated[:32] 268 | access_key_id = concatenated[32:64] 269 | secret_access_key = concatenated[64:] 270 | r2_hash = hash_r2_creds(account_id, access_key_id, secret_access_key) 271 | 272 | self.uid_tracker[uid].chaindata.last_updated_block = last_updated_block 273 | self.uid_tracker[uid].train.r2_hash = r2_hash 274 | self.uid_tracker[uid].train.account_id = account_id 275 | self.uid_tracker[uid].train.access_key_id = access_key_id 276 | self.uid_tracker[uid].train.secret_access_key = secret_access_key 277 | 278 | if uid == self.uid: 279 | peer_id = str(self.dht.peer_id.to_base58()) 280 | else: 281 | peer_id = get_progress( 282 | self, "local", uid=uid, donwload_on_all_ranks=False 283 | )[2] 284 | 285 | if peer_id != self.uid_tracker[uid].all_reduce.peer_id: 286 | uid_peerid_metadata = [ 287 | metadata.all_reduce.peer_id 288 | for key, metadata in self.uid_tracker.items() 289 | if key != uid 290 | ] 291 | if peer_id in uid_peerid_metadata: 292 | uid_list = [ 293 | uid 294 | for uid, metadata in self.uid_tracker.items() 295 | if metadata.all_reduce.peer_id == peer_id 296 | ] 297 | for uid_i in uid_list: 298 | if ( 299 | self.uid_tracker[uid_i].chaindata.last_updated_block 300 | is not None 301 | ) and ( 302 | self.uid_tracker[uid_i].chaindata.last_updated_block 303 | > last_updated_block 304 | ): 305 | self.uid_tracker[uid_i].chaindata.last_updated_block = 0 306 | self.uid_tracker[uid_i].all_reduce.peer_id = None 307 | else: 308 | self.uid_tracker[uid].all_reduce.peer_id = peer_id 309 | self.uid_tracker[ 310 | uid 311 | ].chaindata.last_updated_block = last_updated_block 312 | else: 313 | self.uid_tracker[uid].all_reduce.peer_id = peer_id 314 | self.uid_tracker[ 315 | uid 316 | ].chaindata.last_updated_block = last_updated_block 317 | 318 | self.logger.debug(f"Retrieved commitment for UID {uid}: {metadata}") 319 | 320 | except Exception as e: 321 | self.logger.debug(f"Failed to decode commitment for UID {uid}: {e}") 322 | continue 323 | 324 | self.logger.debug("Finished extracting commitments for all uids") 325 | -------------------------------------------------------------------------------- /distributed_training/utils/config.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2023 Opentensor Foundation 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | # the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | # DEALINGS IN THE SOFTWARE. 18 | 19 | import argparse 20 | import os 21 | 22 | import bittensor as bt 23 | import torch 24 | from distributed_training import __run__, __version__ 25 | from dataclasses import dataclass 26 | 27 | 28 | @dataclass 29 | class R2Access: 30 | access_key_id: str | None = None 31 | secret_access_key: str | None = None 32 | 33 | 34 | @dataclass 35 | class R2Config: 36 | bucket_name: str | None = None 37 | account_id: str | None = None 38 | read: R2Access = R2Access() 39 | write: R2Access = R2Access() 40 | 41 | 42 | def check_config(cls, config: "bt.Config"): 43 | r"""Checks/validates the config namespace object.""" 44 | bt.logging.check_config(config) 45 | 46 | full_path = os.path.expanduser( 47 | "{}/{}/{}/netuid{}/{}".format( 48 | config.logging.logging_dir, 49 | config.wallet.name, 50 | config.wallet.hotkey, 51 | config.netuid, 52 | config.neuron.name, 53 | ) 54 | ) 55 | print("full path:", full_path) 56 | config.neuron.full_path = os.path.expanduser(full_path) 57 | if not os.path.exists(config.neuron.full_path): 58 | os.makedirs(config.neuron.full_path, exist_ok=True) 59 | 60 | 61 | def add_args(cls, parser, prefix=None): 62 | """ 63 | Adds relevant arguments to the parser for operation. 64 | """ 65 | # Netuid Arg: The netuid of the subnet to connect to. 66 | parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1) 67 | 68 | neuron_type = "validator" if "miner" not in cls.__name__.lower() else "miner" 69 | prefix_str = "" if prefix == None else prefix + "." 70 | try: 71 | default_name = os.getenv("BT_WALLET_NAME") or "default" 72 | default_hotkey = os.getenv("BT_WALLET_NAME") or "default" 73 | default_path = os.getenv("BT_WALLET_PATH") or "~/.bittensor/wallets/" 74 | parser.add_argument( 75 | "--no_prompt", 76 | dest="no_prompt", 77 | action="store_true", 78 | help="""Set true to avoid prompting the user.""", 79 | default=False, 80 | ) 81 | parser.add_argument( 82 | "--" + prefix_str + "wallet.name", 83 | required=False, 84 | default=default_name, 85 | help="The name of the wallet to unlock for running bittensor " 86 | "(name mock is reserved for mocking this wallet)", 87 | ) 88 | parser.add_argument( 89 | "--" + prefix_str + "wallet.hotkey", 90 | required=False, 91 | default=default_hotkey, 92 | help="The name of the wallet's hotkey.", 93 | ) 94 | parser.add_argument( 95 | "--" + prefix_str + "wallet.path", 96 | required=False, 97 | default=default_path, 98 | help="The path to your bittensor wallets", 99 | ) 100 | except argparse.ArgumentError as e: 101 | pass 102 | 103 | parser.add_argument( 104 | "--dht.port", 105 | type=int, 106 | help="Trials for this neuron go in neuron.root / (wallet_cold - wallet_hot) / neuron.name. ", 107 | default=8009, 108 | ) 109 | 110 | parser.add_argument( 111 | "--dht.ip", 112 | type=str, 113 | help="The IP address to use in announce_maddrs", 114 | ) 115 | 116 | parser.add_argument( 117 | "--neuron.events_retention_size", 118 | type=str, 119 | help="Events retention size.", 120 | default=2 * 1024 * 1024 * 1024, # 2 GB 121 | ) 122 | 123 | parser.add_argument( 124 | "--neuron.name", 125 | type=str, 126 | help="Trials for this neuron go in neuron.root / (wallet_cold - wallet_hot) / neuron.name. ", 127 | default=neuron_type, 128 | ) 129 | 130 | parser.add_argument( 131 | "--neuron.device", 132 | type=str, 133 | help="Device to run on.", 134 | default=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 135 | ) 136 | 137 | parser.add_argument( 138 | "--neuron.epoch_length", 139 | type=int, 140 | help="The default epoch length (how often we set weights, measured in 12 second blocks).", 141 | default=160, 142 | ) 143 | 144 | parser.add_argument( 145 | "--neuron.dont_save_events", 146 | action="store_true", 147 | help="If set, we dont save events to a log file.", 148 | default=False, 149 | ) 150 | 151 | parser.add_argument( 152 | "--neuron.initial_peers", 153 | type=str, 154 | nargs="+", 155 | help="The addresses for the DHT", 156 | default=[ 157 | "/ip4/161.97.156.125/tcp/8000/p2p/12D3KooWHDMAeP3zKHrALtREoc2vfYQL7HhsMN7WBvft7hgBCRhK", 158 | ], 159 | ) 160 | 161 | parser.add_argument( 162 | "--neuron.blocks_per_allreduce", 163 | type=int, 164 | help="Amount of blocks between each all reduce", 165 | default=400, 166 | ) 167 | 168 | parser.add_argument( 169 | "--neuron.global_model_name", 170 | type=str, 171 | help="The model to be trained", 172 | default="llama-4b-ws-4", 173 | ) 174 | 175 | parser.add_argument( 176 | "--neuron.global_tokenizer_name", 177 | type=str, 178 | help="The HF repo_id of the tokenizer to be used for training", 179 | default="dstrbtd/llama-1b", 180 | ) 181 | 182 | parser.add_argument( 183 | "--neuron.master_ss58_address", 184 | type=str, 185 | help="The ss58 address for the master validator UID", 186 | default="5EnC86fRRRoaXUZvkrDFYpAihuyEAp3wGkY5r3Gak1kPTDVP", 187 | ) 188 | 189 | parser.add_argument( 190 | "--neuron.min_group_size", 191 | type=int, 192 | help="The minimum group size for an all reduce", 193 | default=32, 194 | ) 195 | 196 | parser.add_argument( 197 | "--neuron.local_batch_size_train", 198 | type=int, 199 | help="The default batch size", 200 | default=2, 201 | ) 202 | 203 | parser.add_argument( 204 | "--neuron.global_batch_size_train", 205 | type=int, 206 | help="The hivemind global target_batch_size", 207 | default=35200, 208 | ) 209 | 210 | parser.add_argument( 211 | "--neuron.upload_steps", 212 | type=int, 213 | help="The frequency of uploads per inner step", 214 | default=5, 215 | ) 216 | 217 | parser.add_argument( 218 | "--neuron.local_batch_size_train_effective", 219 | type=int, 220 | help="Amount of micro batches for gradient accumulation", 221 | default=512, 222 | ) 223 | 224 | parser.add_argument( 225 | "--neuron.run_id", 226 | type=str, 227 | help="The DHT run_id", 228 | default=f"v{__version__.replace('.','_')}_r{__run__}", 229 | ) 230 | 231 | parser.add_argument( 232 | "--neuron.show_all_rank_logs", 233 | action="store_true", 234 | help="Set to true to show logs of all ranks", 235 | default=False, 236 | ) 237 | 238 | parser.add_argument( 239 | "--neuron.dont_wandb_log", 240 | action="store_true", 241 | help="Toggles wandb logging for the project", 242 | default=False, 243 | ) 244 | 245 | parser.add_argument( 246 | "--neuron.wandb_project", 247 | type=str, 248 | help="The wandb project to log to", 249 | default="distributed_training", 250 | ) 251 | 252 | parser.add_argument( 253 | "--neuron.wandb_entity", 254 | type=str, 255 | help="The wandb project to log to", 256 | default="kmfoda", 257 | ) 258 | 259 | parser.add_argument( 260 | "--neuron.influxdb_bucket", 261 | type=str, 262 | help="The influxdb bucket", 263 | default="distributed-training-metrics", 264 | ) 265 | 266 | parser.add_argument( 267 | "--neuron.influxdb_url", 268 | type=str, 269 | help="The influxdb url", 270 | default="http://161.97.156.125:8086", 271 | ) 272 | 273 | parser.add_argument( 274 | "--neuron.influxdb_token", 275 | type=str, 276 | help="The influxdb token", 277 | default="JCDOYKFbiC13zdgbTQROpyvB69oaUWvO4pRw_c3AEYhTjU998E_X_oIJJOVAW24nAE0WYxMwIgdFSLZg8aeV-g==", 278 | ) 279 | 280 | parser.add_argument( 281 | "--neuron.influxdb_org", 282 | type=str, 283 | help="The influxdb org", 284 | default="distributed-training", 285 | ) 286 | 287 | parser.add_argument( 288 | "--neuron.use_dct", 289 | action="store_true", 290 | help="If true uses DCT when compressing gradients", 291 | default=False, 292 | ) 293 | 294 | parser.add_argument( 295 | "--neuron.momentum_decay", 296 | type=float, 297 | help="Amount of micro batches for gradient accumulation", 298 | default=0.999, 299 | ) 300 | 301 | parser.add_argument( 302 | "--neuron.target_chunk", 303 | type=int, 304 | help="Amount of micro batches for gradient accumulation", 305 | default=64, 306 | ) 307 | 308 | parser.add_argument( 309 | "--neuron.quantization_bins", 310 | type=int, 311 | help="Amount of micro batches for gradient accumulation", 312 | default=256, 313 | ) 314 | 315 | parser.add_argument( 316 | "--neuron.quantization_range", 317 | type=int, 318 | help="Range", 319 | default=6, 320 | ) 321 | 322 | parser.add_argument( 323 | "--neuron.topk_compression", 324 | type=int, 325 | help="Amount of micro batches for gradient accumulation", 326 | default=32, 327 | ) 328 | 329 | if neuron_type == "validator": 330 | parser.add_argument( 331 | "--neuron.uid_api_url", 332 | type=str, 333 | help="The url for the UID api.", 334 | default="http://161.97.156.125:8002/uid", 335 | ) 336 | 337 | parser.add_argument( 338 | "--neuron.uid_api_get_token", 339 | type=str, 340 | help="The token for the UID get api.", 341 | default=os.getenv("API_GET_TOKEN", None), 342 | ) 343 | 344 | parser.add_argument( 345 | "--neuron.uid_api_post_token", 346 | type=str, 347 | help="The token for the UID post api.", 348 | default=os.getenv("API_POST_TOKEN", None), 349 | ) 350 | 351 | parser.add_argument( 352 | "--neuron.uid_isalive_limit", 353 | type=int, 354 | help="The maximum number of uids to call concurrently", 355 | default=25, 356 | ) 357 | 358 | parser.add_argument( 359 | "--neuron.weight_update_interval", 360 | type=int, 361 | help="The number of steps before updating the model's weights", 362 | default=900, 363 | ) 364 | 365 | parser.add_argument( 366 | "--neuron.num_concurrent_forwards", 367 | type=int, 368 | help="The number of concurrent forwards running at any time.", 369 | default=1, 370 | ) 371 | 372 | parser.add_argument( 373 | "--neuron.sample_size", 374 | type=int, 375 | help="The number of miners to query in a single step.", 376 | default=25, 377 | ) 378 | 379 | parser.add_argument( 380 | "--neuron.disable_set_weights", 381 | action="store_true", 382 | help="Disables setting weights.", 383 | default=False, 384 | ) 385 | 386 | parser.add_argument( 387 | "--neuron.moving_average_alpha", 388 | type=float, 389 | help="Moving average alpha parameter, how much to add of the new observation.", 390 | default=0.6, 391 | ) 392 | 393 | parser.add_argument( 394 | "--neuron.axon_off", 395 | "--axon_off", 396 | action="store_true", 397 | # Note: the validator needs to serve an Axon with their IP or they may 398 | # be blacklisted by the firewall of serving peers on the network. 399 | help="Set this flag to not attempt to serve an Axon.", 400 | default=False, 401 | ) 402 | 403 | parser.add_argument( 404 | "--neuron.vpermit_tao_limit", 405 | type=int, 406 | help="The maximum number of TAO allowed to query a validator with a vpermit.", 407 | default=40960, 408 | ) 409 | 410 | parser.add_argument( 411 | "--neuron.openskill_beta", 412 | type=int, 413 | help="The value of the beta used in the openskill model.", 414 | default=7, 415 | ) 416 | 417 | parser.add_argument( 418 | "--neuron.openskill_tau", 419 | type=int, 420 | help="The value of the tau used in the openskill model.", 421 | default=0.1, 422 | ) 423 | 424 | parser.add_argument( 425 | "--neuron.assigned_loss_score_moving_average_alpha", 426 | type=float, 427 | help="The value of the alpha for the assinged loss score moving average.", 428 | default=0.05, 429 | ) 430 | 431 | else: 432 | parser.add_argument( 433 | "--blacklist.force_validator_permit", 434 | action="store_true", 435 | help="If set, we will force incoming requests to have a permit.", 436 | default=False, 437 | ) 438 | 439 | parser.add_argument( 440 | "--blacklist.allow_non_registered", 441 | action="store_true", 442 | help="If set, miners will accept queries from non registered entities. (Dangerous!)", 443 | default=False, 444 | ) 445 | 446 | 447 | def config(cls): 448 | """ 449 | Returns the configuration object specific to this miner or validator after adding relevant arguments. 450 | """ 451 | parser = argparse.ArgumentParser() 452 | bt.subtensor.add_args(parser) 453 | bt.logging.add_args(parser) 454 | bt.axon.add_args(parser) 455 | cls.add_args(parser) 456 | return bt.config(parser) 457 | -------------------------------------------------------------------------------- /distributed_training/base/neuron.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2025 dstrbtd.ai 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import time 19 | import os 20 | import pathlib 21 | import copy 22 | import boto3 23 | import threading 24 | from abc import ABC, abstractmethod 25 | 26 | import bittensor as bt 27 | 28 | from distributed_training import __spec_version__ as spec_version 29 | from botocore.config import Config 30 | 31 | # Sync calls set weights and also resyncs the metagraph. 32 | from distributed_training.utils.config import ( 33 | add_args, 34 | check_config, 35 | config, 36 | R2Access, 37 | R2Config, 38 | ) 39 | from distributed_training.utils.logger import setup_logging 40 | from distributed_training.utils.misc import ttl_get_block 41 | from dotenv import load_dotenv 42 | 43 | 44 | import torch, torch.distributed as dist 45 | import datetime as dt 46 | 47 | load_dotenv() 48 | 49 | 50 | class BaseNeuron(ABC): 51 | """ 52 | Base class for Bittensor miners. This class is abstract and should be inherited by a subclass. It contains the core logic for all neurons; validators and miners. 53 | 54 | In addition to creating a wallet, subtensor, and metagraph, this class also handles the synchronization of the network state via a basic checkpointing mechanism based on epoch length. 55 | """ 56 | 57 | neuron_type: str = "BaseNeuron" 58 | 59 | @classmethod 60 | def check_config(cls, config: "bt.Config"): 61 | check_config(cls, config) 62 | 63 | @classmethod 64 | def add_args(cls, parser): 65 | add_args(cls, parser) 66 | 67 | @classmethod 68 | def config(cls): 69 | return config(cls) 70 | 71 | subtensor: "bt.subtensor" 72 | wallet: "bt.wallet" 73 | metagraph: "bt.metagraph" 74 | spec_version: int = spec_version 75 | 76 | @property 77 | def block(self): 78 | self.current_block = ttl_get_block(self) 79 | return self.current_block 80 | 81 | def set_current_block_across_ranks(self): 82 | current_block_tensor = ( 83 | torch.tensor([self.current_block]) if self.master else torch.tensor([0]) 84 | ) 85 | dist.broadcast(current_block_tensor, src=0, group=self.gloo_group) 86 | self.current_block = current_block_tensor[0].item() 87 | 88 | def __init__(self, config=None): 89 | base_config = copy.deepcopy(config or BaseNeuron.config()) 90 | self.config = self.config() 91 | self.config.merge(base_config) 92 | self.check_config(self.config) 93 | 94 | # Set up logging with the provided configuration and directory. 95 | bt.logging.set_config(config=self.config.logging) 96 | self.logger = bt.logging 97 | 98 | # If a gpu is required, set the device to cuda:N (e.g. cuda:0) 99 | self.device = self.config.neuron.device 100 | 101 | # Log the configuration for reference. 102 | self.logger.info(self.config) 103 | 104 | # Build Bittensor objects 105 | # These are core Bittensor classes to interact with the network. 106 | self.logger.info("Setting up bittensor objects.") 107 | 108 | # Set distributed variables 109 | self.world_size = int(os.getenv("WORLD_SIZE", 1)) 110 | self.local_rank = int(os.getenv("LOCAL_RANK", 0)) 111 | torch.cuda.set_device(self.local_rank) 112 | self.master = self.local_rank == 0 113 | 114 | if self.master: 115 | # The wallet holds the cryptographic key pairs for the miner. 116 | self.wallet = bt.wallet(config=self.config) 117 | self.logger.info(f"Wallet: {self.wallet}") 118 | 119 | if not dist.is_initialized(): 120 | if not dist.is_initialized(): 121 | dist.init_process_group( 122 | backend="nccl", 123 | init_method="tcp://127.0.0.1:29500", 124 | rank=self.local_rank, 125 | world_size=self.world_size, 126 | # timeout=dt.timedelta(seconds=1800), 127 | ) 128 | if not hasattr(self, "gloo_group"): 129 | self.gloo_group = dist.new_group( 130 | backend="gloo", 131 | ) 132 | 133 | if self.master: 134 | # The subtensor is our connection to the Bittensor blockchain. 135 | self.subtensor = bt.subtensor(config=self.config) 136 | self.logger.info(f"Subtensor: {self.subtensor}") 137 | 138 | # The metagraph holds the state of the network, letting us know about other validators and miners. 139 | self.metagraph = self.subtensor.metagraph(self.config.netuid) 140 | self.logger.info(f"Metagraph: {self.metagraph}") 141 | 142 | # Check if the miner is registered on the Bittensor network before proceeding further. 143 | self.check_registered() 144 | 145 | # Each miner gets a unique identity (UID) in the network for differentiation. 146 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 147 | self.logger.info( 148 | f"Running neuron on subnet: {self.config.netuid} with uid {self.uid} using network: {self.subtensor.chain_endpoint}" 149 | ) 150 | else: 151 | self.uid = 0 152 | 153 | uid = torch.tensor([self.uid], device="cpu") 154 | dist.barrier(group=self.gloo_group) 155 | dist.broadcast(uid, src=0, group=self.gloo_group) 156 | dist.barrier(group=self.gloo_group) 157 | self.uid = uid[0].item() 158 | 159 | master_uid = ( 160 | torch.tensor( 161 | [ 162 | self.metagraph.hotkeys.index( 163 | self.config.neuron.master_ss58_address, 164 | ) 165 | ] 166 | ) 167 | if self.master 168 | else torch.tensor([0]) 169 | ) 170 | dist.broadcast(master_uid, src=0, group=self.gloo_group) 171 | self.master_uid = master_uid[0].item() 172 | 173 | # Setup Logging 174 | setup_logging(self, config=self.config) 175 | 176 | # Create the R2 data model 177 | r2 = R2Config( 178 | bucket_name=f"{self.config.neuron.global_model_name.split('/')[-1]}-{self.uid:03d}" 179 | if "miner" in self.__class__.__name__.lower() 180 | else self.config.neuron.global_model_name, 181 | account_id=os.getenv("R2_ACCOUNT_ID"), 182 | read=R2Access( 183 | access_key_id=os.getenv("R2_READ_ACCESS_KEY_ID"), 184 | secret_access_key=os.getenv("R2_READ_SECRET_ACCESS_KEY"), 185 | ), 186 | write=R2Access( 187 | access_key_id=os.getenv("R2_WRITE_ACCESS_KEY_ID"), 188 | secret_access_key=os.getenv("R2_WRITE_SECRET_ACCESS_KEY"), 189 | ), 190 | ) 191 | self.config.r2 = r2 192 | 193 | # Save directory 194 | self.output_dir = os.path.join(os.getcwd(), self.config.r2.bucket_name) 195 | os.makedirs(self.output_dir, exist_ok=True) 196 | os.makedirs( 197 | os.path.join(os.getcwd(), self.config.neuron.global_model_name), 198 | exist_ok=True, 199 | ) 200 | 201 | # Init Step 202 | self.step = 0 203 | 204 | # Initialize the all_reduce, download and upload variables. 205 | self.allreduce_timeout = 600 206 | self.upload_state_duration = 1800 207 | self.all_reduce_success_status = True 208 | self.should_all_reduce = False 209 | self.retry_limit = 100 210 | self.retry_delay = 60 211 | 212 | # Create different r2 sessions 213 | r2_config = Config( 214 | retries={"max_attempts": 10, "mode": "adaptive"}, # or "standard" 215 | connect_timeout=30, 216 | read_timeout=120, 217 | max_pool_connections=50, 218 | ) 219 | self.session = boto3.session.Session() 220 | self.r2 = { 221 | "local": self.session.client( 222 | "s3", 223 | endpoint_url=f"https://{self.config.r2.account_id}.r2.cloudflarestorage.com", 224 | aws_access_key_id=self.config.r2.read.access_key_id, 225 | aws_secret_access_key=self.config.r2.read.secret_access_key, 226 | region_name="auto", 227 | config=r2_config, 228 | ) 229 | } 230 | self.r2["write"] = boto3.client( 231 | "s3", 232 | endpoint_url=f"https://{self.config.r2.account_id}.r2.cloudflarestorage.com", 233 | aws_access_key_id=self.config.r2.write.access_key_id, 234 | aws_secret_access_key=self.config.r2.write.secret_access_key, 235 | region_name="auto", 236 | config=r2_config, 237 | ) 238 | commitment = None 239 | while commitment == None: 240 | try: 241 | if self.master: 242 | commitment = [ 243 | self.subtensor.get_commitment( 244 | self.config.netuid, self.master_uid 245 | ) 246 | ] 247 | else: 248 | commitment = [ 249 | self.config.r2.account_id 250 | + self.config.r2.read.access_key_id 251 | + self.config.r2.read.secret_access_key 252 | ] 253 | dist.broadcast_object_list(commitment, src=0, group=self.gloo_group) 254 | global_account_id = commitment[0][:32] 255 | global_access_key_id = commitment[0][32:64] 256 | global_asecret_access_key = commitment[0][64:] 257 | self.r2["global"] = self.session.client( 258 | "s3", 259 | endpoint_url=f"https://{global_account_id}.r2.cloudflarestorage.com", 260 | aws_access_key_id=global_access_key_id, 261 | aws_secret_access_key=global_asecret_access_key, 262 | region_name="auto", 263 | config=r2_config, 264 | ) 265 | except Exception as e: 266 | self.logger.info(f"Error getting commitment: {str(e)}") 267 | time.sleep(15) 268 | 269 | self.reload_state_event = threading.Event() 270 | 271 | # @abstractmethod # miner is not using this anymore 272 | async def forward(self, synapse: bt.Synapse) -> bt.Synapse: 273 | ... 274 | 275 | @abstractmethod 276 | def run(self): 277 | ... 278 | 279 | def sync(self): 280 | """ 281 | Wrapper for synchronizing the state of the network for the given miner or validator. 282 | """ 283 | if self.master: 284 | try: 285 | # Ensure miner or validator hotkey is still registered on the network. 286 | self.check_registered() 287 | 288 | if self.should_sync_metagraph(): 289 | self.resync_metagraph() 290 | 291 | if self.should_set_weights(): 292 | self.logger.info("Should Set Weights") 293 | self.set_weights() 294 | 295 | if self.should_sync_metagraph(): 296 | self.metagraph.last_update[self.uid] = self.block 297 | 298 | if (self.step != 0) and (self.neuron_type != "MinerNeuron"): 299 | self.save_state() 300 | except Exception as e: 301 | self.logger.debug("Sync failed with error {e}") 302 | 303 | def check_registered(self): 304 | # --- Check for registration. 305 | if not self.subtensor.is_hotkey_registered( 306 | netuid=self.config.netuid, 307 | hotkey_ss58=self.wallet.hotkey.ss58_address, 308 | ): 309 | self.logger.error( 310 | f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}." 311 | f" Please register the hotkey using `btcli subnets register` before trying again" 312 | ) 313 | exit() 314 | 315 | def should_sync_metagraph(self): 316 | """ 317 | Check if enough epoch blocks have elapsed since the last checkpoint to sync. 318 | """ 319 | return ( 320 | self.block - self.metagraph.last_update[self.uid] 321 | ) > self.config.neuron.epoch_length 322 | 323 | def should_set_weights(self) -> bool: 324 | # Don't set weights on initialization. 325 | if self.step == 0: 326 | return False 327 | 328 | # Check if enough epoch blocks have elapsed since the last epoch. 329 | if self.config.neuron.disable_set_weights: 330 | return False 331 | 332 | # Define appropriate logic for when set weights. 333 | return ( 334 | self.block - self.metagraph.last_update[self.uid] 335 | ) > self.config.neuron.epoch_length and self.neuron_type != "MinerNeuron" # don't set weights if you're a miner 336 | 337 | def save_state(self): 338 | self.logger.warning( 339 | "save_state() not implemented for this neuron. You can implement this function to save model checkpoints or other useful data." 340 | ) 341 | 342 | def load_state(self): 343 | self.logger.warning( 344 | "load_state() not implemented for this neuron. You can implement this function to load model checkpoints or other useful data." 345 | ) 346 | 347 | def print_memory_usage(self): 348 | def cg_read(p): 349 | try: 350 | return pathlib.Path(p).read_text().strip() 351 | except FileNotFoundError: 352 | return None 353 | 354 | memory_used = 0 355 | memory_limit = 0 356 | memory_used_gb = 0 357 | memory_limit_gb = 0 358 | 359 | # Memory limit (bytes) — cgroup v2 then v1 360 | memory_limit = cg_read("/sys/fs/cgroup/memory.max") or cg_read( 361 | "/sys/fs/cgroup/memory/memory.limit_in_bytes" 362 | ) 363 | if memory_limit and memory_limit != "max": 364 | memory_limit_gb = int(memory_limit) / 1024**3 365 | self.logger.debug(f"Memory limit: {memory_limit_gb:.1f} GB") 366 | else: 367 | self.logger.debug("Memory limit: Unlimited Or Not Set") 368 | 369 | memory_used = cg_read("/sys/fs/cgroup/memory.current") or cg_read( 370 | "/sys/fs/cgroup/memory/memory.usage_in_bytes" 371 | ) 372 | if memory_used and memory_used != "max": 373 | memory_used_gb = int(memory_used) / 1024**3 374 | self.logger.debug(f"Memory Used: {memory_used_gb:.1f} GB") 375 | else: 376 | self.logger.debug("Memory Used: Unlimited Or Not Set") 377 | 378 | if self.master: 379 | self.logger.debug( 380 | f"CPU Memory Usage: {memory_used_gb:.1f}GBs out of {memory_limit_gb:.1f}GBs" 381 | ) 382 | 383 | return ( 384 | f"CPU Memory Usage: {memory_used_gb:.1f}GBs out of {memory_limit_gb:.1f}GBs" 385 | ) 386 | -------------------------------------------------------------------------------- /eval/eval_loss.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | import random 4 | import datetime 5 | from zoneinfo import ZoneInfo 6 | from typing import Callable, Dict 7 | 8 | from dotenv import load_dotenv 9 | import boto3 10 | import os 11 | import gc 12 | import sys 13 | import torch 14 | import shutil 15 | import json 16 | import torch.distributed as dist 17 | from distributed_training import __run__ 18 | from distributed_training.data.dataset import DatasetLoader 19 | from transformers import AutoTokenizer, AutoModelForCausalLM 20 | from huggingface_hub import HfApi, snapshot_download 21 | from huggingface_hub.constants import HF_HUB_CACHE 22 | from pathlib import Path 23 | from influxdb_client import InfluxDBClient, Point, WritePrecision 24 | from tabulate import tabulate 25 | from torch.distributed._tensor import DeviceMesh 26 | from torch.distributed._composable.fsdp import ( 27 | fully_shard, 28 | MixedPrecisionPolicy, 29 | ) 30 | from botocore.config import Config 31 | from distributed_training.utils.r2 import ( 32 | upload_folder_to_r2, 33 | r2_download, 34 | log_peerid_to_r2, 35 | ) 36 | import logging 37 | 38 | load_dotenv() 39 | # === CONFIG === 40 | INFLUXDB_URL = os.getenv("INFLUXDB_URL") 41 | INFLUXDB_TOKEN = os.getenv("INFLUXDB_TOKEN") 42 | INFLUXDB_ORG = os.getenv("INFLUXDB_ORG") 43 | INFLUXDB_BUCKET = os.getenv("INFLUXDB_BUCKET") 44 | INFLUXDB_MEASUREMENT = "evaluation_metrics" 45 | BUCKET = "llama-4b-ws-4" 46 | DATASET_ID = "HuggingFaceFW/fineweb-edu" 47 | DATASET_SKIP_PROBABILITY = 0.9 48 | EVAL_DURATION_MINUTES = 15 49 | EVAL_TYPES = ["fineweb", "lm_eval"] # Add lm-eval harness tasks in the future 50 | R2 = boto3.session.Session().client( 51 | "s3", 52 | endpoint_url=f"https://{os.getenv('R2_ACCOUNT_ID')}.r2.cloudflarestorage.com", 53 | aws_access_key_id=os.getenv("R2_READ_ACCESS_KEY_ID"), 54 | aws_secret_access_key=os.getenv("R2_READ_SECRET_ACCESS_KEY"), 55 | region_name="auto", 56 | config=Config( 57 | retries={"max_attempts": 10, "mode": "adaptive"}, # or "standard" 58 | connect_timeout=30, 59 | read_timeout=120, 60 | max_pool_connections=50, 61 | ), 62 | ) 63 | __run__ = "1" 64 | 65 | # === LOGGER SETUP === 66 | logging.basicConfig( 67 | level=logging.INFO, # allow INFO through 68 | format="%(asctime)s %(levelname)s [%(name)s:%(lineno)d] %(message)s", 69 | force=True, # override any prior config 70 | ) 71 | logger = logging.getLogger(__name__) 72 | logger.setLevel(logging.INFO) 73 | h = logging.StreamHandler(sys.stdout) 74 | h.setLevel(logging.INFO) 75 | h.setFormatter(logging.Formatter("%(asctime)s %(levelname)s [%(name)s] %(message)s")) 76 | logger.addHandler(h) 77 | logger.propagate = False 78 | 79 | # === INFLUXDB SETUP === 80 | influx = InfluxDBClient(url=INFLUXDB_URL, token=INFLUXDB_TOKEN, org=INFLUXDB_ORG) 81 | write_api = influx.write_api() 82 | query_api = influx.query_api() 83 | delete_api = influx.delete_api() 84 | 85 | 86 | class Dummy: 87 | def __init__(self): 88 | # Set distributed variables 89 | self.world_size = int(os.getenv("WORLD_SIZE", 1)) 90 | self.local_rank = int(os.getenv("LOCAL_RANK", 0)) 91 | torch.cuda.set_device(self.local_rank) 92 | self.master = self.local_rank == 0 93 | 94 | if not dist.is_initialized(): 95 | if not dist.is_initialized(): 96 | dist.init_process_group( 97 | backend="nccl", 98 | init_method="tcp://127.0.0.1:29500", 99 | rank=self.local_rank, 100 | world_size=self.world_size, 101 | ) 102 | if not hasattr(self, "gloo_group"): 103 | self.gloo_group = dist.new_group( 104 | backend="gloo", 105 | ) 106 | self.logger = logger 107 | 108 | 109 | SELF = Dummy() 110 | 111 | 112 | def tag_exists(tag: str, task: str) -> bool: 113 | if task == "lm_eval": 114 | task = "mmlu_stem.acc" 115 | query = f""" 116 | from(bucket: "{INFLUXDB_BUCKET}") 117 | |> range(start: -365d) 118 | |> filter(fn: (r) => r._measurement == "{INFLUXDB_MEASUREMENT}" and r.tag == "{tag}" and r.task == "{task}") 119 | |> limit(n:1) 120 | """ 121 | result = query_api.query(org=INFLUXDB_ORG, query=query) 122 | return len(result) > 0 123 | 124 | 125 | def log_score( 126 | tag: str, 127 | task: str, 128 | score: float, 129 | output_dir: str = None, 130 | ): 131 | if task == "fineweb": 132 | point = ( 133 | Point(INFLUXDB_MEASUREMENT) 134 | .tag("tag", tag) 135 | .tag("task", task) 136 | .field("score", score) 137 | .time(datetime.datetime.now(datetime.timezone.utc), WritePrecision.NS) 138 | ) 139 | write_api.write(bucket=INFLUXDB_BUCKET, org=INFLUXDB_ORG, record=point) 140 | else: 141 | directory = f"{os.getcwd()}/{output_dir}/{BUCKET.replace('/', '__')}" 142 | json_file = f"{directory}/{os.listdir(directory)[0]}" 143 | new_output_dir = ( 144 | f"{os.path.dirname(os.path.abspath(__file__))}/{output_dir}.json" 145 | ) 146 | 147 | # Load JSON 148 | with open(json_file, "r") as f: 149 | data = json.load(f) 150 | 151 | timestamp = int(data.get("date", 0) * 1e9) # Influx expects ns 152 | for task, values in data["results"].items(): 153 | for metric, score in values.items(): 154 | if metric == "alias": 155 | continue # skip alias itself 156 | # else: 157 | # print(task+"."+metric.replace(",none", ""), score) 158 | try: 159 | score = float(score) 160 | point = ( 161 | Point(INFLUXDB_MEASUREMENT) # measurement 162 | .tag("tag", tag) 163 | .tag("task", f"{task}.{metric.replace(',none', '')}") 164 | .field("score", score) 165 | .time(timestamp, WritePrecision.NS) 166 | ) 167 | 168 | write_api.write( 169 | bucket=INFLUXDB_BUCKET, org=INFLUXDB_ORG, record=point 170 | ) 171 | 172 | except (TypeError, ValueError) as e: 173 | print(f"An error occurred: {e}") 174 | continue # skip non-numeric 175 | 176 | # ---- PRINT SUMMARY TABLE ---- 177 | results = data["results"] 178 | 179 | def get_score(task, metric): 180 | val = results.get(task, {}).get(metric, None) 181 | return f"{val*100:.1f}" if isinstance(val, float) else "N/A" # convert to % 182 | 183 | # Collect rows 184 | rows = [ 185 | [ 186 | "DSTRBTD-1.10B", 187 | "FineWebEdu", 188 | f'~{int(80*int(tag.split(".")[1])*100*512*1025/1e9)}B', # 32 peers per outer step, 65 outer steps, 100 inner steps, 512 samples per inner_step, 1024 Tokens Per Sample 189 | get_score("hellaswag", "acc_norm,none"), 190 | get_score("piqa", "acc_norm,none"), 191 | get_score("arc_easy", "acc,none"), 192 | ], 193 | [ 194 | "TEMPLAR-1.21B", 195 | "FineWebEdu", 196 | "100B-200B", 197 | 51.0, 198 | 71.4, 199 | 59.2, 200 | ], 201 | [ 202 | "DEM0-1.18B", 203 | "Dolmo", 204 | "100B", 205 | "48.0", 206 | "71.0", 207 | "55.0", 208 | ], 209 | [ 210 | "DILOCO-1.30B", 211 | "Dolmo", 212 | "26B", 213 | "45.0", 214 | "68.4", 215 | "39.0", 216 | ], 217 | ] 218 | 219 | print( 220 | tabulate( 221 | rows, 222 | headers=[ 223 | "Model", 224 | "Dataset", 225 | "Tokens", 226 | "HellaSwag acc_norm", 227 | "PIQA acc_norm", 228 | "ARC-E acc", 229 | ], 230 | tablefmt="fancy_grid", 231 | ) 232 | ) 233 | 234 | shutil.rmtree(f"{os.getcwd()}/{output_dir}") 235 | with open(new_output_dir, "w") as f: 236 | json.dump(data, f, indent=4) # indent=4 for pretty printing 237 | 238 | 239 | async def fetch_training_data(tokenizer): 240 | """Async function to fetch training data""" 241 | retry_limit = 10 242 | retry_delay = 60 243 | attempt = 0 244 | local_batch_size_train = 4 245 | if dist.get_rank() == 0: 246 | current_block = random.randint(6193881 * 2, 6193881 * 4) 247 | uid = random.randint(300, 1000000) 248 | tensor = torch.tensor([current_block, uid], dtype=torch.long, device="cuda") 249 | else: 250 | tensor = torch.zeros(2, dtype=torch.long, device="cuda") 251 | 252 | # Broadcast from rank 0 to all others 253 | dist.broadcast(tensor, src=0) 254 | current_block = int(tensor[0].item()) 255 | uid = int(tensor[1].item()) 256 | # print(SELF.local_rank, f"Fetched block {current_block} with uid {uid}") 257 | while attempt < retry_limit: 258 | try: 259 | pages = await DatasetLoader.next_pages( 260 | offset=current_block, 261 | n_pages=5, 262 | seed=uid, 263 | ) 264 | random.seed(uid) 265 | random.shuffle(pages) 266 | 267 | dataset = await DatasetLoader.create( 268 | batch_size=local_batch_size_train, 269 | sequence_length=1024, 270 | pages_info=pages, 271 | tokenizer=tokenizer, 272 | ) 273 | 274 | dataset_length = torch.tensor(len(dataset.buffer)) 275 | dist.all_reduce(dataset_length, op=dist.ReduceOp.MIN, group=SELF.gloo_group) 276 | dataset.buffer = dataset.buffer[:dataset_length] 277 | 278 | return dataset 279 | except Exception as e: 280 | print(f"Error fetching training data: {str(e)}") 281 | attempt += 1 282 | print(f"Failed to fetch data, retrying. Attempt {attempt}/{retry_limit}") 283 | if attempt < retry_limit: 284 | time.sleep(retry_delay * attempt) # Wait before the next retry 285 | else: 286 | print("Maximum retry limit reached. Unable to fetch data.") 287 | raise 288 | 289 | 290 | # === EVALUATORS === 291 | def evaluate_fineweb( 292 | device: str, 293 | tag: str, 294 | max_minutes: int = EVAL_DURATION_MINUTES, 295 | ) -> float: 296 | """ 297 | Stream and evaluate a fixed-time sample of fineweb-edu on average LM loss. 298 | 299 | Args: 300 | model: HuggingFace model 301 | tokenizer: Matching tokenizer 302 | device: cuda or cpu 303 | max_minutes: Time budget in minutes 304 | max_seq_length: Max input length for tokenization 305 | 306 | Returns: 307 | Average loss 308 | """ 309 | prefix = f"epoch-{tag.split('.')[1]}/" 310 | output_dir = os.path.join(os.getcwd(), BUCKET) 311 | _ = r2_download( 312 | SELF, 313 | r2=R2, 314 | bucket=BUCKET, 315 | key=f"{prefix}model.safetensors", 316 | donwload_on_all_ranks=False, 317 | destination=output_dir, 318 | ) 319 | _ = r2_download( 320 | SELF, 321 | r2=R2, 322 | bucket=BUCKET, 323 | key=f"{prefix}config.json", 324 | donwload_on_all_ranks=False, 325 | destination=output_dir, 326 | ) 327 | dist.barrier(device_ids=[SELF.local_rank]) 328 | tokenizer = AutoTokenizer.from_pretrained("dstrbtd/llama-1b") 329 | tokenizer.pad_token = tokenizer.eos_token 330 | model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16) 331 | 332 | mp_policy = MixedPrecisionPolicy( 333 | param_dtype=torch.bfloat16, # match your autocast compute dtype 334 | reduce_dtype=torch.bfloat16, 335 | output_dtype=torch.bfloat16, # required by FSDP2 policy 336 | ) 337 | 338 | # Build a 1D device mesh over all ranks 339 | mesh = DeviceMesh("cuda", list(range(dist.get_world_size()))) 340 | # Keep a plain HF module and enable FSDP2 on it 341 | fully_shard(model, mesh=mesh, mp_policy=mp_policy) 342 | 343 | model.eval() 344 | total_loss = 0.0 345 | n_batches = 0 346 | start_time = time.time() 347 | 348 | loop = asyncio.new_event_loop() 349 | 350 | while (time.time() - start_time) <= (max_minutes * 30): 351 | # Use streaming mode 352 | dataset = loop.run_until_complete(fetch_training_data(tokenizer)) 353 | 354 | with torch.no_grad(): 355 | for i, batch in enumerate(dataset): 356 | if random.random() > (1 - DATASET_SKIP_PROBABILITY): 357 | continue 358 | 359 | inputs, labels = batch 360 | inputs = inputs.to(device) 361 | labels = labels.to(device) 362 | 363 | if inputs is None or len(inputs) == 0: 364 | print(f"Empty batch at index {i}, skipping") 365 | continue 366 | 367 | with torch.autocast(device_type=device.type, dtype=torch.bfloat16): 368 | outputs = model(input_ids=inputs, labels=inputs) 369 | total_loss += outputs.loss.item() 370 | n_batches += 1 371 | if n_batches % 20 == 0 and SELF.master: 372 | SELF.logger.info(total_loss / n_batches) 373 | 374 | del model, tokenizer 375 | gc.collect() 376 | torch.cuda.empty_cache() 377 | 378 | # local aggregates 379 | local_loss = torch.tensor([total_loss], dtype=torch.float64, device="cuda") 380 | local_count = torch.tensor([n_batches], dtype=torch.int64, device="cuda") 381 | 382 | logger.info(f"{SELF.local_rank},{local_loss},{local_count}") 383 | # sum across all ranks (in-place; now identical on every rank) 384 | dist.all_reduce(local_loss, op=dist.ReduceOp.SUM) 385 | dist.all_reduce(local_count, op=dist.ReduceOp.SUM) 386 | logger.info(f"{SELF.local_rank},{local_loss},{local_count}") 387 | global_total_loss = float(local_loss.item()) 388 | global_n_batches = int(local_count.item()) 389 | 390 | score = ( 391 | (global_total_loss / global_n_batches) if global_n_batches > 0 else float("inf") 392 | ) 393 | logger.info(f"{SELF.local_rank},{score}") 394 | 395 | if SELF.master: 396 | log_score(tag, "fineweb", score) 397 | dist.barrier(device_ids=[SELF.local_rank]) 398 | return score 399 | 400 | 401 | def evaluate_with_lm_harness( 402 | device: str, 403 | tag: str, 404 | ) -> float: 405 | """ 406 | Evaluate model using lm-eval-harness (e.g. HellaSwag, ARC). 407 | """ 408 | output_dir = f"{REPO_ID.split('/')[1].replace('-', '_')}_{tag.replace('.','_')}_{datetime.datetime.now(ZoneInfo('Africa/Cairo')).strftime('%Y_%m_%dT%H_%M_%S')}" 409 | tasks = [ 410 | "hellaswag", 411 | "arc_challenge", 412 | "arc_easy", 413 | "openbookqa", 414 | "winogrande", 415 | "piqa", 416 | "mmlu", 417 | ] 418 | 419 | cmd_parts = [ 420 | "lm-eval", 421 | "--model hf", 422 | f"--model_args pretrained={REPO_ID},revision={tag}", 423 | f"--tasks {','.join(tasks)}", 424 | f"--device {device}", 425 | f"--batch_size 4", 426 | f"--output_path {output_dir}", 427 | ] 428 | 429 | # command = " ".join(cmd_parts) + " >/dev/null 2>&1" 430 | command = " ".join(cmd_parts) 431 | start_time = time.time() 432 | print(f"Running command: {command}") 433 | exit_code = os.system(command) 434 | score = 0 435 | # exit_code = 0 436 | # breakpoint() 437 | if exit_code == 0: 438 | log_score(tag, "lm_eval", score, output_dir) 439 | # breakpoint() 440 | benchmark_runtime = time.time() - start_time 441 | # breakpoint() 442 | return score 443 | 444 | 445 | # === EVALUATION REGISTRY === 446 | 447 | 448 | def get_evaluator(task: str) -> Callable: 449 | if task == "fineweb": 450 | return evaluate_fineweb 451 | elif task in ["hellaswag", "arc_easy", "arc_challenge", "lm_eval"]: 452 | return evaluate_with_lm_harness 453 | else: 454 | raise ValueError(f"Unsupported evaluation task: {task}") 455 | 456 | 457 | # === MAIN LOOP === 458 | 459 | 460 | def evaluate_all_tags_once(): 461 | result = R2.list_objects_v2(Bucket=BUCKET, Prefix="", Delimiter="/") 462 | 463 | # Extract subfolders like epoch-0/, epoch-1/, etc. 464 | folders = [ 465 | o.get("Prefix").rstrip("/").split("/")[-1] 466 | for o in result.get("CommonPrefixes", []) 467 | if o.get("Prefix").startswith("epoch-") 468 | ] 469 | 470 | # Sort by epoch number (epoch-0, epoch-1, ...) 471 | epochs = sorted(folders, key=lambda x: int(x.split("-")[1])) 472 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 473 | 474 | for epoch in epochs: 475 | epoch = epoch.split("-")[1] 476 | tag = f"{__run__}.{epoch}.0" 477 | try: 478 | if tag.split(".")[0] != __run__: 479 | continue 480 | else: 481 | print(f"\n=== [TAG] {tag} ===") 482 | 483 | for task in EVAL_TYPES: 484 | if tag_exists(tag, task): 485 | print(f"[✓] {task}: already evaluated") 486 | continue 487 | 488 | if task != "fineweb": 489 | continue 490 | 491 | print(f"[⏳] Evaluating {task}...") 492 | evaluator = get_evaluator(task) 493 | score = evaluator(device, tag) 494 | print(f"[✅] {task}: {score:.4f}") 495 | 496 | except Exception as e: 497 | print(f"[⚠️] Error evaluating tag {tag}: {e}") 498 | 499 | # finally: 500 | # cache_dir = HF_HUB_CACHE 501 | # cache_dir = Path(cache_dir).expanduser().resolve() 502 | # for cache in cache_dir.iterdir(): 503 | # if os.path.isdir(cache): 504 | # shutil.rmtree(str(cache)) 505 | 506 | 507 | # === Optional Continuous Mode === 508 | 509 | 510 | def monitor_repo(poll_interval_sec: int = 18000): 511 | print("[🔁] Starting continuous monitoring...") 512 | while True: 513 | evaluate_all_tags_once() 514 | print(f"[⏳] Sleeping for {poll_interval_sec}s...") 515 | time.sleep(poll_interval_sec) 516 | 517 | 518 | # === Entry === 519 | 520 | if __name__ == "__main__": 521 | monitor_repo() 522 | -------------------------------------------------------------------------------- /distributed_training/utils/compression.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2025 dstrbtd.ai 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | # Adapted from NousResearch (https://github.com/bloc97/DeMo) and Templar (https://github.com/tplr-ai/templar) 19 | 20 | import math 21 | from typing import Generic, Literal, TypeAlias, TypeVar, cast, overload 22 | 23 | import torch 24 | import torch.fft 25 | from einops import rearrange 26 | 27 | # ---------- type aliases ---------- # 28 | IdxT: TypeAlias = torch.Tensor # int16 indices 29 | ValT: TypeAlias = torch.Tensor # (possibly quantised) values 30 | ShapeT: TypeAlias = tuple[int, ...] # original tensor shape 31 | TotK: TypeAlias = int # size of the last dim 32 | Shape4D = tuple[int, int, int, int] # y, x, h, w 33 | 34 | # (shift, scale, offset, lookup table, original dtype) 35 | QuantParamsT: TypeAlias = tuple[torch.Tensor, float, int, torch.Tensor, torch.dtype] 36 | 37 | # Boolean flag that propagates the chosen quantisation mode 38 | Q = TypeVar("Q", Literal[True], Literal[False]) 39 | 40 | 41 | class TransformDCT: 42 | @torch.no_grad() 43 | def __init__(self, model, target_chunk, norm="ortho"): 44 | self.target_chunk = target_chunk 45 | 46 | self.shape_dict = dict() 47 | self.f_dict = dict() 48 | self.b_dict = dict() 49 | 50 | # Get all variants of model tensor sizes 51 | # Generate all possible valid DCT sizes for model tensors 52 | for _, p in model.items(): 53 | # if not p.requires_grad: 54 | # continue 55 | for s in p.shape: 56 | # Get the closest smallest divisor to the targeted DCT size 57 | sc = _get_smaller_split(s, self.target_chunk) 58 | self.shape_dict[s] = sc 59 | 60 | # Pregenerate DCT basis matrices 61 | if sc not in self.f_dict: 62 | I = torch.eye(sc) # noqa: E741 63 | self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device) 64 | self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device) 65 | 66 | @torch.no_grad() 67 | def einsum_2d(self, x, b, d=None): 68 | if d is None: 69 | return torch.einsum("...ij, jb -> ...ib", x, b) 70 | else: 71 | # Note: b-c axis output is transposed to chunk DCT in 2D 72 | return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d) 73 | 74 | @torch.no_grad() 75 | def einsum_2d_t(self, x, b, d=None): 76 | if d is None: 77 | return torch.einsum("...ij, jb -> ...ib", x, b) 78 | else: 79 | # Note: b-c axis output is transposed to chunk DCT in 2D 80 | return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d) 81 | 82 | @torch.no_grad() 83 | def encode(self, x: torch.Tensor, *, use_dct: bool = False) -> torch.Tensor: 84 | if len(x.shape) > 1: # 2D weights 85 | n1 = self.shape_dict[x.shape[0]] 86 | n2 = self.shape_dict[x.shape[1]] 87 | n1w = self.f_dict[n1].to(x.device) 88 | n2w = self.f_dict[n2].to(x.device) 89 | self.f_dict[n1] = n1w 90 | self.f_dict[n2] = n2w 91 | 92 | x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2) 93 | if use_dct: 94 | x = self.einsum_2d(x, n1w, n2w) 95 | 96 | else: # 1D weights 97 | n1 = self.shape_dict[x.shape[0]] 98 | n1w = self.f_dict[n1].to(x.device) 99 | self.f_dict[n1] = n1w 100 | 101 | x = rearrange(x, "(x w) -> x w", w=n1) 102 | if use_dct: 103 | x = self.einsum_2d(x, n1w) 104 | 105 | return x 106 | 107 | @torch.no_grad() 108 | def decode(self, x: torch.Tensor, *, use_dct: bool = False): 109 | if len(x.shape) > 2: # 2D weights 110 | if use_dct: 111 | n1 = x.shape[2] 112 | n2 = x.shape[3] 113 | n1w = self.b_dict[n1].to(x.device) 114 | n2w = self.b_dict[n2].to(x.device) 115 | self.b_dict[n1] = n1w 116 | self.b_dict[n2] = n2w 117 | 118 | x = self.einsum_2d_t(x, n1w, n2w) 119 | x = rearrange(x, "y h x w -> (y h) (x w)") 120 | 121 | else: # 1D weights 122 | if use_dct: 123 | n1 = x.shape[1] 124 | n1w = self.b_dict[n1].to(x.device) 125 | self.b_dict[n1] = n1w 126 | 127 | x = self.einsum_2d_t(x, n1w) 128 | x = rearrange(x, "x w -> (x w)") 129 | 130 | return x 131 | 132 | 133 | class CompressDCT(Generic[Q]): 134 | """DCT-style sparsifier/compressor with optional 8-bit quantisation.""" 135 | 136 | use_quantization: Q 137 | n_bins: int 138 | range_in_sigmas: int 139 | 140 | # ------------------------------------------------------------------ # 141 | # Constructor – two overloads so each instance "remembers" its mode 142 | # ------------------------------------------------------------------ # 143 | @overload 144 | def __init__( 145 | self: "CompressDCT[Literal[True]]", 146 | *, 147 | use_quantization: Literal[True] = True, 148 | quantization_bins: int = 256, 149 | quantization_range: int = 6, 150 | ) -> None: 151 | ... 152 | 153 | @overload 154 | def __init__( 155 | self: "CompressDCT[Literal[False]]", 156 | *, 157 | use_quantization: Literal[False] = False, 158 | quantization_bins: int = 256, 159 | quantization_range: int = 6, 160 | ) -> None: 161 | ... 162 | 163 | @torch.no_grad() 164 | def __init__( 165 | self, 166 | *, 167 | use_quantization: bool = False, 168 | quantization_bins: int = 256, 169 | quantization_range: int = 6, 170 | ) -> None: 171 | self.use_quantization = cast(Q, use_quantization) 172 | if self.use_quantization: 173 | self.n_bins = quantization_bins 174 | self.range_in_sigmas = ( 175 | quantization_range # Quantization range in standard deviations 176 | ) 177 | 178 | def _clamp_topk(self, x, topk): 179 | if topk > x.shape[-1]: 180 | topk = x.shape[-1] 181 | if topk < 1: 182 | topk = 1 183 | return int(topk) 184 | 185 | # ------------------------------------------------------------------ # 186 | # compress – returns a 5-tuple *or* a 4-tuple, depending on the mode 187 | # ------------------------------------------------------------------ # 188 | @overload 189 | def compress( 190 | self: "CompressDCT[Literal[True]]", 191 | x: torch.Tensor, 192 | topk: int, 193 | ) -> tuple[IdxT, ValT, ShapeT, TotK, QuantParamsT]: 194 | ... 195 | 196 | @overload 197 | def compress( 198 | self: "CompressDCT[Literal[False]]", 199 | x: torch.Tensor, 200 | topk: int, 201 | ) -> tuple[IdxT, ValT, ShapeT, TotK]: 202 | ... 203 | 204 | @torch.no_grad() 205 | def compress(self, x: torch.Tensor, topk: int, quantize: bool = False): # type: ignore[override] 206 | xshape = x.shape 207 | if len(x.shape) > 2: # 2D weights 208 | x = rearrange(x, "y x h w -> y x (h w)") 209 | 210 | # Limit topk to max size 211 | totalk = x.shape[-1] 212 | topk = self._clamp_topk(x, topk) 213 | 214 | idx_int64 = torch.topk( 215 | x.abs(), k=topk, dim=-1, largest=True, sorted=False 216 | ).indices 217 | val = torch.gather(x, dim=-1, index=idx_int64) 218 | 219 | # Cast idx to int16 for saving or transmission 220 | idx = idx_int64.to(torch.int16) 221 | 222 | # Apply 8-bit quantization if enabled 223 | if self.use_quantization and quantize: 224 | val, quant_params = self._quantize_values(val) 225 | return idx, val, xshape, totalk, quant_params 226 | 227 | return idx, val, xshape, totalk 228 | 229 | @torch.no_grad() 230 | def decompress( 231 | self, 232 | p: torch.Tensor, 233 | idx: torch.Tensor, 234 | val: torch.Tensor, 235 | xshape: ShapeT, 236 | totalk: int, 237 | quantize_params: QuantParamsT | None = None, 238 | ) -> torch.Tensor: 239 | if self.use_quantization and quantize_params is not None: 240 | val = self._dequantize_values(val, quantize_params) 241 | 242 | x = torch.zeros(xshape, device=p.device, dtype=p.dtype) 243 | 244 | if len(xshape) > 2: # 2D weights 245 | x = rearrange(x, "y x h w -> y x (h w)") 246 | 247 | # Cast back to int64 before using scatter/gather 248 | idx_int64 = idx.to(torch.int64) 249 | x.scatter_reduce_( 250 | dim=-1, index=idx_int64, src=val, reduce="mean", include_self=False 251 | ).reshape(xshape) 252 | 253 | if len(x.shape) > 2: # 2D weights 254 | xshape4 = cast(Shape4D, xshape) 255 | h_dim = xshape4[2] 256 | x = rearrange(x, "y x (h w) -> y x h w", h=h_dim) 257 | 258 | return x 259 | 260 | @torch.no_grad() 261 | def batch_decompress( 262 | self, 263 | p: torch.Tensor, 264 | idx: torch.Tensor | list[torch.Tensor], 265 | val: torch.Tensor | list[torch.Tensor], 266 | xshape: ShapeT, 267 | totalk: int, 268 | quantize_params: QuantParamsT | list[QuantParamsT] | None = None, 269 | *, 270 | block_norms: torch.Tensor | None = None, 271 | normalise: bool = False, 272 | clip_norm: bool = True, 273 | ) -> torch.Tensor: 274 | if not isinstance(idx, list): 275 | idx = [idx] 276 | if not isinstance(val, list): 277 | val = [val] 278 | 279 | if quantize_params is not None and not isinstance(quantize_params, list): 280 | quantize_params = [quantize_params] * len(val) # type: ignore[list-item] 281 | 282 | processed_vals: list[torch.Tensor] = [] 283 | dequant_vals = None 284 | norms = None 285 | clip_norm_val = None 286 | if self.use_quantization and quantize_params: 287 | dequant_vals = [ 288 | self._dequantize_values(v, quantize_params[i]) 289 | for i, v in enumerate(val) 290 | ] 291 | if clip_norm: 292 | # If caller already supplied per-block norms, trust them. 293 | if block_norms is not None: 294 | norms = block_norms.to(p.device) 295 | else: 296 | vals_for_norm = dequant_vals if dequant_vals is not None else val 297 | norms = torch.stack( 298 | [torch.norm(sparse_vals, p=2) for sparse_vals in vals_for_norm] 299 | ) 300 | median_norm = torch.median(norms) 301 | clip_norm_val = torch.clamp(median_norm, min=1, max=10) 302 | 303 | vals = dequant_vals if dequant_vals is not None else val 304 | for i, v in enumerate(vals): 305 | v = v.to(p.device) 306 | 307 | if normalise: 308 | eps = 1e-8 309 | if len(v.shape) == 3: # 2D weights 310 | l2_norm = torch.norm(v, p=2, dim=2, keepdim=True) 311 | v = v / (l2_norm + eps) 312 | elif len(v.shape) == 2: # 1D weights (biases) 313 | l2_norm = torch.norm(v, p=2, dim=1, keepdim=True) 314 | v = v / (l2_norm + eps) 315 | elif len(v.shape) == 1: # Single values 316 | l2_norm = torch.norm(v, p=2) 317 | if l2_norm > eps: 318 | v = v / l2_norm 319 | elif clip_norm and norms is not None and clip_norm_val is not None: 320 | current_norm = norms[i] 321 | clip_factor = torch.clamp(clip_norm_val / (current_norm + 1e-8), max=1) 322 | v = v * clip_factor 323 | processed_vals.append(v) 324 | 325 | # Concatenate everything 326 | idx_concat = torch.cat([i.to(p.device) for i in idx], dim=-1) 327 | val_concat = torch.cat(processed_vals, dim=-1).to(p.dtype) 328 | 329 | # Use decompress without quantization (since we already dequantized) 330 | return self.decompress( 331 | p, idx_concat, val_concat, xshape, totalk, quantize_params=None 332 | ) 333 | 334 | @torch.no_grad() 335 | def _quantize_values(self, val: torch.Tensor) -> tuple[torch.Tensor, QuantParamsT]: 336 | offset = self.n_bins // 2 # 128 for 8-bit 337 | shift = val.mean() 338 | centered = val - shift 339 | 340 | std = centered.norm() / math.sqrt(centered.numel() - 1) 341 | scale = self.range_in_sigmas * std / self.n_bins 342 | if scale == 0 or torch.isnan(scale) or torch.isinf(scale): 343 | scale = torch.tensor(1.0, dtype=centered.dtype, device=val.device) 344 | 345 | centered_fp32 = centered.to(torch.float32) 346 | qval = ( 347 | (centered_fp32 / scale + offset) 348 | .round() 349 | .clamp(0, self.n_bins - 1) 350 | .to(torch.uint8) 351 | ) 352 | 353 | device = qval.device 354 | sums = torch.zeros(self.n_bins, dtype=torch.float32, device=device) 355 | counts = torch.zeros(self.n_bins, dtype=torch.float32, device=device) 356 | 357 | sums.scatter_add_(0, qval.flatten().long(), centered_fp32.flatten()) 358 | counts.scatter_add_( 359 | 0, qval.flatten().long(), torch.ones_like(centered_fp32.flatten()) 360 | ) 361 | 362 | lookup = torch.where(counts > 0, sums / counts, torch.zeros_like(sums)) 363 | qparams: QuantParamsT = (shift, float(scale), offset, lookup, val.dtype) 364 | return qval, qparams 365 | 366 | @torch.no_grad() 367 | def _dequantize_values( 368 | self, val: torch.Tensor, qparams: QuantParamsT 369 | ) -> torch.Tensor: 370 | shift, _, _, lookup, orig_dtype = qparams 371 | lookup = lookup.to(val.device) if isinstance(lookup, torch.Tensor) else lookup 372 | deq = lookup[val.long()] + shift 373 | return deq.to(orig_dtype) 374 | 375 | 376 | # Code modified and sourced from https://github.com/zh217/torch-dct 377 | def _dct_fft_impl(v): 378 | return torch.view_as_real(torch.fft.fft(v, dim=1)) 379 | 380 | 381 | def _idct_irfft_impl(V): 382 | return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 383 | 384 | 385 | def _dct(x, norm=None): 386 | """ 387 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 388 | 389 | For the meaning of the parameter `norm`, see: 390 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 391 | 392 | :param x: the input signal 393 | :param norm: the normalization, None or 'ortho' 394 | :return: the DCT-II of the signal over the last dimension 395 | """ 396 | x_shape = x.shape 397 | N = x_shape[-1] 398 | x = x.contiguous().view(-1, N) 399 | 400 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 401 | 402 | Vc = _dct_fft_impl(v) 403 | 404 | k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N) 405 | W_r = torch.cos(k) 406 | W_i = torch.sin(k) 407 | 408 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 409 | 410 | if norm == "ortho": 411 | V[:, 0] /= math.sqrt(N) * 2 412 | V[:, 1:] /= math.sqrt(N / 2) * 2 413 | 414 | V = 2 * V.view(*x_shape) 415 | 416 | return V 417 | 418 | 419 | def _idct(X, norm=None): 420 | """ 421 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 422 | 423 | Our definition of idct is that idct(dct(x)) == x 424 | 425 | For the meaning of the parameter `norm`, see: 426 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 427 | 428 | :param X: the input signal 429 | :param norm: the normalization, None or 'ortho' 430 | :return: the inverse DCT-II of the signal over the last dimension 431 | """ 432 | 433 | x_shape = X.shape 434 | N = x_shape[-1] 435 | 436 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 437 | 438 | if norm == "ortho": 439 | X_v[:, 0] *= math.sqrt(N) * 2 440 | X_v[:, 1:] *= math.sqrt(N / 2) * 2 441 | 442 | k = ( 443 | torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] 444 | * math.pi 445 | / (2 * N) 446 | ) 447 | W_r = torch.cos(k) 448 | W_i = torch.sin(k) 449 | 450 | V_t_r = X_v 451 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 452 | 453 | V_r = V_t_r * W_r - V_t_i * W_i 454 | V_i = V_t_r * W_i + V_t_i * W_r 455 | 456 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 457 | 458 | v = _idct_irfft_impl(V) 459 | x = v.new_zeros(v.shape) 460 | x[:, ::2] += v[:, : N - (N // 2)] 461 | x[:, 1::2] += v.flip([1])[:, : N // 2] 462 | 463 | return x.view(*x_shape) 464 | 465 | 466 | def _get_prime_divisors(n): 467 | divisors = [] 468 | while n % 2 == 0: 469 | divisors.append(2) 470 | n //= 2 471 | while n % 3 == 0: 472 | divisors.append(3) 473 | n //= 3 474 | i = 5 475 | while i * i <= n: 476 | for k in (i, i + 2): 477 | while n % k == 0: 478 | divisors.append(k) 479 | n //= k 480 | i += 6 481 | if n > 1: 482 | divisors.append(n) 483 | return divisors 484 | 485 | 486 | def _get_divisors(n): 487 | divisors = [] 488 | if n == 1: 489 | divisors.append(1) 490 | elif n > 1: 491 | prime_factors = _get_prime_divisors(n) 492 | divisors = [1] 493 | last_prime = 0 494 | factor = 0 495 | slice_len = 0 496 | # Find all the products that are divisors of n 497 | for prime in prime_factors: 498 | if last_prime != prime: 499 | slice_len = len(divisors) 500 | factor = prime 501 | else: 502 | factor *= prime 503 | for i in range(slice_len): 504 | divisors.append(divisors[i] * factor) 505 | last_prime = prime 506 | divisors.sort() 507 | return divisors 508 | 509 | 510 | def _get_smaller_split(n, close_to): 511 | all_divisors = _get_divisors(n) 512 | for ix, val in enumerate(all_divisors): 513 | if val == close_to: 514 | return val 515 | if val > close_to: 516 | if ix == 0: 517 | return val 518 | return all_divisors[ix - 1] 519 | return n 520 | -------------------------------------------------------------------------------- /distributed_training/averaging/avg_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import random 3 | import time 4 | import math 5 | import torch 6 | import distributed_training 7 | import bittensor as bt 8 | import numpy as np 9 | 10 | from typing import Any, Dict, List, Tuple 11 | from distributed_training.averaging.exceptions import AllReduceError, ModelStateError 12 | from distributed_training.protocol import AllReduce 13 | from distributed_training.data.dataset import DatasetLoader 14 | from distributed_training.utils.dendrite import ( 15 | async_dendrite_forward, 16 | ) 17 | from distributed_training.utils.r2 import r2_download 18 | from transformers import AutoModelForCausalLM 19 | 20 | 21 | class AveragingHandler: 22 | """Handles averaging round and outer step for both validators and miners.""" 23 | 24 | def __init__( 25 | self, 26 | model, 27 | optimizer, 28 | outer_optimizer, 29 | grad_averager, 30 | retry_limit, 31 | retry_delay, 32 | uid, 33 | local_batch_size_train, 34 | local_batch_size_train_effective, 35 | tokenizer, 36 | device, 37 | logger, 38 | parameters_list=None, 39 | ): 40 | self.model = model 41 | self.inner_optimizer = optimizer 42 | self.outer_optimizer = outer_optimizer 43 | self.grad_averager = grad_averager 44 | self.test_loss_loop = asyncio.new_event_loop() 45 | self.retry_limit = retry_limit 46 | self.retry_delay = retry_delay 47 | self.uid = uid 48 | self.local_batch_size_train = local_batch_size_train 49 | self.local_batch_size_train_effective = local_batch_size_train_effective 50 | self.tokenizer = tokenizer 51 | self.device = device 52 | self.number_of_local_steps = ( 53 | self.local_batch_size_train_effective // self.local_batch_size_train 54 | ) 55 | self.logger = logger 56 | self.parameters_list = parameters_list 57 | self.master = True 58 | 59 | def _get_weights_sample(self) -> List[float]: 60 | """Get a sample of model weights for validation.""" 61 | p = list(self.model.parameters())[-2] 62 | if hasattr(p, "to_local"): # sharded/DTensor 63 | return p.to_local()[-10:].tolist() 64 | return p.detach().flatten()[-10:].cpu().tolist() 65 | 66 | async def _validate_weight_update( 67 | self, initial_weights: List[float], block: int 68 | ) -> bool: 69 | """Validate model weight updates.""" 70 | final_weights = self._get_weights_sample() 71 | self.logger.info(f"Final Weights Sample: {final_weights}") 72 | 73 | if final_weights == initial_weights: 74 | raise ModelStateError("Weights unchanged after update") 75 | 76 | if sum(np.isnan(final_weights)) > 1: 77 | raise ModelStateError("NaN values detected in weights after update") 78 | 79 | # TODO Re-introduce 80 | # if await self._test_model_loss(block): 81 | # raise ModelStateError("NaN values detected in loss generated by new model") 82 | 83 | async def fetch_training_data(self, block): 84 | """Async function to fetch training data""" 85 | attempt = 0 86 | while attempt < self.retry_limit: 87 | try: 88 | pages = await DatasetLoader.next_pages( 89 | offset=block, 90 | n_pages=5, 91 | seed=self.uid, 92 | ) 93 | random.seed(self.uid) 94 | random.shuffle(pages) 95 | 96 | dataset = await DatasetLoader.create( 97 | batch_size=4, 98 | sequence_length=1024, 99 | pages_info=pages, 100 | tokenizer=self.tokenizer, 101 | ) 102 | 103 | return dataset 104 | except Exception as e: 105 | self.logger.error(f"Error fetching training data: {str(e)}") 106 | attempt += 1 107 | self.logger.warning( 108 | f"Failed to fetch data, retrying. Attempt {attempt}/{self.retry_limit}" 109 | ) 110 | if attempt < self.retry_limit: 111 | time.sleep(self.retry_delay * attempt) # Wait before the next retry 112 | else: 113 | self.logger.error( 114 | "Maximum retry limit reached. Unable to fetch data." 115 | ) 116 | raise 117 | 118 | async def _test_model_loss(self, block) -> bool: 119 | dataset = await self.fetch_training_data(block) 120 | for inputs, labels in dataset: 121 | inputs = inputs.to(self.device) 122 | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): 123 | outputs = self.model(input_ids=inputs, labels=inputs) 124 | loss = outputs.loss / self.number_of_local_steps 125 | 126 | return math.isnan(loss.item()) 127 | 128 | async def run_validator_allreduce( 129 | self, 130 | timeout: int, 131 | wallet, 132 | metagraph, 133 | peerids_to_uids, 134 | miner_uids, 135 | master, 136 | block, 137 | bandwidth=None, 138 | min_group_size: int = None, 139 | request_timeout: float = None, 140 | allreduce_timeout: float = None, 141 | next_chunk_timeout: float = None, 142 | min_matchmaking_time: float = None, 143 | ) -> Tuple[bool, Dict[str, Any]]: 144 | """ 145 | Process allreduce specifically for validator. 146 | 147 | Returns: 148 | Tuple[bool, Dict[str, Any]]: (success, results) 149 | - success: True if allreduce completed successfully, False otherwise 150 | - results: Dictionary containing peers and bandwidth info if successful, empty dict if failed 151 | """ 152 | query_tasks = [] 153 | results = {} 154 | all_reduce_success_status = True 155 | initial_weights = None 156 | 157 | try: 158 | # Used for load balancing and scoring 159 | if bandwidth is not None: 160 | self.grad_averager.bandwidth = bandwidth["download"] 161 | 162 | self.logger.info(":wait: Starting Pseudo Gradient Averaging..") 163 | # Start gradient averaging without waiting 164 | start_ar_time = time.perf_counter() 165 | gradient_averaging_step = self.grad_averager.step( 166 | timeout=timeout, 167 | wait=False, 168 | gather=0, 169 | peerids_to_uids=peerids_to_uids, 170 | ) 171 | 172 | if master: 173 | # Send AllReduce query to pause miner training and perform global sync 174 | self.logger.info( 175 | ":wait: AllReduce Query Sent Out. Waiting for AllReduce to finish.." 176 | ) 177 | await async_dendrite_forward( 178 | wallet=wallet, 179 | axons=[metagraph.axons[uid] for uid in miner_uids], 180 | synapse=AllReduce( 181 | completion=False, 182 | min_group_size=min_group_size, 183 | request_timeout=request_timeout, 184 | allreduce_timeout=allreduce_timeout, 185 | next_chunk_timeout=next_chunk_timeout, 186 | min_matchmaking_time=min_matchmaking_time, 187 | timeout=timeout, 188 | ), 189 | connection_limit=len(miner_uids), 190 | timeout=timeout, 191 | ) 192 | self.logger.info("AllReduce Query Responses Received..") 193 | 194 | start_time = time.perf_counter() 195 | 196 | while (gradient_averaging_step.done() is False) and ( 197 | (time.perf_counter() - start_time) <= (timeout) 198 | ): 199 | await asyncio.sleep(1) 200 | 201 | if gradient_averaging_step.done(): 202 | end_ar_time = time.perf_counter() 203 | 204 | self.logger.info( 205 | ":white_heavy_check_mark: Finished Averaging Pseudo Gradients" 206 | ) 207 | self.grad_averager.notify_used_averaged_gradients() 208 | 209 | ( 210 | gathered, 211 | failed_peers, 212 | participating_peers, 213 | modes, 214 | bandwidths, 215 | ) = gradient_averaging_step.result() 216 | 217 | initial_weights = self._get_weights_sample() 218 | self.logger.info(f"Initial Weights Sample: {initial_weights}") 219 | 220 | all_reduce_success_status = True 221 | results = { 222 | "gathered": gathered, 223 | "failed_peers": failed_peers, 224 | "participating_peers": participating_peers, 225 | "modes": modes, 226 | "bandwidths": bandwidths, 227 | "duration": end_ar_time - start_ar_time, 228 | } 229 | else: 230 | all_reduce_success_status = False 231 | 232 | except Exception as e: 233 | self.logger.error(f"Unexpected error during Averaging Process: {str(e)}") 234 | all_reduce_success_status = False 235 | 236 | finally: 237 | if gradient_averaging_step: 238 | gradient_averaging_step.cancel() 239 | self.logger.info("Gradient Step Cleaned Up") 240 | if all_reduce_success_status: 241 | self.logger.success("Averaging Round Finished Succesfully") 242 | return all_reduce_success_status, results, initial_weights 243 | 244 | def calculate_allreduce_scores( 245 | self, 246 | participating_peers: list, 247 | failed_peers: list, 248 | alive_uids: list, 249 | peerids_to_uids: dict, 250 | event: dict, 251 | metagraph, 252 | modes: list = None, 253 | bandwidths: list = None, 254 | ) -> dict: 255 | """ 256 | Calculate scores based on AllReduce participation status, modes, and bandwidths. 257 | 258 | Args: 259 | participating_peers (list): List of peers that participated in AllReduce 260 | failed_peers (list): List of peers that failed during AllReduce 261 | peerids_to_uids (dict): Mapping of peer IDs to UIDs 262 | modes (list, optional): List of modes for each participating peer 263 | bandwidths (list, optional): List of bandwidths for each participating peer 264 | 265 | Returns: 266 | dict: Scores for each UID based on participation and optional mode/bandwidth 267 | """ 268 | # Convert peer IDs to UIDs 269 | participating_uids = [] 270 | uid_modes = {} 271 | uid_bandwidths = {} 272 | 273 | for idx, peer in enumerate(participating_peers): 274 | uid = peerids_to_uids.get(str(peer), "'''") 275 | participating_uids.append(uid) 276 | if modes is not None: 277 | uid_modes[uid] = modes[idx] 278 | if bandwidths is not None: 279 | uid_bandwidths[uid] = bandwidths[idx] 280 | 281 | failed_uids = [ 282 | peerids_to_uids.get(str(failed_peer), "'''") for failed_peer in failed_peers 283 | ] 284 | 285 | # Calculate participation metrics 286 | successful_peers_count = len(participating_peers) - len(failed_peers) 287 | 288 | # Update event metrics 289 | event.update( 290 | { 291 | "failed_peers_count": len(failed_peers), 292 | "participating_peers_count": len(participating_peers), 293 | "successful_peers_count": successful_peers_count, 294 | } 295 | ) 296 | 297 | # Find max bandwidth for normalization if bandwidths are provided 298 | if ( 299 | bandwidths 300 | and [bandwidth for bandwidth in bandwidths if bandwidth is not None] != [] 301 | and max([bandwidth for bandwidth in bandwidths if bandwidth is not None]) 302 | != [] 303 | ): 304 | max_bandwidth = max( 305 | [bandwidth for bandwidth in bandwidths if bandwidth is not None] 306 | ) 307 | 308 | # Initialize scores dictionary 309 | scores = {} 310 | status_dict = {} 311 | for uid in range(metagraph.n): # Assuming 256 UIDs in metagraph 312 | str_uid = str(uid) 313 | if uid in participating_uids and uid not in failed_uids: 314 | # Base score for successful participation 315 | base_score = 1.0 316 | final_score = base_score 317 | status = "SUCCESS" 318 | 319 | # Apply mode penalty if modes are provided 320 | if modes is not None and uid in uid_modes: 321 | if uid_modes[uid] == "AveragingMode.CLIENT": 322 | final_score = 0.0 323 | status = "WRONG_MODE" 324 | 325 | # Apply bandwidth bonus if bandwidths are provided 326 | if ( 327 | bandwidths is not None 328 | and uid in uid_bandwidths 329 | and status != "WRONG_MODE" 330 | ): 331 | if uid_bandwidths[uid] is None: 332 | final_score = 0.0 333 | else: 334 | bandwidth_bonus = 0.5 * (uid_bandwidths[uid] / max_bandwidth) 335 | final_score += bandwidth_bonus 336 | self.logger.info( 337 | f"UID {uid} score breakdown - Base: {base_score:.2f}, Bandwidth bonus: {bandwidth_bonus:.2f}" 338 | ) 339 | 340 | scores[str_uid] = 1.0 341 | status_dict[str_uid] = status 342 | 343 | elif uid in failed_uids: 344 | scores[str_uid] = 0.0 345 | status_dict[str_uid] = "FAIL" 346 | elif uid in alive_uids: 347 | # If UID is chosen but not participating, assign a score of 0 348 | scores[str_uid] = 1.0 349 | status_dict[str_uid] = "NON_PARTICIPATING" 350 | else: 351 | scores[str_uid] = 0.0 352 | status_dict[str_uid] = "NOT_ALIVE" 353 | 354 | # Create rewards tensor 355 | rewards = torch.tensor([reward for reward in scores.values()]) 356 | 357 | # Log participation and scoring details 358 | self.logger.info(f"Failed UIDs: {failed_uids}") 359 | self.logger.info(f"Participating UIDs: {participating_uids}") 360 | if modes is not None: 361 | self.logger.info(f"Modes by UID: {uid_modes}") 362 | if bandwidths is not None: 363 | self.logger.info(f"Bandwidths by UID: {uid_bandwidths}") 364 | self.logger.info(f"AllReduce UID Scores: {scores}") 365 | self.logger.info(f"AllReduce UID Rewards: {rewards}") 366 | 367 | return ( 368 | rewards, 369 | status_dict, 370 | event, 371 | successful_peers_count, 372 | ) 373 | 374 | async def run_miner_allreduce( 375 | self, 376 | synapse, 377 | local_progress, 378 | start_time, 379 | block, 380 | bandwidth=None, 381 | ) -> distributed_training.protocol.AllReduce: 382 | """Process allreduce specifically for miner.""" 383 | initial_weights = None 384 | try: 385 | # # Clip gradients 386 | # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) 387 | 388 | # Used for load balancing and scoring 389 | if bandwidth is not None: 390 | self.grad_averager.bandwidth = bandwidth["download"] 391 | self.logger.info(":wait: Starting Pseudo Gradient Averaging..") 392 | gradient_averaging_step = self.grad_averager.step( 393 | timeout=(synapse.timeout - 20), 394 | wait=False, 395 | gather=local_progress.samples_accumulated, 396 | ) 397 | 398 | while (gradient_averaging_step.done() is False) and ( 399 | (time.perf_counter() - start_time) <= (synapse.timeout - 20) 400 | ): 401 | await asyncio.sleep(1) 402 | 403 | if gradient_averaging_step.done(): 404 | self.logger.info( 405 | ":white_heavy_check_mark: Finished Averaging Pseudo Gradients" 406 | ) 407 | self.grad_averager.notify_used_averaged_gradients() 408 | self.logger.info("B: returned from notify_used_averaged_gradients") 409 | 410 | initial_weights = self._get_weights_sample() 411 | self.logger.info(f"Initial Weights Sample: {initial_weights}") 412 | 413 | synapse.completion = True 414 | else: 415 | synapse.completion = False 416 | 417 | except Exception as e: 418 | self.logger.error(f"Unexpected Error During Averaging Process: {str(e)}") 419 | synapse.completion = False 420 | raise AllReduceError( 421 | f"Unexpected Error During Averaging Process: {str(e)}" 422 | ) from e 423 | 424 | finally: 425 | if gradient_averaging_step: 426 | gradient_averaging_step.cancel() 427 | self.logger.info("Gradient Step Cleaned Up") 428 | if synapse.completion: 429 | self.logger.success("Averaging Round Finished Succesfully") 430 | return synapse, initial_weights 431 | 432 | # TODO Test if this is necissary and if it is make this FSDP compliant 433 | def update_main_param_after_outer_step(self): 434 | """Update the main parameters with the inner optimizer step""" 435 | opt_parameters = [ 436 | param 437 | for group in self.inner_optimizer.param_groups 438 | for param in group["params"] 439 | ] 440 | for main_param, opt_param in zip( 441 | tuple(self.state_averager.main_parameters[i] for i in self.parameters_list), 442 | opt_parameters, 443 | ): 444 | main_param.data.copy_(opt_param.data, non_blocking=True) 445 | 446 | def reset_main_parameters(self, r2, model_name, prefix, use_cache, output_dir): 447 | """Reset the optimizer parameteres to the parameters at the start of the epoch""" 448 | try: 449 | if use_cache: 450 | _ = r2_download( 451 | self, 452 | r2=r2, 453 | bucket=model_name, 454 | key=f"{prefix}model.safetensors", 455 | donwload_on_all_ranks=False, 456 | run_on_all_ranks=False, 457 | destination=output_dir, 458 | ) 459 | _ = r2_download( 460 | self, 461 | r2=r2, 462 | bucket=model_name, 463 | key=f"{prefix}/config.json", 464 | donwload_on_all_ranks=False, 465 | run_on_all_ranks=False, 466 | destination=output_dir, 467 | ) 468 | main_parameters = AutoModelForCausalLM.from_pretrained( 469 | output_dir, # directory containing model files 470 | trust_remote_code=False, 471 | ) 472 | opt_parameters = [ 473 | param 474 | for group in self.outer_optimizer.param_groups 475 | for param in group["params"] 476 | ] 477 | for main_param, opt_param in zip( 478 | tuple(main_parameters.parameters()), opt_parameters 479 | ): 480 | opt_param.data.copy_(main_param.data, non_blocking=True) 481 | except Exception as e: 482 | self.logger.info(f"Failed To Reset Optimizer Parameters With Error: {e}") 483 | --------------------------------------------------------------------------------