├── .gitignore ├── README.md ├── common.py ├── dataset.py ├── docker_run.sh ├── hparams.json ├── hparams.py ├── miner.py ├── requirements.txt ├── run.sh ├── start.sh ├── tests ├── eval.py ├── legacy_miner.py ├── legacy_validator.py ├── test.ipynb ├── tests3.ipynb └── val.ipynb ├── tools ├── clean.py └── print.py └── validator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | models/ 19 | lm-evaluation-harness/ 20 | venv/ 21 | wandb/ 22 | wandb/** 23 | .env 24 | *.pyc 25 | __pycache__/ 26 | *.pyo 27 | *.pyd 28 | .Python 29 | env/ 30 | venv/ 31 | ENV/ 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | *.manifest 49 | *.spec 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | test_loaders.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ______ _____ _______ ______ _______ _______ __ _ __ _ 2 | |_____] | | | | ____/ | | | |_____| | \ | | \ | 3 | |_____] |_____| |_____ | /_____ | | | | | | \_| | \_| 4 | 5 | --- 6 | 7 | # BOLTZMANN: Bittensor Incentivized Scalable Training with Reward Optimization 8 | 9 | --- 10 | ```bash 11 | /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/unconst/boltzmann/master/run.sh)" 12 | ``` 13 | 14 | --- 15 | 16 | This repository implements an incentive system for decentralized training on the Bittensor network. 17 | 18 | ## Overview 19 | 20 | **BOLTZMANN** is a framework where **miners** collaboratively train a shared model by processing specific subsets of data, and **validators** ensure the quality and integrity of these contributions. The incentive landscape is designed to reward miners for effectively training on their designated data subsets, promoting efficient and collaborative model improvement in a trustless environment. 21 | 22 | ## How It Works 23 | 24 | ### Miners 25 | 26 | - **Model Synchronization**: Miners start by downloading the latest model state, which is a subset of model parameters (called **slices**) aggregated from other miners. 27 | - **Training**: They receive a designated subset of data (pages) from the dataset for each window (a fixed number of blocks). They train the model on this subset, performing a single gradient update. 28 | - **Uploading Deltas**: After training, miners compute the **delta** (the change in their model parameters) and upload this delta to an S3 bucket associated with their identity. 29 | - **Window Progression**: Miners proceed to the next window, repeating the process with new data subsets. 30 | 31 | ### Validators 32 | 33 | - **Model Synchronization**: Validators synchronize their model state by downloading the latest aggregated slices and applying miners' deltas from the previous window. 34 | - **Fetching Deltas**: Validators download the deltas uploaded by miners corresponding to the last window. 35 | - **Evaluation**: Validators evaluate the miners' contributions by comparing the miners' deltas to the gradients computed locally on the same data subsets. 36 | - **Scoring**: The reward for each miner is calculated based on the **cosine similarity** between the miner's delta and the validator's locally computed gradient. This incentivizes miners to provide genuine, high-quality updates that improve the model. 37 | 38 | ## Incentive Mechanism Explained 39 | 40 | The incentive mechanism in **boltzmann** ensures that miners are rewarded for authentic and beneficial contributions to the model's training. By basing rewards on the **cosine similarity** between the miners' updates and the validators' gradients, we promote alignment of miners' efforts with the overall training objectives. 41 | 42 | ### Key Points 43 | 44 | - **Alignment of Objectives**: Miners are motivated to perform authentic training on their assigned data subsets because providing updates that closely match the true gradient direction maximizes their rewards. 45 | - **Positive Contributions**: By submitting deltas that positively impact the model's performance on the evaluation data, miners increase their rewards. 46 | - **Discouraging Malicious Behavior**: Contributions that deviate significantly from the true gradient (e.g., random or adversarial updates) result in lower or negative rewards. 47 | - **Data Subset Specialization**: Miners are evaluated based on their performance on specific data subsets, encouraging them to specialize and optimize their training for those subsets. 48 | - **Fairness**: By not revealing which model slices need to be uploaded until the end of the window, all miners are on a level playing field, preventing exploitation of the system. 49 | 50 | ### Mathematical Details 51 | 52 | #### Notations 53 | 54 | - **$\theta$**: Current model parameters. 55 | - **$\delta_i$**: Delta (model update) contributed by miner **$i$**. 56 | - **$g_i$**: Gradient of the loss with respect to the model parameters on the data subset assigned to miner **$i$**. 57 | - **$\hat{g}_i$**: Validator's locally computed gradient on the same data subset. 58 | - **$s_i$**: Cosine similarity score between **$\delta_i$** and **$\hat{g}_i$**. 59 | - **$R_i$**: Reward assigned to miner **$i$**. 60 | 61 | #### Cosine Similarity Calculation 62 | 63 | The cosine similarity between the miner's delta and the validator's gradient is calculated as: 64 | 65 | $$ 66 | s_i = \frac{\delta_i \cdot \hat{g}_i}{\|\delta_i\| \|\hat{g}_i\|} 67 | $$ 68 | 69 | - **$\delta_i \cdot \hat{g}_i$**: Dot product of the miner's delta and the validator's gradient. 70 | - **$\|\delta_i\|$** and **$\|\hat{g}_i\|$**: Euclidean norms of the miner's delta and the validator's gradient, respectively. 71 | 72 | #### Reward Calculation 73 | 74 | The reward for miner **$i$** is directly proportional to the cosine similarity score: 75 | 76 | $$ 77 | R_i = \alpha \cdot s_i 78 | $$ 79 | 80 | Where **$\alpha$** is a scaling factor determined by the network's economic parameters. 81 | 82 | - A higher cosine similarity **$s_i$** indicates that the miner's update is closely aligned with the true gradient, resulting in a higher reward. 83 | - If **$s_i$** is negative, it indicates that the miner's update is detrimental to the model's performance on the validation data, leading to a lower or negative reward. 84 | 85 | #### Intuition Behind the Mechanism 86 | 87 | - **Positive Reinforcement**: Miners are rewarded for updates that point in the same direction as the true gradient, improving the model. 88 | - **Penalty for Divergence**: Miners submitting random or harmful updates receive lower rewards due to low or negative cosine similarity. 89 | - **Efficient Collaboration**: This mechanism encourages miners to focus on genuine training rather than attempting to game the system. 90 | 91 | ## Installation Guide 92 | 93 | ### Prerequisites 94 | 95 | - **Python 3.8 or higher** 96 | - **Pip** for package management 97 | - **Git** for cloning the repository 98 | - **AWS Account** with access to S3 (Simple Storage Service) 99 | - **Compute Resources**: 100 | - **Miner**: High-performance GPU (e.g., NVIDIA A100 or better) recommended for training. 101 | - **Validator**: Similar compute requirements as miners due to the need to recompute gradients. 102 | 103 | ### Setup Instructions 104 | 105 | 1. **Clone the Repository** 106 | 107 | ```bash 108 | git clone https://github.com/unconst/boltzmann.git 109 | cd boltzmann 110 | ``` 111 | 112 | 2. **Set Up AWS Credentials** 113 | 114 | Configure your AWS credentials to allow read and write access to your S3 bucket: 115 | 116 | ```bash 117 | export AWS_ACCESS_KEY_ID=your_access_key_id 118 | export AWS_SECRET_ACCESS_KEY=your_secret_access_key 119 | ``` 120 | 121 | Ensure that your S3 bucket has the necessary permissions for read and write operations. 122 | 123 | 3. **Install Dependencies** 124 | 125 | It's recommended to use a virtual environment: 126 | 127 | ```bash 128 | python3 -m venv venv 129 | source venv/bin/activate 130 | ``` 131 | 132 | Install the required Python packages: 133 | 134 | ```bash 135 | pip install -r requirements.txt 136 | ``` 137 | 138 | 4. **Configure Environment Variables** 139 | 140 | Create a `.env` file or export the environment variables required by the project. 141 | 142 | 5. **Register on Bittensor Subnet** 143 | 144 | The system runs on a Bittensor subnet. You need to register your miner and validator. 145 | 146 | ```bash 147 | # Replace <> with your actual wallet names and hotkeys. 148 | btcli register --wallet.name --wallet.hotkey --subtensor.network test --netuid 220 149 | ``` 150 | 151 | ## Running the Miner and Validator 152 | 153 | ### Run the Miner 154 | 155 | ```bash 156 | python3 miner.py \ 157 | --wallet.name \ 158 | --wallet.hotkey \ 159 | --subtensor.network test \ 160 | --netuid 220 \ 161 | --bucket \ 162 | --device cuda 163 | ``` 164 | 165 | ### Run the Validator 166 | 167 | ```bash 168 | python3 validator.py \ 169 | --wallet.name \ 170 | --wallet.hotkey \ 171 | --subtensor.network test \ 172 | --netuid 220 \ 173 | --bucket \ 174 | --device cuda 175 | ``` 176 | 177 | ## Hardware Requirements 178 | 179 | Given the computational intensity of training and validating neural networks, it is highly recommended to use machines equipped with high-performance GPUs like NVIDIA A100 or better. Adequate CPU resources and memory are also necessary to handle data loading and preprocessing tasks. 180 | 181 | ## Contributing 182 | 183 | Contributions to the **boltzmann** project are welcome. Please open issues and submit pull requests for improvements and fixes. 184 | 185 | ## License 186 | 187 | This project is licensed under the MIT License © 2024 Chakana.tech. See the [LICENSE](LICENSE) file for details. 188 | 189 | --- 190 | 191 | **Note**: The mathematical formulations and mechanisms described are integral to ensuring the security and efficiency of the decentralized training process. By participating as a miner or validator, you contribute to a collaborative effort to advance decentralized machine learning. 192 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import os 19 | import io 20 | import sys 21 | import uuid 22 | import time 23 | import fcntl 24 | import torch 25 | import uvloop 26 | import hashlib 27 | import asyncio 28 | import logging 29 | import tempfile 30 | import aiofiles 31 | import numpy as np 32 | import aiobotocore 33 | import bittensor as bt 34 | import botocore.config 35 | from typing import List, Dict 36 | from dotenv import dotenv_values 37 | from types import SimpleNamespace 38 | from rich.logging import RichHandler 39 | from filelock import FileLock, Timeout 40 | from aiobotocore.session import get_session 41 | from rich.highlighter import NullHighlighter 42 | 43 | # Configure loguru logger 44 | FORMAT = "%(message)s" 45 | logging.basicConfig( 46 | level=logging.INFO, 47 | format=FORMAT, 48 | datefmt="[%X]", 49 | handlers=[ 50 | RichHandler( 51 | markup=True, 52 | rich_tracebacks=True, 53 | highlighter=NullHighlighter(), 54 | show_level=False, 55 | show_time=False, 56 | show_path=False 57 | ) 58 | ] 59 | ) 60 | logger = logging.getLogger("rich") 61 | logger.setLevel(logging.INFO) 62 | def debug(): 63 | logger.setLevel(logging.DEBUG) 64 | def trace(): 65 | logger.setLevel(logging.TRACE) 66 | # Log helper. 67 | def T(): return time.time() 68 | def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" 69 | 70 | # Load environment variables 71 | env_config = {**dotenv_values(".env"), **os.environ} 72 | AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID') 73 | AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY') 74 | 75 | # Configure the S3 client 76 | client_config = botocore.config.Config( 77 | max_pool_connections=256, 78 | ) 79 | 80 | # Set uvloop as the event loop policy 81 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) 82 | 83 | # Define a semaphore to limit concurrent downloads (adjust as needed) 84 | semaphore = asyncio.Semaphore(1000) 85 | 86 | async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]: 87 | # Attempt to acquire the lock with a timeout of 1 second. 88 | lock: FileLock = FileLock(f"{filename}.lock") 89 | with lock.acquire(timeout=5): 90 | pass 91 | return torch.load( 92 | filename, 93 | map_location=torch.device(device), 94 | weights_only = True, 95 | ) 96 | 97 | async def apply_slices_to_model(model: torch.nn.Module, window: int, seed: str, compression: int, key:str = 'slice') -> List[str]: 98 | """ 99 | Applies slices from a specific window to the given model. 100 | 101 | Args: 102 | model (torch.nn.Module): The PyTorch model to which the slices will be applied. 103 | window (int): The window identifier. 104 | seed (str): The seed used for generating indices. 105 | compression (int): The compression factor. 106 | 107 | Returns: 108 | List[str]: A list of all the slice files that were applied. 109 | """ 110 | # First get the indices associated with the window given the model. 111 | indices_dict = await get_indices_for_window(model, seed, compression) 112 | 113 | # Load all the slices associated with this window. 114 | slice_files = await load_files_for_window(window=window, key = key) 115 | 116 | # Dictionary to keep track of the number of slices applied per parameter. 117 | slices_per_param = {name: 0 for name, _ in model.named_parameters()} 118 | 119 | # Dictionary to accumulate the sum of values for each parameter. 120 | param_sums = {name: torch.zeros_like(param.data) for name, param in model.named_parameters()} 121 | 122 | # Iterate over each slice file and compute the sum of values. 123 | for file_i in slice_files: 124 | # Create a file lock to ensure exclusive access to the slice file. 125 | try: 126 | slice_i = await get_slices(file_i, model.device) 127 | for name, param in model.named_parameters(): 128 | if name not in indices_dict or name not in slice_i: 129 | continue 130 | values = slice_i[name].to(param.data.device) 131 | param_indices = indices_dict[name].to(param.data.device) 132 | param_sums[name].view(-1)[param_indices] += values 133 | slices_per_param[name] += 1 134 | del values 135 | del slice_i 136 | except Timeout: 137 | # The lock could not be acquired within the timeout. 138 | logger.error(f"Timeout occurred while trying to acquire lock on {file_i}") 139 | continue 140 | except Exception as e: 141 | logger.exception(f"Error applying slice from {file_i}: {e}") 142 | 143 | # Apply the average to the parameters. 144 | for name, param in model.named_parameters(): 145 | if name not in slices_per_param or name not in indices_dict or slices_per_param[name] == 0: 146 | continue 147 | param_indices = indices_dict[name].to(param.data.device) 148 | avg_param = param_sums[name].view(-1)[param_indices] / slices_per_param[name] 149 | avg_param = avg_param.to(param.data.dtype) 150 | avg_param = avg_param.to(param.data.device) 151 | param.data.view(-1)[param_indices] = avg_param.clone() 152 | 153 | # Return the list of the files applied. 154 | return slice_files 155 | 156 | async def upload_slice_for_window(bucket: str, model: torch.nn.Module, window: int, seed: str, wallet: 'bt.wallet', compression: int, key:str = 'slice'): 157 | """ 158 | Uploads a compressed slice of a PyTorch model to an S3 bucket. 159 | 160 | Args: 161 | bucket (str): Name of the S3 bucket. 162 | model (torch.nn.Module): The PyTorch model to be sliceed and uploaded. 163 | window (int): The window identifier. 164 | wallet (bt.wallet): The wallet object containing the hotkey. 165 | compression (int): The compression factor. 166 | """ 167 | filename = f'{key}-{window}-{wallet.hotkey.ss58_address}.pt' 168 | logger.debug(f"Uploading slice to S3: {filename}") 169 | 170 | model_state_dict = model.state_dict() 171 | indices = await get_indices_for_window(model, seed, compression) 172 | 173 | # Apply the slice to the model parameters 174 | for name, param in model.named_parameters(): 175 | model_state_dict[name] = param.data.view(-1)[indices[name].to(model.device)].cpu() 176 | 177 | # Create a temporary file and write the sliceed model state dictionary to it 178 | with tempfile.NamedTemporaryFile(delete=False) as temp_file: 179 | torch.save(model_state_dict, temp_file) 180 | temp_file_name = temp_file.name # Store the temporary file name 181 | 182 | # Upload the file to S3 183 | session = get_session() 184 | async with session.create_client( 185 | 's3', 186 | region_name='us-east-1', 187 | config=client_config, 188 | aws_access_key_id=AWS_ACCESS_KEY_ID, 189 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY 190 | ) as s3_client: 191 | try: 192 | with open(temp_file_name, 'rb') as f: 193 | await s3_client.put_object(Bucket=bucket, Key=filename, Body=f) 194 | # Set the object ACL to public-read 195 | await s3_client.put_object_acl( 196 | Bucket=bucket, 197 | Key=filename, 198 | ACL='public-read' 199 | ) 200 | logger.debug(f"Successfully uploaded slice to S3: {filename}") 201 | except Exception: 202 | logger.exception(f"Failed to upload slice {filename} to S3") 203 | finally: 204 | # Clean up the temporary file 205 | os.remove(temp_file_name) 206 | logger.debug(f"Temporary file {temp_file_name} removed") 207 | 208 | async def upload_master(bucket: str, model: torch.nn.Module, wallet: 'bt.wallet'): 209 | """ 210 | Uploads the master PyTorch model to an S3 bucket. 211 | 212 | Args: 213 | bucket (str): Name of the S3 bucket. 214 | model (torch.nn.Module): The PyTorch model to be uploaded. 215 | wallet (bt.wallet): The wallet object containing the hotkey. 216 | """ 217 | upload_filename = f'master-{wallet.hotkey.ss58_address}.pt' 218 | logger.debug(f"Uploading master model to S3: {upload_filename}") 219 | 220 | session = get_session() 221 | async with session.create_client( 222 | 's3', 223 | region_name='us-east-1', 224 | config=client_config, 225 | aws_access_key_id=AWS_ACCESS_KEY_ID, 226 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY 227 | ) as s3_client: 228 | try: 229 | # Create a temporary file and write the model state dictionary to it 230 | with tempfile.NamedTemporaryFile(delete=False) as temp_file: 231 | torch.save(model.state_dict(), temp_file) 232 | temp_file_name = temp_file.name 233 | 234 | # Upload the file to S3 235 | with open(temp_file_name, 'rb') as f: 236 | await s3_client.put_object(Bucket=bucket, Key=upload_filename, Body=f) 237 | # Set the object ACL to public-read 238 | await s3_client.put_object_acl( 239 | Bucket=bucket, 240 | Key=upload_filename, 241 | ACL='public-read' 242 | ) 243 | logger.debug(f"Successfully uploaded master model to S3: {upload_filename}") 244 | except Exception: 245 | logger.exception(f"Failed to upload master model {upload_filename} to S3") 246 | finally: 247 | # Clean up the temporary file 248 | os.remove(temp_file_name) 249 | logger.debug(f"Temporary file {temp_file_name} removed") 250 | 251 | async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: int) -> Dict[str, torch.LongTensor]: 252 | """ 253 | Computes the indices for the given window and compression factor. 254 | 255 | Args: 256 | model (torch.nn.Module): The PyTorch model. 257 | seed (str): The window seed identifier. 258 | compression (int): The compression factor. 259 | 260 | Returns: 261 | Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors. 262 | """ 263 | logger.debug(f"Computing indices for window seed {seed} with compression {compression}") 264 | result = {} 265 | # Seed the random number generator with the seed 266 | seed = int(hashlib.md5(str(seed).encode('utf-8')).hexdigest(), 16) % (2**32) 267 | rng = np.random.default_rng(seed) 268 | for name, param in model.named_parameters(): 269 | # Randomly select indices based on the compression factor 270 | num_indices = max(1, int(param.numel() // compression)) 271 | indices = rng.choice(param.numel(), size=num_indices, replace=False) 272 | result[name] = torch.from_numpy(indices).long().cpu() 273 | return result 274 | 275 | async def download_file(s3_client, bucket: str, filename: str) -> str: 276 | """ 277 | Downloads a file from S3, using parallel downloads for large files. 278 | 279 | Args: 280 | s3_client: The S3 client. 281 | bucket (str): Name of the S3 bucket. 282 | filename (str): The S3 object key (filename). 283 | 284 | Returns: 285 | str: The path to the downloaded file in the temporary directory. 286 | """ 287 | async with semaphore: 288 | temp_file = os.path.join(tempfile.gettempdir(), filename) 289 | # Check if the file exists. 290 | if os.path.exists(temp_file): 291 | logger.debug(f"File {temp_file} already exists, skipping download.") 292 | return temp_file 293 | lock_file = f"{temp_file}.lock" 294 | lock = FileLock(lock_file) 295 | try: 296 | # Try to acquire both locks with a timeout 297 | with lock.acquire(timeout=1): 298 | # Proceed to download the file 299 | logger.debug(f"Downloading file {filename} to {temp_file}") 300 | head_response = await s3_client.head_object(Bucket=bucket, Key=filename) 301 | object_size = head_response['ContentLength'] 302 | CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB 303 | 304 | response = await s3_client.get_object(Bucket=bucket, Key=filename) 305 | async with aiofiles.open(temp_file, 'wb') as outfile: 306 | while True: 307 | chunk = await response['Body'].read(CHUNK_SIZE) 308 | if not chunk: 309 | break 310 | await outfile.write(chunk) 311 | 312 | logger.debug(f"Successfully downloaded file {filename} to {temp_file}") 313 | return temp_file 314 | 315 | except Timeout: 316 | logger.error(f"Timeout occurred while trying to acquire lock on {lock_file}") 317 | return None 318 | except Exception as e: 319 | logger.exception(f"Failed to download file {filename} from bucket {bucket}: {e}") 320 | return None 321 | finally: 322 | # The lock is automatically released when exiting the 'with' block 323 | pass 324 | 325 | async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int): 326 | """ 327 | Handles downloading a single file from S3. 328 | 329 | Args: 330 | s3_client: The S3 client. 331 | bucket (str): Name of the S3 bucket. 332 | filename (str): The S3 object key (filename). 333 | hotkey (str): The hotkey identifier. 334 | window (int): The window identifier. 335 | 336 | Returns: 337 | SimpleNamespace: An object containing file metadata and the path to the downloaded file. 338 | """ 339 | logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}") 340 | temp_file = await download_file(s3_client, bucket, filename) 341 | if temp_file: 342 | return SimpleNamespace(bucket=bucket, hotkey=hotkey, filename=filename, window=window, temp_file=temp_file) 343 | return None 344 | 345 | async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = 'slice'): 346 | """ 347 | Processes an S3 bucket to download files matching the given windows. 348 | 349 | Args: 350 | s3_client: The S3 client. 351 | bucket (str): Name of the S3 bucket. 352 | windows (List[int]): A list of window identifiers. 353 | 354 | Returns: 355 | List[SimpleNamespace]: A list of file metadata and paths for downloaded files. 356 | """ 357 | logger.debug(f"Processing bucket {bucket} for window {windows}") 358 | files = [] 359 | paginator = s3_client.get_paginator('list_objects_v2') 360 | 361 | for window in windows: 362 | prefix = f'{key}-{window}' 363 | logger.debug(f"Listing objects with prefix {prefix}") 364 | async for page in paginator.paginate(Bucket=bucket, Prefix=prefix): 365 | logger.trace(f"Processing page for prefix {prefix}") 366 | if 'Contents' not in page: 367 | logger.trace(f"No contents found for prefix {prefix}") 368 | continue 369 | download_tasks = [] 370 | for obj in page.get('Contents', []): 371 | filename = obj['Key'] 372 | logger.trace(f"Processing object with key {filename}") 373 | try: 374 | parts = filename.split('-') 375 | slice_window = int(parts[1]) 376 | slice_hotkey = parts[2].split('.')[0] 377 | logger.trace(f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}") 378 | if slice_window == window: 379 | download_tasks.append(handle_file(s3_client, bucket, filename, slice_hotkey, slice_window)) 380 | except Exception: 381 | logger.exception(f"Error processing filename {filename}") 382 | continue 383 | # Download the files concurrently 384 | results = await asyncio.gather(*download_tasks) 385 | files.extend([res for res in results if res]) 386 | logger.trace(f"Completed processing page for prefix {prefix}") 387 | logger.trace(f"Completed processing bucket {bucket} for windows {windows}") 388 | return files 389 | 390 | async def download_slices_for_buckets_and_windows(buckets: List[str], windows: List[int], key:str = 'slice') -> Dict[int, List[SimpleNamespace]]: 391 | """ 392 | Downloads files from multiple S3 buckets for the given windows. 393 | 394 | Args: 395 | buckets (List[str]): A list of S3 bucket names. 396 | windows (List[int]): A list of window identifiers. 397 | 398 | Returns: 399 | Dict[int, List[SimpleNamespace]]: A dictionary mapping windows to lists of file metadata and paths. 400 | """ 401 | logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}") 402 | session = get_session() 403 | async with session.create_client( 404 | 's3', 405 | region_name='us-east-1', 406 | config=client_config, 407 | aws_access_key_id=AWS_ACCESS_KEY_ID, 408 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY 409 | ) as s3_client: 410 | tasks = [] 411 | for bucket in set(buckets): 412 | if not bucket: 413 | continue 414 | tasks.append(process_bucket(s3_client, bucket, windows, key)) 415 | results = await asyncio.gather(*tasks) 416 | # Flatten the list of lists 417 | files = [item for sublist in results for item in sublist] 418 | 419 | # Create a dictionary with windows as keys and list of files as values 420 | windows_dict = {} 421 | for file in files: 422 | window = file.window 423 | if window not in windows_dict: 424 | windows_dict[window] = [] 425 | windows_dict[window].append(file) 426 | 427 | logger.debug(f"Downloaded all files grouped by windows: {windows}") 428 | return windows_dict 429 | 430 | async def load_files_for_window(window: int, key: str = 'slice') -> List[str]: 431 | """ 432 | Retrieves the paths to downloaded window files from the temporary directory. 433 | 434 | Args: 435 | window (int): The window identifier. 436 | 437 | Returns: 438 | List[str]: A list of file paths corresponding to the window. 439 | """ 440 | logger.debug(f"Retrieving files for window {window} from temporary directory") 441 | temp_dir = tempfile.gettempdir() 442 | window_files = [] 443 | for filename in os.listdir(temp_dir): 444 | if filename.startswith(f"{key}-{window}-") and filename.endswith(".pt"): 445 | window_files.append(os.path.join(temp_dir, filename)) 446 | logger.debug(f"Found file {filename} for window {window}") 447 | return window_files 448 | 449 | async def delete_files_before_window(window_max: int, key:str = 'slice'): 450 | """ 451 | Deletes all files on the local machine which have a window id before a specific value window_max. 452 | 453 | Args: 454 | window_max (int): The maximum window id. Files with window ids less than this value will be deleted. 455 | """ 456 | logger.debug(f"Deleting files with window id before {window_max}") 457 | temp_dir = tempfile.gettempdir() 458 | for filename in os.listdir(temp_dir): 459 | if filename.startswith(f"{key}-") and ( filename.endswith(".pt") or filename.endswith(".lock") ): 460 | try: 461 | parts = filename.split('-') 462 | window_id = int(parts[1]) 463 | if window_id < window_max: 464 | if os.path.exists(filename): 465 | os.remove(filename) 466 | logger.debug(f"Deleted file {filename}") 467 | except Exception as e: 468 | logger.error(f"Error deleting file {filename}: {e}") 469 | 470 | async def delete_files_from_bucket_before_window(bucket: str, window_max: int, key: str = 'slice'): 471 | """ 472 | Deletes all files in the specified S3 bucket which have a window id before a specific value window_max. 473 | 474 | Args: 475 | bucket (str): The name of the S3 bucket. 476 | window_max (int): The maximum window id. Files with window ids less than this value will be deleted. 477 | """ 478 | logger.debug(f"Deleting files in bucket {bucket} with window id before {window_max}") 479 | session = get_session() 480 | async with session.create_client( 481 | 's3', 482 | region_name='us-east-1', 483 | config=client_config, 484 | aws_access_key_id=AWS_ACCESS_KEY_ID, 485 | aws_secret_access_key=AWS_SECRET_ACCESS_KEY 486 | ) as s3_client: 487 | try: 488 | response = await s3_client.list_objects_v2(Bucket=bucket) 489 | if 'Contents' in response: 490 | for obj in response['Contents']: 491 | filename = obj['Key'] 492 | if filename.startswith(f"{key}-") and filename.endswith(".pt"): 493 | try: 494 | parts = filename.split('-') 495 | window_id = int(parts[1]) 496 | if window_id < window_max: 497 | await s3_client.delete_object(Bucket=bucket, Key=filename) 498 | logger.debug(f"Deleted file {filename} from bucket {bucket}") 499 | except Exception as e: 500 | logger.error(f"Error deleting file {filename} from bucket {bucket}: {e}") 501 | except Exception as e: 502 | logger.error(f"Error listing objects in bucket {bucket}: {e}") 503 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import time 19 | import typing 20 | import random 21 | import requests 22 | import asyncio 23 | import aiohttp 24 | import numpy as np 25 | from tqdm import tqdm 26 | from transformers import AutoTokenizer 27 | from torch.utils.data import IterableDataset 28 | 29 | class SubsetLoader(IterableDataset): 30 | """ 31 | Base class for data-specific subset loader classes. 32 | 33 | # TODO: Make this class abstract 34 | """ 35 | def __init__( 36 | self, 37 | batch_size=None, 38 | sequence_length=None, 39 | num_pages=None, 40 | tokenizer: AutoTokenizer=None, 41 | pack_samples: bool=False, 42 | ): 43 | self.batch_size = batch_size 44 | self.sequence_length = sequence_length 45 | self.num_pages = num_pages 46 | self.tokenizer = tokenizer 47 | self.pack_samples = pack_samples 48 | 49 | self.num_rows_per_page = 100 50 | 51 | # Buffer to hold pages loaded from the api 52 | self.buffer = [] 53 | 54 | # Buffer to hold pages already loaded into a batch 55 | self.used_buffer = [] 56 | 57 | # Buffer to hold padded pages 58 | self.padded_buffer = [] 59 | 60 | self.lock = asyncio.Lock() # For thread-safe operations 61 | 62 | async def fetch_data_for_pages(self, pages): 63 | """ 64 | Set the pages to be used to fill the buffer. Then fetch the page data 65 | to the buffer. 66 | """ 67 | 68 | self.pages = pages 69 | 70 | # Empty the buffer if it is not. 71 | self.buffer = [] 72 | 73 | async with aiohttp.ClientSession() as session: 74 | tasks = [self._fetch_data_for_page(page, session) for page in self.pages] 75 | await asyncio.gather(*tasks) 76 | 77 | async def _fetch_data_for_page(self, page, session): 78 | retry_limit = 10 79 | attempt = 0 80 | while attempt < retry_limit: 81 | config_name, page_number, split = page 82 | 83 | # Create the request parameters 84 | params = dict(dataset=self.name, 85 | config=config_name, 86 | split=split, 87 | offset=page_number, 88 | limit=self.num_rows_per_page 89 | ) 90 | 91 | try: 92 | async with session.get(self.rows_base_url, params=params) as response: 93 | response.raise_for_status() 94 | data = await response.json() 95 | 96 | # Prepare the data to append 97 | buffer_to_append = [] 98 | for row in data["rows"]: 99 | content = row["row"]["text"] 100 | input_ids = self.tokenizer(content, truncation=True)["input_ids"] 101 | buffer_to_append.extend(input_ids) 102 | buffer_to_append.append(self.tokenizer.eos_token_id) 103 | 104 | async with self.lock: 105 | self.buffer.extend(buffer_to_append) 106 | self.pages.append((config_name, page_number, split)) 107 | break # Success, exit retry loop 108 | 109 | except aiohttp.ClientResponseError as e: 110 | attempt += 1 111 | if attempt < retry_limit: 112 | await asyncio.sleep(5) 113 | else: 114 | raise 115 | 116 | def _get_pad_size(self, input_ids): 117 | """ 118 | Get the number of tokens to be padded to the sample to match 119 | the max allowed sequence length. 120 | If sample packing is activated, then return 1 121 | """ 122 | 123 | if self.pack_samples: 124 | return 1 125 | 126 | sample_size = len(input_ids) 127 | 128 | remainder = (sample_size % self.sequence_length) 129 | pad_size = (self.sequence_length - remainder) 130 | 131 | # Apply modulo again to guarantee a pad size of 0 if remainder is 0 132 | pad_size = pad_size % self.sequence_length 133 | 134 | return pad_size 135 | 136 | def _refill_padded_buffer(self): 137 | """ 138 | This methods pulls one page from `self.buffer`, pads it and pushs 139 | it to the `self.padded_buffer`. 140 | """ 141 | 142 | while ( 143 | self.buffer 144 | and len(self.padded_buffer) < self.sequence_length 145 | ): 146 | 147 | input_ids = [] 148 | 149 | # search for EOS token index and cut the buffer at it. 150 | EOS_index = self.buffer.index(self.tokenizer.eos_token_id) 151 | input_ids = self.buffer[:EOS_index+1] 152 | self.buffer =self.buffer[EOS_index+1:] 153 | 154 | self.used_buffer += input_ids 155 | 156 | # Add to padded buffer without the EOS token. 157 | self.padded_buffer += input_ids[:-1] 158 | 159 | # Pad 160 | self.padded_buffer += [self.tokenizer.eos_token_id] * self._get_pad_size(input_ids=input_ids[:-1]) 161 | 162 | def __iter__(self): 163 | self.buffer = self.used_buffer + self.buffer 164 | self.padded_buffer = [] 165 | 166 | # Pad and prepare one page for batching 167 | self._refill_padded_buffer() 168 | 169 | return self 170 | 171 | def __next__(self): 172 | batch = [] 173 | 174 | while len(self.padded_buffer) >= self.sequence_length: 175 | batch.append(self.padded_buffer[: self.sequence_length]) 176 | self.padded_buffer = self.padded_buffer[self.sequence_length :] 177 | self._refill_padded_buffer() 178 | 179 | if len(batch) == self.batch_size: 180 | return np.stack(batch) 181 | 182 | raise StopIteration 183 | 184 | 185 | class DatasetLoader(SubsetLoader): 186 | 187 | name: str = "HuggingFaceFW/fineweb-edu-score-2" 188 | rows_base_url: str = "https://datasets-server.huggingface.co/rows" 189 | size_base_url: str = "https://datasets-server.huggingface.co/size" 190 | 191 | retry_limit: int = 10 # Number of retries 192 | retry_delay: int = 5 # Seconds to wait between retries 193 | num_rows_per_page: int = 100 194 | 195 | @staticmethod 196 | async def next_pages(offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100): 197 | configs_data = await DatasetLoader.fetch_dataset_configs() 198 | rng = np.random.default_rng(hash(seed) & 0xffffffff) # Create a generator with a seed 199 | rng.bit_generator.advance(offset) # Efficiently skip ahead `n` steps 200 | result = [] 201 | for _ in range(n_pages): 202 | config = rng.choice(list(configs_data.keys())) 203 | choice = rng.integers(0, configs_data[config]['num_rows'] - 1 - num_rows_per_page) 204 | result.append((str(config), int(choice), configs_data[config]['split'])) 205 | return result 206 | 207 | def __init__( 208 | self, 209 | batch_size=None, 210 | sequence_length=None, 211 | num_pages=None, 212 | pages_info=None, 213 | tokenizer: AutoTokenizer = None, 214 | pack_samples: bool = False, 215 | ): 216 | super().__init__(batch_size, 217 | sequence_length, 218 | num_pages, 219 | tokenizer, 220 | pack_samples) 221 | 222 | # Initialize properties 223 | self.configs_data = None 224 | self.pages = [] 225 | self.buffer = [] 226 | self.lock = asyncio.Lock() # For thread-safe operations 227 | 228 | @classmethod 229 | async def create( 230 | cls, 231 | batch_size=None, 232 | sequence_length=None, 233 | num_pages=None, 234 | pages_info=None, 235 | tokenizer: AutoTokenizer = None, 236 | pack_samples: bool = False, 237 | ): 238 | self = cls( 239 | batch_size=batch_size, 240 | sequence_length=sequence_length, 241 | num_pages=num_pages, 242 | tokenizer=tokenizer, 243 | pack_samples=pack_samples 244 | ) 245 | 246 | # Fetch dataset configs asynchronously 247 | self.configs_data = await cls.fetch_dataset_configs() 248 | 249 | if pages_info is not None: 250 | await self._fetch(pages_info) 251 | elif self.num_pages: 252 | await self._fetch_data_to_buffer(self.num_pages) 253 | 254 | return self 255 | 256 | async def _fetch(self, page_info: typing.Tuple[str, int, str]): 257 | self.pages = list(page_info) 258 | num_pages = len(self.pages) 259 | async with aiohttp.ClientSession() as session: 260 | tasks = [self._fetch_data_for_page((config_name, page, split), session) 261 | for (config_name, page, split) in self.pages] 262 | await asyncio.gather(*tasks) 263 | 264 | async def _fetch_data_to_buffer(self, num_pages): 265 | """ 266 | Randomly sample pages and add their data to the buffer. 267 | If a page is inaccessible, another one is sampled. 268 | This method sets the `pages` property. 269 | """ 270 | self.pages = [] 271 | pages_to_fetch = self.get_random_pages(num_pages) 272 | 273 | async with aiohttp.ClientSession() as session: 274 | tasks = [self._fetch_data_for_page(page, session) for page in pages_to_fetch] 275 | await asyncio.gather(*tasks) 276 | 277 | async def fetch_data_to_rows(self, num_pages): 278 | rows = [] 279 | pages_to_fetch = self.get_random_pages(num_pages) 280 | 281 | async with aiohttp.ClientSession() as session: 282 | tasks = [self._fetch_rows_for_page(page, session) for page in pages_to_fetch] 283 | results = await asyncio.gather(*tasks) 284 | for page_rows in results: 285 | rows.extend(page_rows) 286 | 287 | return rows 288 | 289 | 290 | async def _fetch_data_for_page(self, page, session): 291 | """ 292 | Fetches data asynchronously for a single page, processes it without blocking the event loop, 293 | and appends the tokenized data to the buffer. 294 | 295 | Args: 296 | page: A tuple containing the config name, page number, and split. 297 | session: The HTTP session used for making requests. 298 | 299 | Raises: 300 | Exception: If the maximum number of retry attempts is exceeded. 301 | """ 302 | retry_limit = self.retry_limit 303 | attempt = 0 304 | while attempt < retry_limit: 305 | config_name, page_number, split = page 306 | 307 | # Create the request parameters 308 | params = { 309 | 'dataset': self.name, 310 | 'config': config_name, 311 | 'split': split, 312 | 'offset': page_number, 313 | 'limit': self.num_rows_per_page 314 | } 315 | 316 | try: 317 | # Make an asynchronous HTTP GET request to fetch the data 318 | async with session.get(self.rows_base_url, params=params) as response: 319 | response.raise_for_status() # Raise an exception for HTTP errors 320 | data = await response.json() 321 | 322 | # Prepare the data to append 323 | buffer_to_append = [] 324 | 325 | # Asynchronously process each row without blocking the event loop 326 | tasks = [ 327 | self._tokenize_content(row["row"]["text"]) for row in data["rows"] 328 | ] 329 | 330 | # Gather the tokenized results concurrently 331 | row_input_ids = await asyncio.gather(*tasks) 332 | 333 | # Flatten the list of input IDs and append them to the buffer 334 | for input_ids in row_input_ids: 335 | buffer_to_append.extend(input_ids) 336 | 337 | # Safely append the processed data to the shared buffer 338 | async with self.lock: 339 | self.buffer.extend(buffer_to_append) 340 | self.pages.append((config_name, page_number, split)) 341 | break # Success, exit retry loop 342 | 343 | except aiohttp.ClientResponseError as e: 344 | # Handle HTTP client errors with a retry mechanism 345 | attempt += 1 346 | if attempt < retry_limit: 347 | await asyncio.sleep(self.retry_delay) # Wait before retrying 348 | else: 349 | raise Exception(f"Maximum retry attempts exceeded for page {page}") from e 350 | 351 | async def _tokenize_content(self, content): 352 | """ 353 | Asynchronously tokenizes a string of content using the tokenizer in a separate thread. 354 | 355 | Args: 356 | content: The text content to be tokenized. 357 | 358 | Returns: 359 | The list of token IDs for the content, including the EOS token. 360 | """ 361 | # Offload the CPU-bound tokenization to a thread executor to prevent blocking the event loop 362 | input_ids = await asyncio.to_thread( 363 | self.tokenizer.encode, content, truncation=True, max_length=self.sequence_length 364 | ) 365 | input_ids.append(self.tokenizer.eos_token_id) 366 | return input_ids 367 | 368 | async def _fetch_rows_for_page(self, page, session): 369 | retry_limit = self.retry_limit 370 | attempt = 0 371 | while attempt < retry_limit: 372 | config_name, page_number, split = page 373 | 374 | # Create the request parameters 375 | params = dict(dataset=self.name, 376 | config=config_name, 377 | split=split, 378 | offset=page_number, 379 | limit=self.num_rows_per_page 380 | ) 381 | 382 | try: 383 | async with session.get(self.rows_base_url, params=params) as response: 384 | response.raise_for_status() 385 | data = await response.json() 386 | 387 | # Collect the rows 388 | return [row["row"]["text"] for row in data["rows"]] 389 | 390 | except aiohttp.ClientResponseError as e: 391 | attempt += 1 392 | if attempt < retry_limit: 393 | await asyncio.sleep(self.retry_delay) 394 | else: 395 | raise 396 | 397 | def get_random_pages(self, num_pages): 398 | """ 399 | Randomly sample pages. 400 | A page is a row number of a given split of a given dataset dump. 401 | """ 402 | pages = [] 403 | 404 | for _ in range(num_pages): 405 | # Choose a random config 406 | config_name = random.choice(list(self.configs_data.keys())) 407 | 408 | # Choose a random page (row) 409 | page = random.randint(0, 410 | self.configs_data[config_name]['num_rows'] - 1 - self.num_rows_per_page) 411 | 412 | split = self.configs_data[config_name]['split'] 413 | 414 | pages.append((config_name, page, split)) 415 | 416 | return pages 417 | 418 | def get_page_names(self): 419 | """ 420 | This is a utility function that returns the page names that were used. 421 | Each page as a single string instead of a tuple. 422 | """ 423 | page_names = [] 424 | 425 | if hasattr(self, 'pages'): 426 | page_names = [f'{cfg_name}_{num_rows}_{split}' for 427 | cfg_name, num_rows, split in self.pages] 428 | 429 | return page_names 430 | 431 | @staticmethod 432 | async def fetch_dataset_configs() -> typing.Dict[str, typing.Dict]: 433 | """ 434 | Fetch the different dump names, aka configs, aka samples, of the 435 | dataset. 436 | The returned value is a dictionary with dump names as keys and 437 | a dict of the number of rows and the split as values. 438 | """ 439 | # Request parameters 440 | params = dict( 441 | dataset=DatasetLoader.name 442 | ) 443 | 444 | attempt = 0 445 | while attempt < DatasetLoader.retry_limit: 446 | try: 447 | async with aiohttp.ClientSession() as session: 448 | async with session.get(DatasetLoader.size_base_url, params=params) as response: 449 | response.raise_for_status() 450 | 451 | data = await response.json() 452 | 453 | # Extract the configs dict 454 | configs_dict = data['size']['splits'] 455 | 456 | # Now create a dict with config names (except 'default') as 457 | # keys, and the number of rows as values 458 | configs_data = {entry['config']: {'num_rows': entry['num_rows'], 459 | 'split': entry['split']} 460 | for entry in configs_dict 461 | if entry['config'] != 'default' 462 | } 463 | 464 | return configs_data 465 | 466 | except aiohttp.ClientResponseError as e: 467 | attempt += 1 468 | if attempt < DatasetLoader.retry_limit: 469 | await asyncio.sleep(DatasetLoader.retry_delay) 470 | else: 471 | raise 472 | 473 | @staticmethod 474 | async def next_pages_async(offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100): 475 | configs_data = await DatasetLoader.fetch_dataset_configs() 476 | rng = np.random.default_rng(hash(seed) & 0xffffffff) # Create a generator with a seed 477 | rng.bit_generator.advance(offset) # Efficiently skip ahead `n` steps 478 | result = [] 479 | for _ in range(n_pages): 480 | config = rng.choice(list(configs_data.keys())) 481 | choice = rng.integers(0, configs_data[config]['num_rows'] - 1 - num_rows_per_page) 482 | result.append((str(config), int(choice), configs_data[config]['split'])) 483 | return result 484 | -------------------------------------------------------------------------------- /docker_run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # The MIT License (MIT) 4 | # © 2024 Chakana.tech 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 12 | # the Software. 13 | 14 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 15 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | set -euo pipefail 21 | 22 | trap 'abort "An unexpected error occurred."' ERR 23 | 24 | # Set up colors and styles 25 | if [[ -t 1 ]]; then 26 | tty_escape() { printf "\033[%sm" "$1"; } 27 | else 28 | tty_escape() { :; } 29 | fi 30 | tty_mkbold() { tty_escape "1;$1"; } 31 | tty_blue="$(tty_mkbold 34)" 32 | tty_red="$(tty_mkbold 31)" 33 | tty_green="$(tty_mkbold 32)" 34 | tty_yellow="$(tty_mkbold 33)" 35 | tty_bold="$(tty_mkbold 39)" 36 | tty_reset="$(tty_escape 0)" 37 | 38 | ohai() { 39 | printf "${tty_blue}==>${tty_bold} %s${tty_reset}\n" "$*" 40 | } 41 | 42 | pdone() { 43 | printf "${tty_green}[✔]${tty_bold} %s${tty_reset}\n" "$*" 44 | } 45 | 46 | info() { 47 | printf "${tty_green}%s${tty_reset}\n" "$*" 48 | } 49 | 50 | warn() { 51 | printf "${tty_yellow}Warning${tty_reset}: %s\n" "$*" >&2 52 | } 53 | 54 | error() { 55 | printf "${tty_red}Error${tty_reset}: %s\n" "$*" >&2 56 | } 57 | 58 | abort() { 59 | error "$@" 60 | exit 1 61 | } 62 | 63 | getc() { 64 | local save_state 65 | save_state="$(/bin/stty -g)" 66 | /bin/stty raw -echo 67 | IFS='' read -r -n 1 -d '' "$@" 68 | /bin/stty "${save_state}" 69 | } 70 | 71 | wait_for_user() { 72 | local c 73 | echo 74 | echo "Press ${tty_bold}RETURN${tty_reset}/${tty_bold}ENTER${tty_reset} to continue or any other key to abort:" 75 | getc c 76 | # we test for \r and \n because some stuff does \r instead 77 | if ! [[ "${c}" == $'\r' || "${c}" == $'\n' ]] 78 | then 79 | exit 1 80 | fi 81 | } 82 | 83 | execute() { 84 | ohai "Running: $*" 85 | if ! "$@"; then 86 | abort "Failed during: $*" 87 | fi 88 | } 89 | 90 | have_root_access() { 91 | if [ "$EUID" -ne 0 ]; then 92 | warn "This script requires root privileges to install packages." 93 | return 1 94 | fi 95 | return 0 96 | } 97 | 98 | clear 99 | echo "" 100 | echo "" 101 | echo " ______ _____ _______ ______ _______ _______ __ _ __ _" 102 | echo " |_____] | | | | ____/ | | | |_____| | \ | | \ |" 103 | echo " |_____] |_____| |_____ | /_____ | | | | | | \_| | \_|" 104 | echo " " 105 | echo "" 106 | echo "" 107 | 108 | wait_for_user 109 | 110 | # Install Git if not present 111 | if ! command -v git &> /dev/null; then 112 | ohai "Installing git ..." 113 | if have_root_access; then 114 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 115 | ohai "Detected Linux" 116 | if [ -f /etc/os-release ]; then 117 | . /etc/os-release 118 | if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then 119 | ohai "Detected Ubuntu, installing Git..." 120 | execute apt-get update -y 121 | execute apt-get install -y git 122 | else 123 | warn "Unsupported Linux distribution: $ID" 124 | abort "Cannot install Git automatically" 125 | fi 126 | else 127 | warn "Cannot detect Linux distribution" 128 | abort "Cannot install Git automatically" 129 | fi 130 | elif [[ "$OSTYPE" == "darwin"* ]]; then 131 | ohai "Detected macOS, installing Git..." 132 | execute xcode-select --install 133 | else 134 | abort "Unsupported OS type: $OSTYPE" 135 | fi 136 | else 137 | abort "Root access is required to install Git." 138 | fi 139 | else 140 | pdone "Found Git" 141 | fi 142 | 143 | # Check if we are inside the cont repository 144 | if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then 145 | REPO_PATH="." 146 | else 147 | if [ ! -d "cont" ]; then 148 | ohai "Cloning boltzmann ..." 149 | execute git clone https://github.com/unconst/cont 150 | REPO_PATH="cont/" 151 | else 152 | REPO_PATH="cont/" 153 | fi 154 | fi 155 | pdone "Pulled Boltzmann $REPO_PATH" 156 | 157 | # Install Node.js and npm if not present 158 | if ! command -v npm &> /dev/null; then 159 | ohai "Installing Node.js and npm ..." 160 | if have_root_access; then 161 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 162 | ohai "Detected Linux" 163 | execute apt-get update -y 164 | execute apt-get install -y curl 165 | curl -fsSL https://deb.nodesource.com/setup_20.x | bash - 166 | execute apt-get install -y nodejs 167 | elif [[ "$OSTYPE" == "darwin"* ]]; then 168 | ohai "Detected macOS, installing Node.js and npm..." 169 | execute brew install node 170 | else 171 | abort "Unsupported OS type: $OSTYPE" 172 | fi 173 | else 174 | abort "Root access is required to install Node.js and npm." 175 | fi 176 | pdone "Installed Node.js and npm" 177 | else 178 | pdone "Found npm" 179 | fi 180 | 181 | # Install pm2 182 | if ! command -v pm2 &> /dev/null; then 183 | ohai "Installing pm2 ..." 184 | execute npm install pm2 -g 185 | pdone "Installed pm2" 186 | else 187 | pdone "Found pm2" 188 | fi 189 | 190 | # Install Python 3.12 if not installed 191 | if ! command -v python3.12 &> /dev/null; then 192 | ohai "Installing python3.12 ..." 193 | if have_root_access; then 194 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 195 | ohai "Detected Linux" 196 | if [ -f /etc/os-release ]; then 197 | . /etc/os-release 198 | if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then 199 | ohai "Detected Ubuntu, installing Python 3.12..." 200 | execute apt-get update -y 201 | execute apt-get install -y software-properties-common gnupg 202 | 203 | # Add the deadsnakes PPA manually 204 | ohai "Adding deadsnakes PPA manually..." 205 | execute mkdir -p /etc/apt/keyrings 206 | execute curl -fsSL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x6A755776" | gpg --dearmor --batch --yes -o /etc/apt/keyrings/deadsnakes-archive-keyring.gpg 207 | echo "deb [signed-by=/etc/apt/keyrings/deadsnakes-archive-keyring.gpg] http://ppa.launchpad.net/deadsnakes/ppa/ubuntu jammy main" > /etc/apt/sources.list.d/deadsnakes-ppa.list 208 | 209 | execute apt-get update -y 210 | execute apt-get install -y python3.12 python3.12-venv 211 | 212 | else 213 | warn "Unsupported Linux distribution: $ID" 214 | abort "Cannot install Python 3.12 automatically" 215 | fi 216 | else 217 | warn "Cannot detect Linux distribution" 218 | abort "Cannot install Python 3.12 automatically" 219 | fi 220 | elif [[ "$OSTYPE" == "darwin"* ]]; then 221 | ohai "Detected macOS, installing Python 3.12..." 222 | execute brew install python@3.12 223 | else 224 | abort "Unsupported OS type: $OSTYPE" 225 | fi 226 | else 227 | abort "Root access is required to install Python 3.12." 228 | fi 229 | pdone "Installed python3.12" 230 | else 231 | pdone "Found python3.12" 232 | fi 233 | 234 | touch ~/.bash_profile 235 | 236 | # Prompt the user for AWS credentials and inject them into the bash_profile file if not already stored 237 | if ! grep -q "AWS_ACCESS_KEY_ID" ~/.bash_profile || ! grep -q "AWS_SECRET_ACCESS_KEY" ~/.bash_profile || ! grep -q "BUCKET" ~/.bash_profile; then 238 | clear 239 | warn "This script will store your AWS credentials in your ~/.bash_profile file." 240 | warn "This is not secure and is not recommended." 241 | read -p "Do you want to proceed? [y/N]: " proceed 242 | if [[ "$proceed" != "y" && "$proceed" != "Y" ]]; then 243 | abort "Aborted by user." 244 | fi 245 | 246 | read -p "Enter your AWS Access Key ID: " AWS_ACCESS_KEY_ID 247 | read -p "Enter your AWS Secret Access Key: " AWS_SECRET_ACCESS_KEY 248 | read -p "Enter your S3 Bucket Name: " BUCKET 249 | 250 | echo "export AWS_ACCESS_KEY_ID=\"$AWS_ACCESS_KEY_ID\"" >> ~/.bash_profile 251 | echo "export AWS_SECRET_ACCESS_KEY=\"$AWS_SECRET_ACCESS_KEY\"" >> ~/.bash_profile 252 | echo "export BUCKET=\"$BUCKET\"" >> ~/.bash_profile 253 | fi 254 | 255 | # Source the bash_profile file to apply the changes 256 | source ~/.bash_profile 257 | 258 | pdone "Found AWS credentials" 259 | 260 | # Create a virtual environment if it does not exist 261 | if [ ! -d "$REPO_PATH/venv" ]; then 262 | ohai "Creating virtual environment at $REPO_PATH..." 263 | execute python3.12 -m venv "$REPO_PATH/venv" 264 | fi 265 | pdone "Created venv at $REPO_PATH" 266 | 267 | if [[ -z "${VIRTUAL_ENV:-}" ]]; then 268 | ohai "Activating virtual environment..." 269 | source "$REPO_PATH/venv/bin/activate" 270 | fi 271 | pdone "Activated venv at $REPO_PATH" 272 | 273 | ohai "Installing requirements..." 274 | execute pip install --upgrade pip 275 | execute pip install -r "$REPO_PATH/requirements.txt" 276 | pdone "Installed requirements" 277 | 278 | # Check for GPUs 279 | ohai "Checking for GPUs..." 280 | if ! command -v nvidia-smi &> /dev/null; then 281 | warn "nvidia-smi command not found. Please ensure NVIDIA drivers are installed." 282 | NUM_GPUS=0 283 | else 284 | NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 285 | ohai "Number of GPUs: $NUM_GPUS" 286 | 287 | if [ "$NUM_GPUS" -gt 0 ]; then 288 | ohai "GPU Memory Information:" 289 | nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | while read -r memory; do 290 | ohai "$((memory / 1024)) GB" 291 | done 292 | else 293 | warn "No GPUs found on this machine." 294 | fi 295 | fi 296 | 297 | # Check system RAM 298 | ohai "Checking system RAM..." 299 | if command -v free &> /dev/null; then 300 | TOTAL_RAM=$(free -g | awk '/^Mem:/{print $2}') 301 | ohai "Total System RAM: $TOTAL_RAM GB" 302 | else 303 | warn "Cannot determine system RAM. 'free' command not found." 304 | fi 305 | 306 | # Create the default key 307 | ohai "Creating the coldkey" 308 | if ! python3.12 -c "import bittensor as bt; w = bt.wallet(); print(w.coldkey_file.exists_on_device())" | grep -q "True"; then 309 | execute btcli w new_coldkey --wallet.path ~/.bittensor/wallets --wallet.name default --n-words 12 --no_password 310 | else 311 | ohai "Default key already exists on device." 312 | fi 313 | 314 | # Ensure btcli is installed 315 | if ! command -v btcli &> /dev/null; then 316 | abort "btcli command not found. Please ensure it is installed." 317 | fi 318 | 319 | # Create hotkeys and register them 320 | if [ "$NUM_GPUS" -gt 0 ]; then 321 | for i in $(seq 0 $((NUM_GPUS - 1))); do 322 | # Check if the hotkey file exists on the device 323 | exists_on_device=$(python3.12 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); print(w.hotkey_file.exists_on_device())" 2>/dev/null) 324 | if [ "$exists_on_device" != "True" ]; then 325 | echo "n" | btcli wallet new_hotkey --wallet.name default --wallet.hotkey C$i --n-words 12 > /dev/null 2>&1; 326 | else 327 | ohai "Hotkey C$i already exists on device." 328 | fi 329 | 330 | # Check if the hotkey is registered on subnet 220 331 | is_registered=$(python3.12 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); sub = bt.subtensor('test'); print(sub.is_hotkey_registered_on_subnet(hotkey_ss58=w.hotkey.ss58_address, netuid=220))" 2>/dev/null) 332 | if [[ "$is_registered" != *"True"* ]]; then 333 | ohai "Registering key on subnet 220" 334 | btcli subnet pow_register --wallet.name default --wallet.hotkey C$i --netuid 220 --subtensor.network test --no_prompt > /dev/null 2>&1; 335 | else 336 | ohai "Key is already registered on subnet 220" 337 | fi 338 | done 339 | else 340 | warn "No GPUs found. Skipping hotkey creation." 341 | fi 342 | 343 | ohai "Logging into wandb..." 344 | execute wandb login 345 | 346 | # Delete items from bucket 347 | PROJECT=${2:-aesop} 348 | ohai "Cleaning bucket $BUCKET..." 349 | execute python3.12 "$REPO_PATH/tools/clean.py" --bucket "$BUCKET" 350 | 351 | # Start all the processes again 352 | if [ "$NUM_GPUS" -gt 0 ]; then 353 | for i in $(seq 0 $((NUM_GPUS - 1))); do 354 | # Adjust GPU index for zero-based numbering 355 | GPU_INDEX=$i 356 | GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | sed -n "$((i + 1))p") 357 | if [ -z "$GPU_MEMORY" ]; then 358 | warn "Could not get GPU memory for GPU $i" 359 | continue 360 | fi 361 | # Determine batch size based on GPU memory 362 | # This section adjusts the batch size for the miner based on the available GPU memory 363 | # Higher memory allows for larger batch sizes, which can improve performance 364 | if [ "$GPU_MEMORY" -ge 80000 ]; then 365 | # For GPUs with 80GB or more memory, use a batch size of 6 366 | BATCH_SIZE=6 367 | elif [ "$GPU_MEMORY" -ge 40000 ]; then 368 | # For GPUs with 40GB to 79GB memory, use a batch size of 3 369 | BATCH_SIZE=3 370 | elif [ "$GPU_MEMORY" -ge 20000 ]; then 371 | # For GPUs with 20GB to 39GB memory, use a batch size of 1 372 | BATCH_SIZE=1 373 | else 374 | # For GPUs with less than 20GB memory, also use a batch size of 1 375 | # This ensures that even lower-end GPUs can still participate 376 | BATCH_SIZE=1 377 | fi 378 | ohai "Starting miner on GPU $GPU_INDEX with batch size $BATCH_SIZE..." 379 | execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" 380 | done 381 | else 382 | warn "No GPUs found. Skipping miner startup." 383 | fi 384 | 385 | pm2 list 386 | 387 | ohai "Script completed successfully." 388 | -------------------------------------------------------------------------------- /hparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "epoch_length": 250000, 3 | "compression": 100, 4 | "sequence_length": 2048, 5 | "tokenizer_name": "togethercomputer/LLaMA-2-7B-32K", 6 | "num_hidden_layers": 16, 7 | "hidden_size": 2048, 8 | "intermediate_size": 8192, 9 | "num_attention_heads": 8, 10 | "num_key_value_heads": 8, 11 | "activation_function": "swiGLU", 12 | "max_position_embeddings": 2048, 13 | "mixed_precision_param": "bfloat16", 14 | "mixed_precision_reduce": "float32", 15 | "window_length": 2, 16 | "desired_batch_size": 512, 17 | "learning_rate": 7.5e-05, 18 | "optimizer_beta1": 0.9, 19 | "optimizer_beta2": 0.95, 20 | "optimizer_weight_decay": 0.1, 21 | "grad_clip": 1.0, 22 | "cosine_epoch_length": 5000, 23 | "eta_min": 1e-05, 24 | "max_history": 10, 25 | "window_speed": 100, 26 | "validator_moving_alpha": 0.999, 27 | "validator_norm_regularization": 0.01, 28 | "validator_weights_temperature": 10, 29 | "validator_window_eval_size": 3, 30 | "valdiator_sample_rate": 0.01 31 | } -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import os 19 | import json 20 | import time 21 | import requests 22 | from types import SimpleNamespace 23 | from transformers import AutoTokenizer, LlamaConfig 24 | 25 | from common import * 26 | 27 | # Cache file path 28 | HPARAMS_FILE = "hparams.json" 29 | 30 | def create_namespace(hparams: dict) -> SimpleNamespace: 31 | """ 32 | Create a SimpleNamespace from the hyperparameters and add model configuration. 33 | 34 | Args: 35 | hparams (dict): Hyperparameters dictionary. 36 | 37 | Returns: 38 | SimpleNamespace: Namespace containing hyperparameters and model configuration. 39 | """ 40 | hparams_ns = SimpleNamespace(**hparams) 41 | 42 | hparams_ns.tokenizer = AutoTokenizer.from_pretrained( 43 | hparams_ns.tokenizer_name, verbose=False, clean_up_tokenization_spaces=True 44 | ) 45 | hparams_ns.tokenizer.pad_token = hparams_ns.tokenizer.eos_token 46 | 47 | hparams_ns.model_config = LlamaConfig( 48 | vocab_size=hparams_ns.tokenizer.vocab_size, 49 | hidden_size=hparams_ns.hidden_size, 50 | num_hidden_layers=hparams_ns.num_hidden_layers, 51 | num_attention_heads=hparams_ns.num_attention_heads, 52 | intermediate_size=hparams_ns.intermediate_size, 53 | num_key_value_heads=hparams_ns.num_key_value_heads, 54 | activation_function=hparams_ns.activation_function, 55 | max_position_embeddings=hparams_ns.max_position_embeddings, 56 | ) 57 | 58 | return hparams_ns 59 | 60 | def load_hparams() -> SimpleNamespace: 61 | """ 62 | Load hyperparameters from a GitHub file, with caching and fallback mechanisms. 63 | 64 | Returns: 65 | SimpleNamespace: A namespace containing the hyperparameters and model configuration. 66 | 67 | Example: 68 | hparams = load_hparams() 69 | print(hparams.hidden_size) 70 | print(hparams.model_config) 71 | """ 72 | github_url = f"https://raw.githubusercontent.com/unconst/cont/master/hparams.json?timestamp={int(time.time())}" 73 | try: 74 | # Attempt to fetch from the GitHub file first 75 | response = requests.get(github_url, timeout=10, headers={'Cache-Control': 'no-cache'}) 76 | response.raise_for_status() 77 | hparams = json.loads(response.text) 78 | logger.debug("Successfully loaded parameters from GitHub.") 79 | except (requests.RequestException, json.JSONDecodeError) as e: 80 | logger.debug(f"Error loading parameters from GitHub: {e}") 81 | logger.debug("Attempting to load from cache...") 82 | with open(HPARAMS_FILE, "r") as f: 83 | hparams = json.load(f) 84 | # Cache the new parameters 85 | with open(HPARAMS_FILE, "w") as f: 86 | json.dump(hparams, f, indent=4) 87 | return create_namespace(hparams) 88 | -------------------------------------------------------------------------------- /miner.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | # fmt: off 18 | 19 | # Global imports. 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import wandb 25 | import torch 26 | import random 27 | import asyncio 28 | import argparse 29 | import threading 30 | import traceback 31 | from tqdm import tqdm 32 | import bittensor as bt 33 | from typing import List 34 | import torch.optim as optim 35 | from dotenv import dotenv_values 36 | from transformers import LlamaForCausalLM 37 | from torch.optim.lr_scheduler import CosineAnnealingLR 38 | 39 | # Import local files. 40 | from common import * 41 | from hparams import load_hparams 42 | from dataset import DatasetLoader 43 | 44 | # GPU optimizations. 45 | torch.backends.cudnn.benchmark = True 46 | torch.backends.cuda.matmul.allow_tf32 = True 47 | torch.backends.cudnn.allow_tf32 = True 48 | 49 | class Miner: 50 | 51 | @staticmethod 52 | def config(): 53 | parser = argparse.ArgumentParser(description='Miner script') 54 | parser.add_argument('--project', type=str, default='aesop2', help='Optional wandb project name') 55 | parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') 56 | parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') 57 | parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') 58 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') 59 | parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') 60 | parser.add_argument('--remote', action='store_true', help='Connect to other buckets') 61 | parser.add_argument('--debug', action='store_true', help='Enable debug logging') 62 | parser.add_argument('--trace', action='store_true', help='Enable trace logging') 63 | parser.add_argument('--random', action='store_true', help='Train on random') 64 | parser.add_argument('--sync_state', action='store_true', help='Syncs the model state by pulling from the history.') 65 | parser.add_argument('--baseline', action='store_true', help='Dont perform syncing with other peers, just train.') 66 | bt.wallet.add_args(parser) 67 | bt.subtensor.add_args(parser) 68 | config = bt.config(parser) 69 | config.subtensor.network = 'test' 70 | config.subtensor.chain_endpoint = 'wss://test.finney.opentensor.ai:443/' 71 | if config.debug: debug() 72 | if config.trace: trace() 73 | return config 74 | 75 | def __init__(self): 76 | # Init config. 77 | self.config = Miner.config() 78 | logger.info('\n' + '-' * 40 + ' Config ' + '-' * 40) 79 | logger.info(self.config) 80 | 81 | # Init bittensor objects. 82 | self.wallet = bt.wallet(config=self.config) 83 | self.subtensor = bt.subtensor(config=self.config) 84 | self.metagraph = self.subtensor.metagraph(netuid=self.config.netuid) 85 | if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys: 86 | raise ValueError(f'Wallet {self.wallet} is not registered on subnet: {self.metagraph.netuid}') 87 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 88 | logger.info('\n' + '-' * 40 + ' Objects ' + '-' * 40) 89 | logger.info(f'\nWallet: {self.wallet}\nSubtensor: {self.subtensor}\nMetagraph: {self.metagraph}\nUID: {self.uid}') 90 | 91 | # Init bucket. 92 | try: 93 | if self.config.bucket != self.subtensor.get_commitment(self.config.netuid, self.uid): 94 | raise ValueError('') 95 | except: 96 | self.subtensor.commit(self.wallet, self.config.netuid, self.config.bucket) 97 | logger.info('Bucket:' + self.config.bucket) 98 | 99 | # Init Wandb. 100 | if self.config.use_wandb: 101 | # Delete all runs with my name and create a new one. 102 | try: 103 | for run in wandb.Api().runs(path=self.config.project): 104 | if run.name == f'M{self.uid}': 105 | logger.info(f'Deleting old run: {run}'); run.delete() 106 | except: pass 107 | wandb.init(project=self.config.project, resume='allow', name=f'M{self.uid}', config=self.config) 108 | 109 | # Init model. 110 | logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) 111 | self.hparams = load_hparams() 112 | torch.manual_seed(42); np.random.seed(42); random.seed(42) 113 | self.model = LlamaForCausalLM(config=self.hparams.model_config) 114 | # self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') 115 | self.model.to(self.config.device) 116 | self.model.train() 117 | self.optimizer = optim.AdamW( 118 | self.model.parameters(), 119 | lr=self.hparams.learning_rate, # Peak learning rate 120 | betas=(self.hparams.optimizer_beta1, self.hparams.optimizer_beta2), # B1 and B2 121 | weight_decay=self.hparams.optimizer_weight_decay, # Weight decay 122 | foreach=True, # more memory usage, but faster 123 | ) 124 | self.scheduler = CosineAnnealingLR( 125 | self.optimizer, T_max=self.hparams.cosine_epoch_length, 126 | eta_min=self.hparams.eta_min, last_epoch=-1 127 | ) 128 | 129 | # Init buckets. 130 | self.buckets = [] 131 | for uid in self.metagraph.uids: 132 | # Use --remote to connect to other miners, other wise, only see's config.bucket. 133 | try: self.buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid ) ) 134 | except: self.buckets.append(None) 135 | 136 | # Init run state. 137 | self.global_step = 0 138 | self.sample_rate = 1.0 139 | self.current_block = self.subtensor.block 140 | self.current_window = self.block_to_window( self.current_block ) 141 | self.window_seeds = {self.current_window: self.window_to_seed( self.current_window) } 142 | self.new_block_event = asyncio.Event() 143 | self.new_window_event = asyncio.Event() 144 | self.stop_event = asyncio.Event() 145 | self.last_full_steps = self.hparams.desired_batch_size // self.config.actual_batch_size 146 | print ( self.hparams ) 147 | 148 | async def update(self): 149 | while not self.stop_event.is_set(): 150 | st = T() 151 | self.subtensor = bt.subtensor(config=self.config) 152 | self.metagraph = self.subtensor.metagraph(self.config.netuid) 153 | self.hparams = load_hparams() 154 | next_buckets = [] 155 | for uid in self.metagraph.uids: 156 | try: next_buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid )) 157 | except: next_buckets.append(None) 158 | self.buckets = next_buckets 159 | logger.info(f"{P(self.current_window, T() - st)} Updated global state.") 160 | await asyncio.sleep(60) 161 | 162 | async def run(self): 163 | # Main loop. 164 | self.loop = asyncio.get_running_loop() 165 | self.update_task = asyncio.create_task(self.update()) 166 | self.listener = threading.Thread(target=self.block_listener, args=(self.loop,), daemon=True).start() 167 | 168 | # Optionally sync the model state by pulling model states from the history. 169 | if self.config.sync_state: 170 | history_windows = [ self.current_window - i for i in range (self.hparams.max_history) ] 171 | state_slices = await download_slices_for_buckets_and_windows( 172 | buckets = self.buckets, 173 | windows = history_windows, 174 | key = 'state' 175 | ) 176 | for window in tqdm(history_windows, desc="Syncing state"): 177 | await apply_slices_to_model( 178 | model = self.model, 179 | window = window, 180 | seed = window, 181 | compression = self.hparams.compression, 182 | key = 'state' 183 | ) 184 | torch.cuda.empty_cache() 185 | 186 | # Main training loop. 187 | while True: 188 | try: 189 | # Start the window step. 190 | logger.info('[bold]' + '\n' + '-' * 40 + f' Step: {self.global_step} ' + '-' * 40) 191 | self.global_step += 1 192 | start_step = T() 193 | window = self.current_window 194 | 195 | # Run for non-baseline miners. 196 | if not self.config.baseline: 197 | st = T() 198 | state_slices = await download_slices_for_buckets_and_windows( 199 | buckets = self.buckets, 200 | windows = [ window ], 201 | key = 'state' 202 | ) 203 | n_slices = len(state_slices[ window ]) if window in state_slices else 0 204 | logger.info(f"{P(window, T() - st)}: Downloaded {n_slices} window states.") 205 | 206 | # Download the delta from the previous window. 207 | st = T() 208 | delta_slices = await download_slices_for_buckets_and_windows( 209 | buckets = self.buckets, 210 | windows = [ window - 1 ], 211 | key = 'delta' 212 | ) 213 | n_slices = len(delta_slices[ window - 1 ]) if window - 1 in delta_slices else 0 214 | logger.info(f"{P(window, T() - st)}: Download {n_slices} window deltas.") 215 | 216 | # Apply the state for the current window. 217 | st = T() 218 | await apply_slices_to_model( 219 | model = self.model, 220 | window = window, 221 | seed = window, 222 | compression = self.hparams.compression, 223 | key = 'state' 224 | ) 225 | logger.info(f"{P(window, T() - st)}: Applied window state.") 226 | 227 | # Download the page for the current window. 228 | st = T() 229 | pages = await DatasetLoader.next_pages( 230 | offset = window, 231 | n_pages = self.hparams.validator_window_eval_size, 232 | seed = self.uid if not self.config.random else random.randint(0, 1000) 233 | ) 234 | random.shuffle( pages ) 235 | dataset = await DatasetLoader.create( 236 | batch_size = self.config.actual_batch_size, 237 | sequence_length = self.hparams.sequence_length, 238 | pages_info = pages, 239 | tokenizer = self.hparams.tokenizer 240 | ) 241 | logger.info(f"{P(window, T() - st)}: Downloaded training page: [light_steel_blue]{[p[1] for p in pages]}[/light_steel_blue] random = {self.config.random}") 242 | 243 | # Accumualte gradients on the model applied to the base state. 244 | train_start = T() 245 | self.model.zero_grad(); self.model.eval() 246 | total_loss = 0.0 247 | full_steps = 0; total_steps = 0; 248 | exhuasted_window = False 249 | for batch in dataset: 250 | total_steps += 1 251 | if random.random() < self.sample_rate and not exhuasted_window: 252 | full_steps += 1 253 | input_ids = torch.tensor(batch, dtype=torch.long).to(self.model.device) 254 | labels = input_ids.clone() 255 | labels = torch.where(labels == self.hparams.tokenizer.pad_token_id, -100, labels) 256 | with torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16): # Enable autocasting 257 | outputs = self.model(input_ids=input_ids, labels=labels) 258 | total_loss += outputs.loss.item() 259 | outputs.loss.backward() 260 | if window != self.current_window and not self.config.baseline: exhuasted_window = True; continue 261 | if self.hparams.grad_clip: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.grad_clip) 262 | self.optimizer.step() 263 | self.scheduler.step() 264 | self.optimizer.zero_grad() 265 | torch.cuda.empty_cache() 266 | step_loss = total_loss/(full_steps+1) 267 | train_duration = T() - train_start 268 | tokens_per_step = self.hparams.sequence_length * self.config.actual_batch_size * (full_steps + 1) 269 | tokens_per_second = tokens_per_step / train_duration 270 | logger.info(f"{P(window, train_duration)} Accumulated gradients:") 271 | logger.info(f"{P(window, train_duration)} \tTotal steps: [tan]{full_steps}/{total_steps}[/tan], Rate: [tan]{(full_steps/total_steps):.2f}[/tan], Target: [tan]{self.sample_rate:.2f}[/tan]") 272 | logger.info(f"{P(window, train_duration)} \tTotal tokens: [tan]{tokens_per_step}[/tan], Tokens per second: [tan]{tokens_per_second:.2f}[/tan]") 273 | logger.info(f"{P(window, train_duration)} \tLoss: [tan]{step_loss}[tan]") 274 | if exhuasted_window: self.sample_rate = max(0.0001, self.sample_rate * 0.95) 275 | else: self.sample_rate = min(1, self.sample_rate * 1.05) 276 | 277 | # Run for non-baseline nodes. 278 | if not self.config.baseline: 279 | # Upload the delta for the previous window. 280 | st = T() 281 | await upload_slice_for_window( 282 | bucket = self.config.bucket, 283 | model = self.model, 284 | window = window, 285 | seed = window, 286 | wallet = self.wallet, 287 | compression = self.hparams.compression, 288 | key = 'delta' 289 | ) 290 | logger.info(f"{P(window, T() - st)}: Uploaded the delta.") 291 | 292 | # Apply the delta from the previous window. 293 | st = T() 294 | await apply_slices_to_model( 295 | model = self.model, 296 | window = window - 1, 297 | seed = window - 1, 298 | compression = self.hparams.compression, 299 | key = 'delta' 300 | ) 301 | logger.info(f"{P(window, T() - st)}: Applied window delta.") 302 | 303 | # Upload the state for the current window. 304 | st = T() 305 | await upload_slice_for_window( 306 | bucket = self.config.bucket, 307 | model = self.model, 308 | window = window + 1, 309 | seed = window + 1, 310 | wallet = self.wallet, 311 | compression = self.hparams.compression, 312 | key = 'state', 313 | ) 314 | logger.info(f"{P(window, T() - st)}: Uploaded the state.") 315 | 316 | # Clean file history. 317 | st = T() 318 | await delete_files_before_window( window_max = window - self.hparams.max_history, key = 'state') 319 | await delete_files_before_window( window_max = window - self.hparams.max_history, key = 'delta') 320 | await delete_files_from_bucket_before_window( bucket = self.config.bucket, window_max = window - self.hparams.max_history, key = 'state' ) 321 | await delete_files_from_bucket_before_window( bucket = self.config.bucket, window_max = window - self.hparams.max_history, key = 'delta' ) 322 | logger.info(f"{P(window, T() - st)}: Cleaned file history.") 323 | 324 | # Wait until we are on a new window. 325 | end_step = T() 326 | while self.current_window == window: 327 | await asyncio.sleep(0.1) 328 | window_time_delta = self.window_time - end_step 329 | window_delta_str = f"[red]{window_time_delta:.2f}[/red]" if window_time_delta < 0 else f"[green]+{window_time_delta:.2f}[/green]" 330 | logger.info(f"{P(window, end_step - start_step)}[{window_delta_str}]: Finished step.") 331 | if self.config.use_wandb: 332 | wandb.log({ 333 | f"loss": step_loss, 334 | f"tokens_per_step": tokens_per_step, 335 | f"tokens_per_second": tokens_per_second, 336 | f"sample_rate": self.sample_rate, 337 | f"utilization": train_duration / (end_step - start_step), 338 | f"learning_rate": self.scheduler.get_last_lr()[0] 339 | }) 340 | 341 | # Catch keyboard interrrupt. 342 | except KeyboardInterrupt: 343 | logger.info("Training interrupted by user. Stopping the run.") 344 | self.stop_event.set() 345 | await self.update_task 346 | sys.exit(0) 347 | 348 | # Catch unknown. 349 | except Exception as e: 350 | logger.exception(f"Exception during training loop: {e}") 351 | continue 352 | 353 | # Returns the slice window based on a block. 354 | def block_to_window(self, block: int) -> int: 355 | return int( block / self.hparams.window_length ) # floor 356 | 357 | # Returns the slice window based on a block. 358 | def window_to_seed(self, window: int) -> int: 359 | return str( self.subtensor.get_block_hash( window * self.hparams.window_length ) ) 360 | 361 | # A listener thread which posts the block event 362 | # when the chain announces a new block. 363 | def block_listener(self, loop): 364 | def handler(event, _u, _s): 365 | self.current_block = int(event['header']['number']) 366 | loop.call_soon_threadsafe(self.new_block_event.set) 367 | if self.block_to_window(self.current_block) != self.current_window: 368 | self.window_seeds[ self.block_to_window(self.current_block) ] = self.window_to_seed( self.block_to_window(self.current_block) ) 369 | self.current_window = self.block_to_window(self.current_block) 370 | self.window_duration = T() - self.window_time if hasattr(self, 'window_time') else 0 371 | self.window_time = T() 372 | loop.call_soon_threadsafe(self.new_window_event.set) 373 | logger.info(f"{P(self.current_window, self.window_duration)} New Window.") 374 | # Run listener with retry. 375 | while not self.stop_event.is_set(): 376 | try: 377 | bt.subtensor(config=self.config).substrate.subscribe_block_headers(handler); break 378 | except Exception as e: 379 | # Wait for 5 seconds before retrying 380 | logger.error(f"Failed to subscribe to block headers: {e}.\nRetrying in 1 seconds...") 381 | time.sleep(1) 382 | 383 | if __name__ == "__main__": 384 | asyncio.run(Miner().run()) 385 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | aiohttp==3.9.2 19 | bittensor==7.4.0 20 | substrate-interface 21 | boto3==1.34.131 22 | safetensors==0.4.5 23 | torch>=2.4.0 24 | transformers==4.44.2 25 | python-dotenv==1.0.1 26 | datasets==3.0.0 27 | torchvision==0.19.1 28 | wandb>=0.18.3 29 | typer==0.12.5 30 | numpy==1.26 31 | aioboto3>=13.1.1 32 | loguru==0.7.2 33 | uvloop==0.20.0 34 | aiofiles>=24.1.0 35 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # The MIT License (MIT) 4 | # © 2024 Chakana.tech 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | 11 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 12 | # the Software. 13 | 14 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 15 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | set -euo pipefail 21 | 22 | # Initialize default values 23 | DEBUG=false 24 | PROJECT="aesop" 25 | AWS_ACCESS_KEY_ID="" 26 | AWS_SECRET_ACCESS_KEY="" 27 | BUCKET="" 28 | 29 | # Function to display help message 30 | display_help() { 31 | cat << EOF 32 | Usage: $0 [options] 33 | 34 | Options: 35 | --debug Enable debug mode 36 | --project Set the project name (default: aesop) 37 | --aws-access-key-id Set AWS Access Key ID 38 | --aws-secret-access-key Set AWS Secret Access Key 39 | --bucket Set the S3 bucket name 40 | -h, --help Display this help message 41 | 42 | Description: 43 | Installs and runs a Boltzmann miner on your GPU. 44 | EOF 45 | } 46 | 47 | # Parse command-line arguments 48 | while [[ $# -gt 0 ]]; do 49 | key="$1" 50 | case $key in 51 | --debug) 52 | DEBUG=true 53 | shift 54 | ;; 55 | --project) 56 | PROJECT="$2" 57 | shift 2 58 | ;; 59 | --aws-access-key-id) 60 | AWS_ACCESS_KEY_ID="$2" 61 | shift 2 62 | ;; 63 | --aws-secret-access-key) 64 | AWS_SECRET_ACCESS_KEY="$2" 65 | shift 2 66 | ;; 67 | --bucket) 68 | BUCKET="$2" 69 | shift 2 70 | ;; 71 | -h|--help|-help|--h) 72 | display_help 73 | exit 0 74 | ;; 75 | *) 76 | echo "Unknown option: $1" 77 | display_help 78 | exit 1 79 | ;; 80 | esac 81 | done 82 | 83 | # Set up colors and styles 84 | if [[ -t 1 ]]; then 85 | tty_escape() { printf "\033[%sm" "$1"; } 86 | else 87 | tty_escape() { :; } 88 | fi 89 | tty_mkbold() { tty_escape "1;$1"; } 90 | tty_blue="$(tty_mkbold 34)" 91 | tty_red="$(tty_mkbold 31)" 92 | tty_green="$(tty_mkbold 32)" 93 | tty_yellow="$(tty_mkbold 33)" 94 | tty_bold="$(tty_mkbold 39)" 95 | tty_reset="$(tty_escape 0)" 96 | 97 | # Logging functions 98 | ohai() { 99 | printf "${tty_blue}==>${tty_bold} %s${tty_reset}\n" "$*" 100 | } 101 | 102 | pdone() { 103 | printf " ${tty_green}[✔]${tty_bold} %s${tty_reset}\n" "$*" 104 | } 105 | 106 | info() { 107 | printf "${tty_green}%s${tty_reset}\n" "$*" 108 | } 109 | 110 | warn() { 111 | printf "${tty_yellow}Warning${tty_reset}: %s\n" "$*" >&2 112 | } 113 | 114 | error() { 115 | printf "${tty_red}Error${tty_reset}: %s\n" "$*" >&2 116 | } 117 | 118 | abort() { 119 | error "$@" 120 | exit 1 121 | } 122 | 123 | trap 'abort "An unexpected error occurred."' ERR 124 | 125 | getc() { 126 | local save_state 127 | save_state="$(/bin/stty -g)" 128 | /bin/stty raw -echo 129 | IFS='' read -r -n 1 -d '' "$@" 130 | /bin/stty "${save_state}" 131 | } 132 | 133 | wait_for_user() { 134 | local c 135 | echo 136 | echo "Press ${tty_bold}RETURN${tty_reset}/${tty_bold}ENTER${tty_reset} to continue or any other key to abort:" 137 | getc c 138 | # we test for \r and \n because some stuff does \r instead 139 | if ! [[ "${c}" == $'\r' || "${c}" == $'\n' ]] 140 | then 141 | exit 1 142 | fi 143 | } 144 | 145 | execute() { 146 | ohai "Running: $*" 147 | if ! "$@"; then 148 | abort "Failed during: $*" 149 | fi 150 | } 151 | 152 | have_sudo_access() { 153 | if ! command -v sudo &> /dev/null; then 154 | warn "sudo command not found. Please install sudo or run as root." 155 | return 1 156 | fi 157 | if [ "$EUID" -ne 0 ]; then 158 | if ! sudo -n true 2>/dev/null; then 159 | warn "This script requires sudo access to install packages. Please run as root or ensure your user has sudo privileges." 160 | return 1 161 | fi 162 | fi 163 | return 0 164 | } 165 | 166 | execute_sudo() { 167 | if have_sudo_access; then 168 | ohai "sudo $*" 169 | if ! sudo "$@"; then 170 | abort "Failed to execute: sudo $*" 171 | fi 172 | else 173 | warn "Sudo access is required, attempting to run without sudo" 174 | ohai "$*" 175 | if ! "$@"; then 176 | abort "Failed to execute: $*" 177 | fi 178 | fi 179 | } 180 | 181 | # Function to set or replace environment variables in bash_profile 182 | set_or_replace_env_var() { 183 | local var_name="$1" 184 | local var_value="$2" 185 | local profile_file="$3" 186 | 187 | # Escape special characters for sed 188 | local escaped_var_value=$(printf '%s\n' "$var_value" | sed -e 's/[\/&]/\\&/g') 189 | 190 | if grep -q "^export $var_name=" "$profile_file"; then 191 | # Variable exists, replace it 192 | sed -i.bak "s/^export $var_name=.*/export $var_name=\"$escaped_var_value\"/" "$profile_file" 193 | else 194 | # Variable does not exist, append it 195 | echo "export $var_name=\"$var_value\"" >> "$profile_file" 196 | fi 197 | } 198 | 199 | # Clear the screen and display the logo 200 | clear 201 | echo "" 202 | echo "" 203 | echo " ______ _____ _______ ______ _______ _______ __ _ __ _" 204 | echo " |_____] | | | | ____/ | | | |_____| | \ | | \ |" 205 | echo " |_____] |_____| |_____ | /_____ | | | | | | \_| | \_|" 206 | echo " " 207 | echo "" 208 | echo "" 209 | 210 | echo "This script will do the following:" 211 | echo "1. Install required software (Git, npm, pm2, Python 3.12)" 212 | echo "2. Set up AWS credentials" 213 | echo "3. Clone and set up the Boltzmann repository" 214 | echo "4. Create and register Bittensor wallets" 215 | echo "5. Configure wandb for logging" 216 | echo "6. Clean the specified S3 bucket" 217 | echo "7. Start Boltzmann miners on available GPUs" 218 | echo "" 219 | echo "Please ensure you have a stable internet connection and sufficient permissions to install software." 220 | echo "" 221 | 222 | wait_for_user 223 | 224 | # Ensure ~/.bash_profile exists 225 | touch ~/.bash_profile 226 | source ~/.bash_profile 227 | 228 | # Backup the bash_profile 229 | cp ~/.bash_profile ~/.bash_profile.bak 230 | 231 | # Prompt the user for AWS credentials if not supplied via command-line 232 | ohai "Getting AWS credentials ..." 233 | if [[ -z "$AWS_ACCESS_KEY_ID" ]] || [[ -z "$AWS_SECRET_ACCESS_KEY" ]] || [[ -z "$BUCKET" ]]; then 234 | # TODO: Consider securely storing AWS credentials rather than storing them in plain text 235 | warn "This script will store your AWS credentials in your ~/.bash_profile file." 236 | warn "This is not secure and is not recommended." 237 | read -p "Do you want to proceed? [y/N]: " proceed 238 | if [[ "$proceed" != "y" && "$proceed" != "Y" ]]; then 239 | abort "Aborted by user." 240 | fi 241 | 242 | if [[ -z "$AWS_ACCESS_KEY_ID" ]]; then 243 | read -p "Enter your AWS Access Key ID: " AWS_ACCESS_KEY_ID 244 | fi 245 | if [[ -z "$AWS_SECRET_ACCESS_KEY" ]]; then 246 | read -p "Enter your AWS Secret Access Key: " AWS_SECRET_ACCESS_KEY 247 | fi 248 | if [[ -z "$BUCKET" ]]; then 249 | read -p "Enter your S3 Bucket Name: " BUCKET 250 | fi 251 | fi 252 | 253 | # Overwrite or add the AWS credentials in the bash_profile 254 | set_or_replace_env_var "AWS_ACCESS_KEY_ID" "$AWS_ACCESS_KEY_ID" ~/.bash_profile 255 | set_or_replace_env_var "AWS_SECRET_ACCESS_KEY" "$AWS_SECRET_ACCESS_KEY" ~/.bash_profile 256 | set_or_replace_env_var "BUCKET" "$BUCKET" ~/.bash_profile 257 | 258 | # Source the bash_profile to apply the changes 259 | source ~/.bash_profile 260 | pdone "AWS credentials set in ~/.bash_profile" 261 | 262 | ohai "Installing requirements ..." 263 | # Install Git if not present 264 | if ! command -v git &> /dev/null; then 265 | ohai "Git not found. Installing git ..." 266 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 267 | ohai "Detected Linux" 268 | if [ -f /etc/os-release ]; then 269 | . /etc/os-release 270 | if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then 271 | ohai "Detected Ubuntu, installing Git..." 272 | if [[ "$DEBUG" == "true" ]]; then 273 | execute_sudo apt-get update -y 274 | execute_sudo apt-get install git -y 275 | else 276 | execute_sudo apt-get update -y > /dev/null 2>&1 277 | execute_sudo apt-get install git -y > /dev/null 2>&1 278 | fi 279 | else 280 | warn "Unsupported Linux distribution: $ID" 281 | abort "Cannot install Git automatically" 282 | fi 283 | else 284 | warn "Cannot detect Linux distribution" 285 | abort "Cannot install Git automatically" 286 | fi 287 | else 288 | abort "Unsupported OS type: $OSTYPE" 289 | fi 290 | else 291 | pdone "Git is already installed" 292 | fi 293 | 294 | # TODO: Add error handling for package installations 295 | # TODO: Ensure compatibility with different package managers 296 | 297 | # Check for Rust installation 298 | if ! command -v rustc &> /dev/null; then 299 | ohai "Installing Rust ..." 300 | if [[ "$DEBUG" == "true" ]]; then 301 | execute curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 302 | else 303 | execute curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y > /dev/null 2>&1 304 | fi 305 | # Add Rust to the PATH for the current session 306 | source $HOME/.cargo/env 307 | fi 308 | pdone "Rust is installed" 309 | 310 | # Install uv if not present 311 | if ! command -v uv &> /dev/null; then 312 | ohai "Installing uv ..." 313 | if [[ "$DEBUG" == "true" ]]; then 314 | execute curl -LsSf https://astral.sh/uv/install.sh | sh 315 | else 316 | execute curl -LsSf https://astral.sh/uv/install.sh | sh > /dev/null 2>&1 317 | fi 318 | # Add uv to the PATH for the current session 319 | export PATH="$HOME/.cargo/bin:$PATH" 320 | fi 321 | pdone "uv is installed" 322 | 323 | # Check if npm is installed 324 | if ! command -v npm &> /dev/null; then 325 | ohai "Installing npm ..." 326 | if ! command -v node &> /dev/null; then 327 | ohai "Node.js could not be found, installing..." 328 | if ! curl -fsSL https://deb.nodesource.com/setup_18.x | bash; then 329 | abort "Failed to download Node.js setup script" 330 | fi 331 | if ! execute_sudo apt-get install -y nodejs; then 332 | abort "Failed to install Node.js" 333 | fi 334 | fi 335 | if ! curl -L https://www.npmjs.com/install.sh | sh; then 336 | abort "Failed to install npm" 337 | fi 338 | fi 339 | pdone "npm is installed" 340 | 341 | # Install pm2 342 | if ! command -v pm2 &> /dev/null; then 343 | ohai "Installing pm2 ..." 344 | if [[ "$DEBUG" == "true" ]]; then 345 | execute npm install pm2 -g 346 | else 347 | execute npm install pm2 -g > /dev/null 2>&1 348 | fi 349 | fi 350 | pdone "pm2 is installed" 351 | 352 | ohai "Installing Boltzmann ..." 353 | # Check if we are inside the boltzmann repository 354 | if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then 355 | REPO_PATH="." 356 | else 357 | if [ ! -d "boltzmann" ]; then 358 | ohai "Cloning boltzmann ..." 359 | execute git clone https://github.com/unconst/boltzmann 360 | REPO_PATH="boltzmann/" 361 | else 362 | REPO_PATH="boltzmann/" 363 | fi 364 | fi 365 | pdone "Boltzmann repository is ready at $REPO_PATH" 366 | 367 | # Install Python 3.12 if not installed 368 | if ! command -v python3.12 &> /dev/null; then 369 | ohai "Installing python3.12 ..." 370 | if [[ "$OSTYPE" == "linux-gnu"* ]]; then 371 | ohai "Detected Linux" 372 | if [ -f /etc/os-release ]; then 373 | . /etc/os-release 374 | if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then 375 | ohai "Detected Ubuntu, installing Python 3.12..." 376 | if [[ "$DEBUG" == "true" ]]; then 377 | if have_sudo_access; then 378 | execute_sudo add-apt-repository ppa:deadsnakes/ppa -y 379 | else 380 | warn "Skipping add-apt-repository due to lack of sudo access" 381 | fi 382 | execute_sudo apt-get update -y 383 | else 384 | if have_sudo_access; then 385 | execute_sudo add-apt-repository ppa:deadsnakes/ppa -y > /dev/null 2>&1 386 | else 387 | warn "Skipping add-apt-repository due to lack of sudo access" 388 | fi 389 | execute_sudo apt-get update -y > /dev/null 2>&1 390 | execute_sudo apt-get install --reinstall python3-apt > /dev/null 2>&1 391 | execute_sudo apt-get install python3.12 -y > /dev/null 2>&1 392 | execute_sudo apt-get install python3.12-venv > /dev/null 2>&1 393 | fi 394 | else 395 | warn "Unsupported Linux distribution: $ID" 396 | abort "Cannot install Python 3.12 automatically" 397 | fi 398 | else 399 | warn "Cannot detect Linux distribution" 400 | abort "Cannot install Python 3.12 automatically" 401 | fi 402 | else 403 | abort "Unsupported OS type: $OSTYPE" 404 | fi 405 | fi 406 | pdone "Python 3.12 is installed" 407 | 408 | # Create a virtual environment if it does not exist 409 | if [ ! -d "$REPO_PATH/venv" ]; then 410 | ohai "Creating virtual environment at $REPO_PATH..." 411 | if [[ "$DEBUG" == "true" ]]; then 412 | execute uv venv "$REPO_PATH/.venv" 413 | else 414 | execute uv venv "$REPO_PATH/.venv" > /dev/null 2>&1 415 | fi 416 | fi 417 | pdone "Virtual environment is set up at $REPO_PATH" 418 | 419 | 420 | # Activate the virtual environment 421 | ohai "Activating virtual environment ..." 422 | source $REPO_PATH/.venv/bin/activate 423 | pdone "Virtual environment activated" 424 | 425 | ohai "Installing Python requirements ..." 426 | if [[ "$DEBUG" == "true" ]]; then 427 | execute uv pip install -r $REPO_PATH/requirements.txt 428 | execute uv pip install --upgrade cryptography pyOpenSSL 429 | else 430 | execute uv pip install -r $REPO_PATH/requirements.txt > /dev/null 2>&1 431 | execute uv pip install --upgrade cryptography pyOpenSSL > /dev/null 2>&1 432 | fi 433 | pdone "Python requirements installed" 434 | 435 | # Check for GPUs 436 | ohai "Checking for GPUs..." 437 | if ! command -v nvidia-smi &> /dev/null; then 438 | warn "nvidia-smi command not found. Please ensure NVIDIA drivers are installed." 439 | NUM_GPUS=0 440 | else 441 | NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 442 | 443 | if [ "$NUM_GPUS" -gt 0 ]; then 444 | nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | while read -r memory; do 445 | pdone "Found GPU with $((memory / 1024)) GB of memory" 446 | done 447 | else 448 | warn "No GPUs found on this machine." 449 | fi 450 | fi 451 | 452 | # Check system RAM 453 | if command -v free &> /dev/null; then 454 | TOTAL_RAM=$(free -g | awk '/^Mem:/{print $2}') 455 | pdone "System RAM: ${TOTAL_RAM} GB" 456 | else 457 | warn "Cannot determine system RAM. 'free' command not found." 458 | fi 459 | 460 | ohai "Creating wallets ..." 461 | # Create the default key 462 | if ! python3 -c "import bittensor as bt; w = bt.wallet(); print(w.coldkey_file.exists_on_device())" | grep -q "True"; then 463 | execute btcli w new_coldkey --wallet.path ~/.bittensor/wallets --wallet.name default --n-words 12 464 | fi 465 | pdone "Wallet 'default' is ready" 466 | 467 | # Ensure btcli is installed 468 | if ! command -v btcli &> /dev/null; then 469 | abort "btcli command not found. Please ensure it is installed." 470 | fi 471 | 472 | # Create hotkeys and register them 473 | if [ "$NUM_GPUS" -gt 0 ]; then 474 | for i in $(seq 0 $((NUM_GPUS - 1))); do 475 | # Check if the hotkey file exists on the device 476 | exists_on_device=$(python3 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); print(w.hotkey_file.exists_on_device())" 2>/dev/null) 477 | if [ "$exists_on_device" != "True" ]; then 478 | echo "n" | btcli wallet new_hotkey --wallet.name default --wallet.hotkey C$i --n-words 12 > /dev/null 2>&1; 479 | fi 480 | pdone "Created Hotkey 'C$i'" 481 | 482 | # Check if the hotkey is registered on subnet 220 483 | is_registered=$(python3 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); sub = bt.subtensor('test'); print(sub.is_hotkey_registered_on_subnet(hotkey_ss58=w.hotkey.ss58_address, netuid=220))" 2>/dev/null) 484 | if [[ "$is_registered" != *"True"* ]]; then 485 | ohai "Registering hotkey 'C$i' on subnet 220" 486 | btcli subnet pow_register --wallet.name default --wallet.hotkey C$i --netuid 220 --subtensor.network test --no_prompt > /dev/null 2>&1; 487 | fi 488 | pdone "Registered Hotkey 'C$i' on subnet 220" 489 | done 490 | else 491 | warn "No GPUs found. Skipping hotkey creation." 492 | exit 493 | fi 494 | pdone "All hotkeys registered" 495 | 496 | ohai "Logging into wandb..." 497 | execute wandb login 498 | pdone "wandb is configured" 499 | 500 | # Clean the bucket 501 | ohai "Cleaning bucket $BUCKET..." 502 | if [[ "$DEBUG" == "true" ]]; then 503 | execute python3 $REPO_PATH/tools/clean.py --bucket "$BUCKET" 504 | else 505 | execute python3 $REPO_PATH/tools/clean.py --bucket "$BUCKET" > /dev/null 2>&1 506 | fi 507 | pdone "Bucket '$BUCKET' cleaned" 508 | 509 | # Close down all previous processes and restart them 510 | if pm2 list | grep -q 'online'; then 511 | ohai "Stopping old pm2 processes..." 512 | pm2 delete all 513 | pdone "Old processes stopped" 514 | fi 515 | 516 | # Start all the processes again 517 | if [ "$NUM_GPUS" -gt 0 ]; then 518 | for i in $(seq 0 $((NUM_GPUS - 1))); do 519 | # Adjust GPU index for zero-based numbering 520 | GPU_INDEX=$i 521 | GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | sed -n "$((i + 1))p") 522 | if [ -z "$GPU_MEMORY" ]; then 523 | warn "Could not get GPU memory for GPU $i" 524 | continue 525 | fi 526 | # Determine batch size based on GPU memory 527 | if [ "$GPU_MEMORY" -ge 80000 ]; then 528 | BATCH_SIZE=6 529 | elif [ "$GPU_MEMORY" -ge 40000 ]; then 530 | BATCH_SIZE=3 531 | elif [ "$GPU_MEMORY" -ge 20000 ]; then 532 | BATCH_SIZE=1 533 | else 534 | BATCH_SIZE=1 535 | fi 536 | ohai "Starting miner on GPU $GPU_INDEX with batch size $BATCH_SIZE..." 537 | if [[ "$DEBUG" == "true" ]]; then 538 | execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" 539 | else 540 | execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" > /dev/null 2>&1 541 | fi 542 | done 543 | else 544 | warn "No GPUs found. Skipping miner startup." 545 | fi 546 | pdone "All miners started" 547 | pm2 list 548 | 549 | echo "" 550 | pdone "SUCCESS" 551 | echo "" 552 | 553 | # Start logging process 1 554 | pm2 logs C0 555 | 556 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | # Close down all previous processes and restart them. 19 | pm2 sendSignal SIGINT all 20 | pm2 delete all 21 | # Delete items from bucket 22 | BUCKET=${1:-decis} 23 | PROJECT=${2:-aesop} 24 | python3 tools/clean.py --bucket $BUCKET 25 | 26 | # Start all the processes again. 27 | pm2 start validator.py --interpreter python3 --name V1 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey default --bucket $BUCKET --device cuda:0 --use_wandb --project $PROJECT 28 | pm2 start miner.py --interpreter python3 --name M1 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M1 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT 29 | pm2 start miner.py --interpreter python3 --name M2 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M2 --bucket $BUCKET --device cuda:2 --use_wandb --project $PROJECT 30 | pm2 start miner.py --interpreter python3 --name M3 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M3 --bucket $BUCKET --device cuda:3 --use_wandb --project $PROJECT 31 | pm2 start miner.py --interpreter python3 --name M4 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M4 --bucket $BUCKET --device cuda:5 --use_wandb --random --project $PROJECT 32 | pm2 start miner.py --interpreter python3 --name M5 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M5 --bucket $BUCKET --device cuda:6 --use_wandb --random --project $PROJECT 33 | pm2 start miner.py --interpreter python3 --name M6 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M3 --bucket $BUCKET --device cuda:4 --use_wandb --baseline --project $PROJECT 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /tests/eval.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import os 19 | import sys 20 | import json 21 | import uuid 22 | import time 23 | import torch 24 | import wandb 25 | import boto3 26 | import shutil 27 | import argparse 28 | import tempfile 29 | import traceback 30 | import bittensor as bt 31 | from hparams import load_hparams 32 | from types import SimpleNamespace 33 | from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM 34 | 35 | def main(config): 36 | bt.logging.off() # Turn of bt logging. 37 | # Initialize Weights and Biases (wandb) for experiment tracking if enabled. 38 | if config.use_wandb: 39 | # Check for existing runs with the same name and delete them 40 | try: 41 | api = wandb.Api() 42 | runs = api.runs(path=config.project) 43 | for run in runs: 44 | if run.name == f'Eval': 45 | print(f'Deleting old run: {run}') 46 | run.delete() 47 | except: pass 48 | run = wandb.init(project=config.project, resume='allow', name='Eval', config=config) 49 | 50 | while True: 51 | 52 | print('Loading chain state.') 53 | hparams = load_hparams() 54 | subtensor = bt.subtensor(config=config) 55 | metagraph = subtensor.metagraph(config.netuid) 56 | 57 | print('Iterating miners') 58 | for uid in metagraph.uids: 59 | # Check if we are evaling a specific miner. 60 | if config.uid is not None: 61 | uid = config.uid 62 | # Try to eval. 63 | try: 64 | print("Getting commitment from subtensor...") 65 | try: 66 | bucket = subtensor.get_commitment(config.netuid, uid) 67 | except: 68 | print ('Miner has no registered bucket. Continuing.') 69 | time.sleep(1) 70 | continue 71 | 72 | print(f"Preparing to download model state dict for UID {uid}...") 73 | filename = f"master-{metagraph.hotkeys[uid]}.pt" 74 | temp_file = os.path.join(tempfile.gettempdir(), f"{uuid.uuid4()}.pt") 75 | 76 | print(f"Downloading file {filename} from bucket {bucket}...") 77 | # Initialize the S3 client (assuming AWS S3) 78 | CLIENT = boto3.client('s3') 79 | try: 80 | CLIENT.download_file(bucket, filename, temp_file) 81 | except Exception as e: 82 | print(f"No master for UID {uid}. Error: {e}") 83 | time.sleep(1) 84 | continue 85 | 86 | print("Loading model state dict...") 87 | model = LlamaForCausalLM(config=hparams.model_config) 88 | model_state_dict = torch.load(temp_file, map_location='cpu', weights_only = True) 89 | model.load_state_dict(model_state_dict) 90 | 91 | # Delete the temp file after loading the model state dict 92 | os.remove(temp_file) 93 | 94 | model_save_path = f'models/{uid}' 95 | if os.path.exists(model_save_path): 96 | print(f"Deleting existing model at {model_save_path}...") 97 | shutil.rmtree(model_save_path) 98 | 99 | print(f"Saving model to models/{uid}...") 100 | os.makedirs(model_save_path, exist_ok=True) 101 | model.save_pretrained(model_save_path) 102 | 103 | print(f"Saving tokenizer to models/{uid}...") 104 | hparams.tokenizer.save_pretrained(model_save_path) 105 | 106 | print("Running lm-eval harness...") 107 | lm_eval_command = ( 108 | f"lm-eval " 109 | f"--model hf " 110 | f"--model_args pretrained=./models/{uid},tokenizer=./models/{uid} " 111 | f"--tasks {config.tasks} " 112 | f"--device {config.device} " 113 | f"--batch_size {config.actual_batch_size} " 114 | f"--output_path models/{uid}/results " 115 | ) 116 | print(f"Executing command: {lm_eval_command}") 117 | exit_code = os.system(lm_eval_command) 118 | if exit_code != 0: 119 | print(f"Command eval script failed with exit code {exit_code}. Error: {os.strerror(exit_code)}. Continuing...") 120 | continue 121 | 122 | print("Loading evaluation results...") 123 | results_dir = f"models/{uid}/results/.__models__{uid}/" 124 | latest_file = max([os.path.join(results_dir, f) for f in os.listdir(results_dir)], key=os.path.getctime) 125 | with open(latest_file, "r") as f: 126 | results = json.load(f) 127 | 128 | print("Processing results...") 129 | for task_name, task_results in results['results'].items(): 130 | if task_name == 'winogrande': 131 | metric_name = 'acc,none' 132 | else: 133 | metric_name = 'acc_norm,none' 134 | metric_value = float(task_results.get(metric_name)) 135 | if metric_value is not None: 136 | print(f"{uid}/{task_name}: {metric_value}") 137 | if config.use_wandb: 138 | wandb.log({f"{task_name}": metric_value}) 139 | else: 140 | print(f"{uid} - {task_name} not found in results") 141 | 142 | # Delete the model after running the eval off the device 143 | del model 144 | torch.cuda.empty_cache() 145 | 146 | # Delete the storage of the model 147 | shutil.rmtree(model_save_path) 148 | 149 | # Error in eval loop. 150 | except KeyboardInterrupt: 151 | print("Keyboard interrupt received. Exiting gracefully.") 152 | sys.exit(0) 153 | except Exception as e: 154 | print(f"Error: {e}") 155 | traceback.print_exc() 156 | continue 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser(description='Miner script') 160 | parser.add_argument('--project', type=str, default='220A', help='Optional wandb project name') 161 | parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') 162 | parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') 163 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') 164 | parser.add_argument('--uid', type=int, default=None, help='The miner to eval. If None, eval all miners.') 165 | parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') 166 | parser.add_argument('--tasks', type=str, default='arc_challenge,arc_easy,openbookqa,winogrande,piqa,hellaswag', help='Comma-separated list of tasks to evaluate') 167 | bt.subtensor.add_args(parser) 168 | config = bt.config(parser) 169 | config.subtensor.network = 'test' 170 | config.subtensor.chain_endpoint = 'wss://test.finney.opentensor.ai:443/' 171 | main(config) -------------------------------------------------------------------------------- /tests/legacy_miner.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | # fmt: off 18 | 19 | # Global imports. 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import wandb 25 | import torch 26 | import random 27 | import asyncio 28 | import argparse 29 | import threading 30 | import traceback 31 | from tqdm import tqdm 32 | import bittensor as bt 33 | from typing import List 34 | import torch.optim as optim 35 | from dotenv import dotenv_values 36 | from transformers import LlamaForCausalLM 37 | from torch.optim.lr_scheduler import CosineAnnealingLR 38 | 39 | # Import local files. 40 | from common import * 41 | from hparams import load_hparams 42 | from dataset import DatasetLoader 43 | 44 | # GPU optimizations. 45 | torch.backends.cudnn.benchmark = True 46 | torch.backends.cuda.matmul.allow_tf32 = True 47 | torch.backends.cudnn.allow_tf32 = True 48 | 49 | class Miner: 50 | 51 | @staticmethod 52 | def config(): 53 | parser = argparse.ArgumentParser(description='Miner script') 54 | parser.add_argument('--project', type=str, default='QZWXEC', help='Optional wandb project name') 55 | parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') 56 | parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') 57 | parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') 58 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') 59 | parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') 60 | parser.add_argument('--remote', action='store_true', help='Connect to other buckets') 61 | parser.add_argument('--debug', action='store_true', help='Enable debug logging') 62 | parser.add_argument('--trace', action='store_true', help='Enable trace logging') 63 | parser.add_argument('--random', action='store_true', help='Train on random') 64 | bt.wallet.add_args(parser) 65 | bt.subtensor.add_args(parser) 66 | config = bt.config(parser) 67 | config.subtensor.network = 'test' 68 | config.subtensor.chain_endpoint = 'wss://test.finney.opentensor.ai:443/' 69 | if config.debug: debug() 70 | if config.trace: trace() 71 | return config 72 | 73 | def __init__(self): 74 | # Init config. 75 | self.config = Miner.config() 76 | logger.info('\n' + '-' * 40 + ' Config ' + '-' * 40) 77 | logger.info(self.config) 78 | 79 | # Init bittensor objects. 80 | self.wallet = bt.wallet(config=self.config) 81 | self.subtensor = bt.subtensor(config=self.config) 82 | self.metagraph = self.subtensor.metagraph(netuid=self.config.netuid) 83 | if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys: 84 | raise ValueError(f'Wallet {self.wallet} is not registered on subnet: {self.metagraph.netuid}') 85 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 86 | logger.info('\n' + '-' * 40 + ' Objects ' + '-' * 40) 87 | logger.info(f'\nWallet: {self.wallet}\nSubtensor: {self.subtensor}\nMetagraph: {self.metagraph}\nUID: {self.uid}') 88 | 89 | # Init bucket. 90 | try: 91 | if self.config.bucket != self.subtensor.get_commitment(self.config.netuid, self.uid): 92 | raise ValueError('') 93 | except: 94 | self.subtensor.commit(self.wallet, self.config.netuid, self.config.bucket) 95 | logger.info('Bucket:' + self.config.bucket) 96 | 97 | # Init Wandb. 98 | if self.config.use_wandb: 99 | # Delete all runs with my name and create a new one. 100 | try: 101 | [run.delete() for run in wandb.Api().runs(path=self.config.project) 102 | if run.name == f'M{self.uid}-{"r" if self.config.random else ""}' and logger.info(f'Deleting old run: {run}')] 103 | except: pass 104 | wandb.init(project=self.config.project, resume='allow', name=f'M{self.uid}', config=self.config) 105 | 106 | # Init model. 107 | logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) 108 | self.hparams = load_hparams() 109 | torch.manual_seed(42); np.random.seed(42); random.seed(42) 110 | # self.model = LlamaForCausalLM(config=self.hparams.model_config) 111 | self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') 112 | self.model.to(self.config.device) 113 | self.model.train() 114 | self.optimizer = optim.AdamW( 115 | self.model.parameters(), 116 | lr=self.hparams.learning_rate, # Peak learning rate 117 | betas=(self.hparams.optimizer_beta1, self.hparams.optimizer_beta2), # B1 and B2 118 | weight_decay=self.hparams.optimizer_weight_decay, # Weight decay 119 | foreach=True, # more memory usage, but faster 120 | ) 121 | self.scheduler = CosineAnnealingLR( 122 | self.optimizer, T_max=self.hparams.cosine_epoch_length, 123 | eta_min=self.hparams.eta_min, last_epoch=-1 124 | ) 125 | 126 | # Init buckets. 127 | self.buckets = [] 128 | for uid in self.metagraph.uids: 129 | # Use --remote to connect to other miners, other wise, only see's config.bucket. 130 | try: self.buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid ) ) 131 | except: self.buckets.append(None) 132 | 133 | # Init run state. 134 | self.global_step = 0 135 | self.sample_rate = 1.0 136 | self.current_block = self.subtensor.block 137 | self.current_window = self.block_to_window( self.current_block ) 138 | self.window_seeds = {self.current_window: self.window_to_seed( self.current_window) } 139 | self.new_block_event = asyncio.Event() 140 | self.new_window_event = asyncio.Event() 141 | self.stop_event = asyncio.Event() 142 | self.last_full_steps = self.hparams.desired_batch_size // self.config.actual_batch_size 143 | print ( self.hparams ) 144 | 145 | async def update(self): 146 | while not self.stop_event.is_set(): 147 | logger.info(f"\tUpdating global state.") 148 | start_time = time.time() 149 | self.subtensor = bt.subtensor(config=self.config) 150 | self.metagraph = self.subtensor.metagraph(self.config.netuid) 151 | self.hparams = load_hparams() 152 | next_buckets = [] 153 | for uid in self.metagraph.uids: 154 | try: next_buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid )) 155 | except: next_buckets.append(None) 156 | self.buckets = next_buckets 157 | logger.info(f"\t\tUpdated global state in {time.time() - start_time} seconds.") 158 | await asyncio.sleep(60) 159 | 160 | async def run(self): 161 | # Main loop. 162 | self.loop = asyncio.get_running_loop() 163 | self.update_task = asyncio.create_task(self.update()) 164 | self.listener = threading.Thread(target=self.block_listener, args=(self.loop,), daemon=True).start() 165 | while True: 166 | 167 | try: 168 | # Start step. 169 | logger.info('\n' + '-' * 40 + f' Step: {self.global_step} ' + '-' * 40) 170 | logger.info(f"Step: {self.global_step}, Window: {self.current_window}, " 171 | f"Block: {self.current_block}, Time: {int(time.time())}") 172 | global_step_start_time = time.time() 173 | self.step_window = self.current_window 174 | self.global_step += 1 175 | 176 | # Download files. 177 | logger.info(f"\tDownloading slices from previous window: {self.step_window - 1}") 178 | start_time = time.time() 179 | slice_files = await download_slices_for_buckets_and_windows( 180 | buckets = self.buckets, 181 | windows = [self.step_window - 1] 182 | ) 183 | downloaded_per_step = sum([len(slice_files[k]) for k in slice_files]) 184 | logger.info(f"\t\tDownloaded {downloaded_per_step} slices for previous window: {self.step_window - 1} in {time.time() - start_time} seconds") 185 | 186 | # Apply slices to the model from the previous window. 187 | logger.info(f"\tApplying slices from previous window: {self.step_window - 1} to model.") 188 | start_time = time.time() 189 | slice_files = await apply_slices_to_model( 190 | model = self.model, 191 | window = self.step_window - 1, # Get files from previous window. 192 | seed = self.window_seeds[ self.step_window ], # Use seed as the hash of the current window. 193 | compression = self.hparams.compression 194 | ) 195 | applied_per_step = len(slice_files) 196 | logger.info(f"\t\tApplied {applied_per_step} from previous window: {self.step_window - 1} with seed: { self.window_seeds[ self.step_window ] } in {time.time() - start_time} seconds") 197 | 198 | # Train for performance on the current window. 199 | # Load pages from the current eval window. The validators will sample pages from (eval_pages_start, eval_pages_end) 200 | # eval_pages_start : ( window_idx * window_length * window_speed ) 201 | # eval_pages_end : ( window_idx * window_length * window_speed ) + window_eval_size 202 | start_time = time.time() 203 | offset = self.step_window * self.hparams.window_length * self.hparams.window_speed 204 | seed = self.uid if not self.config.random else random.randint(0, 1000) 205 | logger.info(f"\tLoading {self.hparams.validator_window_eval_size} pages for current window: { self.step_window } and offset: {offset} and uid: {self.uid} and seed: {seed}") 206 | pages = await DatasetLoader.next_pages( 207 | offset = offset, 208 | n_pages = self.hparams.validator_window_eval_size, 209 | seed = seed 210 | ) 211 | random.shuffle( pages ) 212 | dataset = await DatasetLoader.create( 213 | batch_size = self.config.actual_batch_size, 214 | sequence_length = self.hparams.sequence_length, 215 | pages_info = pages, 216 | tokenizer = self.hparams.tokenizer 217 | ) 218 | pages_per_step = len(pages) 219 | logger.info(f"\t\tLoaded dataset pages: {[p[1] for p in pages]} in {time.time() - start_time} seconds") 220 | 221 | # Train the model on the current page. 222 | logger.info(f"\tTraining on pages: {[p[1] for p in pages]} with sample rate: {self.sample_rate}") 223 | start_time = time.time() 224 | torch.cuda.empty_cache() # Empty cache going into the training step. 225 | self.optimizer.zero_grad() # Clear any lingering grads. 226 | total_loss = 0.0 227 | exhuasted_window = False 228 | self.full_steps = 0 229 | for idx, batch in enumerate(dataset): 230 | # Randomly sample every sample_rate examples 231 | if random.random() < self.sample_rate: 232 | self.full_steps += 1 233 | input_ids = torch.tensor(batch, dtype=torch.long).to(self.model.device) 234 | labels = input_ids.clone() 235 | labels = torch.where(labels == self.hparams.tokenizer.pad_token_id, -100, labels) 236 | with torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16): # Enable autocasting 237 | outputs = self.model(input_ids=input_ids, labels=labels) 238 | total_loss += outputs.loss.item() 239 | loss = outputs.loss / (self.last_full_steps + 1) # Divide by number of accumulations. 240 | loss.backward() 241 | if self.step_window != self.current_window: 242 | exhuasted_window = True 243 | break 244 | 245 | # Apply step and clean memory. 246 | if self.hparams.grad_clip: 247 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hparams.grad_clip) 248 | self.optimizer.step() 249 | self.scheduler.step() # Update the learning rate. 250 | self.optimizer.zero_grad() 251 | del input_ids, labels, outputs 252 | torch.cuda.empty_cache() 253 | 254 | # Calculate, print and log average loss 255 | self.last_full_steps = self.full_steps 256 | average_loss = total_loss / (self.full_steps + 1) 257 | total_time = time.time() - start_time 258 | tokens_per_step = self.hparams.sequence_length * self.config.actual_batch_size * (self.full_steps + 1) 259 | tokens_per_second = tokens_per_step / total_time 260 | logger.info(f"\t\tTotal steps: {idx}, Applied: {self.full_steps}, Rate: {self.full_steps/(idx + 1)}, Sample Probability: {self.sample_rate}") 261 | logger.info(f"\t\tLoss: {average_loss}, learning_rate: {self.scheduler.get_last_lr()[0]}") 262 | logger.info(f"\t\tTraining completed in {total_time} seconds, Tokens per step: {tokens_per_step}, Tokens per second: {tokens_per_second}") 263 | if exhuasted_window: 264 | self.sample_rate = max(0.0001, self.sample_rate * 0.95) 265 | else: 266 | self.sample_rate = min(1, self.sample_rate * 1.05) 267 | 268 | # Wait until we are on a new window. 269 | while self.current_window == self.step_window: 270 | await asyncio.sleep(0.1) 271 | 272 | # Upload our model slice to S3. 273 | logger.info(f"\tUploading for window: { self.step_window }") 274 | start_time = time.time() 275 | await upload_slice_for_window( 276 | bucket = self.config.bucket, 277 | model = self.model, 278 | window = self.step_window, # Upload for the previous window 279 | seed = self.window_seeds[ self.step_window + 1 ], # Seed the index by the hash of the new window. 280 | wallet = self.wallet, 281 | compression = self.hparams.compression 282 | ) 283 | logger.info(f"\t\tFinished upload for window: {self.step_window} with seed: {self.window_seeds[ self.step_window + 1 ]} in {time.time() - start_time} seconds.") 284 | 285 | # Delete lingering files 286 | logger.info(f"\tCleaning space.") 287 | start_time = time.time() 288 | await delete_files_before_window( window_max = self.current_window - self.hparams.max_history ) 289 | await delete_files_from_bucket_before_window( bucket = self.config.bucket, window_max = self.current_window - self.hparams.max_history ) 290 | logger.info(f"\t\tFinished cleaning space in {time.time() - start_time} seconds.") 291 | 292 | # Calculate and log global steps per second 293 | seconds_per_step = time.time() - global_step_start_time 294 | steps_per_second = 1 / seconds_per_step 295 | if self.config.use_wandb: 296 | wandb.log({ 297 | "step_loss": average_loss, 298 | "tokens_per_step": tokens_per_step, 299 | "tokens_per_second": tokens_per_second, 300 | "applied_per_step": applied_per_step, 301 | "pages_per_step": pages_per_step, 302 | "downloaded_per_step": downloaded_per_step, 303 | "incentive": float(self.metagraph.I[self.uid]), 304 | "learning_rate": self.scheduler.get_last_lr()[0], 305 | "seconds_per_step": seconds_per_step, 306 | "steps_per_second": steps_per_second, 307 | "sample_rate": self.sample_rate, 308 | }) 309 | 310 | logger.info(f'\nGlobal step completed in {seconds_per_step} seconds\n') 311 | 312 | # Catch keyboard interrrupt. 313 | except KeyboardInterrupt: 314 | logger.info("Training interrupted by user. Stopping the run.") 315 | self.stop_event.set() 316 | await self.update_task 317 | sys.exit(0) 318 | 319 | # Catch unknown. 320 | except Exception as e: 321 | logger.exception(f"Exception during training loop: {e}") 322 | continue 323 | 324 | # Returns the slice window based on a block. 325 | def block_to_window(self, block: int) -> int: 326 | return int( block / self.hparams.window_length ) # floor 327 | 328 | # Returns the slice window based on a block. 329 | def window_to_seed(self, window: int) -> int: 330 | return str( self.subtensor.get_block_hash( window * self.hparams.window_length ) ) 331 | 332 | # A listener thread which posts the block event 333 | # when the chain announces a new block. 334 | def block_listener(self, loop): 335 | def handler(event, _u, _s): 336 | self.current_block = int(event['header']['number']) 337 | loop.call_soon_threadsafe(self.new_block_event.set) 338 | if self.block_to_window(self.current_block) != self.current_window: 339 | self.window_seeds[ self.block_to_window(self.current_block) ] = self.window_to_seed( self.block_to_window(self.current_block) ) 340 | self.current_window = self.block_to_window(self.current_block) 341 | loop.call_soon_threadsafe(self.new_window_event.set) 342 | logger.info(f"-- New window: {self.current_window} -- ") 343 | # Run listener with retry. 344 | while not self.stop_event.is_set(): 345 | try: 346 | bt.subtensor(config=self.config).substrate.subscribe_block_headers(handler); break 347 | except Exception as e: 348 | # Wait for 5 seconds before retrying 349 | logger.error(f"Failed to subscribe to block headers: {e}.\nRetrying in 1 seconds...") 350 | time.sleep(1) 351 | 352 | if __name__ == "__main__": 353 | asyncio.run(Miner().run()) 354 | -------------------------------------------------------------------------------- /tests/legacy_validator.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | # fmt: off 18 | 19 | # Global imports. 20 | import os 21 | import sys 22 | import time 23 | import wandb 24 | import torch 25 | import random 26 | import asyncio 27 | import argparse 28 | import threading 29 | import traceback 30 | from tqdm import tqdm 31 | import bittensor as bt 32 | from typing import List 33 | import torch.optim as optim 34 | from dotenv import dotenv_values 35 | from transformers import LlamaForCausalLM 36 | from torch.optim.lr_scheduler import CosineAnnealingLR 37 | 38 | # Import local files. 39 | from common import * 40 | from hparams import load_hparams 41 | from dataset import DatasetLoader 42 | 43 | # GPU optimizations. 44 | torch.backends.cudnn.benchmark = True 45 | torch.backends.cuda.matmul.allow_tf32 = True 46 | torch.backends.cudnn.allow_tf32 = True 47 | 48 | class Validator: 49 | 50 | @staticmethod 51 | def config(): 52 | parser = argparse.ArgumentParser(description='Validator script') 53 | parser.add_argument('--project', type=str, default='QZWXEC', help='Optional wandb project name') 54 | parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') 55 | parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') 56 | parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') 57 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') 58 | parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') 59 | parser.add_argument('--debug', action='store_true', help='Enable debug logging') 60 | parser.add_argument('--trace', action='store_true', help='Enable trace logging') 61 | bt.wallet.add_args(parser) 62 | bt.subtensor.add_args(parser) 63 | config = bt.config(parser) 64 | config.subtensor.network = 'test' 65 | config.subtensor.chain_endpoint = 'wss://test.finney.opentensor.ai:443/' 66 | if config.debug: debug() 67 | if config.trace: trace() 68 | return config 69 | 70 | def __init__(self): 71 | # Init config. 72 | self.config = Validator.config() 73 | logger.info('\n' + '-' * 40 + ' Config ' + '-' * 40) 74 | logger.info(self.config) 75 | 76 | # Init bittensor objects. 77 | self.wallet = bt.wallet(config=self.config) 78 | self.subtensor = bt.subtensor(config=self.config) 79 | self.metagraph = self.subtensor.metagraph(netuid=self.config.netuid) 80 | if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys: 81 | raise ValueError(f'Wallet {self.wallet} is not registered on subnet: {self.metagraph.netuid}') 82 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 83 | logger.info('\n' + '-' * 40 + ' Objects ' + '-' * 40) 84 | logger.info(f'\nWallet: {self.wallet}\nSubtensor: {self.subtensor}\nMetagraph: {self.metagraph}\nUID: {self.uid}') 85 | 86 | # Init bucket. 87 | try: 88 | if self.config.bucket != self.subtensor.get_commitment(self.config.netuid, self.uid): 89 | raise ValueError('') 90 | except: 91 | self.subtensor.commit(self.wallet, self.config.netuid, self.config.bucket) 92 | logger.info('Bucket:' + self.config.bucket) 93 | 94 | # Init Wandb. 95 | if self.config.use_wandb: 96 | # Delete all runs with my name and create a new one. 97 | try: 98 | [run.delete() for run in wandb.Api().runs(path=self.config.project) 99 | if run.name == f'V{self.uid}' and logger.info(f'Deleting old run: {run}')] 100 | except: pass 101 | wandb.init(project=self.config.project, resume='allow', name=f'V{self.uid}', config=self.config) 102 | 103 | # Init model. 104 | logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) 105 | self.hparams = load_hparams() 106 | torch.manual_seed(42); np.random.seed(42); random.seed(42) 107 | #self.model = LlamaForCausalLM(config=self.hparams.model_config) 108 | self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') 109 | self.model.to(self.config.device) 110 | self.model.eval() 111 | 112 | # Init buckets. 113 | self.buckets = [] 114 | for uid in self.metagraph.uids: 115 | # Use --remote to connect to other miners, other wise, only see's config.bucket. 116 | try: self.buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid ) ) 117 | except: self.buckets.append(None) 118 | 119 | # Init run state. 120 | self.global_step = 0 121 | self.last_window = 0 122 | self.optimal_pages_per_step = 4 123 | self.current_block = self.subtensor.block 124 | self.current_window = self.block_to_window( self.current_block ) 125 | self.window_seeds = {self.current_window: self.window_to_seed( self.current_window) } 126 | self.block_event = asyncio.Event() 127 | self.new_window_event = asyncio.Event() 128 | self.stop_event = asyncio.Event() 129 | self.loss_change = torch.zeros( 256, dtype = torch.float32 ) 130 | self.scores = torch.zeros( 256, dtype = torch.float32 ) 131 | self.weights = torch.zeros( 256, dtype = torch.float32 ) 132 | self.sample_rate = 1.0 133 | print ( self.hparams ) 134 | 135 | async def update(self): 136 | while not self.stop_event.is_set(): # Loop until stop_event is set 137 | self.subtensor = bt.subtensor(config=self.config) # Reinitialize subtensor with current config 138 | nxt_meta = self.subtensor.metagraph(self.config.netuid) # Get the new metagraph for the given netuid 139 | self.hparams = load_hparams() # Reload hyperparameters 140 | next_buckets = [] # Initialize the next_buckets list 141 | for uid in nxt_meta.uids: # Iterate over new metagraph uids 142 | try: next_buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid )) 143 | except: next_buckets.append(None) 144 | self.buckets = next_buckets # Update self.buckets with next_buckets 145 | for idx, hotkey in enumerate(self.metagraph.hotkeys): # Iterate over current metagraph hotkeys 146 | if hotkey != nxt_meta.hotkeys[idx]: # Check if hotkey has changed in the new metagraph 147 | self.scores[idx] = 0 # Reset rewards for the changed hotkey 148 | self.weights[idx] = 0 # Reset weights for the changed hotkey 149 | self.metagraph = nxt_meta # Update self.metagraph with new_metagraph 150 | await asyncio.sleep(60) # Sleep for 60 seconds before the next iteration 151 | 152 | async def run(self): 153 | # Main loop. 154 | self.loop = asyncio.get_running_loop() 155 | self.update_task = asyncio.create_task(self.update()) 156 | self.listener = threading.Thread(target=self.block_listener, args=(self.loop,), daemon=True).start() 157 | 158 | while True: 159 | 160 | try: 161 | # Start step. 162 | logger.info('\n' + '-' * 40 + f' Step: {self.global_step} ' + '-' * 40) 163 | step_start_time = time.time() 164 | self.global_step += 1 165 | self.step_window = self.current_window 166 | self.eval_window = self.current_window - 2 167 | logger.info(f"Step: {self.global_step}, Step Window: {self.step_window}, Eval Window: {self.eval_window}" 168 | f"Block: {self.current_block}, Time: {int(step_start_time)}") 169 | 170 | # Download the slices for the window. 171 | logger.info(f"\tDownloading slices from previous window: { self.eval_window }") 172 | start_time = time.time() 173 | slice_infos = await download_slices_for_buckets_and_windows( 174 | buckets=self.buckets, 175 | windows=[self.eval_window] 176 | ) 177 | await download_slices_for_buckets_and_windows( 178 | buckets=self.buckets, 179 | windows=[self.eval_window + 1] 180 | ) 181 | # If there are no slices to eval, wait until the next window then start again. 182 | if self.eval_window not in slice_infos or len(slice_infos[self.eval_window]) == 0: 183 | print ('\t\tNo slices to download, waiting for next window...') 184 | while self.current_window == self.step_window: await asyncio.sleep(0.1) 185 | continue 186 | slice_infos = slice_infos[self.eval_window] 187 | logger.info(f"\t\tDownloaded {len(slice_infos)} slices for previous window: {self.eval_window} in {time.time() - start_time} seconds") 188 | 189 | # Step 2: Apply slices to the model from the previous window. 190 | logger.info(f"\tApplying slices from previous window: {self.eval_window} to model.") 191 | start_time = time.time() 192 | eval_slices = await apply_slices_to_model( 193 | model=self.model, 194 | window=self.eval_window, # Get files from previous window. 195 | seed=self.step_window, # Use seed as the hash of the current window. 196 | compression=self.hparams.compression 197 | ) 198 | await apply_slices_to_model( 199 | model=self.model, 200 | window=self.eval_window + 1, # Get files from previous window. 201 | seed=self.step_window, # Use seed as the hash of the current window. 202 | compression=self.hparams.compression 203 | ) 204 | applied_per_step = len(eval_slices) 205 | logger.info(f"\t\tApplied {applied_per_step} slices from previous window: {self.eval_window} with seed: {self.window_seeds[self.step_window]} in {time.time() - start_time} seconds") 206 | 207 | indices = await get_indices_for_window( 208 | model=self.model, 209 | seed=self.window_to_seed(self.eval_window + 1), # Seed index for the eval window. 210 | compression=self.hparams.compression 211 | ) 212 | 213 | # Step 2: Compute slice importance using second-order approximation with Fisher Information Matrix. 214 | eval_start_time = time.time() 215 | info_i = random.choice(slice_infos) 216 | 217 | # Get the UID we are evalling. 218 | try: uid = self.metagraph.hotkeys.index(info_i.hotkey) 219 | except ValueError: 220 | logger.warning(f"Hotkey {info_i.hotkey} not found in metagraph hotkeys.") 221 | continue 222 | 223 | # Load the slice for the current miner. 224 | logger.info(f"\tEvalling slice from hotkey: {info_i.hotkey} and uid: {uid}") 225 | slice_data = await get_slices( info_i.temp_file, self.model.device ) 226 | 227 | # Load the dataset for this miner. 228 | start_time = time.time() 229 | offset_i = self.eval_window * self.hparams.window_length * self.hparams.window_speed 230 | seed = uid 231 | sampled_pages = await DatasetLoader.next_pages( 232 | offset = offset_i, 233 | n_pages = self.hparams.validator_window_eval_size, 234 | seed = seed 235 | ) 236 | random.shuffle(sampled_pages) # Important to not preference early pages. 237 | logger.info(f"\t\tLoading pages: {[p[1] for p in sampled_pages]} for offset: {offset_i}, uid: {uid} and seed: {seed}") 238 | eval_dataset = await DatasetLoader.create( 239 | batch_size=self.config.actual_batch_size, 240 | sequence_length=self.hparams.sequence_length, 241 | pages_info = sampled_pages, 242 | tokenizer=self.hparams.tokenizer 243 | ) 244 | logger.info(f"\t\t\tLoaded pages in {time.time() - start_time} seconds") 245 | 246 | # Run the eval. 247 | logger.info(f"\t\tRunning evaluation for uid: {uid} with sample rate: {self.sample_rate}") 248 | start_time = time.time() 249 | self.model.zero_grad() 250 | self.model.eval() 251 | # Enable gradient computation 252 | exhuasted_window = False 253 | full_steps = 0 254 | with torch.enable_grad(): 255 | for idx, batch in enumerate(eval_dataset): 256 | # Randomly sample every sample_rate examples 257 | if random.random() < self.sample_rate: 258 | full_steps += 1 259 | input_ids = torch.tensor(batch, dtype=torch.long).to(self.model.device) 260 | labels = input_ids.clone() 261 | labels = torch.where(labels == self.hparams.tokenizer.pad_token_id, -100, labels) 262 | with torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16): # Enable autocasting 263 | outputs = self.model(input_ids=input_ids, labels=labels) 264 | loss = outputs.loss 265 | loss.backward() 266 | if self.step_window != self.current_window: 267 | exhuasted_window = True 268 | break 269 | logger.info(f"\t\t\tTotal steps: {idx}, Applied: {full_steps}, Rate: {full_steps/(idx + 1)}, Sample Probability: {self.sample_rate}") 270 | logger.info(f"\t\t\tFinished running eval with sample rate: {self.sample_rate} on pages in {time.time() - start_time} seconds") 271 | if exhuasted_window: 272 | self.sample_rate = max(0.0001, self.sample_rate * 0.99) 273 | else: 274 | self.sample_rate = min(1, self.sample_rate * 1.01) 275 | 276 | # Collect gradients for all parameters. 277 | logger.info(f"\t\tComputing scores") 278 | start_time = time.time() 279 | gradients = {} 280 | for name, param in self.model.named_parameters(): 281 | if param.grad is None: 282 | continue 283 | gradients[name] = param.grad.view(-1).clone().detach() 284 | 285 | delta_L = 0.0 286 | for name, param in self.model.named_parameters(): 287 | if param.grad is None: 288 | continue # Skip parameters without gradients 289 | # Retrieve the indices for the current parameter subset. 290 | param_indices = indices[name].to(self.model.device) 291 | # Extract the gradient vector for the current parameter subset. 292 | g = gradients[name][param_indices].to(self.model.device) # Shape: [num_params_in_subset] 293 | # Extract and flatten the slice vector for the current parameter subset. 294 | s = slice_data[name].view(-1).to(self.model.device) # Shape: [num_params_in_subset] 295 | # Retrieve the current parameter values for the subset. 296 | theta = param.data.view(-1)[param_indices] # Shape: [num_params_in_subset] 297 | # Calculate the change in parameter values. 298 | delta_theta = theta - s 299 | # Compute the cosine similarity between delta_theta and the gradient vector. 300 | cosine_similarity = torch.nn.functional.cosine_similarity(delta_theta, gradients[name][param_indices], dim=0).item() 301 | # Calculate the weight of the parameter subset. 302 | weight = param.data.view(-1)[param_indices].norm().item() + 1e-8 303 | # Update the total importance score. 304 | delta_L += weight * cosine_similarity 305 | 306 | # Assign the computed importance score to the corresponding UID. 307 | logger.info(f"\t\t\tAssigning computed importance score to UID: {uid} with score {delta_L}") 308 | 309 | # Clean up GPU memory 310 | del slice_data 311 | del eval_dataset 312 | del gradients 313 | torch.cuda.empty_cache() 314 | 315 | # Step 7: Normalize the scores as rewards and use them as weights. 316 | start_time = time.time() 317 | logger.info('\t\t\tWeights:') 318 | self.loss_change[uid] = delta_L 319 | self.scores[uid] = (1 - self.hparams.validator_moving_alpha) * delta_L + self.hparams.validator_moving_alpha * self.scores[uid] 320 | # If a score is NaN, set it to zero 321 | self.scores[torch.isnan(self.scores)] = 0 322 | # Get all valid score value indices. 323 | valid_score_indices = torch.nonzero((self.scores != 0) & (~torch.isnan(self.scores))).squeeze().view(-1, 1) 324 | # Get all valid score values. 325 | valid_scores = self.scores[valid_score_indices].view(-1, 1) if valid_score_indices.dim() == 0 else self.scores[valid_score_indices] 326 | if len(valid_scores) > 0: 327 | max_score = torch.max(valid_scores) 328 | normalized_scores = torch.softmax((valid_scores - max_score) * self.hparams.validator_weights_temperature, dim=0) 329 | self.weights[valid_score_indices] = normalized_scores 330 | if self.config.use_wandb: 331 | for uid_i in valid_score_indices: 332 | wandb.log({ 333 | f"loss_change/{uid_i.item()}": self.loss_change[uid_i].item(), 334 | f"moving_scores/{uid_i.item()}": self.scores[uid_i].item(), 335 | f"weights/{uid_i.item()}": self.weights[uid_i].item(), 336 | 'self.sample_rate': self.sample_rate, 337 | }) 338 | for uid_i in valid_score_indices: 339 | moving_score = self.scores[uid_i].item() 340 | weight = self.weights[uid_i].item() 341 | step_score = self.loss_change[uid_i].item() 342 | logger.info(f"\t\t\t\tuid: {uid_i.item()}, loss_change: {step_score:.6f}, moving_score: {moving_score:.6f}, weight: {weight:.6f}") 343 | logger.info(f"\t\tFinished evalling uid: {uid} in {time.time() - eval_start_time} seconds") 344 | 345 | # Delete lingering files 346 | logger.info(f"\tCleaning space.") 347 | start_time = time.time() 348 | await delete_files_before_window( window_max = self.current_window - self.hparams.max_history ) 349 | logger.info(f"\t\tFinished cleaning space in {time.time() - start_time} seconds.") 350 | 351 | # Ensure window is over. 352 | logger.info(f'\nGlobal step completed in {time.time() - step_start_time} seconds\n') 353 | while self.current_window == self.step_window: await asyncio.sleep(0.1) 354 | 355 | # Catch keyboard interrrupt. 356 | except KeyboardInterrupt: 357 | logger.info("Training interrupted by user. Stopping the run.") 358 | self.stop_event.set() 359 | await self.update_task 360 | sys.exit(0) 361 | 362 | # Catch unknown. 363 | except Exception as e: 364 | logger.exception(f"Exception during training loop: {e}") 365 | continue 366 | 367 | # Returns the slice window based on a block. 368 | def block_to_window(self, block: int) -> int: 369 | return int(block / self.hparams.window_length) 370 | 371 | # Returns the slice window based on a block. 372 | def window_to_seed(self, window: int) -> int: 373 | return str( self.subtensor.get_block_hash( window * self.hparams.window_length ) ) 374 | 375 | # A listener thread which posts the block event 376 | # when the chain announces a new block. 377 | def block_listener(self, loop): 378 | def handler(event, _u, _s): 379 | self.current_block = int(event['header']['number']) 380 | loop.call_soon_threadsafe(self.block_event.set) 381 | if self.block_to_window(self.current_block) != self.current_window: 382 | self.window_seeds[ self.block_to_window(self.current_block) ] = self.window_to_seed( self.block_to_window(self.current_block) ) 383 | self.current_window = self.block_to_window(self.current_block) 384 | loop.call_soon_threadsafe(self.new_window_event.set) 385 | logger.info(f"-- New window: {self.current_window} -- ") 386 | # Run listener with retry. 387 | while not self.stop_event.is_set(): 388 | try: 389 | bt.subtensor(config=self.config).substrate.subscribe_block_headers(handler); break 390 | except Exception as e: 391 | # Wait for 5 seconds before retrying 392 | logger.error(f"Failed to subscribe to block headers: {e}.\nRetrying in 1 seconds...") 393 | time.sleep(1) 394 | 395 | if __name__ == "__main__": 396 | validator = Validator() 397 | asyncio.run(validator.run()) 398 | -------------------------------------------------------------------------------- /tools/clean.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | 18 | import io 19 | import os 20 | import sys 21 | import copy 22 | import json 23 | import time 24 | import types 25 | import boto3 26 | import torch 27 | import typer 28 | import wandb 29 | import random 30 | import argparse 31 | import tempfile 32 | from tqdm import tqdm 33 | import torch.optim as optim 34 | from dotenv import dotenv_values 35 | from types import SimpleNamespace 36 | from transformers import AutoTokenizer 37 | from transformers import GPT2Config, GPT2LMHeadModel 38 | 39 | env_config = {**dotenv_values(".env"), **os.environ} 40 | AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID') 41 | AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY') 42 | CLIENT: boto3.client = boto3.client( 43 | 's3', 44 | region_name='us-east-1', 45 | aws_access_key_id = AWS_ACCESS_KEY_ID, 46 | aws_secret_access_key = AWS_SECRET_ACCESS_KEY 47 | ) 48 | 49 | def main( 50 | bucket: str = 'decis', 51 | ): 52 | # Create your S3 connection. 53 | client: boto3.client = boto3.client( 54 | 's3', 55 | region_name = 'us-east-1', 56 | aws_access_key_id = AWS_ACCESS_KEY_ID, 57 | aws_secret_access_key = AWS_SECRET_ACCESS_KEY 58 | ) 59 | continuation_token = None 60 | while True: 61 | if continuation_token: 62 | response = client.list_objects_v2(Bucket=bucket, ContinuationToken=continuation_token) 63 | else: 64 | response = client.list_objects_v2(Bucket=bucket) 65 | 66 | file_names = [content['Key'] for content in response.get('Contents', [])] 67 | 68 | # Delete all the filenames 69 | for file_name in file_names: 70 | client.delete_object(Bucket=bucket, Key=file_name) 71 | print(f"Deleted {file_name}") 72 | 73 | # Check if there are more files to delete 74 | continuation_token = response.get('NextContinuationToken') 75 | if not continuation_token: 76 | break 77 | 78 | # Main function. 79 | if __name__ == "__main__": 80 | typer.run(main) 81 | -------------------------------------------------------------------------------- /tools/print.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # Copyright © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | import io 18 | import os 19 | import sys 20 | import copy 21 | import json 22 | import time 23 | import types 24 | import boto3 25 | import torch 26 | import typer 27 | import wandb 28 | import random 29 | import argparse 30 | import tempfile 31 | from tqdm import tqdm 32 | import torch.optim as optim 33 | from dotenv import dotenv_values 34 | from types import SimpleNamespace 35 | from transformers import AutoTokenizer 36 | from transformers import GPT2Config, GPT2LMHeadModel 37 | from rich.console import Console 38 | from rich.table import Table 39 | 40 | env_config = {**dotenv_values(".env"), **os.environ} 41 | 42 | def human_readable_size(size, decimal_places=2): 43 | for unit in ['B', 'KB', 'MB', 'GB', 'TB']: 44 | if size < 1024.0: 45 | return f"{size:.{decimal_places}f} {unit}" 46 | size /= 1024.0 47 | 48 | def main( 49 | bucket: str = 'decis', 50 | aws_access_key_id: str = env_config.get('AWS_ACCESS_KEY_ID'), 51 | aws_secret_access_key: str = env_config.get('AWS_SECRET_ACCESS_KEY'), 52 | ): 53 | # Create the hparams item. 54 | hparams = SimpleNamespace( 55 | bucket = bucket, 56 | aws_access_key_id = aws_access_key_id, 57 | aws_secret_access_key = aws_secret_access_key, 58 | ) 59 | # Create your S3 connection. 60 | client: boto3.client = boto3.client( 61 | 's3', 62 | region_name = 'us-east-1', 63 | aws_access_key_id = hparams.aws_access_key_id, 64 | aws_secret_access_key = hparams.aws_secret_access_key 65 | ) 66 | response = client.list_objects_v2(Bucket=hparams.bucket) 67 | if 'Contents' in response: 68 | # Extract both file names and sizes 69 | file_info = [(content['Key'], content['Size'], content['LastModified']) for content in response['Contents']] 70 | 71 | # Create a table using rich 72 | table = Table(title="S3 Bucket Files") 73 | table.add_column("File Name", justify="left", style="cyan", no_wrap=True) 74 | table.add_column("Size", justify="right", style="magenta") 75 | table.add_column("Upload Time", justify="right", style="green") 76 | table.add_column("Hotkey", justify="left", style="yellow") 77 | table.add_column("SS5D Address", justify="left", style="blue") 78 | table.add_column("Block", justify="right", style="red") 79 | 80 | for file_name, file_size, last_modified in file_info: 81 | # Extract hotkey, ss5d address, and block from the file name 82 | parts = file_name.split('-') 83 | if len(parts) >= 3: 84 | hotkey = parts[1] 85 | ss5d_address = parts[2] 86 | block = parts[-1].split('_')[0] 87 | else: 88 | hotkey = ss5d_address = block = "N/A" 89 | 90 | table.add_row( 91 | file_name, 92 | human_readable_size(file_size), 93 | last_modified.strftime("%Y-%m-%d %H:%M:%S"), 94 | hotkey, 95 | ss5d_address, 96 | block 97 | ) 98 | 99 | console = Console() 100 | console.print(table) 101 | 102 | print('\nStats') 103 | print('Total Files:', len(file_info)) 104 | total_size = sum(file_size for _, file_size, _ in file_info) 105 | print(f'Total Size: {human_readable_size(total_size)}') 106 | 107 | else: 108 | print('No files found in the bucket.') 109 | 110 | # Main function. 111 | if __name__ == "__main__": 112 | typer.run(main) 113 | -------------------------------------------------------------------------------- /validator.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # © 2024 Chakana.tech 3 | 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 5 | # documentation files (the “Software”), to deal in the Software without restriction, including without limitation 6 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 7 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 10 | # the Software. 11 | 12 | # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 14 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 15 | # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 16 | # DEALINGS IN THE SOFTWARE. 17 | # fmt: off 18 | 19 | # Global imports. 20 | import os 21 | import sys 22 | import time 23 | import wandb 24 | import torch 25 | import random 26 | import asyncio 27 | import argparse 28 | import threading 29 | import traceback 30 | from tqdm import tqdm 31 | import bittensor as bt 32 | from typing import List 33 | import torch.optim as optim 34 | from dotenv import dotenv_values 35 | from transformers import LlamaForCausalLM 36 | from torch.optim.lr_scheduler import CosineAnnealingLR 37 | 38 | # Import local files. 39 | from common import * 40 | from hparams import load_hparams 41 | from dataset import DatasetLoader 42 | 43 | # GPU optimizations. 44 | torch.backends.cudnn.benchmark = True 45 | torch.backends.cuda.matmul.allow_tf32 = True 46 | torch.backends.cudnn.allow_tf32 = True 47 | 48 | class Validator: 49 | 50 | @staticmethod 51 | def config(): 52 | parser = argparse.ArgumentParser(description='Validator script') 53 | parser.add_argument('--project', type=str, default='aesop2', help='Optional wandb project name') 54 | parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') 55 | parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') 56 | parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') 57 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') 58 | parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') 59 | parser.add_argument('--debug', action='store_true', help='Enable debug logging') 60 | parser.add_argument('--trace', action='store_true', help='Enable trace logging') 61 | parser.add_argument('--sync_state', action='store_true', help='Syncs the model state by pulling from the history.') 62 | bt.wallet.add_args(parser) 63 | bt.subtensor.add_args(parser) 64 | config = bt.config(parser) 65 | config.subtensor.network = 'test' 66 | config.subtensor.chain_endpoint = 'wss://test.finney.opentensor.ai:443/' 67 | if config.debug: debug() 68 | if config.trace: trace() 69 | return config 70 | 71 | def __init__(self): 72 | # Init config. 73 | self.config = Validator.config() 74 | logger.info('\n' + '-' * 40 + ' Config ' + '-' * 40) 75 | logger.info(self.config) 76 | 77 | # Init bittensor objects. 78 | self.wallet = bt.wallet(config=self.config) 79 | self.subtensor = bt.subtensor(config=self.config) 80 | self.metagraph = self.subtensor.metagraph(netuid=self.config.netuid) 81 | if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys: 82 | raise ValueError(f'Wallet {self.wallet} is not registered on subnet: {self.metagraph.netuid}') 83 | self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address) 84 | logger.info('\n' + '-' * 40 + ' Objects ' + '-' * 40) 85 | logger.info(f'\nWallet: {self.wallet}\nSubtensor: {self.subtensor}\nMetagraph: {self.metagraph}\nUID: {self.uid}') 86 | 87 | # Init bucket. 88 | try: 89 | if self.config.bucket != self.subtensor.get_commitment(self.config.netuid, self.uid): 90 | raise ValueError('') 91 | except: 92 | self.subtensor.commit(self.wallet, self.config.netuid, self.config.bucket) 93 | logger.info('Bucket:' + self.config.bucket) 94 | 95 | # Init Wandb. 96 | if self.config.use_wandb: 97 | # Delete all runs with my name and create a new one. 98 | try: 99 | for run in wandb.Api().runs(path=self.config.project): 100 | if run.name == f'V{self.uid}': 101 | logger.info(f'Deleting old run: {run}'); run.delete() 102 | except: pass 103 | wandb.init(project=self.config.project, resume='allow', name=f'V{self.uid}', config=self.config) 104 | 105 | # Init model. 106 | logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) 107 | self.hparams = load_hparams() 108 | torch.manual_seed(42); np.random.seed(42); random.seed(42) 109 | self.model = LlamaForCausalLM(config=self.hparams.model_config) 110 | # self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') 111 | self.model.to(self.config.device) 112 | self.model.eval() 113 | 114 | # Init buckets. 115 | self.buckets = [] 116 | for uid in self.metagraph.uids: 117 | # Use --remote to connect to other miners, other wise, only see's config.bucket. 118 | try: self.buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid ) ) 119 | except: self.buckets.append(None) 120 | 121 | # Init run state. 122 | self.global_step = 0 123 | self.last_window = 0 124 | self.optimal_pages_per_step = 4 125 | self.current_block = self.subtensor.block 126 | self.current_window = self.block_to_window( self.current_block ) 127 | self.window_seeds = {self.current_window: self.window_to_seed( self.current_window) } 128 | self.block_event = asyncio.Event() 129 | self.new_window_event = asyncio.Event() 130 | self.stop_event = asyncio.Event() 131 | self.step_scores = torch.zeros( 256, dtype = torch.float32 ) 132 | self.scores = torch.zeros( 256, dtype = torch.float32 ) 133 | self.weights = torch.zeros( 256, dtype = torch.float32 ) 134 | self.sample_rate = 1.0 135 | print ( self.hparams ) 136 | 137 | async def update(self): 138 | while not self.stop_event.is_set(): # Loop until stop_event is set 139 | self.subtensor = bt.subtensor(config=self.config) # Reinitialize subtensor with current config 140 | nxt_meta = self.subtensor.metagraph(self.config.netuid) # Get the new metagraph for the given netuid 141 | self.hparams = load_hparams() # Reload hyperparameters 142 | next_buckets = [] # Initialize the next_buckets list 143 | for uid in nxt_meta.uids: # Iterate over new metagraph uids 144 | try: next_buckets.append(self.config.bucket if not self.config.remote else self.subtensor.get_commitment( self.config.netuid, uid )) 145 | except: next_buckets.append(None) 146 | self.buckets = next_buckets # Update self.buckets with next_buckets 147 | for idx, hotkey in enumerate(self.metagraph.hotkeys): # Iterate over current metagraph hotkeys 148 | if hotkey != nxt_meta.hotkeys[idx]: # Check if hotkey has changed in the new metagraph 149 | self.scores[idx] = 0 # Reset rewards for the changed hotkey 150 | self.weights[idx] = 0 # Reset weights for the changed hotkey 151 | self.metagraph = nxt_meta # Update self.metagraph with new_metagraph 152 | await asyncio.sleep(60) # Sleep for 60 seconds before the next iteration 153 | 154 | async def run(self): 155 | # Main loop. 156 | self.loop = asyncio.get_running_loop() 157 | self.update_task = asyncio.create_task(self.update()) 158 | self.listener = threading.Thread(target=self.block_listener, args=(self.loop,), daemon=True).start() 159 | 160 | # Optionally sync the model state by pulling model states from the history. 161 | if self.config.sync_state: 162 | history_windows = [ self.current_window - i for i in range (self.hparams.max_history) ] 163 | state_slices = await download_slices_for_buckets_and_windows( 164 | buckets = self.buckets, 165 | windows = history_windows, 166 | key = 'state' 167 | ) 168 | for window in tqdm(history_windows, desc="Syncing state"): 169 | await apply_slices_to_model( 170 | model = self.model, 171 | window = window, 172 | seed = window, 173 | compression = self.hparams.compression, 174 | key = 'state' 175 | ) 176 | torch.cuda.empty_cache() 177 | 178 | # Run validation. 179 | while True: 180 | try: 181 | # Get the window we are evalling. 182 | logger.info('[bold]' + '\n' + '-' * 40 + f' Step: {self.global_step} ' + '-' * 40) 183 | gs_start = T() 184 | self.global_step += 1 185 | offset = 2 186 | window = self.current_window - offset 187 | 188 | # Download the state for the eval window. 189 | st = T() 190 | state_slices = await download_slices_for_buckets_and_windows( 191 | buckets = self.buckets, 192 | windows = [ window ], 193 | key = 'state' 194 | ) 195 | n_state_slices = len(state_slices[ window ]) if window in state_slices else 0 196 | logger.info(f"{P(window, T() - st)}: Downloaded {n_state_slices} window states.") 197 | 198 | # Download the delta for the eval window. 199 | st = T() 200 | eval_slices = await download_slices_for_buckets_and_windows( 201 | buckets = self.buckets, 202 | windows = [ window ], 203 | key = 'delta' 204 | ) 205 | n_eval_slices = len(eval_slices[ window ]) if window in eval_slices else 0 206 | logger.info(f"{P(window, T() - st)}: Downloaded {n_eval_slices} window deltas.") 207 | if n_eval_slices == 0: 208 | logger.info(f"{P(window, T() - st)}: No slices to eval, continue ...") 209 | while self.current_window - offset == window: await asyncio.sleep(0.1) # Wait for next window. 210 | continue 211 | 212 | # Applied the model state state for the eval window. 213 | st = T() 214 | await apply_slices_to_model( 215 | model = self.model, 216 | window = window, 217 | seed = window, 218 | compression = self.hparams.compression, 219 | key = 'state', 220 | ) 221 | logger.info(f"{P(window, T() - st)}: Applied window state.") 222 | 223 | # Attain the indicies for the eval window. 224 | st = T() 225 | indices = await get_indices_for_window( 226 | model = self.model, 227 | seed = window, 228 | compression = self.hparams.compression 229 | ) 230 | logger.info(f"{P(window, T() - st)}: Attained window indices.") 231 | 232 | 233 | # Attain the UID of this slice. 234 | st = T() 235 | eval_slice_info = random.choice( eval_slices[ window ] ) 236 | try: eval_uid = self.metagraph.hotkeys.index(eval_slice_info.hotkey) 237 | except ValueError: 238 | logger.warning(f"{P(window, T() - st)}: {eval_slice_info.hotkey} not found in metagraph") 239 | continue 240 | eval_slice_data = await get_slices( eval_slice_info.temp_file, self.model.device ) 241 | logger.info(f"{P(window, T() - st)}: Loaded window slices for uid: [dark_sea_green]{eval_uid}[/dark_sea_green].") 242 | 243 | # Download the eval page for this uid. 244 | st = T() 245 | eval_pages = await DatasetLoader.next_pages( 246 | offset = window, 247 | n_pages = self.hparams.validator_window_eval_size, 248 | seed = eval_uid 249 | ) 250 | random.shuffle( eval_pages ) 251 | eval_dataset = await DatasetLoader.create( 252 | batch_size = self.config.actual_batch_size, 253 | sequence_length = self.hparams.sequence_length, 254 | pages_info = eval_pages, 255 | tokenizer = self.hparams.tokenizer 256 | ) 257 | logger.info(f"{P(window, T() - st)}: Downloaded eval pages: [light_steel_blue]{[p[1] for p in eval_pages]}[/light_steel_blue].") 258 | 259 | # Accumulate gradients from this page. 260 | eval_start = T() 261 | self.model.zero_grad() 262 | total_loss = 0.0 263 | full_steps = 0; total_steps = 0; 264 | exhuasted_window = False 265 | with torch.enable_grad(): 266 | for idx, batch in enumerate(eval_dataset): 267 | total_steps += 1 268 | if random.random() < self.sample_rate and not exhuasted_window: 269 | full_steps += 1 270 | input_ids = torch.tensor(batch, dtype=torch.long).to(self.model.device) 271 | labels = input_ids.clone() 272 | labels = torch.where(labels == self.hparams.tokenizer.pad_token_id, -100, labels) 273 | with torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16): # Enable autocasting 274 | outputs = self.model(input_ids=input_ids, labels=labels) 275 | total_loss += outputs.loss.item() 276 | outputs.loss.backward() 277 | if self.current_window - offset != window: exhuasted_window = True; continue 278 | step_loss = total_loss/(full_steps+1) 279 | eval_duration = T() - eval_start 280 | tokens_per_step = self.hparams.sequence_length * self.config.actual_batch_size * (full_steps + 1) 281 | tokens_per_second = tokens_per_step / eval_duration 282 | logger.info(f"{P(window, eval_duration)}: Accumulated gradients:") 283 | logger.info(f"{P(window, eval_duration)}: \tTotal steps: [tan]{full_steps}/{total_steps}[/tan], Rate: [tan]{(full_steps/total_steps):.2f}[/tan], Target: [tan]{self.sample_rate:.2f}[/tan]") 284 | logger.info(f"{P(window, eval_duration)}: \tTotal tokens: [tan]{tokens_per_step}[/tan], Tokens per second: [tan]{tokens_per_second:.2f}[/tan]") 285 | logger.info(f"{P(window, eval_duration)}: \tLoss: [tan]{step_loss}[tan]") 286 | if exhuasted_window: self.sample_rate = max(0.0001, self.sample_rate * 0.95) 287 | else: self.sample_rate = min(1, self.sample_rate * 1.05) 288 | 289 | # Compute the score for this slice. 290 | st = T() 291 | score = 0.0 292 | for i, (name_i, param_i) in enumerate( self.model.named_parameters() ): 293 | if param_i.grad is None: continue # Skip parameters without gradients 294 | idxs_i = indices[name_i].to(self.model.device) 295 | grad_i = param_i.grad.view(-1).clone()[idxs_i].to(self.model.device) 296 | slice_i = eval_slice_data[name_i].view(-1).to(self.model.device) 297 | theta_i = param_i.data.view(-1)[idxs_i] 298 | delta_i = theta_i - slice_i 299 | sim_i = torch.nn.functional.cosine_similarity(delta_i, grad_i, dim=0).item() 300 | weight_i = param_i.data.view(-1)[idxs_i].norm().item() + 1e-8 301 | score += weight_i * sim_i 302 | logger.info(f"{P(window, T() - st)}: Computed score: [bold dark_sea_green]{score:.4f}[/bold dark_sea_green]") 303 | 304 | # Assign and log scores. 305 | start_time = T() 306 | self.step_scores[ eval_uid ] = score 307 | self.scores[ eval_uid ] = (1 - self.hparams.validator_moving_alpha) * score + self.hparams.validator_moving_alpha * self.scores[eval_uid] 308 | self.scores[ torch.isnan(self.scores) ] = 0 309 | valid_score_indices = torch.nonzero((self.scores != 0) & (~torch.isnan(self.scores))).squeeze().view(-1, 1) 310 | valid_scores = self.scores[valid_score_indices].view(-1, 1) if valid_score_indices.dim() == 1 else self.scores[valid_score_indices] 311 | if valid_scores.numel() > 0: 312 | self.weights[valid_score_indices] = valid_scores / (valid_scores.sum() + 1e-8) # Weights are normalized scores. 313 | for uid_i in valid_score_indices: 314 | moving_score = self.scores[ uid_i ].item() 315 | weight = self.weights[ uid_i ].item() 316 | step_score = self.step_scores[ uid_i ].item() 317 | logger.info( 318 | f"\tuid: [dark_sea_green]{uid_i.item()}[/dark_sea_green], " 319 | f"last: [dark_sea_green]{step_score:.3f}[/dark_sea_green], " 320 | f"moving: [dark_sea_green]{moving_score:.3f}[/dark_sea_green], " 321 | f"weight: [dark_sea_green]{weight:.3f}[/dark_sea_green]" 322 | ) 323 | 324 | # Apply all deltas to the model state. 325 | st = T() 326 | await apply_slices_to_model( 327 | model = self.model, 328 | window = window, 329 | seed = window, 330 | compression = self.hparams.compression, 331 | key = 'delta', 332 | ) 333 | logger.info(f"{P(window, T() - st)}: Applied window deltas.") 334 | 335 | # Clean local and remote space from old slices. 336 | st = T() 337 | await delete_files_before_window( window_max = window - self.hparams.max_history, key = 'state') 338 | await delete_files_before_window( window_max = window - self.hparams.max_history, key = 'delta') 339 | await delete_files_from_bucket_before_window( bucket = self.config.bucket, window_max = window - self.hparams.max_history, key = 'state' ) 340 | await delete_files_from_bucket_before_window( bucket = self.config.bucket, window_max = window - self.hparams.max_history, key = 'delta' ) 341 | logger.info(f"{P(window, T() - st)}: Cleaned file history.") 342 | 343 | # Finish step. 344 | gs_end = T() 345 | while self.current_window - offset == window: 346 | await asyncio.sleep(0.1) 347 | window_time_delta = self.window_time - gs_end 348 | window_delta_str = f"[red]{window_time_delta:.2f}[/red]" if window_time_delta < 0 else f"[green]+{window_time_delta:.2f}[/green]" 349 | logger.info(f"{P(window, gs_end - gs_start)}[{window_delta_str}]: Finished step.") 350 | if self.config.use_wandb: 351 | wandb.log({ 352 | f"loss": step_loss, 353 | f"tokens_per_step": tokens_per_step, 354 | f"tokens_per_second": tokens_per_second, 355 | f"sample_rate": self.sample_rate, 356 | f"utilization": eval_duration / (gs_end - gs_start) 357 | }) 358 | for uid_i in valid_score_indices: 359 | wandb.log({ 360 | f"step_scores/{uid_i.item()}": self.step_scores[ uid_i ].item(), 361 | f"moving_scores/{uid_i.item()}": self.scores[ uid_i ].item(), 362 | f"weights/{uid_i.item()}": self.weights[ uid_i ].item(), 363 | }) 364 | 365 | 366 | # Catch keyboard interrrupt. 367 | except KeyboardInterrupt: 368 | logger.info("Training interrupted by user. Stopping the run.") 369 | self.stop_event.set() 370 | await self.update_task 371 | sys.exit(0) 372 | 373 | # Catch unknown. 374 | except Exception as e: 375 | logger.exception(f"Exception during training loop: {e}") 376 | continue 377 | 378 | # Returns the slice window based on a block. 379 | def block_to_window(self, block: int) -> int: 380 | return int(block / self.hparams.window_length) 381 | 382 | # Returns the slice window based on a block. 383 | def window_to_seed(self, window: int) -> int: 384 | return str( self.subtensor.get_block_hash( window * self.hparams.window_length ) ) 385 | 386 | # A listener thread which posts the block event 387 | # when the chain announces a new block. 388 | def block_listener(self, loop): 389 | def handler(event, _u, _s): 390 | self.current_block = int(event['header']['number']) 391 | loop.call_soon_threadsafe(self.block_event.set) 392 | if self.block_to_window(self.current_block) != self.current_window: 393 | self.window_seeds[ self.block_to_window(self.current_block) ] = self.window_to_seed( self.block_to_window(self.current_block) ) 394 | self.current_window = self.block_to_window(self.current_block) 395 | self.window_duration = T() - self.window_time if hasattr(self, 'window_time') else 0 396 | self.window_time = T() 397 | loop.call_soon_threadsafe(self.new_window_event.set) 398 | logger.info(f"{P(self.current_window, self.window_duration)} New Window.") 399 | 400 | # Run listener with retry. 401 | while not self.stop_event.is_set(): 402 | try: 403 | bt.subtensor(config=self.config).substrate.subscribe_block_headers(handler); break 404 | except Exception as e: 405 | # Wait for 5 seconds before retrying 406 | logger.error(f"Failed to subscribe to block headers: {e}.\nRetrying in 1 seconds...") 407 | time.sleep(1) 408 | 409 | if __name__ == "__main__": 410 | validator = Validator() 411 | asyncio.run(validator.run()) 412 | --------------------------------------------------------------------------------