├── .gitattributes ├── .gitignore ├── README.md ├── hivetrain ├── __init__.py ├── averaging_logic.py ├── btt_connector.py ├── chain_manager.py ├── config │ ├── __init__.py │ ├── base_subnet_config.py │ ├── config.py │ ├── hivetrain_config.py │ └── mlflow_config.py ├── docs │ ├── test.py │ ├── training_loop_architecture │ └── training_loop_architecture.pdf ├── hf_manager.py ├── new_training_manager.py ├── training_manager.py ├── utils │ ├── auto_update.py │ ├── bootstrap_server.py │ ├── bootstrap_stress.py │ ├── dummy_miner.py │ ├── generate_wallets.py │ ├── mlflow_utils.py │ └── ports.txt └── validation_logic.py ├── neurons ├── averager.py ├── miner.py └── validator.py ├── requirements.txt ├── run_miner.sh ├── run_validator.sh ├── setup.py └── template └── __init__.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.psd filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | ve/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | testing/ 164 | 165 | data 166 | wallets 167 | lightning_logs 168 | .scale_batch_size* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > There is no passion to be found playing small - in settling for a life that is less than the one you are capable of living. Nelson Mandela. 2 | 3 | # Distributed Training Framework 4 | 5 | ## Introduction 6 | 7 | This project introduces a cutting-edge approach to distributed deep learning, utilizing the Bittensor network. Our method incentivizes participants by rewarding the generation of optimal weights that contribute significantly to minimizing the overall loss of the base model. 8 | 9 | To streamline the process and reduce communication overhead between miners, we integrate Hugging Face as a central hub. This serves as an intermediary, facilitating efficient miner-validator communications without the complexities of direct exchanges. 10 | 11 | Key Components: 12 | * Miners: Miners are responsible for training a model. Each miner trains a weight-delta. A weight-delta is the difference between the weights of the trained model and the base model. This delta is then uploaded to Hugging Face, from where it can be accessed by validators. 13 | * Validators: Validators asses the loss reduction by each miner on a randomized test set. They download the weight deltas from Hugging Face and evaluate them based on their impact on the model’s performance, focusing on metrics such as loss reduction and accuracy.Better performing miners that improve on the base model are assigned better scores. 14 | * Averager: We also introduce an averager node, a centralized node run by the subnet owner. The averager is responsible for creating the averaged model that becomes the base model for miners and validators, this is repeated every averaging interval. The averager performs a weighted average of the parameters resulting in an averaged model. Currently the weights of the weighted average are also parameterized allowing the process to be optimized to find the best averaged model. 15 | 16 | ## Clone the Repo 17 | 18 | ``` 19 | git clone https://github.com/bit-current/DistributedTraining 20 | ``` 21 | 22 | ## Move into the Repo 23 | 24 | ``` 25 | cd DistributedTraining 26 | ``` 27 | 28 | ## Remove Previous Hivetrain installation 29 | 30 | ``` 31 | pip uninstall hivetrain 32 | ``` 33 | 34 | ## Install Repo + Requirements 35 | 36 | ``` 37 | pip install -e . 38 | ``` 39 | 40 | ## Hugging Face 41 | Continue setting up by following these step: 42 | 43 | ### 1. Create a Hugging Face Account 44 | If you don't already have a Hugging Face account, you'll need to create one: 45 | 46 | Visit [Hugging Face](https://huggingface.co/) to sign up 47 | ### 2. Create a Hugging Face Model Repository (For miners only) 48 | Once you have your Hugging Face account, you need to create a model repository: 49 | * Navigate to your profile by clicking on your username in the top right corner. 50 | * Click on "New Model" (you may find this button under the "Models" section if you have existing models). 51 | * Fill in the repository name, description, and set the visibility to public. 52 | * Click on "Create Model" to establish your new model repository. 53 | ### 3. Generate a Token for the Repository (For miners and validators) 54 | To allow programmatic communication with huggingface, you will need to generate an authentication token: 55 | 56 | * From your Hugging Face account, go to "Settings" by clicking on your profile icon. 57 | * Select the "Access Tokens" tab from the sidebar. 58 | * Click on "New Token". 59 | * Name your token and select the "write" access to be able to upload changes. 60 | * Click on "Create Token". 61 | 62 | ### 4. Create a New .env File to Store Your Hugging Face Token 63 | Open your .env file in DistributedTranining directory and store your new token there: 64 | ``` 65 | HF_TOKEN="your_huggingface_token_here" 66 | ``` 67 | or in terminal enter: 68 | 69 | ``` 70 | echo "HF_TOKEN=your_huggingface_token_here" >> .env 71 | ``` 72 | ### 4. Install git-lfs to to handle upload of large files 73 | 74 | ``` 75 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 76 | sudo apt install git-lfs 77 | ``` 78 | ## Load Wallets and Register to Subnet 79 | 80 | ``` 81 | btcli regen_coldkey --mnemonic your super secret mnemonic 82 | btcli regen_hotkey --mnemonic your super secret mnemonic 83 | btcli s register --netuid 25 84 | ``` 85 | 86 | ## New arguments 87 | ```storage.averaged_model_repo_id```: The repo that is used by the averager. Currently this is ```Hivetrain/averaging_run_1```. Changes with each training run, review changes on the discord channel. 88 | ```storage.my_repo_id```: Repo id for the repo that is used by a **miner only** to upload the miner's trained model weight delta. 89 | 90 | ## Miner Run Command 91 | 92 | ``` 93 | python neurons/miner.py --netuid 25 --wallet.name wallet_name --wallet.hotkey hotkey_name --storage.my_repo_id your_hf_username/your_repo --storage.averaged_model_repo_id Hivetrain/averaging_run_1 94 | ``` 95 | 96 | ## Validator 97 | 98 | ### Validators need to have at least 1000 TAO to set weights on the main net and 10 TAO on the test net 99 | 100 | ``` 101 | python neurons/validator.py --netuid 25 --wallet.name wallet_name --wallet.hotkey hotkey_name --storage.averaged_model_repo_id Hivetrain/averaging_run_1 102 | ``` 103 | 104 | ## Bug Reporting and Contributions 105 | 106 | - **Reporting Issues:** Use the GitHub Issues tab to report bugs, providing detailed steps to reproduce along with relevant logs or error messages. 107 | - **Contributing:** Contributions are welcome! Fork the repo, make changes, and submit a pull request. Break it in as many ways as possible to help make the system resilient. 108 | 109 | ## Communication and Support 110 | 111 | - Join our [Project Discord](#) and the [Bittensor Discord](#) to discuss the project, seek help, and collaborate with the community. 112 | 113 | ## License 114 | 115 | Licensed under the MIT License - see the LICENSE file for details. 116 | 117 | ## Acknowledgments 118 | 119 | - Thanks to the PyTorch team for their deep learning library. 120 | - Gratitude to Bittensor for enabling decentralized computing and finance with TAO rewards. 121 | -------------------------------------------------------------------------------- /hivetrain/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.4" 2 | version_split = __version__.split(".") 3 | __spec_version__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) 4 | 5 | from . import btt_connector 6 | from . import chain_manager 7 | from . import validation_logic 8 | from . import averaging_logic 9 | from . import hf_manager 10 | from . import training_manager 11 | -------------------------------------------------------------------------------- /hivetrain/averaging_logic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import random 5 | import math 6 | import mlflow 7 | import mlflow.pytorch 8 | from huggingface_hub import hf_hub_download 9 | from hivetrain.config.mlflow_config import MLFLOW_UI_URL, CURRENT_MODEL_NAME 10 | from hivetrain.utils.mlflow_utils import VERSION, initialize_mlflow, log_model_metrics 11 | from hivetrain.config.mlflow_config import ( 12 | MLFLOW_UI_URL, 13 | CURRENT_MODEL_NAME, 14 | MLFLOW_ACTIVE, 15 | ) 16 | 17 | from copy import deepcopy 18 | from hivetrain.btt_connector import BittensorNetwork, sync 19 | from bittensor import logging 20 | from transformers import TrainingArguments, Trainer, AdamW 21 | from torch import nn, optim 22 | import torch.nn.functional as F 23 | import numpy as np 24 | from tqdm import tqdm 25 | 26 | 27 | class Averager: 28 | def __init__( 29 | self, 30 | model, 31 | local_dir, 32 | repo_id, 33 | hf_manager, 34 | chain_manager, 35 | bittensor_network, 36 | hf_token=os.environ.get("HF_TOKEN"), 37 | ): 38 | self.model = model 39 | self.local_dir = local_dir 40 | self.repo_id = repo_id 41 | self.hf_token = hf_token 42 | self.scored_gradients = None 43 | self.last_sync_time = 0 44 | self.bittensor_network = bittensor_network 45 | self.chain_manager = chain_manager 46 | self.hf_manager = hf_manager 47 | 48 | 49 | # initialize mlflow 50 | if MLFLOW_ACTIVE: 51 | initialize_mlflow( 52 | role="averager", 53 | device=self.device, 54 | version=VERSION, 55 | mlflow_ui_url=MLFLOW_UI_URL, 56 | current_model_name=CURRENT_MODEL_NAME) 57 | 58 | 59 | 60 | def receive_gradients( 61 | self, repo_id="your_username/your_repo_name", gradient_file_name="gradients.pt" 62 | ): 63 | try: 64 | # Download the gradients file from Hugging Face Hub 65 | gradient_file_path = hf_hub_download( 66 | repo_id=repo_id, filename=gradient_file_name, use_auth_token=True 67 | ) 68 | 69 | # Load the gradients directly using torch.load 70 | aggregated_gradients = torch.load(gradient_file_path) 71 | 72 | if self.have_nans(aggregated_gradients): 73 | return None 74 | 75 | return aggregated_gradients 76 | except Exception as e: 77 | logging.debug(f"Error receiving gradients from Hugging Face: {e}") 78 | return None 79 | 80 | def receive_and_score_gradients(self): 81 | # Get validators uids 82 | self.bittensor_network.sync(lite=False) # scope issue FIXME? 83 | 84 | validator_uids = self.bittensor_network.get_validator_uids( 85 | vpermit_tao_limit=1024 86 | ) 87 | 88 | if isinstance(self.bittensor_network.metagraph.W, list): 89 | self.validator_combined_weights = [] 90 | weights_list = [] 91 | for validator_uid in validator_uids: 92 | weights = self.bittensor_network.metagraph.W[validator_uid] 93 | if sum(weights) == 0: 94 | continue 95 | else: 96 | weights_list.append(weights) 97 | self.validator_combined_weights = torch.mean( 98 | torch.tensor(weights_list), axis=0 99 | ) 100 | else: 101 | self.validator_combined_weights = torch.mean( 102 | self.bittensor_network.metagraph.W[validator_uids, :], axis=0 103 | ) 104 | # n = len(self.bittensor_network.metagraph.hotkeys) #FIXME I am only for testing NOPROD 105 | # self.validator_combined_weights = torch.full((n,1), 1/n, dtype=torch.float32) #FIXME I am only for testing NOPROD 106 | # Get average of validator weights weighted by their stake? 107 | self.miner_gradients = [] 108 | self.miner_weights = [] 109 | self.miner_hotkeys = [] 110 | for uid, hotkey in enumerate(self.bittensor_network.metagraph.hotkeys): 111 | try: 112 | repo_id = self.chain_manager.retrieve_hf_repo(hotkey) 113 | gradient = self.receive_gradients(repo_id=repo_id) 114 | self.miner_gradients.append(gradient) 115 | except Exception as e: 116 | logging.debug(f"Receiving gradients failed due to: {e}") 117 | self.miner_gradients.append(None) 118 | self.miner_weights.append(0.0) 119 | self.miner_hotkeys.append(hotkey) 120 | 121 | @staticmethod 122 | def have_nans(aggregated_gradients): 123 | for tensor in aggregated_gradients.values(): 124 | if torch.isnan(tensor).any(): 125 | logging.debug("NaN values detected in the aggregated gradients.") 126 | return True 127 | return False 128 | 129 | def average_gradients(self, beta=1.0): 130 | self.miner_gradients = [ 131 | gradients for gradients in self.miner_gradients if gradients is not None 132 | ] 133 | assert len(self.miner_gradients) > 0 134 | 135 | averaged_gradients = { 136 | name: torch.zeros_like(grad) 137 | for name, grad in self.miner_gradients[0].items() 138 | } 139 | 140 | for score, gradients in zip( 141 | self.validator_combined_weights, self.miner_gradients 142 | ): 143 | logging.info("Averaging Gradient") 144 | for name, grad in gradients.items(): 145 | averaged_gradients[name] += grad * score * beta 146 | 147 | return averaged_gradients 148 | 149 | def apply_averaged_gradients(self, averaged_gradients, alpha=0.00001): 150 | with torch.no_grad(): 151 | for name, param in self.model.named_parameters(): 152 | if name in averaged_gradients: 153 | param -= alpha * averaged_gradients[name] 154 | 155 | def save_model(self): 156 | self.model.save_pretrained(self.local_dir) 157 | 158 | def push_to_hf_hub(self, commit_message="Pushing model to Hub"): 159 | training_args = TrainingArguments( 160 | output_dir=self.local_dir, # Local directory to save the model 161 | per_device_train_batch_size=1, # Dummy argument, won't actually be used for training here 162 | per_device_eval_batch_size=1, # Dummy argument, necessary to specify but won't be used 163 | push_to_hub=True, # Enable pushing to hub 164 | push_to_hub_model_id=self.repo_id, # Repository ID on the Hugging Face Hub 165 | push_to_hub_organization=None, # Specify organization name here if applicable 166 | push_to_hub_token=self.hf_token, # Hugging Face authentication token 167 | ) 168 | 169 | # Initialize the Trainer 170 | trainer = Trainer( 171 | model=self.model, # Your PyTorch model 172 | args=training_args, 173 | ) 174 | 175 | # Push the model to the Hugging Face Hub 176 | trainer.push_to_hub(commit_message=commit_message) 177 | 178 | def run_periodic_averaging(self, t): 179 | while True: 180 | logging.info("Averaging Beggining") 181 | start_time = time.time() 182 | 183 | if self.hf_manager.check_for_new_submissions(): 184 | logging.info( 185 | "Model updated from Hugging Face. Continuing training with new model..." 186 | ) 187 | self.model = self.hf_manager.update_model(self.model) 188 | 189 | self.receive_and_score_gradients() 190 | averaged_gradients = self.average_gradients() 191 | self.apply_averaged_gradients(averaged_gradients) 192 | self.save_model() 193 | self.push_to_hf_hub(commit_message="Updated model with new gradients") 194 | elapsed_time = time.time() - start_time 195 | time_to_wait = max(0, t - elapsed_time) 196 | time.sleep(time_to_wait) 197 | logging.info("Averaging Done") 198 | 199 | 200 | class DeltaAverager(Averager): 201 | def __init__( 202 | self, 203 | model, 204 | local_dir, 205 | repo_id, 206 | hf_manager, 207 | chain_manager, 208 | bittensor_network, 209 | hf_token=os.environ.get("HF_TOKEN"), 210 | ): 211 | self.model = model 212 | self.local_dir = local_dir 213 | self.repo_id = repo_id 214 | self.hf_token = hf_token 215 | self.scored_gradients = None 216 | self.last_sync_time = 0 217 | self.bittensor_network = bittensor_network 218 | self.chain_manager = chain_manager 219 | self.hf_manager = hf_manager 220 | 221 | def average_gradients(self, beta=1.0): 222 | self.non_none_miner_gradients = [ 223 | gradients for gradients in self.miner_gradients if gradients is not None 224 | ] 225 | assert len(self.non_none_miner_gradients) > 0 226 | averaged_gradients = { 227 | name: torch.zeros_like(grad) 228 | for name, grad in self.non_none_miner_gradients[0].items() 229 | } 230 | 231 | for score, gradients in zip( 232 | self.validator_combined_weights, self.miner_gradients 233 | ): 234 | logging.info("Averaging Gradient") 235 | if gradients is not None: 236 | for (name, grad), (param_name, param) in zip( 237 | gradients.items(), self.model.named_parameters() 238 | ): 239 | averaged_gradients[name] += (param + grad) * score * beta 240 | for name, param in averaged_gradients.items(): 241 | averaged_gradients[name] = param / len(self.non_none_miner_gradients) 242 | return averaged_gradients 243 | 244 | def apply_averaged_gradients(self, averaged_gradients, alpha=0.00001): 245 | with torch.no_grad(): 246 | for name, param in self.model.named_parameters(): 247 | if name in averaged_gradients: 248 | param = averaged_gradients[name] 249 | 250 | def run_periodic_averaging(self, t): 251 | while True: 252 | logging.info("Averaging Beggining") 253 | start_time = time.time() 254 | 255 | if self.hf_manager.check_for_new_submissions(): 256 | logging.info( 257 | "Model updated from Hugging Face. Continuing training with new model..." 258 | ) 259 | self.model = self.hf_manager.update_model(self.model) 260 | 261 | self.receive_and_score_gradients() 262 | averaged_weights = self.average_gradients() 263 | self.apply_averaged_gradients(averaged_weights) 264 | self.save_model() 265 | self.push_to_hf_hub(commit_message="Updated model with new gradients") 266 | elapsed_time = time.time() - start_time 267 | time_to_wait = max(0, t - elapsed_time) 268 | time.sleep(time_to_wait) 269 | logging.info("Averaging Done") 270 | 271 | 272 | class LocalAverager(DeltaAverager): 273 | def __init__( 274 | self, 275 | model, 276 | local_dir, 277 | hf_manager, 278 | chain_manager, 279 | bittensor_network=None, 280 | hf_token=os.environ.get("HF_TOKEN"), 281 | ): 282 | super().__init__( 283 | model, 284 | local_dir, 285 | hf_manager=hf_manager, 286 | chain_manager=chain_manager, 287 | repo_id=None, 288 | bittensor_network=bittensor_network, 289 | hf_token=hf_token, 290 | ) 291 | # No need for repo_id or hf_token in the local version 292 | 293 | def receive_gradients(self, repo_id=None, gradient_file_name="gradients.pt"): 294 | """ 295 | Overrides the receive_gradients method to fetch gradients from a local directory. 296 | """ 297 | if repo_id is None: 298 | return None 299 | try: 300 | gradient_file_path = os.path.join(repo_id, gradient_file_name) 301 | if not os.path.exists(gradient_file_path): 302 | logging.warning(f"Gradient file not found: {gradient_file_path}") 303 | return None 304 | 305 | # Load the gradients directly using torch.load 306 | aggregated_gradients = torch.load(gradient_file_path) 307 | 308 | if self.have_nans(aggregated_gradients): 309 | return None 310 | 311 | return aggregated_gradients 312 | except Exception as e: 313 | logging.error(f"Error receiving gradients locally: {e}") 314 | return None 315 | 316 | def push_to_hf_hub(self, commit_message="Pushing model to Hub"): 317 | """ 318 | Overrides the push_to_hf_hub method to simply save the model locally. 319 | """ 320 | self.save_model() 321 | logging.info( 322 | f"Model saved locally in {self.local_dir} instead of pushing to Hugging Face Hub." 323 | ) 324 | 325 | def save_model(self): 326 | """ 327 | Saves the model to the specified local directory. 328 | """ 329 | os.makedirs(self.local_dir, exist_ok=True) 330 | model_save_path = os.path.join(self.local_dir, "averaged_model.pt") 331 | torch.save(self.model.state_dict(), model_save_path) 332 | logging.info(f"Model saved locally at {model_save_path}.") 333 | 334 | 335 | class ParameterizedAverager(DeltaAverager): 336 | # __init__(self, model, local_dir, repo_id,hf_manager, chain_manager,bittensor_network, hf_token=os.environ.get("HF_TOKEN")) 337 | def __init__( 338 | self, 339 | model, 340 | local_dir, 341 | gradients_dir, 342 | device, 343 | repo_id=None, 344 | hf_manager=None, 345 | chain_manager=None, 346 | bittensor_network=None, 347 | hf_token=os.environ.get("HF_TOKEN"), 348 | check_update_interval=300, 349 | ): 350 | DeltaAverager.__init__( 351 | self, 352 | model, 353 | local_dir=local_dir, 354 | repo_id=repo_id, 355 | hf_manager=hf_manager, 356 | chain_manager=chain_manager, 357 | bittensor_network=bittensor_network, 358 | hf_token=hf_token, 359 | ) 360 | self.device = device 361 | self.last_pull_time = 0 362 | self.check_update_interval = check_update_interval 363 | self.gradients_dir = gradients_dir 364 | 365 | def get_model_paths(self, gradient_file_name="gradients.pt"): 366 | self.model_paths = [] 367 | for uid, hotkey in enumerate(self.bittensor_network.metagraph.hotkeys): 368 | try: 369 | repo_id = self.chain_manager.retrieve_hf_repo(hotkey) 370 | # gradient = self.receive_gradients(repo_id=repo_id) 371 | if repo_id is not None: 372 | self.model_paths.append( 373 | {"repo_id": repo_id, "hotkey": hotkey, "uid": uid} 374 | ) 375 | except Exception as e: 376 | logging.debug(f"Receiving gradients failed due to: {e}") 377 | 378 | def store_weight_delta(self, weight_delta, hotkey): 379 | """ 380 | Save the weight_delta state_dict to a local directory at regular intervals. 381 | """ 382 | os.makedirs(self.gradients_dir, exist_ok=True) 383 | save_path = os.path.join(self.gradients_dir, f"weight_delta_{hotkey}.pt") 384 | torch.save(weight_delta, save_path) 385 | 386 | def load_weight_delta(self, hotkey): 387 | """ 388 | Load the cached weight_delta state_dict from the local directory. 389 | """ 390 | load_path = os.path.join(self.gradients_dir, f"weight_delta_{hotkey}.pt") 391 | if os.path.exists(load_path): 392 | return torch.load(load_path, map_location=self.device) 393 | else: 394 | return None 395 | 396 | def cache_params_locally(self): 397 | self.num_models = 0 398 | self.active_hotkeys = [] 399 | for model_path in self.model_paths: 400 | if model_path is None: 401 | continue 402 | else: 403 | #try: 404 | weight_delta = self.hf_manager.receive_gradients(model_path["repo_id"]) 405 | false_model_flag = False 406 | for name, param in weight_delta.items(): 407 | if weight_delta[name].shape != self.model.state_dict()[name].shape: 408 | false_model_flag = True 409 | if false_model_flag: 410 | continue 411 | self.store_weight_delta(weight_delta, model_path["hotkey"]) 412 | if weight_delta is None: 413 | raise ValueError(f"Failed to receive gradients at: {model_path['repo_id']}") 414 | self.num_models +=1 415 | self.active_hotkeys.append(model_path["hotkey"]) 416 | #except Exception as e: 417 | # logging.warning(f"Failed to get model at {model_path['hotkey']}: {e}") 418 | if len(self.active_hotkeys) > 0: 419 | assert len(self.active_hotkeys) == self.num_models 420 | # if time.time() - self.last_cache_time > self.caching_interval: 421 | 422 | def get_averaged_params(self): 423 | if self.weights is None: 424 | self.weights = nn.functional.softmax( 425 | torch.ones( 426 | (self.num_models, len(list(self.model.parameters()))), 427 | device=self.device, 428 | ), 429 | dim=0, 430 | ) 431 | averaged_gradients = { 432 | name: torch.zeros_like(grad) for name, grad in self.model.named_parameters() 433 | } 434 | for params, weight in zip(self.lazy_load_params(), self.weights): 435 | if params is None or torch.all(weight == 0): 436 | continue 437 | for j, (name_reconstructed_model, param_reconstructed_model) in enumerate( 438 | params.items() 439 | ): 440 | param_reconstructed_model = param_reconstructed_model.to(self.device) 441 | averaged_gradients[name_reconstructed_model] = averaged_gradients[name_reconstructed_model].to(self.device) 442 | try: 443 | averaged_gradients[name_reconstructed_model] += (param_reconstructed_model * weight[j]) #FIXME make weights per param 444 | except Exception as e: 445 | #logging.warning(f"Skipping parameter due to: {e}") 446 | pass 447 | 448 | return averaged_gradients 449 | 450 | def lazy_load_params(self): 451 | for uid, hotkey in enumerate(self.active_hotkeys): 452 | weight_delta = self.load_weight_delta(hotkey) 453 | # if weight_delta is None: 454 | # yield None 455 | # continue 456 | base_model = torch.load( 457 | os.path.join( 458 | self.hf_manager.get_local_model_directory(), "averaged_model.pt" 459 | ), 460 | map_location=self.device, 461 | ) # self.model.state_dict() 462 | for name, delta_param in weight_delta.items(): 463 | weight_delta[name] = weight_delta[name].to(self.device) 464 | base_model[name] = base_model[name].to(self.device) 465 | try: 466 | weight_delta[name] = weight_delta[name] + base_model[name] 467 | except Exception as e: 468 | #logging.warning(f"Error loading param: {e}") 469 | pass 470 | yield weight_delta 471 | 472 | def get_averaged_model(self): 473 | averaged_params = self.get_averaged_params() 474 | for param, averaged_param in zip( 475 | self.model.parameters(), averaged_params.values() 476 | ): 477 | param.data.copy_(averaged_param.data) 478 | self.model.to(self.device) 479 | return self.model 480 | 481 | def save_model(self): 482 | """ 483 | Saves the model to the specified local directory. 484 | """ 485 | os.makedirs(self.hf_manager.get_local_model_directory(), exist_ok=True) 486 | model_save_path = os.path.join(self.hf_manager.get_local_model_directory(), "averaged_model.pt") 487 | torch.save(self.model.state_dict(), model_save_path) 488 | logging.info(f"Model saved locally at {model_save_path}.") 489 | 490 | def meta_learning(self, val_loader, meta_epochs, lr): 491 | criterion = nn.CrossEntropyLoss() 492 | self.weights = None # nn.Parameter(nn.functional.softmax(torch.ones(self.num_models, device=self.device), dim=0)) 493 | for epoch in range(meta_epochs): 494 | for epoch in range(meta_epochs): 495 | total_loss = 0 496 | correct_predictions = 0 497 | total_samples = 0 498 | 499 | for batch_count, batch in enumerate(val_loader): 500 | averaged_model = self.get_averaged_model() 501 | 502 | outputs = averaged_model( 503 | input_ids=batch["input_ids"].to(self.device), 504 | attention_mask=batch["attention_mask"].to(self.device), 505 | labels=batch["labels"].to(self.device), 506 | ) 507 | val_loss = outputs.loss 508 | total_loss += val_loss.item() * batch["input_ids"].size(0) 509 | total_samples += batch["input_ids"].size(0) 510 | 511 | val_loss.backward() 512 | with torch.no_grad(): 513 | grad_weights = torch.zeros_like(self.weights) 514 | 515 | for i, model in enumerate(self.lazy_load_params()): 516 | for j, (model_param, main_param) in enumerate( 517 | zip(model.values(), averaged_model.parameters()) 518 | ): 519 | if main_param.grad is not None: 520 | grad_weights[i, j] += torch.sum( 521 | main_param.grad * (model_param - main_param) 522 | ) 523 | 524 | for main_param in averaged_model.parameters(): 525 | main_param.grad.zero_() 526 | 527 | #grad_weights = torch.clamp(grad_weights,min=-1,max=1) 528 | self.weights.data -= (lr * grad_weights) 529 | #if (batch_count * epoch+1) % 1000: 530 | # logging.info(f"Meta-Epoch [{epoch+1}/{meta_epochs}], Validation Loss: {val_loss.item():.4f}, Weights: {torch.mean(self.weights,dim=1)}") 531 | 532 | average_loss = total_loss / total_samples 533 | perplexity = math.exp(average_loss) 534 | 535 | if MLFLOW_ACTIVE: 536 | step = int(time.time()) 537 | log_model_metrics(step=step, loss_averaged = average_loss, perplexity_averaged = perplexity) 538 | 539 | logging.info(f"Meta-Epoch [{epoch+1}/{meta_epochs}], Validation Loss: {average_loss:.4f},Perplexity: {perplexity}, Weights: {torch.mean(self.weights,dim=1)}") 540 | 541 | return self.get_averaged_model() 542 | 543 | 544 | def run_periodic_averaging(self, val_loader, meta_epochs, lr, t): 545 | while True: 546 | logging.info("Averaging Beginning") 547 | start_time = time.time() 548 | 549 | if time.time() - self.last_pull_time >= self.check_update_interval: 550 | if self.hf_manager.check_for_new_submissions( 551 | self.hf_manager.model_repo_id 552 | ): 553 | logging.info( 554 | "Averaged model updated on Hugging Face. Pulling latest model..." 555 | ) 556 | self.hf_manager.pull_latest_model() 557 | time.sleep(10) # just to give enough time for pull 558 | self.model = self.hf_manager.update_model(self.model) 559 | self.model = self.model.to(self.device) 560 | optimizer = AdamW( 561 | self.model.parameters(), lr=5e-5 562 | ) # Reinitialize the optimizer 563 | self.base_weights = { 564 | name: param.clone() 565 | for name, param in self.model.named_parameters() 566 | } 567 | 568 | self.last_pull_time = time.time() 569 | 570 | self.get_model_paths() 571 | # self.num_models = len(self.model_paths) 572 | self.cache_params_locally() 573 | 574 | self.model = self.meta_learning(val_loader, meta_epochs, lr) 575 | # self.apply_averaged_gradients(averaged_weights) 576 | self.save_model() 577 | self.hf_manager.push_to_hf_hub(path_to_model= "averaged_model.pt") 578 | 579 | elapsed_time = time.time() - start_time 580 | time_to_wait = max(0, t - elapsed_time) 581 | time.sleep(time_to_wait) 582 | 583 | logging.info("Averaging Done") 584 | 585 | 586 | class LocalParameterizedAverager(LocalAverager): 587 | def __init__( 588 | self, 589 | model, 590 | local_dir, 591 | device, 592 | hf_manager, 593 | chain_manager=None, 594 | bittensor_network=None, 595 | hf_token=os.environ.get("HF_TOKEN"), 596 | ): 597 | LocalAverager.__init__( 598 | self, 599 | model, 600 | local_dir, 601 | hf_manager=hf_manager, 602 | chain_manager=chain_manager, 603 | bittensor_network=bittensor_network, 604 | hf_token=hf_token, 605 | ) 606 | self.device = device 607 | 608 | def get_model_paths(self, gradient_file_name="gradients.pt"): 609 | self.model_paths = [] 610 | for uid, hotkey in enumerate(self.bittensor_network.metagraph.hotkeys): 611 | try: 612 | repo_id = self.chain_manager.retrieve_hf_repo(hotkey) 613 | # gradient = self.receive_gradients(repo_id=repo_id) 614 | if repo_id is not None: 615 | self.model_paths.append(repo_id) 616 | except Exception as e: 617 | logging.debug(f"Receiving gradients failed due to: {e}") 618 | 619 | def get_averaged_params(self): 620 | if self.weights is None: 621 | self.weights = nn.functional.softmax( 622 | torch.ones( 623 | (self.num_models, len(list(self.model.parameters()))), 624 | device=self.device, 625 | ), 626 | dim=0, 627 | ) 628 | averaged_gradients = { 629 | name: torch.zeros_like(grad) for name, grad in self.model.named_parameters() 630 | } 631 | for params, weight in zip(self.lazy_load_params(), self.weights): 632 | for j, (name_reconstructed_model, param_reconstructed_model) in enumerate( 633 | params.items() 634 | ): 635 | averaged_gradients[name_reconstructed_model] += ( 636 | param_reconstructed_model * weight[j] 637 | ) # FIXME make weights per param 638 | 639 | return averaged_gradients 640 | 641 | def lazy_load_params(self): 642 | for model_path in self.model_paths: 643 | if model_path is None: 644 | yield None 645 | else: 646 | weight_delta = torch.load( 647 | os.path.join(model_path, "gradients.pt"), map_location=self.device 648 | ) 649 | base_model = torch.load( 650 | os.path.join(self.local_dir, "averaged_model.pt"), 651 | map_location=self.device, 652 | ) # self.model.state_dict() 653 | for name, delta_param in weight_delta.items(): 654 | weight_delta[name] = weight_delta[name].to(self.device) 655 | base_model[name] = base_model[name].to(self.device) 656 | weight_delta[name] = weight_delta[name] + base_model[name] 657 | yield weight_delta 658 | 659 | def get_averaged_model(self): 660 | averaged_params = self.get_averaged_params() 661 | for param, averaged_param in zip( 662 | self.model.parameters(), averaged_params.values() 663 | ): 664 | param.data.copy_(averaged_param.data) 665 | self.model.to(self.device) 666 | return self.model 667 | 668 | def save_model(self): 669 | """ 670 | Saves the model to the specified local directory. 671 | """ 672 | os.makedirs(self.local_dir, exist_ok=True) 673 | model_save_path = os.path.join(self.local_dir, "averaged_model.pt") 674 | torch.save(self.model.state_dict(), model_save_path) 675 | logging.info(f"Model saved locally at {model_save_path}.") 676 | 677 | def meta_learning(self, val_loader, meta_epochs, lr): 678 | criterion = nn.CrossEntropyLoss() 679 | # optimizer = optim.SGD([self.weights], lr=lr) 680 | self.weights = None # nn.Parameter(nn.functional.softmax(torch.ones(self.num_models, device=self.device), dim=0)) 681 | for epoch in range(meta_epochs): 682 | # Outer loop: Update averaging weights 683 | 684 | # val_loss = evaluate_model(averaged_model, val_loader, criterion, self.device) 685 | 686 | # self.model.eval() 687 | 688 | for epoch in range(meta_epochs): 689 | total_loss = 0 690 | correct_predictions = 0 691 | total_samples = 0 692 | 693 | for batch_count, batch in enumerate(val_loader): 694 | averaged_model = self.get_averaged_model() 695 | 696 | images, labels = batch 697 | outputs = averaged_model(images) 698 | val_loss = F.cross_entropy(outputs, labels) 699 | total_loss += val_loss.item() 700 | _, predicted = torch.max(outputs.data, 1) 701 | correct_predictions += (predicted == labels).sum().item() 702 | total_samples += labels.size(0) 703 | 704 | val_loss.backward() 705 | with torch.no_grad(): 706 | grad_weights = torch.zeros_like(self.weights) 707 | 708 | # for main_param in averaged_model.parameters(): 709 | # if main_param.grad is not None: 710 | # main_param.grad = torch.clamp(main_param.grad,min=-0.1,max=0.1) 711 | 712 | for i, model in enumerate(self.lazy_load_params()): 713 | for j, (model_param, main_param) in enumerate( 714 | zip(model.values(), averaged_model.parameters()) 715 | ): 716 | if main_param.grad is not None: 717 | grad_weights[i, j] += torch.sum( 718 | main_param.grad * (model_param - main_param) 719 | ) 720 | 721 | for main_param in averaged_model.parameters(): 722 | main_param.grad.zero_() 723 | 724 | # grad_weights = torch.clamp(grad_weights,min=-1,max=1) 725 | self.weights.data -= lr * grad_weights 726 | # if (batch_count * epoch+1) % 100: 727 | # logging.info(f"Meta-Epoch [{epoch+1}/{meta_epochs}], Validation Loss: {val_loss.item():.4f}, Weights: {self.weights}") 728 | 729 | average_loss = total_loss / total_samples 730 | accuracy = correct_predictions / total_samples 731 | logging.info( 732 | f"Meta-Epoch [{epoch+1}/{meta_epochs}], Validation Loss: {average_loss:.4f},Accuracy: {accuracy}, Weights: {self.weights}" 733 | ) 734 | 735 | return self.get_averaged_model() 736 | 737 | def run_periodic_averaging(self, val_loader, meta_epochs, lr, t): 738 | while True: 739 | logging.info("Averaging Beginning") 740 | start_time = time.time() 741 | 742 | if self.hf_manager.check_for_new_submissions(): 743 | logging.info( 744 | "Model updated from Hugging Face. Continuing training with new model..." 745 | ) 746 | self.model = self.hf_manager.update_model(self.model) 747 | 748 | self.get_model_paths() 749 | self.num_models = len(self.model_paths) 750 | 751 | self.model = self.meta_learning(val_loader, meta_epochs, lr) 752 | # self.apply_averaged_gradients(averaged_weights) 753 | self.save_model() 754 | self.push_to_hf_hub(commit_message="Updated model with new gradients") 755 | 756 | elapsed_time = time.time() - start_time 757 | time_to_wait = max(0, t - elapsed_time) 758 | time.sleep(time_to_wait) 759 | 760 | logging.info("Averaging Done") 761 | 762 | 763 | class LocalLLMParameterizedAverager(LocalParameterizedAverager): 764 | def meta_learning(self, val_loader, meta_epochs, lr): 765 | criterion = nn.CrossEntropyLoss() 766 | # optimizer = optim.SGD([self.weights], lr=lr) 767 | self.weights = None # nn.Parameter(nn.functional.softmax(torch.ones(self.num_models, device=self.device), dim=0)) 768 | for epoch in range(meta_epochs): 769 | # Outer loop: Update averaging weights 770 | 771 | # val_loss = evaluate_model(averaged_model, val_loader, criterion, self.device) 772 | 773 | # self.model.eval() 774 | 775 | for epoch in range(meta_epochs): 776 | total_loss = 0 777 | total_samples = 0 778 | 779 | for batch_num, batch in enumerate( 780 | val_loader 781 | ): # FIXME turn me into a generator? 782 | averaged_model = self.get_averaged_model() 783 | averaged_model.train() 784 | optimizer = optim.SGD( 785 | averaged_model.parameters(), lr=1010 786 | ) # This is only used to clear grads 787 | outputs = averaged_model( 788 | input_ids=batch["input_ids"], 789 | attention_mask=batch["attention_mask"], 790 | labels=batch["labels"], 791 | ) 792 | val_loss = outputs.loss 793 | total_loss += val_loss.item() * batch["input_ids"].size(0) 794 | total_samples += batch["input_ids"].size(0) 795 | 796 | val_loss.backward() 797 | with torch.no_grad(): 798 | grad_weights = torch.zeros_like(self.weights) 799 | 800 | for i, model in enumerate(self.lazy_load_params()): 801 | for j, (model_param, main_param) in enumerate( 802 | zip(model.values(), averaged_model.parameters()) 803 | ): 804 | if main_param.grad is not None: 805 | grad_weights[i, j] += torch.sum( 806 | main_param.grad * (model_param - main_param) 807 | ) 808 | 809 | for main_param in averaged_model.parameters(): 810 | main_param.grad.zero_() 811 | optimizer.zero_grad() 812 | # grad_weights = torch.clamp(grad_weights,min=-1,max=1) 813 | self.weights.data -= lr * grad_weights 814 | if (batch_num * epoch + 1) % 100: 815 | average_loss = total_loss / total_samples 816 | # logging.info(f"Meta-Epoch [{epoch+1}/{meta_epochs}], Validation Loss: {average_loss}, Weights: {self.weights}") 817 | 818 | average_loss = total_loss / total_samples 819 | try: 820 | perplexity = math.exp(average_loss) 821 | except: 822 | perplexity = 999999 823 | logging.info( 824 | f"Meta-Epoch [{epoch+1}/{meta_epochs}], Validation Loss: {average_loss:.4f}, Perplexity: {perplexity}, Weights: {self.weights}" 825 | ) 826 | 827 | return self.get_averaged_model() 828 | 829 | 830 | class GeneticAverager(nn.Module): 831 | def __init__( 832 | self, 833 | model, 834 | local_dir, 835 | device, 836 | hf_manager, 837 | chain_manager=None, 838 | bittensor_network=None, 839 | hf_token=os.environ.get("HF_TOKEN"), 840 | ): 841 | super(LocalParameterizedAverager, self).__init__() 842 | self.model = model 843 | self.local_dir = local_dir 844 | self.device = device 845 | self.hf_manager = hf_manager 846 | self.chain_manager = chain_manager 847 | self.bittensor_network = bittensor_network 848 | self.hf_token = hf_token 849 | self.population_size = 10 850 | self.num_generations = 10 851 | self.sigma = 0.1 # Standard deviation for Gaussian noise 852 | 853 | def get_model_paths(self): 854 | self.model_paths = [] 855 | for uid, hotkey in enumerate(self.bittensor_network.metagraph.hotkeys): 856 | try: 857 | repo_id = self.chain_manager.retrieve_hf_repo(hotkey) 858 | # gradient = self.receive_gradients(repo_id=repo_id) 859 | if repo_id is not None: 860 | self.model_paths.append(repo_id) 861 | except Exception as e: 862 | logging.debug(f"Receiving gradients failed due to: {e}") 863 | 864 | def lazy_load_params(self): 865 | for model_path in self.model_paths: 866 | if model_path is None: 867 | yield None 868 | else: 869 | weight_delta = torch.load( 870 | os.path.join(model_path, "gradients.pt"), map_location="cpu" 871 | ) 872 | yield weight_delta 873 | 874 | def get_averaged_params(self, weights): 875 | averaged_params = { 876 | name: torch.zeros_like(param, device=self.device) 877 | for name, param in self.model.named_parameters() 878 | } 879 | for params, weight in zip(self.lazy_load_params(), weights): 880 | for (name_weight_delta, param_weight_delta), ( 881 | name_base_model, 882 | param_base_model, 883 | ) in zip(params.items(), self.model.named_parameters()): 884 | averaged_params[name_base_model] += ( 885 | param_weight_delta + param_base_model 886 | ) * weight 887 | return averaged_params 888 | 889 | def get_averaged_model(self, weights): 890 | averaged_params = self.get_averaged_params(weights) 891 | for name, param in self.model.named_parameters(): 892 | param.data.copy_(averaged_params[name].data / len(self.model_paths)) 893 | return self.model 894 | 895 | def evaluate_population(self, val_loader, population): 896 | # Evaluate all individuals in the population 897 | fitness_scores = [] 898 | for weights in tqdm(population): 899 | model = self.get_averaged_model(weights) 900 | total_loss = 0 901 | for images, labels in val_loader: 902 | images, labels = images.to(self.device), labels.to(self.device) 903 | outputs = model(images) 904 | loss = nn.functional.cross_entropy(outputs, labels) 905 | total_loss += loss.item() 906 | average_loss = total_loss / len(val_loader) 907 | fitness_scores.append( 908 | -average_loss 909 | ) # Negative because lower loss is better 910 | return fitness_scores 911 | 912 | def evolve_population(self, population, val_loader): 913 | fitness_scores = self.evaluate_population(val_loader, population) 914 | sorted_indices = np.argsort(fitness_scores)[::-1] # Descending order of fitness 915 | best_individuals = [ 916 | population[i] for i in sorted_indices[: len(population) // 2] 917 | ] # Select top 50% 918 | 919 | # Reproduce with mutation 920 | new_population = [] 921 | while len(new_population) < len(population): 922 | parent = random.choice(best_individuals) 923 | child = parent + torch.randn_like(parent) * self.sigma # Gaussian mutation 924 | new_population.append(child) 925 | return new_population 926 | 927 | def run_evolution(self, val_loader): 928 | # Initialize population 929 | population = [ 930 | torch.rand(len(self.model_paths), device=self.device) 931 | for _ in range(self.population_size) 932 | ] 933 | 934 | for generation in tqdm(range(self.num_generations)): 935 | population = self.evolve_population(population, val_loader) 936 | best_weights = population[0] 937 | best_fitness = self.evaluate_population(val_loader, [best_weights])[0] 938 | print(f"Generation {generation}, Best Fitness: {best_fitness}") 939 | 940 | return self.get_averaged_model(best_weights) 941 | 942 | def run_periodic_averaging(self, val_loader, t=40): 943 | while True: 944 | logging.info("Averaging Beginning") 945 | start_time = time.time() 946 | 947 | if self.hf_manager.check_for_new_submissions(): 948 | logging.info( 949 | "Model updated from Hugging Face. Continuing training with new model..." 950 | ) 951 | self.model = self.hf_manager.update_model(self.model) 952 | 953 | self.get_model_paths() 954 | self.num_models = len(self.model_paths) 955 | self.weights = nn.Parameter( 956 | nn.functional.softmax( 957 | torch.ones(self.num_models, device=self.device), dim=0 958 | ) 959 | ) 960 | 961 | self.model = self.run_evolution(val_loader) 962 | # self.apply_averaged_gradients(averaged_weights) 963 | self.save_model() 964 | self.push_to_hf_hub(commit_message="Updated model with new gradients") 965 | 966 | elapsed_time = time.time() - start_time 967 | time_to_wait = max(0, t - elapsed_time) 968 | time.sleep(time_to_wait) 969 | 970 | logging.info("Averaging Done") 971 | -------------------------------------------------------------------------------- /hivetrain/btt_connector.py: -------------------------------------------------------------------------------- 1 | import bittensor as bt 2 | import copy 3 | import math 4 | import numpy as np 5 | import bittensor 6 | import torch 7 | import time 8 | from typing import List, Tuple 9 | import bittensor.utils.networking as net 10 | import threading 11 | import logging 12 | from . import __spec_version__ 13 | from bittensor import logging 14 | logger = logging 15 | #logger = logging.getLogger('waitress') 16 | #logger.setLevel(logging.DEBUG) 17 | 18 | 19 | def initialize_bittensor_objects(): 20 | global wallet, subtensor, metagraph, config 21 | base_config = copy.deepcopy(config) 22 | # check_config(base_config) 23 | 24 | if base_config.mock: 25 | wallet = bt.MockWallet(config=base_config) 26 | subtensor = MockSubtensor(base_config.netuid, wallet=wallet) 27 | metagraph = MockMetagraph(base_config.netuid, subtensor=subtensor) 28 | else: 29 | wallet = bt.wallet(config=base_config) 30 | subtensor = bt.subtensor(config=base_config) 31 | metagraph = subtensor.metagraph(base_config.netuid) 32 | 33 | 34 | # def check_registered(netuid): 35 | 36 | # if not BittensorNetwork.subtensor.is_hotkey_registered(netuid=netuid, hotkey_ss58=BittensorNetwork.wallet.hotkey.ss58_address): 37 | # print(f"Wallet: {wallet} is not registered on netuid {netuid}. Please register the hotkey before trying again") 38 | # exit() 39 | 40 | def resync_metagraph(lite): 41 | global metagraph, config, subtensor 42 | # Fetch the latest state of the metagraph from the Bittensor network 43 | print("Resynchronizing metagraph...") 44 | # Update the metagraph with the latest information from the network 45 | metagraph = BittensorNetwork.subtensor.metagraph(BittensorNetwork.config.netuid, lite=lite) 46 | print("Metagraph resynchronization complete.") 47 | 48 | def should_sync_metagraph(last_sync_time,sync_interval): 49 | current_time = time.time() 50 | return (current_time - last_sync_time) > sync_interval 51 | 52 | def sync(last_sync_time, sync_interval, config, lite=False): 53 | if should_sync_metagraph(last_sync_time,sync_interval): 54 | # Assuming resync_metagraph is a method to update the metagraph with the latest state from the network. 55 | # This method would need to be defined or adapted from the BaseNeuron implementation. 56 | try: 57 | resync_metagraph(lite) 58 | last_sync_time = time.time() 59 | except Exception as e: 60 | logger.warn(f"Failed to resync metagraph: {e}") 61 | return last_sync_time 62 | else: 63 | return last_sync_time 64 | 65 | 66 | 67 | # def serve_on_subtensor(external_ip, external_port, netuid, max_retries=5, wait_for_inclusion=True, wait_for_finalization=False): 68 | # retry_count = 0 69 | # check_registered(netuid) 70 | # while retry_count < max_retries: 71 | # try: 72 | # breakpoint() 73 | # serve_success = BittensorNetwork.subtensor.serve( 74 | # wallet=BittensorNetwork.wallet, 75 | # ip=external_ip, 76 | # port=external_port, 77 | # netuid=netuid, 78 | # protocol=4, 79 | # wait_for_inclusion=wait_for_inclusion, 80 | # wait_for_finalization=wait_for_finalization, 81 | # prompt=False, 82 | # ) 83 | # if serve_success: 84 | # print(f"Serving on IP: {external_ip}, Port: {external_port}") 85 | # break 86 | # else: 87 | # print("Failed to serve on Subtensor network. Retrying...") 88 | # except Exception as e: 89 | # print(f"Error serving on Subtensor network: {e}") 90 | 91 | # retry_count += 1 92 | # sleep_time = math.pow(2, retry_count) # Exponential backoff 93 | # print(f"Retry {retry_count}/{max_retries}. Retrying in {sleep_time} seconds.") 94 | # time.sleep(sleep_time) 95 | 96 | # if retry_count == max_retries: 97 | # print("Max retries reached. Failed to serve on Subtensor network.") 98 | 99 | def serve_extrinsic( 100 | subtensor: "bittensor.subtensor", 101 | wallet: "bittensor.wallet", 102 | ip: str, 103 | port: int, 104 | protocol: int, 105 | netuid: int, 106 | placeholder1: int = 0, 107 | placeholder2: int = 0, 108 | wait_for_inclusion: bool = False, 109 | wait_for_finalization=True, 110 | prompt: bool = False, 111 | ) -> bool: 112 | r"""Subscribes a Bittensor endpoint to the subtensor chain. 113 | 114 | Args: 115 | wallet (bittensor.wallet): 116 | Bittensor wallet object. 117 | ip (str): 118 | Endpoint host port i.e., ``192.122.31.4``. 119 | port (int): 120 | Endpoint port number i.e., ``9221``. 121 | protocol (int): 122 | An ``int`` representation of the protocol. 123 | netuid (int): 124 | The network uid to serve on. 125 | placeholder1 (int): 126 | A placeholder for future use. 127 | placeholder2 (int): 128 | A placeholder for future use. 129 | wait_for_inclusion (bool): 130 | If set, waits for the extrinsic to enter a block before returning ``true``, or returns ``false`` if the extrinsic fails to enter the block within the timeout. 131 | wait_for_finalization (bool): 132 | If set, waits for the extrinsic to be finalized on the chain before returning ``true``, or returns ``false`` if the extrinsic fails to be finalized within the timeout. 133 | prompt (bool): 134 | If ``true``, the call waits for confirmation from the user before proceeding. 135 | Returns: 136 | success (bool): 137 | Flag is ``true`` if extrinsic was finalized or uncluded in the block. If we did not wait for finalization / inclusion, the response is ``true``. 138 | """ 139 | # Decrypt hotkey 140 | wallet.hotkey 141 | params: "bittensor.AxonServeCallParams" = { 142 | "version": bittensor.__version_as_int__, 143 | "ip": net.ip_to_int(ip), 144 | "port": port, 145 | "ip_type": net.ip_version(ip), 146 | "netuid": netuid, 147 | "hotkey": wallet.hotkey.ss58_address, 148 | "coldkey": wallet.coldkeypub.ss58_address, 149 | "protocol": protocol, 150 | "placeholder1": placeholder1, 151 | "placeholder2": placeholder2, 152 | } 153 | bittensor.logging.debug("Checking axon ...") 154 | neuron = subtensor.get_neuron_for_pubkey_and_subnet( 155 | wallet.hotkey.ss58_address, netuid=netuid 156 | ) 157 | neuron_up_to_date = not neuron.is_null and params == { 158 | "version": neuron.axon_info.version, 159 | "ip": net.ip_to_int(neuron.axon_info.ip), 160 | "port": neuron.axon_info.port, 161 | "ip_type": neuron.axon_info.ip_type, 162 | "netuid": neuron.netuid, 163 | "hotkey": neuron.hotkey, 164 | "coldkey": neuron.coldkey, 165 | "protocol": neuron.axon_info.protocol, 166 | "placeholder1": neuron.axon_info.placeholder1, 167 | "placeholder2": neuron.axon_info.placeholder2, 168 | } 169 | output = params.copy() 170 | output["coldkey"] = wallet.coldkeypub.ss58_address 171 | output["hotkey"] = wallet.hotkey.ss58_address 172 | if neuron_up_to_date: 173 | bittensor.logging.debug( 174 | f"Axon already served on: AxonInfo({wallet.hotkey.ss58_address},{ip}:{port}) " 175 | ) 176 | return True 177 | 178 | if prompt: 179 | output = params.copy() 180 | output["coldkey"] = wallet.coldkeypub.ss58_address 181 | output["hotkey"] = wallet.hotkey.ss58_address 182 | if not Confirm.ask( 183 | "Do you want to serve axon:\n [bold white]{}[/bold white]".format( 184 | json.dumps(output, indent=4, sort_keys=True) 185 | ) 186 | ): 187 | return False 188 | 189 | bittensor.logging.debug( 190 | f"Serving axon with: AxonInfo({wallet.hotkey.ss58_address},{ip}:{port}) -> {subtensor.network}:{netuid}" 191 | ) 192 | params["ip"] = net.int_to_ip(params["ip"]) 193 | success, error_message = subtensor._do_serve_axon( 194 | wallet=wallet, 195 | call_params=params, 196 | wait_for_finalization=wait_for_finalization, 197 | wait_for_inclusion=wait_for_inclusion, 198 | ) 199 | 200 | if wait_for_inclusion or wait_for_finalization: 201 | if success == True: 202 | bittensor.logging.debug( 203 | f"Axon served with: AxonInfo({wallet.hotkey.ss58_address},{ip}:{port}) on {subtensor.network}:{netuid} " 204 | ) 205 | return True 206 | else: 207 | bittensor.logging.debug( 208 | f"Axon failed to served with error: {error_message} " 209 | ) 210 | return False 211 | else: 212 | return True 213 | 214 | def serve_axon(netuid,host_address,external_address, host_port, external_port): 215 | """Serve axon to enable external connections.""" 216 | 217 | logger.info("serving ip to chain...") 218 | try: 219 | axon = bt.axon( 220 | config=BittensorNetwork.config, 221 | wallet=BittensorNetwork.wallet, 222 | # port=host_port, 223 | # ip=host_address, 224 | # external_ip=external_address, 225 | # external_port=external_port 226 | ) 227 | axon.external_ip = external_address 228 | axon.external_port = external_port 229 | try: 230 | # BittensorNetwork.subtensor.serve_axon( 231 | # netuid=netuid, 232 | # axon=axon, 233 | # ) 234 | serve_success = BittensorNetwork.subtensor.serve( 235 | wallet=BittensorNetwork.wallet, 236 | ip=external_address, 237 | port=external_port, 238 | netuid=netuid, 239 | protocol=4, 240 | wait_for_inclusion=True, 241 | wait_for_finalization=True, 242 | prompt=False, 243 | ) 244 | if serve_success: 245 | print("success") 246 | else: 247 | print("ARGH") 248 | logger.info( 249 | f"Served Axon {axon} on network: {BittensorNetwork.config.subtensor.chain_endpoint} with netuid: {BittensorNetwork.config.netuid}" 250 | ) 251 | except Exception as e: 252 | logger.error(f"Failed to serve Axon with exception: {e}") 253 | pass 254 | 255 | except Exception as e: 256 | logger.error( 257 | f"Failed to create Axon initialize with exception: {e}" 258 | ) 259 | pass 260 | return axon 261 | 262 | 263 | 264 | class BittensorNetwork: 265 | _instance = None 266 | _lock = threading.Lock() # Singleton lock 267 | _weights_lock = threading.Lock() # Lock for set_weights 268 | _anomaly_lock = threading.Lock() # Lock for detect_metric_anomaly 269 | _config_lock = threading.Lock() # Lock for modifying config 270 | _rate_limit_lock = threading.Lock() 271 | metrics_data = {} 272 | model_checksums = {} 273 | request_counts = {} # Track request counts 274 | blacklisted_addresses = {} # Track blacklisted addresses 275 | last_sync_time = 0 276 | sync_interval = 600 277 | 278 | 279 | def __new__(cls): 280 | with cls._lock: 281 | if cls._instance is None: 282 | cls._instance = super(BittensorNetwork, cls).__new__(cls) 283 | cls.wallet = None 284 | cls.subtensor = None 285 | cls.metagraph = None 286 | cls.config = None 287 | return cls._instance 288 | 289 | @classmethod 290 | def initialize(cls, config, ignore_regs=False): 291 | with cls._lock: 292 | cls.wallet = bt.wallet(config=config) 293 | cls.subtensor = bt.subtensor(config=config) 294 | cls.metagraph = cls.subtensor.metagraph(config.netuid) 295 | cls.config = config 296 | if not cls.subtensor.is_hotkey_registered(netuid=config.netuid, hotkey_ss58=cls.wallet.hotkey.ss58_address) and not ignore_regs: 297 | print(f"Wallet: {config.wallet} is not registered on netuid {config.netuid}. Please register the hotkey before trying again") 298 | exit() 299 | cls.uid = cls.metagraph.hotkeys.index( 300 | cls.wallet.hotkey.ss58_address 301 | ) 302 | else: 303 | cls.uid = 0 304 | cls.device="cpu" 305 | cls.base_scores = torch.zeros( 306 | cls.metagraph.n, dtype=torch.float32, device=cls.device 307 | ) 308 | # Additional initialization logic here 309 | 310 | @classmethod 311 | def set_weights(cls, scores): 312 | try: 313 | #chain_weights = torch.zeros(cls.subtensor.subnetwork_n(netuid=cls.metagraph.netuid)) 314 | uids = [] 315 | for uid, public_address in enumerate(cls.metagraph.hotkeys): 316 | try: 317 | alpha = 0.333333 # T=5 (2/(5+1)) 318 | cls.base_scores[uid] = alpha * scores.get(public_address, 0) + (1 - alpha) * cls.base_scores[uid].to(cls.device) 319 | uids.append(uid) 320 | except KeyError: 321 | continue 322 | uids = torch.tensor(uids) 323 | logger.info(f"raw_weights {cls.base_scores}") 324 | logger.info(f"raw_weight_uids {uids}") 325 | # Process the raw weights to final_weights via subtensor limitations. 326 | ( 327 | processed_weight_uids, 328 | processed_weights, 329 | ) = bt.utils.weight_utils.process_weights_for_netuid( 330 | uids=uids.to("cpu"), 331 | weights=cls.base_scores.to("cpu"), 332 | netuid=cls.config.netuid, 333 | subtensor=cls.subtensor, 334 | metagraph=cls.metagraph, 335 | ) 336 | logger.info(f"processed_weights {processed_weights}") 337 | logger.info(f"processed_weight_uids {processed_weight_uids}") 338 | 339 | # Convert to uint16 weights and uids. 340 | ( 341 | uint_uids, 342 | uint_weights, 343 | ) = bt.utils.weight_utils.convert_weights_and_uids_for_emit( 344 | uids=processed_weight_uids, weights=processed_weights 345 | ) 346 | logger.info("Sending weights to subtensor") 347 | result = cls.subtensor.set_weights( 348 | wallet=cls.wallet, 349 | netuid=cls.metagraph.netuid, 350 | uids=uint_uids, 351 | weights=uint_weights, 352 | wait_for_inclusion=False, 353 | version_key=__spec_version__ 354 | ) 355 | except Exception as e: 356 | logger.info(f"Error setting weights: {e}") 357 | 358 | @classmethod 359 | def get_validator_uids( 360 | cls, vpermit_tao_limit: int = 1024 361 | ): 362 | """ 363 | Check availability of all UIDs in a given subnet, returning their IP, port numbers, and hotkeys 364 | if they are serving and have at least vpermit_tao_limit stake, along with a list of strings 365 | formatted as 'ip:port' for each validator. 366 | 367 | Args: 368 | metagraph (bt.metagraph.Metagraph): Metagraph object. 369 | vpermit_tao_limit (int): Validator permit tao limit. 370 | 371 | Returns: 372 | Tuple[List[dict], List[str]]: A tuple where the first element is a list of dicts with details 373 | of available UIDs, including their IP, port, and hotkeys, and the 374 | second element is a list of strings formatted as 'ip:port'. 375 | """ 376 | validator_uids = [] # List to hold 'ip:port' strings 377 | for uid in range(len(cls.metagraph.S)): 378 | if cls.metagraph.S[uid] >= vpermit_tao_limit: 379 | validator_uids.append(uid) 380 | return validator_uids 381 | 382 | @classmethod 383 | def should_set_weights(cls) -> bool: 384 | with cls._lock: # Assuming last_update modification is protected elsewhere with the same lock 385 | return (cls.subtensor.get_current_block() - cls.metagraph.last_update[cls.uid]) > cls.config.neuron.epoch_length 386 | 387 | @classmethod 388 | def detect_metric_anomaly(cls, metric="loss", OUTLIER_THRESHOLD=2, MEDIAN_ABSOLUTE_DEVIATION=True): 389 | from scipy.stats import median_abs_deviation 390 | with cls._anomaly_lock: 391 | if not cls.metrics_data: 392 | return {} 393 | 394 | logger.info(f"Metrics Data: {cls.metrics_data}") 395 | aggregated_metrics = {} 396 | for public_address, data in cls.metrics_data.items(): 397 | if metric in data: 398 | if public_address in aggregated_metrics:#FIXME no need for an if condition 399 | aggregated_metrics[public_address].append(data[metric]) 400 | else: 401 | aggregated_metrics[public_address] = [data[metric]] 402 | 403 | if MEDIAN_ABSOLUTE_DEVIATION: 404 | # Use Median Absolute Deviation for outlier detection 405 | values = [np.median(vals) for vals in aggregated_metrics.values()] 406 | median = np.median(values) 407 | deviation = median_abs_deviation(values, scale='normal') 408 | is_outlier = {} 409 | for addr, vals in aggregated_metrics.items(): 410 | try: 411 | is_outlier[addr] = (abs(np.median(vals) - median) / deviation) > OUTLIER_THRESHOLD 412 | except: 413 | is_outlier[addr] = True 414 | 415 | else: 416 | # Use Mean and Standard Deviation for outlier detection 417 | average_metrics = {addr: np.nanmean(vals) for addr, vals in aggregated_metrics.items()} 418 | losses = np.array(list(average_metrics.values())) 419 | mean_loss = np.mean(losses) 420 | std_loss = np.std(losses) 421 | is_outlier = {addr: abs(avg_loss - mean_loss) / std_loss > OUTLIER_THRESHOLD 422 | for addr, avg_loss in average_metrics.items()} 423 | 424 | scores = {public_address: 0 if is_outlier.get(public_address, False) else 1 for public_address in aggregated_metrics} 425 | logger.info(f"Scores calculated: {scores}") 426 | return scores 427 | 428 | 429 | @classmethod 430 | def run_evaluation(cls): 431 | #global model_checksums, metrics_data 432 | logger.info("Evaluating miners") 433 | # checksum_frequencies = {} 434 | # for public_address, checksum in cls.model_checksums.items(): 435 | # checksum_frequencies[public_address] = checksum_frequencies.get(public_address, 0) + 1 436 | 437 | # model_scores = {} 438 | # try: 439 | # most_common_checksum = max(checksum_frequencies, key=checksum_frequencies.get) 440 | # model_scores = {public_address: (1 if checksum == most_common_checksum else 0) for public_address, checksum in cls.model_checksums.items()} 441 | # logger.info("Model scores based on checksum consensus:", model_scores) 442 | 443 | # except ValueError: 444 | # pass 445 | 446 | with cls._weights_lock: 447 | if BittensorNetwork.should_set_weights(): 448 | scores = BittensorNetwork.detect_metric_anomaly() 449 | BittensorNetwork.set_weights(scores) 450 | 451 | cls.model_checksums.clear() 452 | cls.metrics_data.clear() 453 | 454 | @classmethod 455 | def rate_limiter(cls, public_address, n=10, t=60): 456 | """ 457 | Check if a public_address has exceeded n requests in t seconds. 458 | If exceeded, add to blacklist. 459 | """ 460 | with cls._rate_limit_lock: 461 | current_time = time.time() 462 | if public_address in cls.blacklisted_addresses: 463 | # Check if the blacklist period is over 464 | if current_time - cls.blacklisted_addresses[public_address] > t: 465 | del cls.blacklisted_addresses[public_address] 466 | else: 467 | return False # Still blacklisted 468 | 469 | request_times = cls.request_counts.get(public_address, []) 470 | # Filter out requests outside of the time window 471 | request_times = [rt for rt in request_times if current_time - rt <= t] 472 | request_times.append(current_time) 473 | cls.request_counts[public_address] = request_times 474 | 475 | if len(request_times) > n: 476 | logger.info(f"Blacklisted {public_address} for making {len(request_times)} in {t} seconds") 477 | cls.blacklisted_addresses[public_address] = current_time 478 | return False # Too many requests, added to blacklist 479 | 480 | return True # Request allowed 481 | @classmethod 482 | def resync_metagraph(cls,lite=True): 483 | 484 | # Fetch the latest state of the metagraph from the Bittensor network 485 | print("Resynchronizing metagraph...") 486 | # Update the metagraph with the latest information from the network 487 | cls.metagraph = cls.subtensor.metagraph(cls.config.netuid, lite=lite) 488 | print("Metagraph resynchronization complete.") 489 | 490 | @staticmethod 491 | def should_sync_metagraph(last_sync_time,sync_interval): 492 | current_time = time.time() 493 | return (current_time - last_sync_time) > sync_interval 494 | 495 | @classmethod 496 | def sync(cls, lite=True): 497 | if cls.should_sync_metagraph(cls.last_sync_time,cls.sync_interval): 498 | # Assuming resync_metagraph is a method to update the metagraph with the latest state from the network. 499 | # This method would need to be defined or adapted from the BaseNeuron implementation. 500 | try: 501 | cls.resync_metagraph(lite) 502 | cls.last_sync_time = time.time() 503 | except Exception as e: 504 | logger.warn(f"Failed to resync metagraph: {e}") 505 | else: 506 | logger.info("Metagraph Sync Interval not yet passed") 507 | 508 | 509 | import json 510 | import os 511 | import threading 512 | import time 513 | 514 | class LocalMetagraph: 515 | def __init__(self): 516 | self._hotkeys = [] 517 | self._network_state = 'initial' 518 | self._weights = [] 519 | 520 | 521 | class Wallet: 522 | def __init__(self, hotkey): 523 | self.hotkey = hotkey 524 | 525 | class Hotkey: 526 | def __init__(self, ss58_address): 527 | self.ss58_address = ss58_address 528 | 529 | 530 | class LocalBittensorNetwork: 531 | _instance = None 532 | _lock = threading.Lock() 533 | _weights_lock = threading.Lock() 534 | _anomaly_lock = threading.Lock() 535 | _config_lock = threading.Lock() 536 | _rate_limit_lock = threading.Lock() 537 | _data_directory = 'bittensor_network' 538 | _metagraph_file = os.path.join(_data_directory, 'metagraph.json') 539 | _weights_file = os.path.join(_data_directory, 'weights.json') 540 | _metagraph = None 541 | metrics_data = {} 542 | model_checksums = {} 543 | request_counts = {} 544 | blacklisted_addresses = {} 545 | last_sync_time = 0 546 | sync_interval = 600 547 | subtensor=None 548 | wallet=None 549 | last_update = 0 550 | update_interval = 600 551 | 552 | def __new__(cls): 553 | with cls._lock: 554 | if cls._instance is None: 555 | cls._instance = super(BittensorNetwork, cls).__new__(cls) 556 | return cls._instance 557 | 558 | @classmethod 559 | def _load_data(cls, filepath): 560 | if os.path.exists(filepath): 561 | with open(filepath, 'r') as file: 562 | return json.load(file) 563 | else: 564 | return None 565 | 566 | @classmethod 567 | def _save_data(cls, data, filepath): 568 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 569 | with open(filepath, 'w') as file: 570 | json.dump(data, file, indent=4) 571 | 572 | @classmethod 573 | def initialize(cls, config): 574 | cls.config = config 575 | cls.sync_interval = config.get('sync_interval', 600) # Example of using config to set sync interval 576 | 577 | cls._wallet_file = os.path.join(cls._data_directory, f'wallet_{config.wallet.hotkey}.json') 578 | 579 | wallet_data = cls._load_data(cls._wallet_file) 580 | metagraph_data = cls._load_data(cls._metagraph_file) 581 | 582 | if not wallet_data or not metagraph_data: 583 | print("Data files not found, initializing new simulation.") 584 | wallet_data = {'hotkey': config.wallet.hotkey} 585 | metagraph_data = {'hotkeys': [], 'network_state': 'initial', 'weights': [], "stake": []} 586 | 587 | for i in range(100): 588 | hotkey = f'simulated_hotkey_{i}' 589 | metagraph_data['hotkeys'].append(hotkey) 590 | metagraph_data["weights"].append([0 for _ in range(100)]) 591 | if i > 90: 592 | metagraph_data["stake"].append(10000) 593 | else: 594 | metagraph_data["stake"].append(10) 595 | 596 | 597 | 598 | cls._save_data(wallet_data, cls._wallet_file) 599 | cls._save_data(metagraph_data, cls._metagraph_file) 600 | 601 | cls.metagraph = LocalMetagraph() 602 | cls.metagraph.hotkeys = metagraph_data['hotkeys'] 603 | cls.metagraph.network_state = metagraph_data['network_state'] 604 | cls.metagraph.weights = metagraph_data['weights'] 605 | cls.metagraph.W = metagraph_data['weights'] 606 | cls.wallet=Wallet(hotkey=Hotkey(ss58_address=cls.config.wallet.hotkey)) 607 | 608 | @classmethod 609 | def set_weights(cls, scores): 610 | # Simulated set_weights method 611 | assert len(scores) == len(cls.metagraph.weights) 612 | normalized_scores = (torch.tensor(scores) / sum(scores)).numpy().tolist() 613 | 614 | # Loop over the normalized tensor elements 615 | #for uid, score in enumerate(normalized_scores): 616 | my_hotkey = cls.wallet.hotkey.ss58_address 617 | my_uid = cls.metagraph.hotkeys.index(my_hotkey) 618 | 619 | cls.metagraph.weights[my_uid] = normalized_scores 620 | 621 | # Save the updated metagraph data 622 | metagraph_data = { 623 | 'hotkeys': cls.metagraph.hotkeys, 624 | 'network_state': cls.metagraph.network_state, 625 | 'weights': cls.metagraph.weights 626 | } 627 | cls._save_data(metagraph_data, cls._metagraph_file) 628 | cls.last_update = time.time() 629 | 630 | @staticmethod 631 | def should_sync_metagraph(last_sync_time,sync_interval): 632 | current_time = time.time() 633 | return (current_time - last_sync_time) > sync_interval 634 | 635 | @classmethod 636 | def should_set_weights(cls) -> bool: 637 | return (time.time() - cls.last_sync_time) > cls.update_interval 638 | 639 | @classmethod 640 | def sync(cls, lite=True): 641 | 642 | if cls.should_sync_metagraph(cls.last_sync_time, cls.sync_interval ): 643 | print("Syncing metagraph...") 644 | metagraph_data = cls._load_data(cls._metagraph_file) 645 | 646 | if metagraph_data: 647 | print("Metagraph synced:", metagraph_data) 648 | cls.metagraph.hotkeys = metagraph_data['hotkeys'] 649 | cls.metagraph.network_state = metagraph_data['network_state'] 650 | cls.metagraph.weights = metagraph_data['weights'] 651 | else: 652 | print("Failed to load metagraph data.") 653 | 654 | cls.last_sync_time = time.time() 655 | 656 | 657 | @classmethod 658 | def run_evaluation(cls): 659 | if LocalBittensorNetwork.should_sync_metagraph(): 660 | LocalBittensorNetwork.sync() 661 | 662 | with cls._weights_lock: 663 | print("Evaluating miners...") 664 | 665 | @classmethod 666 | def get_validator_uids( 667 | cls, vpermit_tao_limit: int = 1024 668 | ): 669 | 670 | 671 | return [i for i in range(91,100)] 672 | 673 | 674 | # @property 675 | # def metagraph(self): 676 | # return self._metagraph 677 | -------------------------------------------------------------------------------- /hivetrain/chain_manager.py: -------------------------------------------------------------------------------- 1 | #Thanks SN9 2 | 3 | import multiprocessing 4 | import functools 5 | import bittensor as bt 6 | import os 7 | import lzma 8 | import base64 9 | import multiprocessing 10 | from typing import Optional, Any 11 | from bittensor.btlogging import logging 12 | 13 | 14 | def _wrapped_func(func: functools.partial, queue: multiprocessing.Queue): 15 | try: 16 | result = func() 17 | queue.put(result) 18 | except (Exception, BaseException) as e: 19 | # Catch exceptions here to add them to the queue. 20 | queue.put(e) 21 | 22 | def run_in_subprocess(func: functools.partial, ttl: int, mode="fork") -> Any: 23 | """Runs the provided function on a subprocess with 'ttl' seconds to complete. 24 | 25 | Args: 26 | func (functools.partial): Function to be run. 27 | ttl (int): How long to try for in seconds. 28 | 29 | Returns: 30 | Any: The value returned by 'func' 31 | """ 32 | ctx = multiprocessing.get_context(mode) 33 | queue = ctx.Queue() 34 | process = ctx.Process(target=_wrapped_func, args=[func, queue]) 35 | 36 | process.start() 37 | 38 | process.join(timeout=ttl) 39 | 40 | if process.is_alive(): 41 | process.terminate() 42 | process.join() 43 | raise TimeoutError(f"Failed to {func.func.__name__} after {ttl} seconds") 44 | 45 | # Raises an error if the queue is empty. This is fine. It means our subprocess timed out. 46 | result = queue.get(block=False) 47 | 48 | # If we put an exception on the queue then raise instead of returning. 49 | if isinstance(result, Exception): 50 | raise result 51 | if isinstance(result, BaseException): 52 | raise Exception(f"BaseException raised in subprocess: {str(result)}") 53 | 54 | return result 55 | 56 | 57 | class ChainMultiAddressStore: 58 | """Chain based implementation for storing and retrieving multiaddresses.""" 59 | 60 | def __init__( 61 | self, 62 | subtensor: bt.subtensor, 63 | subnet_uid: int, 64 | wallet: Optional[bt.wallet] = None, 65 | 66 | ): 67 | self.subtensor = subtensor 68 | self.wallet = wallet 69 | self.subnet_uid = subnet_uid 70 | 71 | def store_hf_repo(self, hf_repo: str): 72 | """Stores compressed multiaddress on this subnet for a specific wallet.""" 73 | if self.wallet is None: 74 | raise ValueError("No wallet available to write to the chain.") 75 | 76 | # Compress the multiaddress 77 | 78 | # Wrap calls to the subtensor in a subprocess with a timeout to handle potential hangs. 79 | partial = functools.partial( 80 | self.subtensor.commit, 81 | self.wallet, 82 | self.subnet_uid, 83 | hf_repo, 84 | ) 85 | run_in_subprocess(partial, 60) 86 | 87 | def retrieve_hf_repo(self, hotkey: str) -> Optional[str]: 88 | """Retrieves and decompresses multiaddress on this subnet for specific hotkey""" 89 | # Wrap calls to the subtensor in a subprocess with a timeout to handle potential hangs. 90 | partial = functools.partial( 91 | bt.extrinsics.serving.get_metadata, self.subtensor, self.subnet_uid, hotkey 92 | ) 93 | 94 | try: 95 | metadata = run_in_subprocess(partial, 60) 96 | except: 97 | metadata = None 98 | logging.warning(f"Failed to retreive multiaddress for: {hotkey}") 99 | 100 | 101 | if not metadata: 102 | return None 103 | 104 | commitment = metadata["info"]["fields"][0] 105 | hex_data = commitment[list(commitment.keys())[0]][2:] 106 | multiaddress = bytes.fromhex(hex_data).decode() 107 | 108 | try: 109 | return multiaddress 110 | except: 111 | # If the data format is not correct or decompression fails, return None. 112 | bt.logging.trace( 113 | f"Failed to parse the data on the chain for hotkey {hotkey}." 114 | ) 115 | return None 116 | 117 | # Synchronous test cases for ChainMultiAddressStore 118 | 119 | 120 | import json 121 | import os 122 | from typing import Optional 123 | 124 | class LocalAddressStore: 125 | """Simulated local storage for storing and retrieving multiaddresses, using a file for persistence.""" 126 | 127 | def __init__(self, 128 | subtensor: bt.subtensor, 129 | subnet_uid: int, 130 | wallet: Optional[bt.wallet] = None, 131 | ): 132 | self.storage_file = "storage.json" 133 | self.wallet = wallet 134 | # Ensure the storage file exists 135 | if not os.path.exists(self.storage_file): 136 | with open(self.storage_file, 'w') as file: 137 | json.dump({}, file) 138 | 139 | def _load_storage(self): 140 | """Loads the storage content from a file.""" 141 | with open(self.storage_file, 'r') as file: 142 | return json.load(file) 143 | 144 | def _save_storage(self, storage): 145 | """Saves the updated storage content to a file.""" 146 | with open(self.storage_file, 'w') as file: 147 | json.dump(storage, file) 148 | 149 | def store_hf_repo(self, hf_repo: str): 150 | """Stores the Hugging Face repository link for a specific wallet.""" 151 | if self.wallet is None: 152 | raise ValueError("No wallet available to write to the storage.") 153 | 154 | storage = self._load_storage() 155 | storage[self.wallet.hotkey.ss58_address] = hf_repo 156 | self._save_storage(storage) 157 | print(f"Stored {hf_repo} for {self.wallet.hotkey.ss58_address}") 158 | 159 | def retrieve_hf_repo(self, hotkey: str) -> Optional[str]: 160 | """Retrieves the Hugging Face repository link for a specific wallet.""" 161 | storage = self._load_storage() 162 | hf_repo = storage.get(hotkey) 163 | if hf_repo: 164 | print(f"Retrieved {hf_repo}") 165 | return hf_repo 166 | else: 167 | print(f"Failed to retrieve repository for: {hotkey}") 168 | return None 169 | 170 | 171 | def test_store_multiaddress(): 172 | """Verifies that the ChainMultiAddressStore can store data on the chain.""" 173 | multiaddress = "/ip4/198.51.100.0/tcp/4242/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N" 174 | 175 | # Use a different subnet that does not leverage chain storage to avoid conflicts. 176 | subtensor = bt.subtensor() 177 | 178 | # Uses .env configured wallet/hotkey/uid for the test. 179 | coldkey = os.getenv("TEST_COLDKEY") 180 | hotkey = os.getenv("TEST_HOTKEY") 181 | net_uid = int(os.getenv("TEST_SUBNET_UID")) 182 | 183 | wallet = bt.wallet(name=coldkey, hotkey=hotkey) 184 | 185 | address_store = ChainMultiAddressStore(subtensor, wallet, net_uid) 186 | 187 | # Store the multiaddress on chain. 188 | address_store.store_multiaddress(hotkey, multiaddress) 189 | 190 | print(f"Finished storing multiaddress for {hotkey} on the chain.") 191 | 192 | 193 | def test_retrieve_multiaddress(): 194 | """Verifies that the ChainMultiAddressStore can retrieve data from the chain.""" 195 | expected_multiaddress = "/ip4/198.51.100.0/tcp/4242/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N" 196 | 197 | # Use a different subnet that does not leverage chain storage to avoid conflicts. 198 | subtensor = bt.subtensor() 199 | 200 | # Uses .env configured hotkey/uid for the test. 201 | net_uid = int(os.getenv("TEST_SUBNET_UID")) 202 | hotkey = os.getenv("TEST_HOTKEY") 203 | 204 | address_store = ChainMultiAddressStore(subtensor, None, net_uid) 205 | 206 | # Retrieve the multiaddress from the chain. 207 | retrieved_multiaddress = address_store.retrieve_multiaddress(hotkey) 208 | 209 | print(f"Retrieved multiaddress matches expected: {expected_multiaddress == retrieved_multiaddress}") 210 | 211 | 212 | def test_roundtrip_multiaddress(): 213 | """Verifies that the ChainMultiAddressStore can roundtrip data on the chain.""" 214 | multiaddress = "/ip4/198.51.100.0/tcp/4242/p2p/QmYyQSo1c1Ym7orWxLYvCrM2EmxFTANf8wXmmE7DWjhx5N" 215 | 216 | # Use a different subnet that does not leverage chain storage to avoid conflicts. 217 | subtensor = bt.subtensor() 218 | 219 | # Uses .env configured wallet/hotkey/uid for the test. 220 | coldkey = os.getenv("TEST_COLDKEY") 221 | hotkey = os.getenv("TEST_HOTKEY") 222 | net_uid = int(os.getenv("TEST_SUBNET_UID")) 223 | 224 | wallet = bt.wallet(name=coldkey, hotkey=hotkey) 225 | 226 | address_store = ChainMultiAddressStore(subtensor, wallet, net_uid) 227 | 228 | # Store the multiaddress on chain. 229 | address_store.store_multiaddress(hotkey, multiaddress) 230 | 231 | # Retrieve the multiaddress from the chain. 232 | retrieved_multiaddress = address_store.retrieve_multiaddress(hotkey) 233 | 234 | print(f"Expecting matching multiaddress: {multiaddress == retrieved_multiaddress}") 235 | 236 | 237 | 238 | if __name__ == "__main__": 239 | # Can only commit data every ~20 minutes. 240 | # asyncio.run(test_roundtrip_model_metadata()) 241 | # asyncio.run(test_store_model_metadata()) 242 | test_retrieve_model_metadata() 243 | 244 | -------------------------------------------------------------------------------- /hivetrain/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_subnet_config import * 2 | from .config import Configurator 3 | from .hivetrain_config import * 4 | -------------------------------------------------------------------------------- /hivetrain/config/base_subnet_config.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2023 Yuma Rao 3 | # Copyright © 2023 Opentensor Foundation 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 6 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 7 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 8 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 11 | # the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 14 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 15 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 16 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 17 | # DEALINGS IN THE SOFTWARE. 18 | 19 | import os 20 | import torch 21 | import argparse 22 | import bittensor as bt 23 | 24 | 25 | def check_config(config: "bt.Config"): 26 | r"""Checks/validates the config namespace object.""" 27 | bt.logging.check_config(config) 28 | 29 | full_path = os.path.expanduser( 30 | "{}/{}/{}/netuid{}/{}".format( 31 | config.logging.logging_dir, # TODO: change from ~/.bittensor/miners to ~/.bittensor/neurons 32 | config.wallet.name, 33 | config.wallet.hotkey, 34 | config.netuid, 35 | config.neuron.name, 36 | ) 37 | ) 38 | print("full path:", full_path) 39 | config.neuron.full_path = os.path.expanduser(full_path) 40 | if not os.path.exists(config.neuron.full_path): 41 | os.makedirs(config.neuron.full_path, exist_ok=True) 42 | 43 | if not config.neuron.dont_save_events: 44 | # Add custom event logger for the events. 45 | logger.level("EVENTS", no=38, icon="📝") 46 | logger.add( 47 | os.path.join(config.neuron.full_path, "events.log"), 48 | rotation=config.neuron.events_retention_size, 49 | serialize=True, 50 | enqueue=True, 51 | backtrace=False, 52 | diagnose=False, 53 | level="EVENTS", 54 | format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", 55 | ) 56 | 57 | 58 | def add_neuron_args(parser): 59 | """ 60 | Adds relevant arguments to the parser for operation. 61 | """ 62 | 63 | parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1) 64 | 65 | parser.add_argument( 66 | "--neuron.device", 67 | type=str, 68 | help="Device to run on.", 69 | default="cuda" if torch.cuda.is_available() else "cpu", 70 | ) 71 | 72 | parser.add_argument( 73 | "--neuron.epoch_length", 74 | type=int, 75 | help="The default epoch length (how often we set weights, measured in 12 second blocks).", 76 | default=100, 77 | ) 78 | 79 | parser.add_argument( 80 | "--mock", 81 | action="store_true", 82 | help="Mock neuron and all network components.", 83 | default=False, 84 | ) 85 | 86 | parser.add_argument( 87 | "--neuron.events_retention_size", 88 | type=str, 89 | help="Events retention size.", 90 | default="2 GB", 91 | ) 92 | 93 | parser.add_argument( 94 | "--neuron.dont_save_events", 95 | action="store_true", 96 | help="If set, we dont save events to a log file.", 97 | default=False, 98 | ) 99 | 100 | parser.add_argument( 101 | "--neuron.initial_peers", 102 | type=str, 103 | nargs="+", 104 | help="If set, we dont save events to a log file.", 105 | default=None, 106 | ) 107 | 108 | 109 | def add_miner_args(parser): 110 | """Add miner specific arguments to the parser.""" 111 | 112 | parser.add_argument( 113 | "--blacklist.force_validator_permit", 114 | action="store_true", 115 | help="If set, we will force incoming requests to have a permit.", 116 | default=False, 117 | ) 118 | 119 | parser.add_argument( 120 | "--blacklist.allow_non_registered", 121 | action="store_true", 122 | help="If set, miners will accept queries from non registered entities. (Dangerous!)", 123 | default=False, 124 | ) 125 | 126 | 127 | 128 | 129 | def add_validator_args(parser): 130 | """Add validator specific arguments to the parser.""" 131 | 132 | 133 | parser.add_argument( 134 | "--neuron.timeout", 135 | type=float, 136 | help="The timeout for each forward call in seconds.", 137 | default=10, 138 | ) 139 | 140 | parser.add_argument( 141 | "--neuron.num_concurrent_forwards", 142 | type=int, 143 | help="The number of concurrent forwards running at any time.", 144 | default=1, 145 | ) 146 | 147 | parser.add_argument( 148 | "--neuron.sample_size", 149 | type=int, 150 | help="The number of miners to query in a single step.", 151 | default=50, 152 | ) 153 | 154 | parser.add_argument( 155 | "--neuron.disable_set_weights", 156 | action="store_true", 157 | help="Disables setting weights.", 158 | default=False, 159 | ) 160 | 161 | parser.add_argument( 162 | "--neuron.moving_average_alpha", 163 | type=float, 164 | help="Moving average alpha parameter, how much to add of the new observation.", 165 | default=0.1, 166 | ) 167 | 168 | parser.add_argument( 169 | "--neuron.axon_off", 170 | "--axon_off", 171 | action="store_true", 172 | # Note: the validator needs to serve an Axon with their IP or they may 173 | # be blacklisted by the firewall of serving peers on the network. 174 | help="Set this flag to not attempt to serve an Axon.", 175 | default=False, 176 | ) 177 | 178 | parser.add_argument( 179 | "--neuron.vpermit_tao_limit", 180 | type=int, 181 | help="The maximum number of TAO allowed to query a validator with a vpermit.", 182 | default=4096, 183 | ) 184 | 185 | 186 | 187 | 188 | # def config(cls): 189 | # """ 190 | # Returns the configuration object specific to this miner or validator after adding relevant arguments. 191 | # """ 192 | # parser = argparse.ArgumentParser() 193 | # bt.wallet.add_args(parser) 194 | # bt.subtensor.add_args(parser) 195 | # bt.logging.add_args(parser) 196 | # bt.axon.add_args(parser) 197 | # cls.add_args(parser) 198 | # return bt.config(parser) -------------------------------------------------------------------------------- /hivetrain/config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import bittensor as bt 5 | #from loguru import logger 6 | from argparse import ArgumentParser 7 | import bittensor as bt 8 | from .hivetrain_config import add_meta_miner_args, add_orchestrator_args, add_torch_miner_args #s, add_validator_args 9 | from .base_subnet_config import add_neuron_args, add_validator_args, add_miner_args 10 | 11 | 12 | # def check_config(cls, config: "bt.Config"): 13 | # r"""Checks/validates the config namespace object.""" 14 | # bt.logging.check_config(config) 15 | 16 | # full_path = os.path.expanduser( 17 | # "{}/{}/{}/netuid{}/{}".format( 18 | # config.logging.logging_dir, # TODO: change from ~/.bittensor/miners to ~/.bittensor/neurons 19 | # config.wallet.name, 20 | # config.wallet.hotkey, 21 | # config.netuid, 22 | # config.neuron.name, 23 | # ) 24 | # ) 25 | # print("full path:", full_path) 26 | # config.neuron.full_path = os.path.expanduser(full_path) 27 | # if not os.path.exists(config.neuron.full_path): 28 | # os.makedirs(config.neuron.full_path, exist_ok=True) 29 | 30 | # if not config.neuron.dont_save_events: 31 | # # Add custom event logger for the events. 32 | # logger.level("EVENTS", no=38, icon="📝") 33 | # logger.add( 34 | # os.path.join(config.neuron.full_path, "events.log"), 35 | # rotation=config.neuron.events_retention_size, 36 | # serialize=True, 37 | # enqueue=True, 38 | # backtrace=False, 39 | # diagnose=False, 40 | # level="EVENTS", 41 | # format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", 42 | # ) 43 | 44 | class Configurator: 45 | @staticmethod 46 | def combine_configs(): 47 | parser = ArgumentParser(description="Unified Configuration for Bittensor") 48 | bt.wallet.add_args(parser) 49 | bt.subtensor.add_args(parser) 50 | bt.logging.add_args(parser) 51 | bt.axon.add_args(parser) 52 | 53 | add_torch_miner_args(parser) 54 | add_meta_miner_args(parser) 55 | add_orchestrator_args(parser) 56 | add_neuron_args(parser) 57 | add_miner_args(parser) 58 | add_validator_args(parser) 59 | args = parser.parse_args() 60 | return bt.config(parser) -------------------------------------------------------------------------------- /hivetrain/config/hivetrain_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import bittensor as bt 5 | import time 6 | def add_meta_miner_args(parser): 7 | #parser.add_argument("--meta-miner.log-activity", type=bool, help="Display logging message every request") 8 | parser.add_argument("--miner.batch-size", type=int, default=64, help="Batch size per forward/backward pass") 9 | parser.add_argument("--miner.epochs", type=int, default=100, help="Number of epochs to train") 10 | 11 | parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference/training") 12 | 13 | parser.add_argument('--miner.send_interval', type=int, help='URLs of the validators for local testing only') 14 | parser.add_argument('--storage.gradient_dir', type=str, help='Local path to gradients/weight deltas') 15 | parser.add_argument('--storage.model_dir', type=str, help='Local path of averaged model') 16 | parser.add_argument('--storage.my_repo_id', type=str, help='Miner repo on HuggingFace for storing weights') 17 | parser.add_argument('--storage.averaged_model_repo_id', type=str, help='Huggingface repo for storing final model') 18 | 19 | 20 | 21 | def add_torch_miner_args(parser): 22 | parser.add_argument('--rank', type=int, help='Rank of process/node in training run') 23 | parser.add_argument('--world-size', type=int, help='Number of processes/nodes in training run') 24 | parser.add_argument('--store-address', type=str,default="127.0.0.1", help='IP/URL of the TCPStore')#FIXME add the main from btt 25 | parser.add_argument('--store-port', type=int,default=4999, help='Port of the test TCPStore')#FIXME add the main from btt 26 | parser.add_argument( 27 | "--initial_peers", 28 | action="append", 29 | help="Add a peer. Can be used multiple times to pass multiple peers.", 30 | nargs="*", 31 | default=[], 32 | ) 33 | 34 | parser.add_argument( 35 | "--batch_size", 36 | type=int, 37 | help="The largest batch size able to fit on your GPU.", 38 | default=1, 39 | const=1, 40 | nargs="?", 41 | ) 42 | 43 | parser.add_argument( 44 | "--save_every", 45 | type=int, 46 | help="Save the model every X global steps.", 47 | default=0, 48 | const=0, 49 | nargs="?", 50 | ) 51 | 52 | 53 | 54 | def add_orchestrator_args(parser): 55 | parser.add_argument('--port', type=int, default=5000) 56 | parser.add_argument('--host-address', type=str, default="127.0.0.1") 57 | 58 | # def add_validator_args(parser): 59 | # parser.add_argument('--port', type=int, default=5000, help="Port for the validator") 60 | # parser.add_argument('--host-address', type=str, default="127.0.0.1", help="Host address for the validator") -------------------------------------------------------------------------------- /hivetrain/config/mlflow_config.py: -------------------------------------------------------------------------------- 1 | MLFLOW_UI_URL = "http://ml9in6up.clj5khk.gcp.restack.it" 2 | CURRENT_MODEL_NAME = "openai-community/gpt2" 3 | MLFLOW_ACTIVE = False 4 | -------------------------------------------------------------------------------- /hivetrain/docs/test.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | 3 | dot = Digraph(comment='Architectural Diagram') 4 | 5 | dot.node('A', 'TrainingLoop') 6 | dot.node('B', 'LocalTrainingLoop') 7 | dot.node('C', 'torch') 8 | dot.node('D', 'transformers') 9 | dot.node('E', 'huggingface_hub') 10 | dot.node('F', 'time') 11 | 12 | dot.edges(['AB', 'AC', 'AD', 'AE', 'AF']) 13 | dot.edge('B', 'C', constraint='false') 14 | 15 | dot.render('training_loop_architecture', view=True) -------------------------------------------------------------------------------- /hivetrain/docs/training_loop_architecture: -------------------------------------------------------------------------------- 1 | // Architectural Diagram 2 | digraph { 3 | A [label=TrainingLoop] 4 | B [label=LocalTrainingLoop] 5 | C [label=torch] 6 | D [label=transformers] 7 | E [label=huggingface_hub] 8 | F [label=time] 9 | A -> B 10 | A -> C 11 | A -> D 12 | A -> E 13 | A -> F 14 | B -> C [constraint=false] 15 | } 16 | -------------------------------------------------------------------------------- /hivetrain/docs/training_loop_architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bit-current/DistributedTraining/8d5dd201bcf9253f496cf9ace2c529cbf44f1be2/hivetrain/docs/training_loop_architecture.pdf -------------------------------------------------------------------------------- /hivetrain/hf_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import hashlib 4 | from bittensor import logging 5 | from dotenv import load_dotenv 6 | from huggingface_hub import HfApi, Repository, HfFolder 7 | from huggingface_hub import hf_hub_download, scan_cache_dir 8 | import subprocess 9 | 10 | load_dotenv() 11 | 12 | class HFManager: 13 | """ 14 | Manages interactions with the Hugging Face Hub for operations such as cloning, pushing and pulling models or weights/gradients. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | local_dir=".",#gradients local 20 | hf_token=os.getenv("HF_TOKEN"), 21 | my_repo_id=None,#gradients HF 22 | averaged_model_repo_id=None,#averaged HF 23 | model_dir=None,#averaged local 24 | device="cuda" 25 | ): 26 | 27 | # Initializes the HFManager with the necessary repository and authentication details. 28 | self.my_repo_id = my_repo_id 29 | self.model_repo_id = averaged_model_repo_id 30 | self.hf_token = hf_token 31 | self.device = device 32 | #self.local_dir = local_dir 33 | 34 | # Define the local directory structure based on repository IDs but only do clone personal repo if miner 35 | if self.my_repo_id != None: 36 | self.gradient_repo = Repository( 37 | local_dir=os.path.join(local_dir, my_repo_id.split("/")[-1]), 38 | clone_from=my_repo_id, 39 | use_auth_token=hf_token, 40 | ) 41 | self.local_gradient_dir = os.path.join(local_dir, my_repo_id.split("/")[-1]) 42 | 43 | self.model_dir = ( 44 | model_dir 45 | if model_dir 46 | else os.path.join(local_dir, averaged_model_repo_id.split("/")[-1]) 47 | ) 48 | self.model_repo = Repository( 49 | local_dir=self.model_dir, clone_from=averaged_model_repo_id 50 | ) 51 | 52 | self.api = HfApi() 53 | # Get the latest commit SHA for synchronization checks 54 | self.latest_model_commit_sha = self.get_latest_commit_sha(self.model_repo_id) 55 | 56 | 57 | @staticmethod 58 | def clear_hf_cache(): 59 | # Get the cache directory 60 | hf_cache_info = scan_cache_dir() 61 | commit_hashes = [ 62 | revision.commit_hash 63 | for repo in hf_cache_info.repos 64 | for revision in repo.revisions 65 | ] 66 | 67 | # Check if the cache directory exists 68 | delete_strategy = scan_cache_dir().delete_revisions(*commit_hashes) 69 | 70 | logging.info("Will free " + delete_strategy.expected_freed_size_str) 71 | delete_strategy.execute() 72 | 73 | @staticmethod 74 | def git_prune_and_refresh(repo_path): 75 | """ 76 | Change to the specified repository directory, execute 'git lfs prune', and revert to the original directory. 77 | """ 78 | original_dir = os.getcwd() 79 | try: 80 | os.chdir(repo_path) 81 | subprocess.run(['git', 'config', 'pull.rebase', 'true'], check=True) 82 | subprocess.run(['git', 'pull', '--force'], check=True) 83 | subprocess.run(['git', 'lfs', 'prune'], check=True) 84 | except subprocess.CalledProcessError as e: 85 | print(f"Failed to prune Git LFS objects: {e}") 86 | finally: 87 | os.chdir(original_dir) 88 | 89 | 90 | 91 | def push_changes(self, file_to_send): 92 | """ 93 | Stages, commits, squashes, and pushes changes to the configured repository. 94 | Also prunes unnecessary Git LFS objects to free up storage. 95 | """ 96 | try: 97 | # Stage the changes 98 | self.gradient_repo.git_add(file_to_send) 99 | 100 | 101 | # Commit with a unified message 102 | self.gradient_repo.git_commit("Squashed commits - update model gradients") 103 | 104 | # Push the changes to the repository 105 | self.gradient_repo.git_push() 106 | 107 | self.api.super_squash_history(repo_id=self.my_repo_id) 108 | 109 | # Prune unneeded Git LFS objects and pull the squashed version locally 110 | self.git_prune_and_refresh(self.local_gradient_dir) # Clean up unused LFS objects 111 | 112 | 113 | except Exception as e: 114 | print(f"Failed to push changes: {e}") 115 | 116 | def push_to_hf_hub(self, path_to_model, commit_message="Pushing model to Hub"): 117 | try: 118 | # Stage the changes 119 | self.model_repo.git_add(path_to_model) 120 | 121 | # Squash commits into a single one before pushing 122 | 123 | # Commit with a unified message 124 | self.model_repo.git_commit("Squashed commits - update model gradients") 125 | 126 | self.model_repo.git_push() 127 | 128 | self.api.super_squash_history(repo_id=self.model_repo_id) 129 | 130 | # Prune unneeded Git LFS objects and pull the squashed version locally 131 | self.git_prune_and_refresh(self.model_dir) 132 | 133 | # Push the changes to the repository 134 | 135 | except Exception as e: 136 | print(f"Failed to push changes: {e}") 137 | 138 | def get_latest_commit_sha(self, repo): 139 | """ 140 | Fetches the latest commit SHA of the specified repository from the Hugging Face Hub. 141 | """ 142 | try: 143 | repo_info = self.api.repo_info(repo) 144 | latest_commit_sha = repo_info.sha 145 | # print(latest_commit_sha) 146 | return latest_commit_sha 147 | except Exception as e: 148 | logging.info(f"Failed to fetch latest commit SHA: {e}") 149 | return None 150 | 151 | def check_for_new_submissions(self, repo): 152 | """ 153 | Compares the current commit SHA with the latest to determine if there are new submissions. 154 | """ 155 | current_commit_sha = self.get_latest_commit_sha(repo) 156 | if current_commit_sha != self.latest_model_commit_sha: 157 | self.latest_model_commit_sha = current_commit_sha 158 | return True 159 | return False 160 | 161 | def update_model(self, model, model_file_name="averaged_model.pt"): 162 | """ 163 | Loads an updated model from a .pt file and updates the in-memory model's parameters. 164 | """ 165 | model_path = os.path.join(self.model_dir, model_file_name) 166 | if os.path.exists(model_path): 167 | model_state_dict = torch.load(model_path) 168 | model.load_state_dict(model_state_dict) 169 | model.train() 170 | logging.info(f"Model updated from local path: {model_path}") 171 | return model 172 | else: 173 | raise FileNotFoundError(f"{model_file_name} not found in the repository.") 174 | 175 | def get_local_gradient_directory(self): 176 | """Return the local directory of the repository.""" 177 | return self.local_gradient_dir 178 | 179 | def get_local_model_directory(self): 180 | """Return the local directory of the repository.""" 181 | return self.model_dir 182 | 183 | def pull_latest_model(self): 184 | self.model_repo.git_pull() 185 | 186 | def receive_gradients(self, miner_repo_id, weights_file_name="weight_diff.pt"): 187 | try: #TODO Add some garbage collection. 188 | # Download the gradients file from Hugging Face Hub 189 | weights_file_path = hf_hub_download( 190 | repo_id=miner_repo_id, filename=weights_file_name, use_auth_token=True 191 | ) 192 | # Load the gradients directly using torch.load 193 | miner_weights = torch.load(weights_file_path, map_location=self.device) 194 | os.remove(weights_file_path) 195 | return miner_weights 196 | except Exception as e: 197 | logging.debug(f"Error receiving gradients from Hugging Face: {e}") 198 | 199 | 200 | class LocalHFManager: 201 | def __init__(self, my_repo_id="local_models"): 202 | self.my_repo_id = my_repo_id 203 | # Ensure the local directory exists 204 | os.makedirs(self.my_repo_id, exist_ok=True) 205 | self.model_hash_file = os.path.join(self.my_repo_id, "model_hash.txt") 206 | # Initialize model hash value 207 | self.last_known_hash = None 208 | 209 | def set_model_hash(self, hash_value): 210 | """Sets and saves the latest model hash to the hash file.""" 211 | with open(self.model_hash_file, "w") as file: 212 | file.write(hash_value) 213 | print(f"Set latest model hash to: {hash_value}") 214 | 215 | def check_for_new_submissions(self): 216 | """Checks if a new or updated model is available.""" 217 | model_file_path = os.path.join(self.my_repo_id, "averaged_model.pt") 218 | if not os.path.exists(model_file_path): 219 | print("No model available.") 220 | return False 221 | 222 | with open(model_file_path, "rb") as file: 223 | file_hash = hashlib.sha256(file.read()).hexdigest() 224 | 225 | if self.last_known_hash is None or self.last_known_hash != file_hash: 226 | print("New or updated model found. Updating model...") 227 | self.last_known_hash = file_hash 228 | return True 229 | return False 230 | 231 | def update_model(self, model): 232 | """Updates an existing model's state dict from a .pt file.""" 233 | model_file_path = os.path.join(self.my_repo_id, "averaged_model.pt") 234 | if os.path.exists(model_file_path): 235 | model_state_dict = torch.load(model_file_path) 236 | model.load_state_dict(model_state_dict) 237 | model.train() # Or model.eval(), depending on your use case 238 | return model 239 | print(f"Model updated from local path: {model_file_path}") 240 | else: 241 | print(f"Model file not found: {model_file_path}") 242 | -------------------------------------------------------------------------------- /hivetrain/new_training_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | from transformers import AdamW #FIXME replace me with LAMB 3 | from huggingface_hub import Repository 4 | import torch 5 | import math 6 | #from dotenv import load_dotenv 7 | from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer 8 | from bittensor import logging 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.optim import SGD 12 | import torch.optim as optim 13 | import os 14 | import hashlib 15 | 16 | #load_dotenv() 17 | token = os.getenv("HF_TOKEN") 18 | 19 | 20 | class TrainingLoopNew: 21 | def __init__(self, model, device, hf_manager, train_loader,test_loader, send_interval = 120, check_update_interval = 60, learning_rate=5e-5,): 22 | self.model = model.to(device) 23 | self.device = device 24 | self.hf_manager = hf_manager 25 | self.train_loader = train_loader 26 | self.test_loader = test_loader 27 | self.send_interval = send_interval 28 | self.check_update_interval = check_update_interval 29 | self.learning_rate = learning_rate 30 | 31 | def train(self, epochs, n_steps): 32 | total_loss = 0 33 | total_examples = 0 34 | step_counter = 0 # Initialize step counter that persists across epochs 35 | test_counter = 0 36 | criterion = nn.CrossEntropyLoss() 37 | optimizer = optim.Adam(self.model.parameters(), lr=0.001) 38 | self.model.train() 39 | self.base_weights = {name: param.clone() for name, param in self.model.named_parameters()} 40 | 41 | self.last_pull_time = time.time() 42 | self.last_send_time = time.time() 43 | 44 | for epoch in range(epochs): 45 | 46 | print("************** NEW EPOCH") 47 | for batch_idx, (data, target) in enumerate(self.train_loader): 48 | 49 | if time.time() - self.last_pull_time >= self.check_update_interval and self.hf_manager.check_for_new_submissions(self.hf_manager.model_repo_id): 50 | logging.info("Averaged model updated on Hugging Face. Pulling latest model...") 51 | print("********Averaged model updated on Hugging Face. Pulling latest model...") 52 | self.hf_manager.pull_latest_model() 53 | time.sleep(10) #just to give enough time for pull 54 | self.model = self.hf_manager.update_model(self.model) 55 | optimizer = optim.Adam(self.model.parameters(), lr=5e-5) # Reinitialize the optimizer 56 | self.base_weights = {name: param.clone() for name, param in self.model.named_parameters()} 57 | self.last_pull_time = time.time() 58 | 59 | data, target = data.to(self.device), target.to(self.device) 60 | optimizer.zero_grad() 61 | 62 | output = self.model(data) 63 | loss = criterion(output, target) 64 | loss.backward() 65 | optimizer.step() 66 | optimizer.zero_grad() 67 | 68 | total_loss += loss.item() 69 | total_examples += len(data) 70 | 71 | average_loss = total_loss / total_examples 72 | #logging.info(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}") 73 | 74 | 75 | # Check if it's time to step the optimizer and reset gradients 76 | if (step_counter + 1) % n_steps == 0: 77 | test_counter += 1 78 | 79 | test_loss, test_accuracy = self.test() 80 | train_loss = total_loss / total_examples 81 | logging.info(f"Train Loss: {train_loss} At {step_counter} accumulated gradients") 82 | print("***Train Loss: {train_loss} At {step_counter} accumulated gradients") 83 | 84 | logging.info(f"Test Loss: {test_loss} At {step_counter} accumulated gradients") 85 | logging.info(f"Test Accuracy: {test_accuracy} At {step_counter} accumulated gradients") 86 | print((f"Test Accuracy: {test_accuracy} At {step_counter} accumulated gradients")) 87 | 88 | #return train_loss, test_loss, test_accuracy 89 | self.model.train() 90 | 91 | step_counter += 1 # Increment step counter after processing each batch 92 | # Periodic actions such as logging and sending gradients 93 | if time.time() - self.last_send_time >= self.send_interval: 94 | average_loss = total_loss / total_examples 95 | perplexity = math.exp(average_loss) 96 | logging.info(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}") 97 | 98 | try: 99 | logging.info(f"Attempting to send weights") 100 | logging.info(f"********* Attempting to send weights") 101 | # Periodically save gradients 102 | model_gradients_path = os.path.join(self.hf_manager.get_local_gradient_directory(), 'weight_diff.pt') 103 | self.weight_diffs = {name: param.data - self.base_weights[name] for name, param in self.model.named_parameters() if param.requires_grad} 104 | torch.save(self.weight_diffs, model_gradients_path) 105 | self.hf_manager.push_changes(['weight_diff.pt']) 106 | except Exception as e: 107 | logging.warning(f"Sending gradients failed: {e}") 108 | continue 109 | 110 | logging.info(f"Model hash is: {self.calculate_model_hash()}") 111 | print(f"Model hash is: {self.calculate_model_hash()}") 112 | self.last_send_time = time.time() 113 | 114 | if batch_idx % 50 == 0: # For example, save every 50 batches 115 | print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(self.train_loader.dataset)} ({100. * batch_idx / len(self.train_loader):.0f}%)]\tLoss: {loss.item():.6f}") 116 | 117 | 118 | def calculate_model_hash(self): 119 | model_hash = hashlib.sha256() 120 | for name, param in self.model.named_parameters(): 121 | model_hash.update(name.encode('utf-8')) 122 | model_hash.update(param.data.cpu().numpy().tobytes()) 123 | return model_hash.hexdigest() 124 | 125 | def test(self): 126 | self.model.eval() 127 | test_loss = 0 128 | correct_predictions = 0 129 | total_test_samples = 0 130 | 131 | with torch.no_grad(): 132 | for batch in self.test_loader: 133 | 134 | images, labels = batch 135 | outputs = self.model(images) 136 | loss = F.cross_entropy(outputs, labels) 137 | test_loss += loss.item() 138 | _, predicted = torch.max(outputs.data, 1) 139 | correct_predictions += (predicted == labels).sum().item() 140 | total_test_samples += labels.size(0) 141 | 142 | average_test_loss = test_loss / total_test_samples 143 | accuracy = correct_predictions / total_test_samples 144 | return average_test_loss, accuracy 145 | 146 | 147 | @staticmethod 148 | def normalize_gradients(parameter, threshold=1.0): 149 | """ 150 | Normalize the gradients to avoid exploding or vanishing gradients. 151 | 152 | Args: 153 | parameters (iterable): Iterable of model parameters (typically model.parameters() in PyTorch). 154 | threshold (float): The maximum norm value for gradients. Defaults to 1.0. 155 | """ 156 | param_norm = parameter.norm(2) 157 | 158 | # Normalize if the total norm exceeds the threshold 159 | if param_norm > threshold: 160 | return parameter.data.mul_(threshold / param_norm) 161 | else: 162 | return parameter 163 | 164 | def calculate_model_hash(self): 165 | model_hash = hashlib.sha256() 166 | for name, param in self.model.named_parameters(): 167 | model_hash.update(name.encode('utf-8')) 168 | model_hash.update(param.data.cpu().numpy().tobytes()) 169 | return model_hash.hexdigest() 170 | 171 | 172 | # Define the CNN model 173 | class SimpleCNN(nn.Module): 174 | def __init__(self): 175 | super(SimpleCNN, self).__init__() 176 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 177 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 178 | self.conv2_drop = nn.Dropout2d() 179 | self.fc1 = nn.Linear(320, 50) 180 | self.fc2 = nn.Linear(50, 10) 181 | 182 | def forward(self, x): 183 | x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2)) 184 | x = nn.functional.relu(nn.functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 185 | x = x.view(-1, 320) 186 | x = nn.functional.relu(self.fc1(x)) 187 | x = nn.functional.dropout(x, training=self.training) 188 | x = self.fc2(x) 189 | return nn.functional.log_softmax(x, dim=1) 190 | -------------------------------------------------------------------------------- /hivetrain/training_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import math 5 | import hashlib 6 | import mlflow 7 | import mlflow.pytorch 8 | from hivetrain.config import Configurator 9 | from hivetrain.btt_connector import BittensorNetwork 10 | from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer 11 | from bittensor import logging 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.optim import SGD 15 | from hivetrain.config.mlflow_config import ( 16 | MLFLOW_UI_URL, 17 | CURRENT_MODEL_NAME, 18 | MLFLOW_ACTIVE, 19 | ) 20 | from hivetrain.utils.mlflow_utils import initialize_mlflow, log_model_metrics, VERSION 21 | 22 | args = Configurator.combine_configs() 23 | BittensorNetwork.initialize(args, ignore_regs=True) 24 | MY_HOTKEY = BittensorNetwork.wallet.hotkey.ss58_address 25 | 26 | 27 | 28 | class TrainingLoop: 29 | def __init__( 30 | self, 31 | device, 32 | model_name, 33 | data_loader, 34 | learning_rate=5e-5, 35 | check_update_interval=300, 36 | send_interval=300, 37 | hf_manager=None, 38 | ): 39 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 40 | self.model = AutoModelForCausalLM.from_pretrained(model_name) 41 | self.model = self.model.to(device) 42 | self.device = device 43 | 44 | self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 45 | self.model.resize_token_embeddings(len(self.tokenizer)) 46 | self.model.train() 47 | self.hf_manager = hf_manager 48 | self.learning_rate = learning_rate 49 | 50 | self.data_loader = data_loader 51 | self.optimizer = AdamW(self.model.parameters(), lr=self.learning_rate) 52 | self.check_update_interval = check_update_interval 53 | self.send_interval = send_interval 54 | self.last_pull_time = 0 55 | 56 | # initialize mlflow 57 | if MLFLOW_ACTIVE: 58 | initialize_mlflow( 59 | role="miner", 60 | device=self.device, 61 | version=VERSION, 62 | mlflow_ui_url=MLFLOW_UI_URL, 63 | current_model_name=CURRENT_MODEL_NAME, 64 | my_hotkey=MY_HOTKEY, 65 | learning_rate=self.learning_rate, 66 | send_interval=self.send_interval, 67 | check_update_interval=self.check_update_interval, 68 | ) 69 | else: 70 | logging.info("****************MLFLOW IS INACTIVE************") 71 | 72 | def train(self, epochs): 73 | self.last_send_time = time.time() 74 | self.optimizer.zero_grad() 75 | self.aggregated_gradients = { 76 | name: torch.zeros_like(param) 77 | for name, param in self.model.named_parameters() 78 | if param.requires_grad 79 | } 80 | for epoch in range(epochs): 81 | logging.info(f"Starting Epoch: {epoch}") 82 | # Check for new submissions at the start of each epoch 83 | total_loss = 0 84 | total_examples = 0 85 | 86 | current_time = time.time() 87 | if ( 88 | current_time - self.last_pull_time >= self.check_update_interval 89 | and self.hf_manager.check_for_new_submissions( 90 | self.hf_manager.model_repo 91 | ) 92 | ): 93 | logging.info( 94 | "Averaged model updated on Hugging Face. Pulling latest model..." 95 | ) 96 | self.hf_manager.pull_latest_model() 97 | self.model = self.hf_manager.update_model(self.model) 98 | self.optimizer = SGD( 99 | self.model.parameters(), lr=5e-5 100 | ) # Reinitialize the optimizer 101 | self.last_pull_time = current_time 102 | 103 | for step, batch in enumerate(self.data_loader): 104 | outputs = self.model( 105 | input_ids=batch["input_ids"], 106 | attention_mask=batch["attention_mask"], 107 | labels=batch["input_ids"], 108 | ) 109 | loss = outputs.loss 110 | loss.backward() 111 | 112 | # Update loss and example counts 113 | total_loss += loss.item() * batch["input_ids"].size(0) 114 | total_examples += batch["input_ids"].size(0) 115 | 116 | for name, param in self.model.named_parameters(): 117 | if param.requires_grad and param.grad is not None: 118 | self.aggregated_gradients[name] += param.grad 119 | 120 | self.optimizer.step() 121 | self.optimizer.zero_grad() 122 | 123 | if step % 500 == 0: 124 | if MLFLOW_ACTIVE: 125 | log_model_metrics(step=step, train_loss=loss.item()) 126 | try: 127 | mlflow.log_param("Version of Code", VERSION) 128 | except Exception as e: 129 | logging.error(f"Failed to log metrics to MLflow: {e}") 130 | 131 | # Example of a condition to periodically send gradients 132 | 133 | if time.time() - self.last_send_time >= self.send_interval: 134 | average_loss = total_loss / total_examples 135 | perplexity = math.exp(average_loss) 136 | logging.info( 137 | f"Epoch: {epoch}, Examples: {total_examples}, Loss: {average_loss:.4f}, Perplexity: {perplexity:.4f}" 138 | ) 139 | 140 | try: 141 | logging.info(f"Attempting to send gradients") 142 | # Periodically save gradients 143 | model_gradients_path = os.path.join( 144 | self.hf_manager.get_local_gradient_dir(), "gradients.pt" 145 | ) 146 | torch.save(self.model.state_dict(), model_gradients_path) 147 | self.hf_manager.push_changes("gradients.pt") 148 | log_model_metrics( 149 | step=step, gradient_staleness=self.get_gradient_staleness() 150 | ) 151 | except Exception as e: 152 | logging.warning(f"Sending gradients failed: {e}") 153 | continue 154 | self.last_send_time = time.time() 155 | 156 | def get_gradient_staleness(self): 157 | """ 158 | Calculates the staleness of the gradient by measuring the time elapsed since the last gradient update. 159 | 160 | Returns: 161 | float: The staleness of the gradient in seconds. Returns 0.0 if this is the first call (no previous updates). 162 | """ 163 | current_time = time.time() 164 | if self.last_send_time == 0: 165 | return 0.0 166 | else: 167 | staleness = current_time - self.last_send_time 168 | return staleness 169 | 170 | 171 | class MNISTDeltaTrainHugging(TrainingLoop): 172 | def __init__(self): 173 | super(MNISTDeltaTrainHugging, self).__init__() 174 | self.model = FeedforwardNN() 175 | self.model.train() 176 | 177 | self.optimizer = SGD(self.model.parameters(), lr=self.learning_rate) 178 | 179 | self.last_send_time = time.time() 180 | 181 | @staticmethod 182 | def normalize_gradients(parameter, threshold=1.0): 183 | """ 184 | Normalize the gradients to avoid exploding or vanishing gradients. 185 | 186 | Args: 187 | parameters (iterable): Iterable of model parameters (typically model.parameters() in PyTorch). 188 | threshold (float): The maximum norm value for gradients. Defaults to 1.0. 189 | """ 190 | param_norm = parameter.norm(2) 191 | 192 | # Normalize if the total norm exceeds the threshold 193 | if param_norm > threshold: 194 | return parameter.data.mul_(threshold / param_norm) 195 | else: 196 | return parameter 197 | 198 | def calculate_model_hash(self): 199 | model_hash = hashlib.sha256() 200 | for name, param in self.model.named_parameters(): 201 | model_hash.update(name.encode("utf-8")) 202 | model_hash.update(param.data.cpu().numpy().tobytes()) 203 | return model_hash.hexdigest() 204 | 205 | def train(self, epochs, hf_manager, n_steps): 206 | step_counter = 0 # Initialize step counter that persists across epochs 207 | test_counter = 0 208 | test_losses = [] 209 | test_accuracies = [] 210 | training_losses = [] 211 | logging.info( 212 | "Model updated from Hugging Face. Continuing training with new model..." 213 | ) 214 | # self.model = hf_manager.update_model(self.model) 215 | self.model = FeedforwardNN() 216 | 217 | self.optimizer = SGD( 218 | self.model.parameters(), lr=0.1 219 | ) # Reinitialize the optimizer 220 | self.base_weights = { 221 | name: param.clone() for name, param in self.model.named_parameters() 222 | } 223 | 224 | for epoch in range(epochs): 225 | logging.info(f"Starting Epoch: {epoch}") 226 | total_loss = 0 227 | total_examples = 0 228 | 229 | for batch_idx, (data, target) in enumerate(self.data_loader): 230 | if ( 231 | hf_manager.check_for_new_submissions() 232 | ): # FIXME add this in other training manager classes 233 | logging.info( 234 | "Model updated from Hugging Face. Continuing training with new model..." 235 | ) 236 | self.model = hf_manager.update_model(self.model) 237 | self.optimizer = SGD( 238 | self.model.parameters(), lr=0.001 239 | ) # Reinitialize the optimizer 240 | self.base_weights = { 241 | name: param.clone() 242 | for name, param in self.model.named_parameters() 243 | } 244 | # self.optimizer.zero_grad() # Ensure gradients are reset after model update 245 | 246 | output = self.model(data) 247 | loss = F.cross_entropy(output, target) 248 | loss.backward() 249 | 250 | self.optimizer.step() 251 | self.optimizer.zero_grad() 252 | 253 | total_loss += loss.item() 254 | total_examples += len(data) 255 | 256 | average_loss = total_loss / total_examples 257 | # logging.info(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}") 258 | 259 | # Check if it's time to step the optimizer and reset gradients 260 | if (step_counter + 1) % n_steps == 0: 261 | test_counter += 1 262 | 263 | test_loss, test_accuracy = self.test() 264 | # test_losses.append(test_loss) 265 | # test_accuracies.append(test_accuracy) 266 | train_loss = total_loss / total_examples 267 | # training_losses.append(train_loss) 268 | logging.info( 269 | f"Train Loss: {train_loss} At {step_counter} accumulated gradients" 270 | ) 271 | logging.info( 272 | f"Test Loss: {test_loss} At {step_counter} accumulated gradients" 273 | ) 274 | logging.info( 275 | f"Test Accuracy: {test_accuracy} At {step_counter} accumulated gradients" 276 | ) 277 | 278 | # return train_loss, test_loss, test_accuracy 279 | 280 | self.model.train() 281 | 282 | step_counter += 1 # Increment step counter after processing each batch 283 | 284 | # Periodic actions such as logging and sending gradients 285 | if time.time() - self.last_send_time >= self.send_interval: 286 | average_loss = total_loss / total_examples 287 | logging.info( 288 | f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}" 289 | ) 290 | 291 | # Logic to send aggregated gradients 292 | self.weight_diffs = { 293 | name: param.data - self.base_weights[name] 294 | for name, param in self.model.named_parameters() 295 | if param.requires_grad 296 | } 297 | self.store_gradients(self.weight_diffs, self.gradients_dir) 298 | 299 | logging.info(f"Model hash is: {self.calculate_model_hash()}") 300 | 301 | self.last_send_time = time.time() 302 | # total_loss = 0 303 | # total_examples = 0 # Reset for the next interval 304 | 305 | def test(self): 306 | self.model.eval() 307 | test_loss = 0 308 | correct_predictions = 0 309 | total_test_samples = 0 310 | 311 | with torch.no_grad(): 312 | for batch in self.test_loader: 313 | images, labels = batch 314 | outputs = self.model(images) 315 | loss = F.cross_entropy(outputs, labels) 316 | test_loss += loss.item() 317 | _, predicted = torch.max(outputs.data, 1) 318 | correct_predictions += (predicted == labels).sum().item() 319 | total_test_samples += labels.size(0) 320 | 321 | average_test_loss = test_loss / total_test_samples 322 | accuracy = correct_predictions / total_test_samples 323 | return average_test_loss, accuracy 324 | 325 | 326 | class LocalTrainingLoop(TrainingLoop): 327 | @staticmethod 328 | def store_gradients( 329 | aggregated_gradients, local_dir, gradient_file_name="gradients.pt" 330 | ): 331 | """ 332 | Saves gradients to a file in a specified local directory. 333 | """ 334 | # Ensure the local directory exists 335 | os.makedirs(local_dir, exist_ok=True) 336 | 337 | # Construct the full path to the gradient file 338 | gradient_file_path = os.path.join(local_dir, gradient_file_name) 339 | 340 | # Save gradients to the file 341 | torch.save(aggregated_gradients, gradient_file_path) 342 | print(f"Gradients saved locally at: {gradient_file_path}") 343 | 344 | 345 | class DeltaLoop(TrainingLoop): 346 | def train(self, epochs): 347 | self.last_send_time = time.time() 348 | self.optimizer.zero_grad() 349 | self.base_weights = { 350 | name: param.clone() for name, param in self.model.named_parameters() 351 | } 352 | self.model.to(self.device) 353 | for epoch in range(epochs): 354 | logging.info(f"Starting Epoch: {epoch}") 355 | # Check for new submissions at the start of each epoch 356 | 357 | total_loss = 0 358 | total_examples = 0 359 | 360 | for step, batch in enumerate(self.data_loader): 361 | if time.time() - self.last_pull_time >= self.check_update_interval: 362 | if self.hf_manager.check_for_new_submissions( 363 | self.hf_manager.model_repo_id 364 | ): 365 | logging.info( 366 | "Averaged model updated on Hugging Face. Pulling latest model..." 367 | ) 368 | self.hf_manager.pull_latest_model() 369 | time.sleep(10) # just to give enough time for pull 370 | self.model = self.hf_manager.update_model(self.model) 371 | self.optimizer = AdamW( 372 | self.model.parameters(), lr=5e-5 373 | ) # Reinitialize the optimizer 374 | self.base_weights = { 375 | name: param.clone() 376 | for name, param in self.model.named_parameters() 377 | } 378 | self.last_pull_time = time.time() 379 | 380 | outputs = self.model( 381 | input_ids=batch["input_ids"].to(self.device), 382 | attention_mask=batch["attention_mask"].to(self.device), 383 | labels=batch["input_ids"].to(self.device), 384 | ) 385 | loss = outputs.loss 386 | loss.backward() 387 | # Update loss and example counts 388 | total_loss += loss.item() * batch["input_ids"].size(0) 389 | total_examples += batch["input_ids"].size(0) 390 | 391 | self.optimizer.step() 392 | self.optimizer.zero_grad() 393 | 394 | if step % 1000 == 0: 395 | if MLFLOW_ACTIVE: 396 | log_model_metrics(step=step, train_loss=loss.item()) 397 | try: 398 | mlflow.log_param( 399 | "Version of Code", VERSION 400 | ) # just to make sure version is update frequently 401 | except Exception as e: 402 | return None 403 | 404 | # Example of a condition to periodically send gradients 405 | if time.time() - self.last_send_time >= self.send_interval: 406 | average_loss = total_loss / total_examples 407 | perplexity = math.exp(average_loss) 408 | logging.info(f"Epoch: {epoch}, Loss: {average_loss:.4f}") 409 | 410 | try: 411 | logging.info(f"Attempting to send weights") 412 | # Periodically save gradients 413 | model_gradients_path = os.path.join( 414 | self.hf_manager.get_local_gradient_directory(), 415 | "weight_diff.pt", 416 | ) 417 | self.weight_diffs = { 418 | name: param.data - self.base_weights[name] 419 | for name, param in self.model.named_parameters() 420 | if param.requires_grad 421 | } 422 | torch.save(self.weight_diffs, model_gradients_path) 423 | self.hf_manager.push_changes("weight_diff.pt") 424 | self.last_send_time = time.time() 425 | log_model_metrics( 426 | step=step, gradient_staleness=self.get_gradient_staleness() 427 | ) 428 | except Exception as e: 429 | logging.warning(f"Sending gradients failed: {e}") 430 | self.last_send_time = time.time() 431 | continue 432 | if MLFLOW_ACTIVE: 433 | mlflow.en_run() 434 | 435 | 436 | class LocalDeltaLoop(DeltaLoop, LocalTrainingLoop): 437 | pass 438 | 439 | 440 | class FeedforwardNN(nn.Module): 441 | def __init__(self): 442 | super(FeedforwardNN, self).__init__() 443 | self.flatten = nn.Flatten() 444 | self.fc1 = nn.Linear(28 * 28, 512) # Flatten 28x28 images to a 784 vector 445 | self.fc2 = nn.Linear(512, 512) 446 | self.fc3 = nn.Linear(512, 128) 447 | self.fc4 = nn.Linear(128, 128) 448 | self.fc5 = nn.Linear(128, 10) # MNIST has 10 classes 449 | 450 | def forward(self, x): 451 | x = self.flatten(x) 452 | x = F.relu(self.fc1(x)) 453 | x = F.relu(self.fc2(x)) 454 | x = F.relu(self.fc3(x)) 455 | x = F.relu(self.fc4(x)) 456 | x = self.fc5( 457 | x 458 | ) # No activation, as we'll use CrossEntropyLoss which includes Softmax 459 | return x 460 | 461 | 462 | class MNISTTrain(LocalTrainingLoop): 463 | def __init__( 464 | self, 465 | model_name, 466 | data_loader, 467 | gradients_dir, 468 | test_loader, 469 | averaging_dir="averaged_model", 470 | learning_rate=5e-5, 471 | send_interval=30, 472 | ): 473 | self.model = FeedforwardNN() 474 | self.model.train() 475 | 476 | self.data_loader = data_loader 477 | self.test_loader = test_loader 478 | 479 | self.optimizer = SGD(self.model.parameters(), lr=learning_rate) 480 | self.send_interval = send_interval 481 | self.gradients_dir = gradients_dir 482 | self.averaging_dir = averaging_dir 483 | 484 | def save_model(self): 485 | """ 486 | Saves the model to the specified local directory. 487 | """ 488 | os.makedirs(self.averaging_dir, exist_ok=True) 489 | model_save_path = os.path.join(self.averaging_dir, "averaged_model.pt") 490 | torch.save(self.model.state_dict(), model_save_path) 491 | logging.info(f"Model saved locally at {model_save_path}.") 492 | 493 | @staticmethod 494 | def normalize_gradients(parameter, threshold=1.0): 495 | """ 496 | Normalize the gradients to avoid exploding or vanishing gradients. 497 | 498 | Args: 499 | parameters (iterable): Iterable of model parameters (typically model.parameters() in PyTorch). 500 | threshold (float): The maximum norm value for gradients. Defaults to 1.0. 501 | """ 502 | param_norm = parameter.norm(2) 503 | 504 | # Normalize if the total norm exceeds the threshold 505 | if param_norm > threshold: 506 | return parameter.data.mul_(threshold / param_norm) 507 | else: 508 | return parameter 509 | 510 | def train(self, epochs, hf_manager, n_steps): 511 | self.last_send_time = time.time() 512 | step_counter = 0 # Initialize step counter that persists across epochs 513 | test_counter = 0 514 | test_losses = [] 515 | test_accuracies = [] 516 | training_losses = [] 517 | logging.info( 518 | "Model updated from Hugging Face. Continuing training with new model..." 519 | ) 520 | # self.model = hf_manager.update_model(self.model) 521 | self.model = FeedforwardNN() 522 | 523 | self.optimizer = SGD( 524 | self.model.parameters(), lr=0.1 525 | ) # Reinitialize the optimizer 526 | self.optimizer.zero_grad() # Ensure gradients are reset after model update 527 | self.aggregated_gradients = ( 528 | {} 529 | ) # Initialize an empty dictionary for storing aggregated gradients 530 | for ( 531 | name, 532 | param, 533 | ) in self.model.named_parameters(): # Iterate over all parameters of the model 534 | if ( 535 | param.requires_grad 536 | ): # Check if the parameter requires gradients and has gradients computed 537 | self.aggregated_gradients[name] = torch.zeros_like( 538 | param 539 | ) # Create a zero tensor with the same shape as the parameter 540 | 541 | for epoch in range(epochs): 542 | logging.info(f"Starting Epoch: {epoch}") 543 | total_loss = 0 544 | total_examples = 0 545 | 546 | # if hf_manager.check_for_new_submissions(): 547 | # logging.info("Model updated from Hugging Face. Continuing training with new model...") 548 | # self.model = hf_manager.update_model(self.model) 549 | # self.optimizer = SGD(self.model.parameters(), lr=5e-5) # Reinitialize the optimizer 550 | # self.optimizer.zero_grad() # Ensure gradients are reset after model update 551 | 552 | for batch_idx, (data, target) in enumerate(self.data_loader): 553 | output = self.model(data) 554 | loss = F.cross_entropy(output, target) 555 | loss.backward() 556 | 557 | for name, param in self.model.named_parameters(): 558 | if param.grad is not None and param.requires_grad: 559 | self.aggregated_gradients[name] += self.normalize_gradients( 560 | param.grad, threshold=0.1 561 | ) 562 | 563 | self.optimizer.zero_grad() 564 | 565 | total_loss += loss.item() 566 | total_examples += len(data) 567 | 568 | average_loss = total_loss / total_examples 569 | # logging.info(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}") 570 | 571 | # Check if it's time to step the optimizer and reset gradients 572 | if (step_counter + 1) % n_steps == 0: 573 | test_counter += 1 574 | 575 | # for param in self.model.parameters(): 576 | # if param.grad is not None: 577 | # param.grad /= (n_steps//10) 578 | self.optimizer.zero_grad() 579 | 580 | for name, param in self.model.named_parameters(): 581 | if param.grad is not None: 582 | param.grad = self.aggregated_gradients[name] 583 | 584 | self.optimizer.step() 585 | 586 | test_loss, test_accuracy = self.test() 587 | # test_losses.append(test_loss) 588 | # test_accuracies.append(test_accuracy) 589 | train_loss = total_loss / total_examples 590 | # training_losses.append(train_loss) 591 | logging.info( 592 | f"Train Loss: {train_loss} At {step_counter} accumulated gradients" 593 | ) 594 | logging.info( 595 | f"Test Loss: {test_loss} At {step_counter} accumulated gradients" 596 | ) 597 | logging.info( 598 | f"Test Accuracy: {test_accuracy} At {step_counter} accumulated gradients" 599 | ) 600 | 601 | return train_loss, test_loss, test_accuracy 602 | 603 | self.model.train() 604 | 605 | step_counter += 1 # Increment step counter after processing each batch 606 | 607 | # Periodic actions such as logging and sending gradients 608 | if time.time() - self.last_send_time >= self.send_interval: 609 | average_loss = total_loss / total_examples 610 | logging.info( 611 | f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}" 612 | ) 613 | 614 | # Logic to send aggregated gradients 615 | self.weight_diffs = { 616 | name: param.data - self.base_weights[name] 617 | for name, param in self.model.named_parameters() 618 | if param.requires_grad 619 | } 620 | self.store_gradients(self.weight_diffs, self.gradients_dir) 621 | 622 | self.last_send_time = time.time() 623 | # total_loss = 0 624 | # total_examples = 0 # Reset for the next interval 625 | 626 | def test(self): 627 | self.model.eval() 628 | test_loss = 0 629 | correct_predictions = 0 630 | total_test_samples = 0 631 | 632 | with torch.no_grad(): 633 | for batch in self.test_loader: 634 | images, labels = batch 635 | outputs = self.model(images) 636 | loss = F.cross_entropy(outputs, labels) 637 | test_loss += loss.item() 638 | _, predicted = torch.max(outputs.data, 1) 639 | correct_predictions += (predicted == labels).sum().item() 640 | total_test_samples += labels.size(0) 641 | 642 | average_test_loss = test_loss / total_test_samples 643 | accuracy = correct_predictions / total_test_samples 644 | return average_test_loss, accuracy 645 | 646 | 647 | class MNISTDeltaTrain(LocalTrainingLoop): 648 | def __init__( 649 | self, 650 | model_name, 651 | data_loader, 652 | gradients_dir, 653 | test_loader, 654 | averaging_dir="averaged_model", 655 | learning_rate=5e-5, 656 | send_interval=30, 657 | ): 658 | self.model = FeedforwardNN() 659 | self.model.train() 660 | 661 | self.data_loader = data_loader 662 | self.test_loader = test_loader 663 | 664 | self.optimizer = SGD(self.model.parameters(), lr=learning_rate) 665 | self.send_interval = send_interval 666 | self.gradients_dir = gradients_dir 667 | self.averaging_dir = averaging_dir 668 | 669 | self.last_send_time = time.time() 670 | 671 | def save_model(self): 672 | """ 673 | Saves the model to the specified local directory. 674 | """ 675 | os.makedirs(self.averaging_dir, exist_ok=True) 676 | model_save_path = os.path.join(self.averaging_dir, "averaged_model.pt") 677 | torch.save(self.model.state_dict(), model_save_path) 678 | logging.info(f"Model saved locally at {model_save_path}.") 679 | 680 | @staticmethod 681 | def normalize_gradients(parameter, threshold=1.0): 682 | """ 683 | Normalize the gradients to avoid exploding or vanishing gradients. 684 | 685 | Args: 686 | parameters (iterable): Iterable of model parameters (typically model.parameters() in PyTorch). 687 | threshold (float): The maximum norm value for gradients. Defaults to 1.0. 688 | """ 689 | param_norm = parameter.norm(2) 690 | 691 | # Normalize if the total norm exceeds the threshold 692 | if param_norm > threshold: 693 | return parameter.data.mul_(threshold / param_norm) 694 | else: 695 | return parameter 696 | 697 | def calculate_model_hash(self): 698 | model_hash = hashlib.sha256() 699 | for name, param in self.model.named_parameters(): 700 | model_hash.update(name.encode("utf-8")) 701 | model_hash.update(param.data.cpu().numpy().tobytes()) 702 | return model_hash.hexdigest() 703 | 704 | def train(self, epochs, hf_manager, n_steps): 705 | step_counter = 0 # Initialize step counter that persists across epochs 706 | test_counter = 0 707 | test_losses = [] 708 | test_accuracies = [] 709 | training_losses = [] 710 | logging.info( 711 | "Model updated from Hugging Face. Continuing training with new model..." 712 | ) 713 | # self.model = hf_manager.update_model(self.model) 714 | self.model = FeedforwardNN() 715 | 716 | self.optimizer = SGD( 717 | self.model.parameters(), lr=0.1 718 | ) # Reinitialize the optimizer 719 | self.base_weights = { 720 | name: param.clone() for name, param in self.model.named_parameters() 721 | } 722 | 723 | for epoch in range(epochs): 724 | logging.info(f"Starting Epoch: {epoch}") 725 | total_loss = 0 726 | total_examples = 0 727 | 728 | for batch_idx, (data, target) in enumerate(self.data_loader): 729 | if ( 730 | hf_manager.check_for_new_submissions() 731 | ): # FIXME add this in other training manager classes 732 | time.sleep(3) 733 | logging.info( 734 | "Model updated from Hugging Face. Continuing training with new model..." 735 | ) 736 | self.model = hf_manager.update_model(self.model) 737 | self.optimizer = SGD( 738 | self.model.parameters(), lr=0.001 739 | ) # Reinitialize the optimizer 740 | self.base_weights = { 741 | name: param.clone() 742 | for name, param in self.model.named_parameters() 743 | } 744 | # self.optimizer.zero_grad() # Ensure gradients are reset after model update 745 | 746 | output = self.model(data) 747 | loss = F.cross_entropy(output, target) 748 | loss.backward() 749 | 750 | self.optimizer.step() 751 | self.optimizer.zero_grad() 752 | 753 | total_loss += loss.item() 754 | total_examples += len(data) 755 | 756 | average_loss = total_loss / total_examples 757 | # logging.info(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}") 758 | 759 | # Check if it's time to step the optimizer and reset gradients 760 | if (step_counter + 1) % n_steps == 0: 761 | test_counter += 1 762 | 763 | test_loss, test_accuracy = self.test() 764 | # test_losses.append(test_loss) 765 | # test_accuracies.append(test_accuracy) 766 | train_loss = total_loss / total_examples 767 | # training_losses.append(train_loss) 768 | logging.info( 769 | f"Train Loss: {train_loss} At {step_counter} accumulated gradients" 770 | ) 771 | logging.info( 772 | f"Test Loss: {test_loss} At {step_counter} accumulated gradients" 773 | ) 774 | logging.info( 775 | f"Test Accuracy: {test_accuracy} At {step_counter} accumulated gradients" 776 | ) 777 | 778 | # return train_loss, test_loss, test_accuracy 779 | 780 | self.model.train() 781 | 782 | step_counter += 1 # Increment step counter after processing each batch 783 | 784 | # Periodic actions such as logging and sending gradients 785 | if time.time() - self.last_send_time >= self.send_interval: 786 | average_loss = total_loss / total_examples 787 | logging.info( 788 | f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {average_loss:.4f}" 789 | ) 790 | 791 | # Logic to send aggregated gradients 792 | self.weight_diffs = { 793 | name: param.data - self.base_weights[name] 794 | for name, param in self.model.named_parameters() 795 | if param.requires_grad 796 | } 797 | self.store_gradients(self.weight_diffs, self.gradients_dir) 798 | 799 | logging.info(f"Model hash is: {self.calculate_model_hash()}") 800 | 801 | self.last_send_time = time.time() 802 | # total_loss = 0 803 | # total_examples = 0 # Reset for the next interval 804 | 805 | 806 | # =============================== 807 | -------------------------------------------------------------------------------- /hivetrain/utils/auto_update.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | import requests 5 | 6 | def run_script(repo_dir): 7 | original_dir = os.getcwd() # Save the original working directory 8 | 9 | repo_url = "https://github.com/username/repo.git" 10 | 11 | if os.path.exists(repo_dir): 12 | # Remove the existing directory 13 | subprocess.run(["rm", "-rf", repo_dir]) 14 | 15 | 16 | # Change to the repository directory (if not already there) 17 | os.chdir(repo_dir) 18 | 19 | # Install the package using pip 20 | subprocess.run(["pip", "uninstall", "hivetrain"]) 21 | subprocess.run(["pip", "install", "-e", "."]) 22 | 23 | # Stop the existing PM2 process (if running) 24 | subprocess.run(["pm2", "stop", "script"]) 25 | 26 | # Start the script using PM2 27 | subprocess.run(["pm2", "start", "script.py", "--name", "script"]) 28 | 29 | # Revert back to the original working directory 30 | os.chdir(original_dir) 31 | 32 | def get_latest_commit_sha(repo_owner, repo_name, file_path): 33 | url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents/{file_path}" 34 | response = requests.get(url) 35 | 36 | if response.status_code == 200: 37 | data = response.json() 38 | return data["sha"] 39 | else: 40 | print("Failed to fetch the latest commit SHA.") 41 | return None 42 | 43 | def monitor_repo(): 44 | repo_owner = "username" 45 | repo_name = "repo" 46 | file_path = "__init__.py" 47 | repo_dir = "repo" 48 | 49 | current_sha = get_latest_commit_sha(repo_owner, repo_name, file_path) 50 | 51 | while True: 52 | latest_sha = get_latest_commit_sha(repo_owner, repo_name, file_path) 53 | 54 | if current_sha is not None and latest_sha is not None and current_sha != latest_sha: 55 | print("__init__.py file updated. Running script...") 56 | run_script(repo_dir) 57 | current_sha = latest_sha 58 | 59 | # Sleep for a certain interval before checking again 60 | time.sleep(60) # Check every 60 seconds, adjust as needed 61 | 62 | # Start monitoring the repository 63 | if __name__ == "__main__": 64 | monitor_repo() -------------------------------------------------------------------------------- /hivetrain/utils/bootstrap_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from flask import Flask, jsonify 3 | import hivemind 4 | import time 5 | import random 6 | import threading 7 | import logging 8 | import sys 9 | 10 | logging.basicConfig(level=logging.ERROR) # Set default logging level to ERROR for all loggers 11 | 12 | logger = logging.getLogger('bootstrap') 13 | logger.setLevel(logging.INFO) 14 | 15 | # Create handler for stdout 16 | handler = logging.StreamHandler() 17 | handler.setLevel(logging.INFO) 18 | 19 | # Create formatter and add it to the handler 20 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 21 | handler.setFormatter(formatter) 22 | 23 | # Add the handler to the logger 24 | logger.addHandler(handler) 25 | 26 | from waitress import serve 27 | 28 | app = Flask(__name__) 29 | 30 | # List to store interconnected DHTs 31 | dht_list = [] 32 | 33 | # Global variable to store the last checked timestamp 34 | last_checked = 0#time.time() 35 | 36 | # Create a lock for thread-safe access to shared resources 37 | lock = threading.Lock() 38 | 39 | def check_and_manage_dhts(): 40 | global last_checked 41 | 42 | # Check the status of each DHT in the list 43 | 44 | for dht in dht_list: 45 | try: 46 | # Attempt to connect to the DHT by creating a new disposable DHT 47 | logger.info("Is DHT Alive") 48 | test_dht = hivemind.DHT(initial_peers=[str(dht.get_visible_maddrs()[0])], start=True) 49 | test_dht.shutdown() 50 | logger.info("DHT Alive") 51 | except Exception as e: 52 | logger.info(f"DHT failed. {e}") 53 | # If the connection fails, mark the DHT as non-responsive 54 | dht.terminate() 55 | dht_list.remove(dht) 56 | 57 | # Create new DHTs if needed 58 | if len(dht_list) < 10: 59 | initial_peers = [dht.get_visible_maddrs()[0] for dht in dht_list] 60 | logger.info(f"Replacing {10 - len(dht_list)} DHTs") 61 | for _ in range(10 - len(dht_list)): 62 | new_dht = hivemind.DHT( 63 | host_maddrs=[f"/ip4/{args.host_address}/tcp/0", f"/ip4/{args.host_address}/udp/0/quic"], 64 | #announce_maddrs=[f"/ip4/{args.host_address}", f"/ip4/{args.host_address}"], 65 | initial_peers=initial_peers, 66 | start=True 67 | ) 68 | new_dht.wait_until_ready() 69 | dht_list.append(new_dht) 70 | 71 | # Update the last checked timestamp 72 | last_checked = time.time() 73 | 74 | 75 | @app.before_request 76 | def before_request(): 77 | global last_checked 78 | 79 | with lock: 80 | # Check if more than 10 minutes have passed since the last check 81 | if (time.time() - last_checked > 100) and len(dht_list) > 0: # 600 seconds = 10 minutes 82 | logger.info("Checking DHT Status") 83 | check_and_manage_dhts() 84 | 85 | @app.route('/return_dht_address') 86 | def return_dht_address(): 87 | # Check if there are any available DHTs 88 | global dht_list 89 | if dht_list: 90 | # Choose a random DHT from the list 91 | logger.debug(f"Request Received. Available DHTs") 92 | random_dht = random.choice(dht_list) 93 | initial_peers = [str(multiaddr).replace(args.host_address, args.external_address) for multiaddr in random_dht.get_visible_maddrs()] 94 | return jsonify({"initial_peers":initial_peers}) 95 | else: 96 | # If no DHTs are available, create a new one and return its address 97 | with lock: 98 | logger.info("Initializing 1st DHT") 99 | new_dht = hivemind.DHT( 100 | host_maddrs=[f"/ip4/{args.host_address}/tcp/0", f"/ip4/{args.host_address}/udp/0/quic"], 101 | #announce_maddrs=[f"/ip4/{args.host_address}"], 102 | start=True 103 | ) 104 | dht_list.append(new_dht) 105 | initial_peers = [str(multiaddr).replace(args.host_address, args.external_address) for multiaddr in new_dht.get_visible_maddrs()] 106 | return jsonify({"initial_peers":initial_peers}) 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser(description='DHT Manager') 110 | parser.add_argument('--host_address', type=str, default="0.0.0.0", help='Machine\'s internal IP') 111 | parser.add_argument('--host_port', type=int, default=5000, help='Port number (default: 5000)') 112 | parser.add_argument('--external_address', type=str, default="20.20.20.20", help='Machine\'s external IP') 113 | args = parser.parse_args() 114 | #app.run(host=args.host_address, port=args.host_port,threaded=True) 115 | serve(app, host=args.host_address, port=args.host_port) 116 | -------------------------------------------------------------------------------- /hivetrain/utils/bootstrap_stress.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import aiohttp 4 | import random 5 | import time 6 | 7 | async def ping_server(session, server_url): 8 | start_time = time.time() 9 | async with session.get(server_url) as response: 10 | if response.status == 200: 11 | dht_address = await response.text() 12 | end_time = time.time() 13 | latency = end_time - start_time 14 | print(f"Received DHT address: {dht_address}, Latency: {latency:.4f} seconds") 15 | else: 16 | print(f"Request failed with status code: {response.status}") 17 | 18 | async def stress_test(server_url, num_requests, concurrent_requests): 19 | async with aiohttp.ClientSession() as session: 20 | tasks = [] 21 | for _ in range(num_requests): 22 | task = asyncio.create_task(ping_server(session, server_url)) 23 | tasks.append(task) 24 | if len(tasks) >= concurrent_requests: 25 | await asyncio.gather(*tasks) 26 | tasks = [] 27 | if tasks: 28 | await asyncio.gather(*tasks) 29 | 30 | async def run_stress_test(server_url, num_requests, concurrent_requests, duration): 31 | start_time = time.time() 32 | while time.time() - start_time < duration: 33 | await stress_test(server_url, num_requests, concurrent_requests) 34 | print(f"Completed iteration at {time.strftime('%Y-%m-%d %H:%M:%S')}") 35 | end_time = time.time() 36 | print(f"\nStress test completed in {end_time - start_time:.2f} seconds") 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser(description='Stress Test Client') 40 | parser.add_argument('--server_url', type=str, default="http://localhost:5000/return_dht_address", 41 | help='Server URL to ping') 42 | parser.add_argument('--num_requests', type=int, default=100, help='Total number of requests per iteration') 43 | parser.add_argument('--concurrent_requests', type=int, default=50, 44 | help='Number of concurrent requests to send') 45 | parser.add_argument('--duration', type=int, default=300, help='Duration of the stress test in seconds (default: 1800 = 30 minutes)') 46 | args = parser.parse_args() 47 | 48 | asyncio.run(run_stress_test(args.server_url, args.num_requests, args.concurrent_requests, args.duration)) -------------------------------------------------------------------------------- /hivetrain/utils/dummy_miner.py: -------------------------------------------------------------------------------- 1 | import ipaddress 2 | import logging 3 | import os 4 | import random 5 | import re 6 | import sys 7 | import time 8 | import bittensor as bt 9 | from bittensor import metagraph 10 | 11 | import requests 12 | 13 | from hivetrain.btt_connector import ( 14 | BittensorNetwork, 15 | ) 16 | from hivetrain.config import Configurator 17 | from hivetrain import __spec_version__ 18 | import logging 19 | 20 | logging.getLogger("lightning.pytorch").setLevel(logging.INFO) 21 | logger = logging.getLogger("lightning.pytorch") 22 | 23 | args = Configurator.combine_configs() 24 | 25 | class ValidationCommunicator: 26 | """Periodically send dummy requests to validators.""" 27 | 28 | def __init__(self, args, sync_interval=600): 29 | BittensorNetwork.initialize(args) 30 | 31 | self.wallet = BittensorNetwork.wallet 32 | self.subtensor = BittensorNetwork.subtensor 33 | self.metagraph = BittensorNetwork.metagraph 34 | self.sync_interval = sync_interval 35 | self.last_sync_time = 0 36 | self.validator_urls = ["127.0.0.1:8888"] 37 | 38 | def start(self): 39 | while True: 40 | current_time = int(time.time()) 41 | if self.should_sync_metagraph(): 42 | self.resync_metagraph() 43 | timestamp = str(current_time) 44 | message, signature, public_address = self.create_signed_message(timestamp) 45 | 46 | for url in self.validator_urls: 47 | try: 48 | response = requests.post( 49 | f"http://{url}/validate_metrics", 50 | json={"metrics": {"loss": random.random()}, 51 | "message": message, "signature": signature, "public_address": public_address, "miner_version": __spec_version__}, 52 | timeout=3 53 | ) 54 | if response.status_code == 200: 55 | logger.info(f"Dummy metrics reported successfully to validator {url}") 56 | else: 57 | logger.warn(f"Error @ validator {url} --- Error: {response.json()['error']}") 58 | except Exception as e: 59 | logger.warn(f"Failed to confirm reception at {url}: {str(e)}") 60 | 61 | time.sleep(60) # Sleep for 60 seconds before sending the next request 62 | 63 | def create_signed_message(self, message): 64 | signature = self.wallet.hotkey.sign( 65 | message 66 | ).hex() 67 | public_address = self.wallet.hotkey.ss58_address 68 | return message, signature, public_address 69 | 70 | def resync_metagraph(self): 71 | try: 72 | self.metagraph.sync(subtensor=self.subtensor) 73 | logger.info("Syncing Metagraph Successful") 74 | except Exception as e: 75 | logger.warning(f"Failed to sync metagraph: {e}") 76 | 77 | def should_sync_metagraph(self): 78 | return (time.time() - self.last_sync_time) > self.sync_interval 79 | 80 | if __name__ == "__main__": 81 | communicator = ValidationCommunicator(args) 82 | communicator.start() -------------------------------------------------------------------------------- /hivetrain/utils/generate_wallets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from substrateinterface import Keypair 3 | from typing import List 4 | import bittensor as bt 5 | from tqdm import tqdm 6 | from hivetrain.config import Configurator 7 | import time 8 | 9 | def generate_multiple_wallets(n: int, main_wallet_mnemonic: str, subtensor: bt.subtensor, reg_amount: int = 0.0001, 10 | netuid: int = 100) -> List[dict]: 11 | """ 12 | Generates multiple wallets, each with a hot and cold keypair, without encryption or user prompts. 13 | 14 | Args: 15 | - n (int): The number of wallets to generate. 16 | 17 | Returns: 18 | - List[dict]: A list of dictionaries, each containing the mnemonic, hotkey, and coldkey for each wallet. 19 | """ 20 | wallets = [] 21 | core_tao_wallet = bt.wallet(name="faucet_source", hotkey="faucet_hot") 22 | core_tao_wallet.regen_coldkey(use_password=False, overwrite=True, mnemonic=main_wallet_mnemonic) 23 | for wallet_number in tqdm(range(n)): 24 | # Generate mnemonics for hot and cold keys 25 | wallet_of_tao = bt.wallet(name=f"test_coldkey_{wallet_number}", hotkey=f"test_hotkey_{wallet_number}") 26 | wallet_of_tao.new_coldkey(use_password=False, overwrite=True) 27 | wallet_of_tao.new_hotkey(overwrite=True) 28 | subtensor.transfer(core_tao_wallet, wallet_of_tao.coldkey.ss58_address, reg_amount, 29 | wait_for_inclusion=True, 30 | wait_for_finalization=True, 31 | keep_alive=True, 32 | prompt=False) 33 | time.sleep(0.5) #Make sure subnet hyperparams allow lots of regs 34 | subtensor.register(wallet_of_tao, 35 | netuid, 36 | wait_for_inclusion=True, 37 | wait_for_finalization=True, 38 | prompt=False) 39 | 40 | 41 | return wallets 42 | 43 | if __name__ == '__main__': 44 | # Example: Generate 3 wallets 45 | config = Configurator.combine_configs() 46 | MAIN_WALLET_MNEMONIC = "ENTER THE MNEMONIC HERE" 47 | wallets = generate_multiple_wallets(3, MAIN_WALLET_MNEMONIC, config.subtensor, 0.00000001) 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /hivetrain/utils/mlflow_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import psutil 4 | import time 5 | import re 6 | import logging 7 | import mlflow 8 | from hivetrain.config.mlflow_config import MLFLOW_UI_URL 9 | import requests 10 | from requests.adapters import HTTPAdapter 11 | from requests.sessions import Session 12 | from urllib3.util.retry import Retry 13 | 14 | 15 | def get_gpu_utilization(): 16 | """ 17 | Retrieves the GPU utilization as a percentage of the GPU's capacity. 18 | 19 | Returns: 20 | float: The GPU utilization percentage if CUDA is available, otherwise 0.0. 21 | 22 | """ 23 | # Check if the CUDA is available on the current device using PyTorch's built-in function 24 | if torch.cuda.is_available(): 25 | # Retrieve and return the GPU utilization percentage 26 | utilization = torch.cuda.utilization() 27 | return utilization 28 | else: 29 | # Return 0.0 if CUDA is not available, indicating no GPU activity 30 | return 0.0 31 | 32 | 33 | def get_cpu_utilization(): 34 | """ 35 | Returns the current system-wide CPU utilization as a percentage. 36 | 37 | Returns: 38 | float: The percentage of CPU utilization. 39 | """ 40 | cpu_usage = psutil.cpu_percent(interval=1) 41 | return cpu_usage 42 | 43 | 44 | def get_memory_usage(): 45 | """ 46 | Retrieves the current memory usage of the process that runs a Python script. 47 | 48 | Returns: 49 | int: The current memory usage (RSS) of the process in megabytes. 50 | 51 | """ 52 | # Create a process instance for the current process using psutil 53 | process = psutil.Process() 54 | # Retrieve memory info from the current process 55 | memory_info = process.memory_info() 56 | # Return the Resident Set Size (RSS) which indicates how much memory in megabytes the process is using in RAM 57 | return round(memory_info.rss / 1024**2, 2) 58 | 59 | 60 | def get_network_bandwidth(): 61 | """ 62 | Retrieves the total network bandwidth usage by calculating the sum of bytes sent and received. 63 | 64 | Returns: 65 | int: The total number of bytes sent and received over the network. 66 | 67 | """ 68 | net_io_counters = psutil.net_io_counters() 69 | return net_io_counters.bytes_sent + net_io_counters.bytes_recv 70 | 71 | 72 | def get_version_from_file(): 73 | file_path = os.path.join(os.getcwd(), "template/__init__.py") 74 | # Read the specified file and search for version information 75 | with open(file_path, "r") as file: 76 | content = file.read() 77 | # Regex to match __version__ = 'x.y.z' 78 | version_match = re.search(r"__version__\s*=\s*['\"]([^'\"]+)['\"]", content) 79 | if version_match: 80 | return version_match.group(1) 81 | else: 82 | return None 83 | 84 | 85 | def initialize_mlflow( 86 | role, 87 | device, 88 | version, 89 | mlflow_ui_url, 90 | current_model_name, 91 | my_hotkey = None, 92 | learning_rate=None, 93 | send_interval=None, 94 | check_update_interval=None, 95 | ): 96 | try: 97 | os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true" 98 | os.environ["MLFLOW_START_RETRY_ATTEMPT_MAX"] = "2" 99 | mlflow.set_tracking_uri(mlflow_ui_url) 100 | mlflow.set_experiment(current_model_name) 101 | 102 | if role == "miner": 103 | run_name = f"miner_{my_hotkey}" 104 | mlflow.start_run(run_name=run_name) 105 | mlflow.log_param("device", device) 106 | mlflow.log_param("Version of Code", version) 107 | mlflow.log_param("learning_rate", learning_rate) 108 | mlflow.log_param("send_interval", send_interval) 109 | mlflow.log_param("check_update_interval", check_update_interval) 110 | elif role == "validator": 111 | run_name = f"validator_{my_hotkey}" 112 | mlflow.start_run(run_name=run_name) 113 | mlflow.log_param("device", device) 114 | mlflow.log_param("Version of Code", version) 115 | mlflow.log_param("check_update_interval", check_update_interval) 116 | else: 117 | run_name = f"AVERAGER" 118 | mlflow.start_run(run_name=run_name) 119 | mlflow.log_param("device", device) 120 | mlflow.log_param("Version of Code", version) 121 | except Exception as e: 122 | logging.error(f"Failed to initialize and log parameters to MLflow: {e}") 123 | return None 124 | 125 | 126 | def log_model_metrics(step, **metrics): 127 | """ 128 | Logs given metrics to MLflow with the provided step count/int(time). 129 | 130 | Args: 131 | step (int): The step count or timestamp at which metrics are logged, providing a timeline for metrics. 132 | **metrics (dict): Arbitrary number of keyword arguments where keys are the metric names and values are their respective values to log. 133 | """ 134 | try: 135 | for metric_name, metric_value in metrics.items(): 136 | mlflow.log_metric(metric_name, metric_value, step=step) 137 | print(f"Logged metrics to MLflow at Step {step}: {metrics}") 138 | except Exception as e: 139 | logging.error(f"Failed to log metrics to MLflow: {e}") 140 | print(f"Error logging to MLflow, but training continues: {e}") 141 | 142 | 143 | def setup_mlflow_session( 144 | mlflow_tracking_uri, retries=2, backoff_factor=1, status_forcelist=(500, 502, 504) 145 | ): 146 | """ 147 | Sets up a custom requests session for MLflow with specified retry logic. 148 | 149 | Args: 150 | mlflow_tracking_uri (str): The MLflow server's tracking URI. 151 | retries (int): The number of retries for requests. 152 | backoff_factor (float): A backoff factor to apply between attempts. 153 | status_forcelist (tuple): A set of HTTP status codes that we should force a retry on. 154 | """ 155 | session = requests.Session() 156 | retry = Retry( 157 | total=retries, 158 | read=retries, 159 | connect=retries, 160 | backoff_factor=backoff_factor, 161 | status_forcelist=status_forcelist, 162 | ) 163 | adapter = HTTPAdapter(max_retries=retry) 164 | session.mount("http://", adapter) 165 | session.mount("https://", adapter) 166 | 167 | mlflow.set_tracking_uri(mlflow_tracking_uri) 168 | mlflow.utils.rest_utils.http_request(session=session) 169 | 170 | 171 | def create_mlflow_session(): 172 | """Creates a requests session for MLflow with custom retry behavior.""" 173 | session = Session() 174 | retries = Retry(total=2, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) 175 | session.mount("http://", HTTPAdapter(max_retries=retries)) 176 | session.mount("https://", HTTPAdapter(max_retries=retries)) 177 | return session 178 | 179 | 180 | VERSION = get_version_from_file() 181 | -------------------------------------------------------------------------------- /hivetrain/utils/ports.txt: -------------------------------------------------------------------------------- 1 | 32155 2 | 32225 3 | 32229 4 | 32450 5 | 32669 6 | 32717 7 | 32845 8 | 32858 9 | 32861 10 | 32897 11 | 32944 12 | 32982 -------------------------------------------------------------------------------- /hivetrain/validation_logic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import torch 4 | import time 5 | import math 6 | import mlflow 7 | import mlflow.pytorch 8 | from bittensor import logging 9 | import logging 10 | from copy import deepcopy 11 | from hivetrain.btt_connector import BittensorNetwork 12 | from hivetrain.config import Configurator 13 | from hivetrain.config.mlflow_config import ( 14 | MLFLOW_UI_URL, 15 | CURRENT_MODEL_NAME, 16 | MLFLOW_ACTIVE, 17 | ) 18 | from hivetrain.utils.mlflow_utils import initialize_mlflow, log_model_metrics, VERSION 19 | from torch.optim import AdamW 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | 24 | args = Configurator.combine_configs() 25 | BittensorNetwork.initialize(args, ignore_regs=True) 26 | MY_HOTKEY = BittensorNetwork.wallet.hotkey.ss58_address 27 | 28 | 29 | class ModelValidator: 30 | def __init__( 31 | self, 32 | device, 33 | model, 34 | optimizer, 35 | data_loader, 36 | check_update_interval=300, 37 | bittensor_network=None, 38 | chain_manager=None, 39 | hf_manager=None, 40 | interval=300, 41 | ): 42 | self.device = device 43 | self.model = model 44 | self.model = self.model.to(device) 45 | self.optimizer = optimizer 46 | self.data_loader = data_loader 47 | self.interval = interval # Validation interval in seconds 48 | self.base_loss, self.base_perplexity = self.evaluate_model() 49 | self.bittensor_network = bittensor_network 50 | self.scores = { 51 | hotkey: 0.0 for hotkey in self.bittensor_network.metagraph.hotkeys 52 | } 53 | self.normalized_scores = { 54 | hotkey: 0.0 for hotkey in self.bittensor_network.metagraph.hotkeys 55 | } 56 | self.chain_manager = chain_manager 57 | self.hf_manager = hf_manager 58 | self.last_pull_time = 0 59 | self.check_update_interval = check_update_interval 60 | 61 | if MLFLOW_ACTIVE: 62 | initialize_mlflow( 63 | role="validator", 64 | device=self.device, 65 | version=VERSION, 66 | mlflow_ui_url=MLFLOW_UI_URL, 67 | current_model_name=CURRENT_MODEL_NAME, 68 | my_hotkey=MY_HOTKEY, 69 | check_update_interval=self.check_update_interval, 70 | ) 71 | 72 | def update_model_weights(self, gradients, alpha=5e-4): 73 | with torch.no_grad(): 74 | for name, param in self.model.named_parameters(): 75 | if name in gradients: 76 | param -= gradients[name] * alpha 77 | 78 | def evaluate_model(self, metric="perplexity"): 79 | self.model.eval() 80 | total_loss = 0 81 | total_samples = 0 82 | with torch.no_grad(): 83 | for batch_num, batch in enumerate( 84 | self.data_loader 85 | ): # FIXME turn me into a generator? 86 | outputs = self.model( 87 | input_ids=batch["input_ids"].to(self.device), 88 | attention_mask=batch["attention_mask"].to(self.device), 89 | labels=batch["labels"].to(self.device), 90 | ) 91 | loss = outputs.loss 92 | total_loss += loss.item() * batch["input_ids"].size(0) 93 | total_samples += batch["input_ids"].size(0) 94 | 95 | average_loss = total_loss / total_samples 96 | perplexity = math.exp(average_loss) if metric == "perplexity" else None 97 | return average_loss, perplexity 98 | 99 | def validate_and_score(self): 100 | """Check if the model is changed on HF , Check if HF commit hash is updated? 101 | If true pull""" 102 | 103 | logging.info("!Receiving Gradients from chain") 104 | self.bittensor_network.sync(lite=True) # FIXME too prone to issues 105 | 106 | if time.time() - self.last_pull_time >= self.check_update_interval: 107 | if self.hf_manager.check_for_new_submissions(self.hf_manager.model_repo_id): 108 | logging.info( 109 | "Averaged model updated on Hugging Face. Pulling latest model..." 110 | ) 111 | self.hf_manager.pull_latest_model() 112 | time.sleep(10) # Give enough time for pull 113 | self.model = self.hf_manager.update_model(self.model) 114 | self.model = self.model.to(self.device) 115 | self.optimizer = AdamW( 116 | self.model.parameters(), lr=5e-5 117 | ) # Reinitialize the optimizer 118 | self.base_weights = { 119 | name: param.clone() for name, param in self.model.named_parameters() 120 | } 121 | self.last_pull_time = time.time() 122 | 123 | self.original_state_dict = deepcopy(self.model.state_dict()) 124 | 125 | total_scores = 0 126 | for uid, hotkey_address in enumerate(self.bittensor_network.metagraph.hotkeys): 127 | hf_repo = self.chain_manager.retrieve_hf_repo(hotkey_address) 128 | gradients = self.hf_manager.receive_gradients(hf_repo) 129 | if gradients is not None: 130 | logging.info(f"Receiving Gradients from: {hotkey_address}") 131 | logging.info(f"Updating Model Weights") 132 | self.update_model_weights(gradients) 133 | logging.info(f"The model hash: {self.calculate_model_hash()}") 134 | logging.info(f"Evaluating model") 135 | loss, perplexity = self.evaluate_model() 136 | loss_score = max(0, self.base_loss - loss) 137 | perplexity_score = max(0, self.base_perplexity - perplexity) 138 | total_scores += perplexity_score 139 | self.model.load_state_dict(self.original_state_dict) 140 | 141 | if MLFLOW_ACTIVE: 142 | metrics = { 143 | f"loss_{hotkey_address}": loss.item(), 144 | f"perplexity_{hotkey_address}": perplexity, 145 | f"loss_score_{hotkey_address}": loss_score, 146 | f"perplexity_score_{hotkey_address}": perplexity_score, 147 | } 148 | 149 | # Log metrics with dynamic names 150 | log_model_metrics(step=int(current_time), **metrics) 151 | 152 | else: 153 | loss = 99999999.0 154 | perplexity = 99999999.0 155 | loss_score = 0.0 156 | perplexity_score = 0.0 157 | 158 | current_time = int(time.time()) 159 | if MLFLOW_ACTIVE: 160 | metrics = { 161 | f"loss_{hotkey_address}": loss, 162 | f"perplexity_{hotkey_address}": perplexity, 163 | f"loss_score_{hotkey_address}": loss_score, 164 | f"perplexity_score_{hotkey_address}": perplexity_score, 165 | } 166 | log_model_metrics(step=int(current_time), **metrics) 167 | 168 | self.scores[hotkey_address] = perplexity_score 169 | # log validator performance 170 | 171 | if uid == 1: 172 | if MLFLOW_ACTIVE: 173 | try: 174 | mlflow.log_param("Version of Code", VERSION) 175 | except Exception as e: 176 | return None 177 | 178 | # Reset the model to its original state 179 | logging.info(f"Loss: {loss}, Perplexity: {perplexity}") 180 | logging.info( 181 | f"Loss Score: {loss_score}, Perplexity Score: {perplexity_score}" 182 | ) 183 | time.sleep(0.1) 184 | 185 | # normalize scores 186 | for uid, hotkey_address in enumerate(self.bittensor_network.metagraph.hotkeys): 187 | self.normalized_scores[hotkey_address] = max(0,self.scores[hotkey_address]/total_scores) 188 | if self.bittensor_network.should_set_weights(): 189 | self.bittensor_network.set_weights(self.normalized_scores) 190 | 191 | def start_periodic_validation(self): 192 | while True: 193 | self.validate_and_score() 194 | self.hf_manager.clear_hf_cache() 195 | logging.info(f"One round done sleeping for: {self.interval}") 196 | time.sleep(self.interval) 197 | 198 | def calculate_model_hash(self): 199 | model_hash = hashlib.sha256() 200 | for name, param in self.model.named_parameters(): 201 | model_hash.update(name.encode("utf-8")) 202 | model_hash.update(param.data.cpu().numpy().tobytes()) 203 | return model_hash.hexdigest() 204 | 205 | 206 | class LocalValidator(ModelValidator): 207 | def __init__( 208 | self, 209 | model, 210 | optimizer, 211 | data_loader, 212 | bittensor_network=None, 213 | chain_manager=None, 214 | hf_manager=None, 215 | interval=3600, 216 | local_gradient_dir="local_gradients", 217 | ): 218 | super().__init__( 219 | model, 220 | optimizer, 221 | data_loader, 222 | bittensor_network, 223 | chain_manager, 224 | hf_manager, 225 | interval, 226 | ) 227 | self.local_gradient_dir = local_gradient_dir 228 | # Ensure the local directory exists 229 | os.makedirs(self.local_gradient_dir, exist_ok=True) 230 | 231 | def receive_gradients(self, repo_id=None, gradient_file_name="gradients.pt"): 232 | """ 233 | Overrides the receive_gradients method to fetch gradients from a local directory. 234 | """ 235 | try: 236 | if repo_id == None: 237 | return None 238 | gradient_file_path = os.path.join(repo_id, gradient_file_name) 239 | if not os.path.exists(gradient_file_path): 240 | logging.warning(f"Gradient file not found: {gradient_file_path}") 241 | return None 242 | 243 | # Load the gradients directly using torch.load 244 | aggregated_gradients = torch.load(gradient_file_path) 245 | return aggregated_gradients 246 | except Exception as e: 247 | logging.error(f"Error receiving gradients locally: {e}") 248 | return None 249 | 250 | 251 | class DeltaValidator(ModelValidator): 252 | def update_model_weights(self, weight_deltas, alpha=5e-4): 253 | with torch.no_grad(): 254 | for name, param in self.model.named_parameters(): 255 | if name in weight_deltas: 256 | try: 257 | param.data = weight_deltas[name] + param.data 258 | except Exception as e: 259 | logging.warning(f"Error loading gradients: {e}") 260 | 261 | class LocalDeltaValidator(DeltaValidator, LocalValidator): 262 | pass 263 | 264 | 265 | class MNISTValidator(LocalValidator): 266 | def __init__( 267 | self, 268 | model, 269 | optimizer, 270 | data_loader, 271 | bittensor_network=None, 272 | chain_manager=None, 273 | hf_manager=None, 274 | interval=300, 275 | local_gradient_dir="local_gradients", 276 | ): 277 | super().__init__( 278 | model, 279 | optimizer, 280 | data_loader, 281 | bittensor_network, 282 | chain_manager, 283 | hf_manager, 284 | interval, 285 | ) 286 | ( 287 | self.base_loss, 288 | self.base_accuracy, 289 | ) = self.evaluate_model() # Redefine to use accuracy for MNIST 290 | 291 | def evaluate_model(self, *args, **kwargs): 292 | """Evaluate the model on the MNIST validation dataset.""" 293 | self.model.eval() 294 | total_loss = 0 295 | correct_predictions = 0 296 | total_samples = 0 297 | 298 | with torch.no_grad(): 299 | for batch in self.data_loader: 300 | images, labels = batch 301 | outputs = self.model(images) 302 | loss = F.cross_entropy(outputs, labels) 303 | total_loss += loss.item() 304 | _, predicted = torch.max(outputs.data, 1) 305 | correct_predictions += (predicted == labels).sum().item() 306 | total_samples += labels.size(0) 307 | 308 | average_loss = total_loss / total_samples 309 | accuracy = correct_predictions / total_samples 310 | return average_loss, accuracy 311 | 312 | 313 | class MNISTDeltaValidator(MNISTValidator): 314 | def update_model_weights(self, weight_deltas, alpha=5e-4): 315 | with torch.no_grad(): 316 | for name, param in self.model.named_parameters(): 317 | if name in weight_deltas: 318 | param.data = weight_deltas[name] + param.data 319 | -------------------------------------------------------------------------------- /neurons/averager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import io 3 | import argparse 4 | import ipaddress 5 | import os 6 | import random 7 | import re 8 | import sys 9 | import time 10 | from functools import partial 11 | from math import isnan 12 | import bittensor as bt 13 | from bittensor import metagraph 14 | 15 | import numpy as np 16 | import requests 17 | import torch 18 | from datasets import load_dataset 19 | from torch.optim import AdamW 20 | from torch.utils.data import DataLoader, Dataset, IterableDataset 21 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 22 | 23 | from huggingface_hub import Repository, HfFolder 24 | from hivetrain.averaging_logic import ParameterizedAverager 25 | from hivetrain.btt_connector import BittensorNetwork 26 | from hivetrain.config import Configurator 27 | from hivetrain.chain_manager import ChainMultiAddressStore 28 | from hivetrain.training_manager import FeedforwardNN 29 | from hivetrain.hf_manager import HFManager 30 | 31 | # Assuming `model` is your PyTorch model, `scores` contain your keys and their respective scores, 32 | # and `get_weights(key)` is a function to retrieve serialized gradients 33 | from torchvision import transforms, datasets 34 | from datasets import load_dataset 35 | from bittensor import logging 36 | 37 | logging.enable_debug() 38 | 39 | args = Configurator.combine_configs() 40 | BittensorNetwork.initialize(args, ignore_regs=True) 41 | 42 | my_hotkey = BittensorNetwork.wallet.hotkey.ss58_address 43 | #my_uid = BittensorNetwork.metagraph.hotkeys.index(my_hotkey) 44 | 45 | address_store = ChainMultiAddressStore(BittensorNetwork.subtensor, args.netuid,BittensorNetwork.wallet) 46 | 47 | # Define your model's local directory and repository ID 48 | #local_dir = "./save_me"#args.averager.save_directory #TODO add me to config :) 49 | #repo_id = "test_me"#args.averager.hf_repository #TODO add me to config :) 50 | 51 | batch_size = args.batch_size 52 | epochs = 30_000_000_000_000_000 53 | learning_rate = 5e-5 54 | send_interval = 120 # Every 60 seconds 55 | 56 | # Load model and tokenizer 57 | # Load the Wikitext dataset 58 | dataset = load_dataset("wikitext", "wikitext-103-v1") 59 | 60 | # Assuming you want to use the 'train' split of the dataset 61 | texts = dataset['test']['text'][:100] 62 | 63 | # Load model and tokenizer 64 | model_name = "openai-community/gpt2" 65 | tokenizer = AutoTokenizer.from_pretrained(model_name) 66 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 67 | model = AutoModelForCausalLM.from_pretrained(model_name) 68 | model.resize_token_embeddings(len(tokenizer)) 69 | model.train() 70 | 71 | class WikitextDataset(Dataset): 72 | def __init__(self, texts, tokenizer, max_length=512): 73 | self.tokenizer = tokenizer 74 | self.texts = texts 75 | self.max_length = max_length 76 | 77 | def __len__(self): 78 | return len(self.texts) 79 | 80 | def __getitem__(self, idx): 81 | encoding = self.tokenizer(self.texts[idx], return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) 82 | input_ids = encoding['input_ids'].squeeze() # Remove batch dimension 83 | attention_mask = encoding['attention_mask'].squeeze() 84 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': input_ids.clone()} 85 | 86 | def custom_collate_fn(batch): 87 | input_ids = torch.stack([item['input_ids'] for item in batch]) 88 | attention_mask = torch.stack([item['attention_mask'] for item in batch]) 89 | labels = input_ids.clone() # Copy input_ids to labels 90 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels} 91 | 92 | # Create the dataset and data loader 93 | wikitext_dataset = WikitextDataset(texts, tokenizer) 94 | test_loader = DataLoader(wikitext_dataset, batch_size=batch_size, collate_fn=custom_collate_fn) 95 | 96 | 97 | #__init__(self, model, local_dir, bittensor_network=None) 98 | #model, device, chain_manager=None,bittensor_network=None, hf_token=hf_token 99 | 100 | hf_manager = HFManager(my_repo_id = None, averaged_model_repo_id= args.storage.averaged_model_repo_id) 101 | device = "cuda" if torch.cuda.is_available() else "cpu" 102 | averager = ParameterizedAverager(model=model,device=device,hf_manager=hf_manager, local_dir=args.storage.model_dir, gradients_dir=args.storage.gradient_dir ,chain_manager=address_store,bittensor_network=BittensorNetwork, hf_token=os.environ.get("HF_TOKEN")) 103 | #averager.run_periodic_averaging(test_loader,20,300) 104 | #val_loader,meta_epochs, lr, t 105 | #averager.save_model() 106 | averager.run_periodic_averaging(test_loader, 7,0.01,1200) 107 | # # Push the model to the Hugging Face Hub 108 | # push_to_hf_hub(local_dir=local_dir, repo_id=repo_id, hf_token=args.averager.hf_token, commit_message=f"Updated model SN25 with {222}")#FIXME add block numbers 109 | -------------------------------------------------------------------------------- /neurons/miner.py: -------------------------------------------------------------------------------- 1 | from bittensor import logging 2 | import torch 3 | from torchvision import datasets, transforms 4 | from torch.utils.data import DataLoader, Dataset, IterableDataset 5 | from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer 6 | from datasets import load_dataset 7 | from hivetrain.btt_connector import BittensorNetwork 8 | from hivetrain.chain_manager import ChainMultiAddressStore 9 | from hivetrain.config import Configurator 10 | from hivetrain.hf_manager import HFManager 11 | from hivetrain.training_manager import DeltaLoop 12 | 13 | 14 | logging.enable_debug() 15 | logging.info("Starting !") 16 | 17 | 18 | def flatten_list(nested_list): 19 | """Flatten a nested list.""" 20 | if nested_list and isinstance(nested_list[0], list): 21 | # Assumes only one level of nesting 22 | return [item for sublist in nested_list for item in sublist] 23 | return nested_list 24 | 25 | 26 | # # set some basic configuration values 27 | # inital_peers_request = requests.get(args.miner.bootstrapping_server) 28 | # initial_peers = inital_peers_request.json()["initial_peers"] 29 | # assert not (initial_peers is None) 30 | args = Configurator.combine_configs() 31 | 32 | BittensorNetwork.initialize(args) 33 | my_hotkey = BittensorNetwork.wallet.hotkey.ss58_address 34 | my_uid = BittensorNetwork.metagraph.hotkeys.index(my_hotkey) 35 | 36 | address_store = ChainMultiAddressStore( 37 | BittensorNetwork.subtensor, args.netuid, BittensorNetwork.wallet 38 | ) 39 | 40 | current_address_in_store = address_store.retrieve_hf_repo(my_hotkey) 41 | logging.info(f"Current value in store:{current_address_in_store}") 42 | if current_address_in_store != args.storage.my_repo_id: 43 | logging.info(f"Storing new value: {args.storage.my_repo_id}") 44 | address_store.store_hf_repo(args.storage.my_repo_id) 45 | 46 | # Parameters 47 | 48 | batch_size = args.batch_size 49 | epochs = 30_000_000_000_000_000 50 | learning_rate = 5e-5 51 | send_interval = 600 # Every 60 seconds 52 | 53 | # Load the Wikitext dataset 54 | dataset = load_dataset("wikitext", "wikitext-103-v1") 55 | 56 | # Assuming you want to use the 'train' split of the dataset 57 | texts = dataset["train"]["text"] 58 | 59 | # Load model and tokenizer 60 | model_name = "openai-community/gpt2" 61 | tokenizer = AutoTokenizer.from_pretrained(model_name) 62 | tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 63 | # model = AutoModelForCausalLM.from_pretrained(model_name) 64 | # model.resize_token_embeddings(len(tokenizer)) 65 | # model.train() 66 | 67 | 68 | 69 | class WikitextDataset(Dataset): 70 | def __init__(self, texts, tokenizer, max_length=64): 71 | self.tokenizer = tokenizer 72 | self.texts = texts 73 | self.max_length = max_length 74 | 75 | def __len__(self): 76 | return len(self.texts) 77 | 78 | def __getitem__(self, idx): 79 | encoding = self.tokenizer( 80 | self.texts[idx], 81 | return_tensors="pt", 82 | padding="max_length", 83 | truncation=True, 84 | max_length=self.max_length, 85 | ) 86 | input_ids = encoding["input_ids"].squeeze() # Remove batch dimension 87 | attention_mask = encoding["attention_mask"].squeeze() 88 | return { 89 | "input_ids": input_ids, 90 | "attention_mask": attention_mask, 91 | "labels": input_ids.clone(), 92 | } 93 | 94 | 95 | def custom_collate_fn(batch): 96 | input_ids = torch.stack([item["input_ids"] for item in batch]) 97 | attention_mask = torch.stack([item["attention_mask"] for item in batch]) 98 | labels = input_ids.clone() # Copy input_ids to labels 99 | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} 100 | 101 | 102 | # Create the dataset and data loader 103 | wikitext_dataset = WikitextDataset(texts, tokenizer) 104 | data_loader = DataLoader( 105 | wikitext_dataset, batch_size=batch_size, collate_fn=custom_collate_fn 106 | ) 107 | # Optimizer 108 | #optimizer = AdamW(model.parameters(), lr=learning_rate) 109 | 110 | 111 | hf_manager = HFManager(my_repo_id = args.storage.my_repo_id, averaged_model_repo_id= args.storage.averaged_model_repo_id) 112 | #device = "cuda" if torch.cuda.is_available() else "cpu" 113 | device = args.device 114 | 115 | hf_manager = HFManager( 116 | my_repo_id=args.storage.my_repo_id, 117 | averaged_model_repo_id=args.storage.averaged_model_repo_id, 118 | ) 119 | device = args.device 120 | model_name = "openai-community/gpt2" 121 | training_loop = DeltaLoop( 122 | device, 123 | model_name, 124 | data_loader, 125 | send_interval=800, 126 | learning_rate=5e-4, 127 | hf_manager=hf_manager, 128 | ) 129 | training_loop.train(epochs=30_000_000_000_000_000) 130 | 131 | -------------------------------------------------------------------------------- /neurons/validator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import bittensor as bt 3 | from torch.utils.data import DataLoader, Dataset, IterableDataset 4 | 5 | from hivetrain.btt_connector import ( 6 | BittensorNetwork, 7 | # get_validator_uids_and_addresses, 8 | serve_axon, 9 | ) 10 | from bittensor.btlogging import logging 11 | from torch.optim import AdamW, SGD 12 | from torchvision import datasets, transforms 13 | from torch.utils.data import DataLoader, Dataset, IterableDataset 14 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 15 | from datasets import load_dataset 16 | 17 | from hivetrain.chain_manager import ChainMultiAddressStore 18 | from hivetrain.config import Configurator 19 | from hivetrain import __spec_version__ 20 | from hivetrain.validation_logic import DeltaValidator 21 | from hivetrain.hf_manager import HFManager 22 | from hivetrain.training_manager import FeedforwardNN 23 | 24 | logging.enable_debug() 25 | 26 | args = Configurator.combine_configs() 27 | 28 | BittensorNetwork.initialize(args) 29 | my_hotkey = BittensorNetwork.wallet.hotkey.ss58_address 30 | my_uid = BittensorNetwork.metagraph.hotkeys.index(my_hotkey) 31 | 32 | address_store = ChainMultiAddressStore( 33 | BittensorNetwork.subtensor, args.netuid, BittensorNetwork.wallet 34 | ) 35 | 36 | 37 | batch_size = args.batch_size 38 | epochs = 30_000_000_000_000_000 39 | learning_rate = 5e-5 40 | receive_interval = 1800 # Every 60 seconds 41 | 42 | # Load model and tokenizer 43 | # Load the Wikitext dataset 44 | # Load the wikitext dataset, focusing on the test split 45 | dataset = load_dataset("wikitext", "wikitext-103-v1") 46 | 47 | # Get test data (using the first 100 texts for this example) 48 | 49 | texts = dataset["test"]["text"][:100] # FIXME 400? 50 | 51 | # Load the tokenizer and model 52 | model_name = "openai-community/gpt2" 53 | tokenizer = AutoTokenizer.from_pretrained(model_name) 54 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 55 | model = AutoModelForCausalLM.from_pretrained(model_name) 56 | model.resize_token_embeddings(len(tokenizer)) 57 | model.train() 58 | optimizer = AdamW(model.parameters(), lr=5e-5) 59 | 60 | 61 | # Create a custom dataset class for Wikitext 62 | class WikitextDataset(Dataset): 63 | def __init__(self, texts, tokenizer, max_length=512): 64 | self.tokenizer = tokenizer 65 | self.texts = texts 66 | self.max_length = max_length 67 | 68 | def __len__(self): 69 | return len(self.texts) 70 | 71 | def __getitem__(self, idx): 72 | encoding = self.tokenizer( 73 | self.texts[idx], 74 | return_tensors="pt", 75 | padding="max_length", 76 | truncation=True, 77 | max_length=self.max_length, 78 | ) 79 | input_ids = encoding["input_ids"].squeeze() # Remove batch dimension 80 | attention_mask = encoding["attention_mask"].squeeze() 81 | return { 82 | "input_ids": input_ids, 83 | "attention_mask": attention_mask, 84 | "labels": input_ids.clone(), 85 | } 86 | 87 | 88 | # Define a collate function for data batching 89 | def custom_collate_fn(batch): 90 | input_ids = torch.stack([item["input_ids"] for item in batch]) 91 | attention_mask = torch.stack([item["attention_mask"] for item in batch]) 92 | labels = input_ids.clone() # Copy input_ids to labels 93 | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} 94 | 95 | 96 | # Create the test set and DataLoader 97 | test_set = WikitextDataset(texts, tokenizer) 98 | test_loader = DataLoader(test_set, batch_size=8, collate_fn=custom_collate_fn) 99 | # Load your model and other necessary components here 100 | # def __init__(self, model, optimizer, data_loader, bittensor_network=None, chain_manager=None, interval=3600, local_gradient_dir="local_gradients"): 101 | hf_manager = HFManager( 102 | my_repo_id=None, averaged_model_repo_id=args.storage.averaged_model_repo_id 103 | ) 104 | device = "cuda" if torch.cuda.is_available() else "cpu" 105 | validator = DeltaValidator( 106 | device=device, 107 | model=model, 108 | optimizer=optimizer, 109 | data_loader=test_loader, 110 | bittensor_network=BittensorNetwork, 111 | hf_manager=hf_manager, 112 | interval=receive_interval, 113 | chain_manager=address_store, 114 | ) 115 | validator.start_periodic_validation() 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bittensor==6.10.1 2 | requests 3 | substrate-interface 4 | torch 5 | torchvision 6 | datasets 7 | transformers 8 | mlflow 9 | python-dotenv 10 | psutil 11 | pynvml 12 | -------------------------------------------------------------------------------- /run_miner.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/bin/bash 3 | # Inspired by https://github.com/surcyf123/smart-scrape/blob/main/run.sh 4 | 5 | # Initialize variables 6 | script="neurons/miner.py" 7 | autoRunLoc=$(readlink -f "$0") 8 | proc_name="distributed_training_miner" 9 | args=() 10 | version_location="template/__init__.py" 11 | version="__version__ " 12 | repo="bit-current/DistributedTraining" 13 | branch="main" 14 | repo_url="https://github.com/$repo.git" 15 | 16 | 17 | old_args=$@ 18 | echo "=====old_args" 19 | 20 | # Check if pm2 is installed 21 | if ! command -v pm2 &> /dev/null 22 | then 23 | echo "pm2 could not be found. To install see: https://pm2.keymetrics.io/docs/usage/quick-start/" 24 | exit 1 25 | fi 26 | 27 | # Returns the difference between 28 | # two versions as a numerical value. 29 | get_version_difference() { 30 | local tag1="$1" 31 | local tag2="$2" 32 | 33 | # Extract the version numbers from the tags 34 | local version1=$(echo "$tag1" | sed 's/v//') 35 | local version2=$(echo "$tag2" | sed 's/v//') 36 | 37 | # Split the version numbers into an array 38 | IFS='.' read -ra version1_arr <<< "$version1" 39 | IFS='.' read -ra version2_arr <<< "$version2" 40 | 41 | # Calculate the numerical difference 42 | local diff=0 43 | for i in "${!version1_arr[@]}"; do 44 | local num1=${version1_arr[$i]} 45 | local num2=${version2_arr[$i]} 46 | 47 | # Compare the numbers and update the difference 48 | if (( num1 > num2 )); then 49 | diff=$((diff + num1 - num2)) 50 | elif (( num1 < num2 )); then 51 | diff=$((diff + num2 - num1)) 52 | fi 53 | done 54 | 55 | strip_quotes $diff 56 | } 57 | 58 | read_version_value() { 59 | # Read each line in the file 60 | while IFS= read -r line; do 61 | # Check if the line contains the variable name 62 | if [[ "$line" == *"$version"* ]]; then 63 | # Extract the value of the variable 64 | local value=$(echo "$line" | awk -F '=' '{print $2}' | tr -d ' ') 65 | strip_quotes $value 66 | return 0 67 | fi 68 | done < "$version_location" 69 | 70 | echo "" 71 | } 72 | 73 | 74 | check_variable_value_on_github() { 75 | local repo="$1" 76 | local file_path="$2" 77 | local variable_name="$3" 78 | local branch="$4" 79 | 80 | # Simplified URL construction to include the branch directly 81 | local url="https://api.github.com/repos/$repo/contents/$file_path?ref=$branch" 82 | 83 | # Fetch file content from GitHub and decode from Base64 in one go 84 | local variable_value=$(curl -s "$url" | jq -r '.content' | base64 --decode | grep "$variable_name" | cut -d '=' -f 2 | tr -d '[:space:]' | tr -d "'\"") 85 | 86 | if [[ -z "$variable_value" ]]; then 87 | echo "Error: Variable '$variable_name' not found in the file '$file_path' on branch '$branch'." 88 | return 1 89 | else 90 | echo "$variable_value" 91 | fi 92 | } 93 | 94 | check_package_installed() { 95 | local package_name="$1" 96 | os_name=$(uname -s) 97 | 98 | if [[ "$os_name" == "Linux" ]]; then 99 | # Use dpkg-query to check if the package is installed 100 | if dpkg-query -W -f='${Status}' "$package_name" 2>/dev/null | grep -q "installed"; then 101 | return 1 102 | else 103 | return 0 104 | fi 105 | elif [[ "$os_name" == "Darwin" ]]; then 106 | if brew list --formula | grep -q "^$package_name$"; then 107 | return 1 108 | else 109 | return 0 110 | fi 111 | else 112 | echo "Unknown operating system" 113 | return 0 114 | fi 115 | } 116 | 117 | strip_quotes() { 118 | local input="$1" 119 | 120 | # Remove leading and trailing quotes using parameter expansion 121 | local stripped="${input#\"}" 122 | stripped="${stripped%\"}" 123 | 124 | echo "$stripped" 125 | } 126 | 127 | # reclone and install packages 128 | 129 | check_and_clone() { 130 | cd .. 131 | rm -rf $(basename "$repo_url" .git) 132 | if git clone "$repo_url"; then 133 | echo "Successfully cloned repository." 134 | cd $(basename "$repo_url" .git) || exit 135 | # Additional setup after cloning, if necessary 136 | pip install -e . 137 | pm2 restart "$proc_name" # Restart the PM2 process 138 | else 139 | echo "Failed to clone the repository. Please check the URL and your internet connection." 140 | exit 1 141 | fi 142 | } 143 | 144 | 145 | 146 | # enforce what's on main branch 147 | enforce_main() { 148 | git stash 149 | git fetch origin "$branch" 150 | git reset --hard "origin/$branch" 151 | git clean -df 152 | # Additional commands after enforcing main, if necessary 153 | pip install -e . 154 | pm2 restart "$proc_name" # Restart the PM2 process 155 | } 156 | 157 | # Loop through all command line arguments 158 | while [[ $# -gt 0 ]]; do 159 | arg="$1" 160 | 161 | # Check if the argument starts with a hyphen (flag) 162 | if [[ "$arg" == -* ]]; then 163 | # Check if the argument has a value 164 | if [[ $# -gt 1 && "$2" != -* ]]; then 165 | if [[ "$arg" == "--script" ]]; then 166 | script="$2"; 167 | shift 2 168 | else 169 | # Add '=' sign between flag and value 170 | args+=("'$arg'"); 171 | args+=("'$2'"); 172 | shift 2 173 | fi 174 | else 175 | # Add '=True' for flags with no value 176 | args+=("'$arg'"); 177 | shift 178 | fi 179 | else 180 | # Argument is not a flag, add it as it is 181 | args+=("'$arg '"); 182 | shift 183 | fi 184 | done 185 | 186 | # Check if script argument was provided 187 | if [[ -z "$script" ]]; then 188 | echo "The --script argument is required." 189 | exit 1 190 | fi 191 | 192 | local_branch=$(git branch --show-current) # get current branch. 193 | echo watching branch: $local_branch 194 | echo pm2 process name: $proc_name 195 | 196 | # Get the current version locally. 197 | local_version=$(read_version_value) 198 | 199 | # Check if script is already running with pm2 200 | if pm2 status | grep -q $proc_name; then 201 | echo "The script is already running with pm2. Stopping and restarting..." 202 | pm2 delete $proc_name 203 | fi 204 | 205 | # Run the Python script with the arguments using pm2 206 | echo "Running $script with the following pm2 config:" 207 | 208 | # Join the arguments with commas using printf 209 | joined_args=$(printf "%s," "${args[@]}") 210 | 211 | # Remove the trailing comma 212 | joined_args=${joined_args%,} 213 | 214 | # Create the pm2 config file 215 | echo "module.exports = { 216 | apps : [{ 217 | name : '$proc_name', 218 | script : '$script', 219 | interpreter: 'python3', 220 | min_uptime: '5m', 221 | max_restarts: '5', 222 | args: [$joined_args] 223 | }] 224 | }" > app.config.js 225 | 226 | # Print configuration to be used 227 | cat app.config.js 228 | pm2 start app.config.js 229 | 230 | # Check if packages are installed. 231 | check_package_installed "jq" 232 | if [ "$?" -eq 1 ]; then 233 | while true; do 234 | # First ensure that this is a git installation 235 | if [ -d "./.git" ]; then 236 | # Fetch remote changes without applying them 237 | git fetch origin "$branch" 238 | # check value on github remotely 239 | remote_version=$(check_variable_value_on_github "$repo" "$version_location" "$version" "$branch") 240 | 241 | if [ "$local_version" != "$remote_version" ]; then 242 | echo "Version mismatch detected. Local version: $local_version, Remote version: $remote_version." 243 | 244 | if [ "$local_branch" = "$branch" ]; then 245 | # Case 3: On main branch, and versions differ. Delete local and reclone. 246 | echo "On main branch with version mismatch. Recloning..." 247 | check_and_clone 248 | else 249 | # Case 2: On a different branch, enforce main. 250 | echo "On branch $local_branch, enforcing main branch changes..." 251 | enforce_main 252 | fi 253 | 254 | local_version=$(read_version_value) 255 | echo "Repository reset to the latest version." 256 | # Restart autorun script 257 | echo "Restarting script..." 258 | ./$(basename $0) $old_args && exit 259 | else 260 | echo "**Skipping update **" 261 | echo "$local_version is the same as or more than $remote_version. You are likely running locally." 262 | fi 263 | else 264 | echo "The installation does not appear to be done through Git. Please install from source at https://github.com/opentensor/validators and rerun this script." 265 | fi 266 | # wait for 3hrs and then check for changes again 267 | sleep 1800 268 | done 269 | else 270 | echo "Missing package 'jq'. Please install it for your system first." 271 | fi 272 | 273 | 274 | -------------------------------------------------------------------------------- /run_validator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Inspired by https://github.com/surcyf123/smart-scrape/blob/main/run.sh 3 | 4 | # Initialize variables 5 | script="neurons/validator.py" 6 | autoRunLoc=$(readlink -f "$0") 7 | proc_name="distributed_training_validator" 8 | args=() 9 | version_location="./template/__init__.py" 10 | version="__version__" 11 | repo="bit-current/DistributedTraining" 12 | branch="main" 13 | repo_url="https://github.com/$repo.git" 14 | 15 | old_args=$@ 16 | 17 | # Check if pm2 is installed 18 | if ! command -v pm2 &> /dev/null 19 | then 20 | echo "pm2 could not be found. To install see: https://pm2.keymetrics.io/docs/usage/quick-start/" 21 | exit 1 22 | fi 23 | 24 | 25 | # Returns the difference between 26 | # two versions as a numerical value. 27 | get_version_difference() { 28 | local tag1="$1" 29 | local tag2="$2" 30 | 31 | # Extract the version numbers from the tags 32 | local version1=$(echo "$tag1" | sed 's/v//') 33 | local version2=$(echo "$tag2" | sed 's/v//') 34 | 35 | # Split the version numbers into an array 36 | IFS='.' read -ra version1_arr <<< "$version1" 37 | IFS='.' read -ra version2_arr <<< "$version2" 38 | 39 | # Calculate the numerical difference 40 | local diff=0 41 | for i in "${!version1_arr[@]}"; do 42 | local num1=${version1_arr[$i]} 43 | local num2=${version2_arr[$i]} 44 | 45 | # Compare the numbers and update the difference 46 | if (( num1 > num2 )); then 47 | diff=$((diff + num1 - num2)) 48 | elif (( num1 < num2 )); then 49 | diff=$((diff + num2 - num1)) 50 | fi 51 | done 52 | 53 | strip_quotes $diff 54 | } 55 | 56 | read_version_value() { 57 | # Read each line in the file 58 | while IFS= read -r line; do 59 | # Check if the line contains the variable name 60 | if [[ "$line" == *"$version"* ]]; then 61 | # Extract the value of the variable 62 | local value=$(echo "$line" | awk -F '=' '{print $2}' | tr -d ' ') 63 | strip_quotes $value 64 | return 0 65 | fi 66 | done < "$version_location" 67 | 68 | echo "" 69 | } 70 | check_variable_value_on_github() { 71 | local repo="$1" 72 | local file_path="$2" 73 | local variable_name="$3" 74 | local branch="$4" 75 | 76 | # Simplified URL construction to include the branch directly 77 | local url="https://api.github.com/repos/$repo/contents/$file_path?ref=$branch" 78 | 79 | # Fetch file content from GitHub and decode from Base64 in one go 80 | local variable_value=$(curl -s "$url" | jq -r '.content' | base64 --decode | grep "$variable_name" | cut -d '=' -f 2 | tr -d '[:space:]' | tr -d "'\"") 81 | 82 | if [[ -z "$variable_value" ]]; then 83 | echo "Error: Variable '$variable_name' not found in the file '$file_path' on branch '$branch'." 84 | return 1 85 | else 86 | echo "$variable_value" 87 | fi 88 | } 89 | 90 | check_package_installed() { 91 | local package_name="$1" 92 | os_name=$(uname -s) 93 | 94 | if [[ "$os_name" == "Linux" ]]; then 95 | # Use dpkg-query to check if the package is installed 96 | if dpkg-query -W -f='${Status}' "$package_name" 2>/dev/null | grep -q "installed"; then 97 | return 1 98 | else 99 | return 0 100 | fi 101 | elif [[ "$os_name" == "Darwin" ]]; then 102 | if brew list --formula | grep -q "^$package_name$"; then 103 | return 1 104 | else 105 | return 0 106 | fi 107 | else 108 | echo "Unknown operating system" 109 | return 0 110 | fi 111 | } 112 | 113 | strip_quotes() { 114 | local input="$1" 115 | 116 | # Remove leading and trailing quotes using parameter expansion 117 | local stripped="${input#\"}" 118 | stripped="${stripped%\"}" 119 | 120 | echo "$stripped" 121 | } 122 | 123 | # Delete and reclone main branch 124 | check_and_clone() { 125 | cd .. 126 | rm -rf $(basename "$repo_url" .git) 127 | if git clone "$repo_url"; then 128 | echo "Successfully cloned repository." 129 | cd $(basename "$repo_url" .git) || exit 130 | # Additional setup after cloning, if necessary 131 | pip install -e . 132 | pm2 restart "$proc_name" # Restart the PM2 process 133 | else 134 | echo "Failed to clone the repository. Please check the URL and your internet connection." 135 | exit 1 136 | fi 137 | } 138 | 139 | # enforce changes on main to local branch 140 | enforce_main() { 141 | git stash 142 | git fetch origin "$branch" 143 | git reset --hard "origin/$branch" 144 | git clean -df 145 | # Additional commands after enforcing main, if necessary 146 | pip install -e . 147 | pm2 restart "$proc_name" # Restart the PM2 process 148 | } 149 | 150 | # Loop through all command line arguments 151 | while [[ $# -gt 0 ]]; do 152 | arg="$1" 153 | 154 | # Check if the argument starts with a hyphen (flag) 155 | if [[ "$arg" == -* ]]; then 156 | # Check if the argument has a value 157 | if [[ $# -gt 1 && "$2" != -* ]]; then 158 | if [[ "$arg" == "--script" ]]; then 159 | script="$2"; 160 | shift 2 161 | else 162 | # Add '=' sign between flag and value 163 | args+=("'$arg'"); 164 | args+=("'$2'"); 165 | shift 2 166 | fi 167 | else 168 | # Add '=True' for flags with no value 169 | args+=("'$arg'"); 170 | shift 171 | fi 172 | else 173 | # Argument is not a flag, add it as it is 174 | args+=("'$arg '"); 175 | shift 176 | fi 177 | done 178 | 179 | # Check if script argument was provided 180 | if [[ -z "$script" ]]; then 181 | echo "The --script argument is required." 182 | exit 1 183 | fi 184 | 185 | local_branch=$(git branch --show-current) # get current branch. 186 | echo watching branch: $local_branch 187 | echo pm2 process name: $proc_name 188 | 189 | # Get the current version locally. 190 | local_version=$(read_version_value) 191 | 192 | # Check if script is already running with pm2 193 | if pm2 status | grep -q $proc_name; then 194 | echo "The script is already running with pm2. Stopping and restarting..." 195 | pm2 delete $proc_name 196 | fi 197 | 198 | # Run the Python script with the arguments using pm2 199 | echo "Running $script with the following pm2 config:" 200 | 201 | # Join the arguments with commas using printf 202 | joined_args=$(printf "%s," "${args[@]}") 203 | 204 | # Remove the trailing comma 205 | joined_args=${joined_args%,} 206 | 207 | # Create the pm2 config file 208 | echo "module.exports = { 209 | apps : [{ 210 | name : '$proc_name', 211 | script : '$script', 212 | interpreter: 'python3', 213 | min_uptime: '5m', 214 | max_restarts: '5', 215 | args: [$joined_args] 216 | }] 217 | }" > app.config.js 218 | 219 | # Print configuration to be used 220 | cat app.config.js 221 | 222 | pm2 start app.config.js 223 | 224 | # Check if packages are installed. 225 | check_package_installed "jq" 226 | 227 | if [ "$?" -eq 1 ]; then 228 | while true; do 229 | 230 | # First ensure that this is a git installation 231 | if [ -d "./.git" ]; then 232 | # Fetch remote changes without applying them 233 | git fetch origin "$branch" 234 | # check value on github remotely 235 | remote_version=$(check_variable_value_on_github "$repo" "$version_location" "$version" "$branch") 236 | 237 | if [ "$local_version" != "$remote_version" ]; then 238 | echo "Version mismatch detected. Local version: $local_version, Remote version: $remote_version." 239 | 240 | if [ "$local_branch" = "$branch" ]; then 241 | # Case 1: On main branch, and versions differ. Delete local and reclone. 242 | echo "On main branch with version mismatch. Recloning..." 243 | check_and_clone 244 | else 245 | # Case 1: On a different branch, enforce main. 246 | echo "On branch $local_branch, enforcing main branch changes..." 247 | enforce_main 248 | fi 249 | 250 | local_version=$(read_version_value) 251 | echo "Repository reset to the latest version." 252 | # Restart autorun script 253 | echo "Restarting script..." 254 | ./$(basename $0) $old_args && exit 255 | else 256 | echo "**Skipping update **" 257 | echo "$local_version is the same as or more than $remote_version. You are likely running locally." 258 | fi 259 | else 260 | echo "The installation does not appear to be done through Git. Please install from source at https://github.com/opentensor/validators and rerun this script." 261 | fi 262 | sleep 1800 263 | done 264 | else 265 | echo "Missing package 'jq'. Please install it for your system first." 266 | fi 267 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='hivetrain', 5 | version='0.3.2', 6 | author='Hivetrain', 7 | author_email='test@test.com', 8 | description='A short description of your project', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | url='https://github.com/yourgithubusername/your_project_repo', 12 | packages=find_packages(), 13 | include_package_data=True, 14 | install_requires=open('requirements.txt').read().splitlines(), 15 | classifiers=[ 16 | 'Programming Language :: Python :: 3', 17 | 'License :: OSI Approved :: MIT License', 18 | 'Operating System :: OS Independent', 19 | ], 20 | python_requires='>=3.6', 21 | ) 22 | -------------------------------------------------------------------------------- /template/__init__.py: -------------------------------------------------------------------------------- 1 | #For backward compatibility with Auto-Update 2 | 3 | # The MIT License (MIT) 4 | # Copyright © 2023 Yuma Rao 5 | # TODO(developer): Set your name 6 | # Copyright © 2023 7 | 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 9 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 10 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 11 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 14 | # the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 17 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | # DEALINGS IN THE SOFTWARE. 21 | 22 | # version nomenclature = __training_type__.__model__.__other_changes__ 23 | 24 | __version__ = "0.0.34" 25 | 26 | version_split = __version__.split(".") 27 | __spec_version__ = (100 * int(version_split[0])) + (10 * int(version_split[1])) + (1 * int(version_split[2])) 28 | --------------------------------------------------------------------------------