├── lib ├── __init__.py └── training │ ├── __init__.py │ ├── clipped_lamb.py │ ├── wrapper.py │ ├── hf_trainer.py │ ├── offload.py │ ├── tpu.py │ └── lamb_8bit.py ├── inference ├── .gitignore └── run_inference.py ├── requirements.txt ├── LICENSE ├── README.md ├── .gitignore ├── data.py ├── run_trainer.py ├── utils.py ├── run_trainer_tpu.py ├── callback.py ├── run_aux_peer.py ├── arguments.py ├── task.py ├── huggingface_auth.py └── manage_scaleset.py /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /inference/.gitignore: -------------------------------------------------------------------------------- 1 | queries.txt 2 | outputs 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | https://github.com/learning-at-home/hivemind/archive/dalle-v1.zip 2 | https://github.com/learning-at-home/dalle-pytorch/archive/weight-sharing.zip 3 | torchmetrics==0.6.2 4 | bitsandbytes-cuda111==0.26.0.post2 5 | transformers==4.12.2 6 | tokenizers==0.10.3 7 | datasets==1.14.0 8 | torch-optimizer==0.1.0 9 | wandb==0.12.1 10 | nltk==3.6.2 11 | sentencepiece==0.1.96 12 | aiohttp==3.7.4.post0 13 | requests>=2.23.0 14 | termcolor==1.1.0 15 | -------------------------------------------------------------------------------- /lib/training/clipped_lamb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_optimizer import Lamb 3 | 4 | 5 | class LambWithGradientClipping(Lamb): 6 | """ A version of LAMB that clips gradients based on their norm. """ 7 | def __init__(self, *args, max_grad_norm: float, **kwargs): 8 | self.max_grad_norm = max_grad_norm 9 | super().__init__(*args, **kwargs) 10 | 11 | def step(self, *args, **kwargs): 12 | iter_params = (param for group in self.param_groups for param in group['params']) 13 | torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm) 14 | return super().step(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-2022 Learning@home authors and collaborators 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Training DALL-E with volunteers from all over the Internet 2 | 3 |

4 | This repository is a part of the NeurIPS 2021 demonstration "Training Transformers Together". 5 |

6 |

7 | In this demo, we train a model similar to OpenAI DALL-E — 8 | a Transformer "language model" that generates images from text descriptions. 9 | Training happens collaboratively — volunteers from all over the Internet contribute to the training using hardware available to them. 10 | We use LAION-400M, 11 | the world's largest openly available image-text-pair dataset with 400 million samples. Our model is based on 12 | the dalle‑pytorch implementation 13 | by Phil Wang with a few tweaks to make it communication-efficient. 14 |

15 |

16 | See details about how to join and how it works on our website. 17 |

18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # node and NPM 2 | npm-debug.log 3 | node_modules 4 | 5 | # swap files 6 | *~ 7 | *.swp 8 | 9 | examples/data/* 10 | examples/runs/* 11 | examples/.ipynb_checkpoints/* 12 | 13 | env.sh 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | bin/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg/ 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Mr Developer 53 | .mr.developer.cfg 54 | .project 55 | .pydevproject 56 | .idea 57 | .ipynb_checkpoints 58 | 59 | # Rope 60 | .ropeproject 61 | 62 | # Django stuff: 63 | *.log 64 | *.pot 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | docs/tmp* 69 | 70 | # OS X garbage 71 | .DS_Store 72 | 73 | # Debian things 74 | debian/reproducible-experiment-platform 75 | debian/files 76 | *.substvars 77 | *.debhelper.log 78 | 79 | # protobuf stuff 80 | hivemind/proto/*_pb2* 81 | 82 | # libp2p-daemon binary 83 | hivemind/hivemind_cli/p2pd 84 | -------------------------------------------------------------------------------- /lib/training/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class OptimizerWrapper(torch.optim.Optimizer): 5 | r""" 6 | A wrapper for pytorch.optimizer that forwards all methods to the wrapped optimizer 7 | """ 8 | 9 | def __init__(self, optim: torch.optim.Optimizer): 10 | object.__init__(self) 11 | self.optim = optim 12 | 13 | @property 14 | def defaults(self): 15 | return self.optim.defaults 16 | 17 | @property 18 | def state(self): 19 | return self.optim.state 20 | 21 | def __getstate__(self): 22 | return self.optim.__getstate__() 23 | 24 | def __setstate__(self, state): 25 | self.optim.__setstate__(state) 26 | 27 | def __repr__(self): 28 | return f"{self.__class__.__name__}({repr(self.optim)})" 29 | 30 | def state_dict(self): 31 | return self.optim.state_dict() 32 | 33 | def load_state_dict(self, state_dict: dict) -> None: 34 | return self.optim.load_state_dict(state_dict) 35 | 36 | def step(self, *args, **kwargs): 37 | return self.optim.step(*args, **kwargs) 38 | 39 | def zero_grad(self, *args, **kwargs): 40 | return self.optim.zero_grad(*args, **kwargs) 41 | 42 | @property 43 | def param_groups(self): 44 | return self.optim.param_groups 45 | 46 | def add_param_group(self, param_group: dict) -> None: 47 | return self.optim.add_param_group(param_group) 48 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Optional 3 | 4 | import hivemind 5 | import numpy as np 6 | from datasets import load_dataset 7 | 8 | logger = hivemind.get_logger(__name__) 9 | 10 | 11 | def preprocess_batch(batch, tokenizer, max_sequence_length: int): 12 | mask = [ 13 | ( 14 | caption is not None and len(caption) >= 3 and 15 | nsfw == 'UNLIKELY' and 16 | orig_width > 0 and orig_height > 0 and 17 | max(orig_height / orig_width, orig_width / orig_height) <= 2 18 | ) for caption, nsfw, orig_width, orig_height in 19 | zip(batch['caption'], batch['NSFW'], batch['original_width'], batch['original_height']) 20 | ] 21 | logger.debug(f'{np.mean(mask) * 100:.1f}% of examples left after filtering') 22 | 23 | if any(mask): 24 | result = tokenizer(list(itertools.compress(batch['caption'], mask)), 25 | add_special_tokens=False, max_length=max_sequence_length, truncation=True) 26 | else: 27 | # This branch is necessary because tokenizer([]) raises IndexError 28 | result = {'input_ids': [], 'attention_mask': []} 29 | result['image'] = [np.frombuffer(encoded, np.int16).astype(np.int64) 30 | for encoded in itertools.compress(batch['code'], mask)] 31 | return result 32 | 33 | 34 | def make_dataset( 35 | tokenizer, 36 | *, 37 | shuffle_buffer_size: int = 8192, 38 | shuffle_seed: Optional[int], 39 | preprocessing_batch_size: int = 256, 40 | max_sequence_length: int, 41 | ): 42 | ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True) 43 | ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed) 44 | ds = ds.map(lambda batch: preprocess_batch(batch, tokenizer, max_sequence_length), 45 | batched=True, batch_size=preprocessing_batch_size) 46 | ds = ds.with_format('torch') 47 | return ds 48 | -------------------------------------------------------------------------------- /run_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import torch 7 | import transformers 8 | from transformers import HfArgumentParser 9 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 10 | 11 | from lib.training.hf_trainer import CollaborativeHFTrainer 12 | 13 | import callback 14 | import utils 15 | from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments 16 | from task import TrainingTask 17 | 18 | 19 | transformers.utils.logging.set_verbosity_warning() 20 | use_hivemind_log_handler("in_root_logger") 21 | logger = get_logger(__name__) 22 | 23 | torch.set_num_threads(1) # Otherwise, it becomes very slow on machines with ~100 CPUs 24 | 25 | 26 | def main(): 27 | parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments)) 28 | training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses() 29 | 30 | logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}") 31 | 32 | utils.log_process_rank(trainer_args) 33 | task = TrainingTask(training_peer_args, trainer_args, collab_args) 34 | model = task.model.to(trainer_args.device) 35 | 36 | collaborative_callback = callback.CollaborativeCallback(task, training_peer_args) 37 | assert trainer_args.do_train and not trainer_args.do_eval 38 | 39 | # Note: the code below creates the trainer with dummy scheduler and removes some callbacks. 40 | # This is done because collaborative training has its own callbacks that take other peers into account. 41 | trainer = CollaborativeHFTrainer( 42 | model=model, 43 | args=trainer_args, 44 | tokenizer=task.tokenizer, 45 | data_collator=task.data_collator, 46 | data_seed=hash(task.local_public_key), 47 | train_dataset=task.training_dataset, 48 | eval_dataset=None, 49 | collaborative_optimizer=task.collaborative_optimizer, 50 | callbacks=[collaborative_callback], 51 | ) 52 | trainer.remove_callback(transformers.trainer_callback.PrinterCallback) 53 | trainer.remove_callback(transformers.trainer_callback.ProgressCallback) 54 | 55 | latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None) 56 | trainer.train(model_path=latest_checkpoint_dir) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from multiaddr import Multiaddr 4 | from pydantic import BaseModel, StrictFloat, confloat, conint 5 | 6 | from hivemind import choose_ip_address 7 | from hivemind.dht.crypto import RSASignatureValidator 8 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator 9 | from hivemind.dht.validation import RecordValidatorBase 10 | from hivemind.utils.logging import get_logger 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | class LocalMetrics(BaseModel): 16 | step: conint(ge=0, strict=True) 17 | samples_per_second: confloat(ge=0.0, strict=True) 18 | samples_accumulated: conint(ge=0, strict=True) 19 | loss: StrictFloat 20 | mini_steps: conint(ge=0, strict=True) 21 | 22 | 23 | class MetricSchema(BaseModel): 24 | metrics: Dict[BytesWithPublicKey, LocalMetrics] 25 | 26 | 27 | def make_validators(experiment_prefix: str) -> Tuple[List[RecordValidatorBase], bytes]: 28 | signature_validator = RSASignatureValidator() 29 | validators = [SchemaValidator(MetricSchema, prefix=experiment_prefix), signature_validator] 30 | return validators, signature_validator.local_public_key 31 | 32 | 33 | class TextStyle: 34 | BOLD = "\033[1m" 35 | BLUE = "\033[34m" 36 | RESET = "\033[0m" 37 | 38 | 39 | def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None: 40 | if only_p2p: 41 | unique_addrs = {addr["p2p"] for addr in visible_maddrs} 42 | initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs) 43 | else: 44 | available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr] 45 | if available_ips: 46 | preferred_ip = choose_ip_address(available_ips) 47 | selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)] 48 | else: 49 | selected_maddrs = visible_maddrs 50 | initial_peers_str = " ".join(str(addr) for addr in selected_maddrs) 51 | 52 | logger.info( 53 | f"Running a DHT peer. To connect other peers to this one over the Internet, use " 54 | f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}" 55 | ) 56 | logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}") 57 | 58 | 59 | def log_process_rank(training_args): 60 | logger.info( 61 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 62 | f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 63 | ) 64 | -------------------------------------------------------------------------------- /lib/training/hf_trainer.py: -------------------------------------------------------------------------------- 1 | """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training""" 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from transformers.trainer import Trainer 6 | from hivemind import CollaborativeOptimizer 7 | from hivemind.optim import HivemindGradScaler 8 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 9 | 10 | use_hivemind_log_handler("in_root_logger") 11 | logger = get_logger() 12 | LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None) 13 | 14 | 15 | class CollaborativeHFTrainer(Trainer): 16 | """ 17 | A version of HuggingFace trainer that shuffles the dataset using a separate random seed. 18 | Used to ensure that peers don't process batches in the same order. 19 | """ 20 | 21 | def __init__(self, *, data_seed: int, collaborative_optimizer: CollaborativeOptimizer, **kwargs): 22 | self.data_seed = data_seed 23 | self.collaborative_optimizer = collaborative_optimizer 24 | super().__init__(optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)), **kwargs) 25 | 26 | if self.fp16_backend is not None: 27 | assert self.use_amp 28 | self.scaler = HivemindGradScaler() 29 | 30 | def get_train_dataloader(self) -> DataLoader: 31 | """Shuffle data independently for each peer to avoid duplicating batches [important for quality]""" 32 | torch.manual_seed(self.data_seed) 33 | return super().get_train_dataloader() 34 | 35 | def _wrap_model(self, model, training=True): 36 | # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step 37 | return IgnoreGradManipulations(super()._wrap_model(model, training=training), 38 | override_zero_grad=self.collaborative_optimizer.grad_averager.reuse_grad_buffers) 39 | 40 | 41 | class NoOpScheduler(LRSchedulerBase): 42 | """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler""" 43 | 44 | def get_lr(self): 45 | return [group['lr'] for group in self.optimizer.param_groups] 46 | 47 | def print_lr(self, *args, **kwargs): 48 | if self.optimizer.scheduler: 49 | return self.optimizer.scheduler.print_lr(*args, **kwargs) 50 | 51 | def step(self): 52 | logger.debug("Called NoOpScheduler.step") 53 | self._last_lr = self.get_lr() 54 | 55 | def state_dict(self): 56 | return {} 57 | 58 | def load_state_dict(self, *args, **kwargs): 59 | logger.debug("Called NoOpScheduler.load_state_dict") 60 | 61 | 62 | class IgnoreGradManipulations(nn.Module): 63 | """ Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad) """ 64 | def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True): 65 | super().__init__() 66 | self.module = module 67 | self.override_clipping = override_clipping 68 | self.override_zero_grad = override_zero_grad 69 | 70 | def forward(self, *args, **kwargs): 71 | return self.module.forward(*args, **kwargs) 72 | 73 | def zero_grad(self, set_to_none: bool = False) -> None: 74 | if self.override_zero_grad and \ 75 | all(param.grad.isfinite().all() for param in self.parameters() if param.requires_grad): 76 | logger.debug("Successfully bypassed zero_grad") 77 | else: 78 | self.module.zero_grad(set_to_none=set_to_none) 79 | 80 | def clip_grad_norm_(self, max_norm: float, norm_type: int = 2): 81 | """ ignore clip_grad_norm on each step, clip in optimizer instead """ 82 | if self.override_clipping: 83 | logger.debug("Successfully bypassed clip_grad_norm_") 84 | else: 85 | return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type) 86 | -------------------------------------------------------------------------------- /run_trainer_tpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | 4 | import wandb 5 | import torch 6 | import transformers 7 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 8 | from transformers import HfArgumentParser 9 | 10 | import utils 11 | from arguments import TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments 12 | from lib.training.tpu import TPUManager 13 | from callback import CollaborativeCallback 14 | from task import TrainingTask 15 | 16 | 17 | transformers.utils.logging.set_verbosity_warning() 18 | use_hivemind_log_handler("in_root_logger") 19 | logger = get_logger() 20 | 21 | transformers.training_args.is_torch_tpu_available = lambda: False # disable builtin TPU support to use custom code 22 | 23 | torch.set_num_threads(min(torch.get_num_threads(), 4)) # Otherwise, it becomes very slow on machines with ~100 CPUs 24 | 25 | 26 | def main(): 27 | parser = HfArgumentParser((TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments)) 28 | peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses() 29 | 30 | logger.info(f"Found {len(peer_args.initial_peers)} initial peers: {peer_args.initial_peers}") 31 | if len(peer_args.initial_peers) == 0: 32 | logger.warning("Please specify at least one network endpoint in initial peers.") 33 | 34 | utils.log_process_rank(trainer_args) 35 | task = TrainingTask(peer_args, trainer_args, collab_args) 36 | model = task.model 37 | 38 | # BEGIN init TPU 39 | assert trainer_args.do_train and not trainer_args.do_eval 40 | tpu_manager = TPUManager(model, dataset=task.training_dataset, collate_fn=task.data_collator, 41 | grad_accumulation_steps=trainer_args.gradient_accumulation_steps, 42 | batch_size_per_device=trainer_args.per_device_train_batch_size, 43 | nprocs=trainer_args.n_tpus, start=True) 44 | 45 | model = task.model = tpu_manager._synchronizer.master_model 46 | 47 | # warmup tpus 48 | logger.info("Waiting for TPUs to warm up, this may take a minute...") 49 | tpu_manager.step() 50 | logger.info("Warmup step 1 / 3 done.") 51 | tpu_manager.update_model_parameters(model.parameters()) 52 | tpu_manager.step() 53 | logger.info("Warmup step 2 / 3 done.") 54 | tpu_manager.step() 55 | tpu_manager.get_aggregated_gradients() 56 | tpu_manager.zero_grad() 57 | logger.info("Warmup step 3 / 3 done.") 58 | # END init TPU 59 | 60 | def push_params_onto_tpu(): 61 | logger.info("Pushing new params onto TPU.") 62 | tpu_manager.update_model_parameters(model.parameters()) 63 | tpu_manager.zero_grad() 64 | 65 | collaborative_optimizer = task.collaborative_optimizer 66 | collaborative_optimizer.callbacks.on_after_global_step.add(push_params_onto_tpu) 67 | collaborative_optimizer.callbacks.on_load_state_from_peers(push_params_onto_tpu) 68 | 69 | collaborative_training_callback = CollaborativeCallback(task, peer_args) 70 | 71 | state = transformers.TrainerState() 72 | control = transformers.TrainerControl() 73 | collaborative_training_callback.on_train_begin(trainer_args, state, control) 74 | tpu_manager.update_model_parameters(model.parameters()) 75 | 76 | wandb.init(project=trainer_args.wandb_project, name=trainer_args.run_name) 77 | 78 | while True: 79 | start_time = time.perf_counter() 80 | loss, num_accumulated = tpu_manager.step() 81 | time_delta = time.perf_counter() - start_time 82 | logger.info(f"Accumulated {num_accumulated} gradients at {num_accumulated / time_delta:.3f} samples/second.") 83 | wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.state_averager.scheduler.get_lr()[0]}) 84 | 85 | with torch.no_grad(): 86 | for param, grad_from_tpu in zip(model.parameters(), tpu_manager.get_aggregated_gradients()): 87 | param.grad[...] = grad_from_tpu 88 | collaborative_optimizer.step() 89 | 90 | state.log_history.append(dict(loss=loss)) 91 | collaborative_training_callback.on_step_end(trainer_args, state, control) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /lib/training/offload.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Type, Iterable, Dict, Union, Optional 3 | import multiprocessing as mp 4 | 5 | import torch 6 | 7 | from .wrapper import OptimizerWrapper 8 | 9 | 10 | class OffloadOptimizer(OptimizerWrapper): 11 | r""" A wrapper that stores optimizer statistics and performs updates on the offloaded device (e.g. CPU RAM). """ 12 | 13 | def __init__( 14 | self, param_groups: Union[Iterable[torch.nn.Parameter], Iterable[Dict]], 15 | optim_cls: Type[torch.optim.Optimizer], *args, full_sync: bool = True, 16 | offload_device=torch.device('cpu'), offload_dtype: Optional[torch.dtype] = None, **kwargs): 17 | param_groups = list(param_groups) 18 | if not isinstance(param_groups[0], dict): 19 | param_groups = [{'params': param_groups}] 20 | super().__init__(optim_cls(param_groups, *args, **kwargs)) 21 | self.full_sync = full_sync 22 | self.lock = mp.Lock() 23 | 24 | with torch.no_grad(): 25 | self.offload_params_by_group = tuple( 26 | [torch.nn.Parameter(torch.empty_like(param, device=offload_device, dtype=offload_dtype), 27 | requires_grad=param.requires_grad) 28 | for param in group["params"]] for group in param_groups) 29 | 30 | for group, offload_params in zip(param_groups, self.offload_params_by_group): 31 | for param, offload_param in zip(group['params'], offload_params): 32 | offload_param.copy_(param, non_blocking=True) 33 | if offload_param.grad is None: 34 | offload_param.grad = torch.zeros_like(offload_param) 35 | if param.grad is not None: 36 | offload_param.grad.copy_(param.grad, non_blocking=True) 37 | 38 | @contextlib.contextmanager 39 | def _use_offloaded_params(self, *, 40 | sync_params_before: bool, sync_grads_before: bool, 41 | sync_params_after: bool, sync_grads_after: bool): 42 | assert len(self.param_groups) == len(self.offload_params_by_group) 43 | original_params_per_group = [group["params"] for group in self.param_groups] 44 | with self.lock: 45 | try: 46 | with torch.no_grad(): 47 | for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group): 48 | for original_param, replacement_param in zip(original_params, replacement_params): 49 | if sync_params_before: 50 | replacement_param.copy_(original_param, non_blocking=True) 51 | if sync_grads_before and original_param.grad is not None: 52 | replacement_param.grad.copy_(original_param.grad, non_blocking=True) 53 | 54 | for group, replacement_params in zip(self.param_groups, self.offload_params_by_group): 55 | group["params"] = replacement_params 56 | yield self.param_groups 57 | finally: 58 | for group, original_params in zip(self.param_groups, original_params_per_group): 59 | group["params"] = original_params 60 | 61 | with torch.no_grad(): 62 | for original_params, replacement_params in zip(original_params_per_group, self.offload_params_by_group): 63 | for original_param, replacement_param in zip(original_params, replacement_params): 64 | if sync_params_after: 65 | original_param.copy_(replacement_param, non_blocking=True) 66 | if sync_grads_after and original_param.grad is not None: 67 | original_param.grad.copy_(replacement_param.grad) 68 | 69 | def add_param_group(self, param_group: dict) -> None: 70 | raise NotImplementedError(f"{self.__class__.__name__} does not support add_param_group.") 71 | 72 | def step(self, closure=None, *args, **kwargs): 73 | assert closure is None, "closure not supported in cpu offload mode" 74 | with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=True, 75 | sync_params_after=True, sync_grads_after=self.full_sync): 76 | return self.optim.step(*args, **kwargs) 77 | 78 | def zero_grad(self, set_to_none: bool = False, *args, **kwargs): 79 | if not self.full_sync: 80 | torch.optim.Optimizer.zero_grad(self, set_to_none) 81 | with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync, 82 | sync_params_after=self.full_sync, sync_grads_after=self.full_sync): 83 | return super().zero_grad(*args, set_to_none=False, **kwargs) 84 | 85 | def state_dict(self): 86 | with self._use_offloaded_params(sync_params_before=self.full_sync, sync_grads_before=self.full_sync, 87 | sync_params_after=False, sync_grads_after=False): 88 | return self.optim.state_dict() 89 | 90 | def load_state_dict(self, state_dict: dict) -> None: 91 | with self._use_offloaded_params(sync_params_before=False, sync_grads_before=False, 92 | sync_params_after=True, sync_grads_after=self.full_sync): 93 | return self.optim.load_state_dict(state_dict) 94 | -------------------------------------------------------------------------------- /inference/run_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | import pickle 6 | from collections import OrderedDict 7 | from datetime import datetime 8 | from itertools import cycle, islice 9 | 10 | import clip 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from PIL import Image 16 | from einops import rearrange 17 | # Note: Use dalle_pytorch >= 1.4.2 for this script (newer than in the rest of the repo) 18 | from dalle_pytorch import DALLE 19 | from dalle_pytorch.vae import VQGanVAE 20 | from transformers import T5TokenizerFast 21 | from tqdm import tqdm 22 | 23 | torch.set_grad_enabled(False) 24 | 25 | 26 | class VQGanParams(VQGanVAE): 27 | def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True): 28 | nn.Module.__init__(self) 29 | 30 | self.num_layers = num_layers 31 | self.image_size = image_size 32 | self.num_tokens = num_tokens 33 | self.is_gumbel = is_gumbel 34 | 35 | 36 | class ModelWrapper(nn.Module): 37 | def __init__(self, model): 38 | super().__init__() 39 | self.model = model 40 | 41 | def forward(self, input_ids, attention_mask, image): 42 | loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True) 43 | return {'loss': loss} 44 | 45 | 46 | def make_model(): 47 | tokenizer = T5TokenizerFast.from_pretrained('t5-small') 48 | tokenizer.pad_token = tokenizer.eos_token 49 | 50 | depth = 64 51 | attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1)) 52 | attn_types.append('conv_like') 53 | shared_layer_ids = list(islice(cycle(range(4)), depth - 1)) 54 | shared_layer_ids.append('w_conv') 55 | 56 | dalle = DALLE( 57 | vae=VQGanParams(), 58 | num_text_tokens=tokenizer.vocab_size, 59 | text_seq_len=256, 60 | dim=1024, 61 | depth=depth, 62 | heads=16, 63 | dim_head=64, 64 | attn_types=attn_types, 65 | ff_dropout=0, 66 | attn_dropout=0, 67 | shared_attn_ids=shared_layer_ids, 68 | shared_ff_ids=shared_layer_ids, 69 | rotary_emb=True, 70 | reversible=True, 71 | share_input_output_emb=True, 72 | optimize_for_inference=True, 73 | ) 74 | model = ModelWrapper(dalle) 75 | 76 | return tokenizer, model 77 | 78 | 79 | def generate(query, *, tokenizer, model, 80 | batch_size, n_iters, temperature, top_k, top_p): 81 | input_ids = torch.tensor(tokenizer(query, add_special_tokens=False, max_length=256, truncation=True)['input_ids']) 82 | input_ids = F.pad(input_ids, (0, 256 - len(input_ids)), value=1) 83 | input_ids = input_ids.repeat(batch_size, 1) 84 | input_ids = input_ids.cuda() 85 | 86 | result = [] 87 | for _ in tqdm(range(n_iters), desc=query, leave=False): 88 | output = model.model.generate_images( 89 | input_ids, temperature=temperature, top_k=top_k, top_p=top_p, use_cache=True) 90 | output = rearrange(output, 'b c h w -> b h w c').cpu().numpy() 91 | result.extend(output) 92 | return result 93 | 94 | 95 | def main(): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--queries', type=str, help='List of queries (*.txt, newline-separated)') 98 | parser.add_argument('--temperature', type=float, help='Sampling temperature') 99 | parser.add_argument('--top-k', type=int, default=0) 100 | parser.add_argument('--top-p', type=float, default=1.0) 101 | parser.add_argument('--model', type=str, help='DALL-E checkpoint (*.pt)') 102 | parser.add_argument('--vqgan', type=str, help='VQGAN checkpoint (*.ckpt)') 103 | parser.add_argument('--vqgan-config', type=str, help='VQGAN config (*.yaml)') 104 | parser.add_argument('--output-dir', type=str, help='Output directory') 105 | args = parser.parse_args() 106 | 107 | with open(args.queries) as f: 108 | queries = [line.rstrip() for line in f] 109 | queries = [item for item in queries if len(item) > 0] 110 | print(f'[*] Loaded {len(queries)} queries') 111 | 112 | tokenizer, model = make_model() 113 | 114 | print(f'[*] Model modification time: {datetime.fromtimestamp(os.stat(args.model).st_mtime)}') 115 | state_dict = torch.load(args.model) 116 | # The model version optimized for inference requires some renaming in state_dict 117 | state_dict = OrderedDict([(key.replace('net.fn.fn', 'net.fn.fn.fn').replace('to_qkv', 'fn.to_qkv').replace('to_out', 'fn.to_out'), value) 118 | for key, value in state_dict.items()]) 119 | ok = model.load_state_dict(state_dict) 120 | print(f'[*] Loaded model: {ok}') 121 | 122 | gan = VQGanVAE(args.vqgan, args.vqgan_config).cuda() 123 | model.model.vae = gan 124 | model = model.cuda() 125 | 126 | clip_model, clip_preprocess = clip.load("ViT-B/32", device='cuda') 127 | 128 | os.makedirs(args.output_dir, exist_ok=True) 129 | print(f'[*] Saving results to `{args.output_dir}`') 130 | 131 | for query in tqdm(queries): 132 | images = generate(query, tokenizer=tokenizer, model=model, batch_size=16, n_iters=8, 133 | temperature=args.temperature, top_k=args.top_k, top_p=args.top_p) 134 | 135 | images_for_clip = torch.cat([clip_preprocess(Image.fromarray((img * 255).astype(np.uint8))).unsqueeze(0).cuda() for img in images]) 136 | text = clip.tokenize([query]).cuda() 137 | _, logits_per_text = clip_model(images_for_clip, text) 138 | clip_scores = logits_per_text[0].softmax(dim=-1).cpu().numpy() 139 | 140 | with open(os.path.join(args.output_dir, f'{query}.pickle'), 'wb') as f: 141 | outputs = {'query': query, 'temperature': args.temperature, 'images': images, 'clip_scores': clip_scores} 142 | pickle.dump(outputs, f) 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /callback.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Any 3 | 4 | import hivemind 5 | import torch 6 | import transformers 7 | from transformers import TrainingArguments 8 | 9 | from arguments import TrainingPeerArguments 10 | from task import TrainingTask 11 | from utils import LocalMetrics, logger 12 | 13 | 14 | class CollaborativeCallback(transformers.TrainerCallback): 15 | """ 16 | This callback monitors and reports collaborative training progress, 17 | In case of a catastrophic failure, it can also revert training to a backup 18 | """ 19 | 20 | def __init__(self, task: TrainingTask, args: TrainingPeerArguments): 21 | super().__init__() 22 | self.task = task 23 | self.dht, self.collaborative_optimizer = task.dht, task.collaborative_optimizer 24 | self.statistics_expiration = args.statistics_expiration 25 | self.last_reported_collaboration_step = -1 26 | self.samples = 0 27 | self.steps = 0 28 | self.loss = 0 29 | self.total_samples_processed = 0 30 | self.backup_every_steps = args.backup_every_steps 31 | self.state_path = args.state_path 32 | 33 | def on_train_begin( 34 | self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs 35 | ): 36 | if os.path.isfile(self.state_path): 37 | self.restore_from_backup(self.state_path) 38 | logger.info("Loaded state") 39 | 40 | logger.info("Loading state from peers") 41 | self.collaborative_optimizer.load_state_from_peers() 42 | 43 | if os.path.isfile(self.state_path): 44 | self.restore_from_backup(self.state_path, check_step=True) 45 | 46 | def on_step_end( 47 | self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs 48 | ): 49 | control.should_log = True 50 | if not self.params_are_finite(): 51 | if not os.path.exists(self.state_path): 52 | raise RuntimeError("Encountered broken parameters, but there is no backup to fall back to.") 53 | logger.warning("Parameters are invalid, reloading model from earlier state") 54 | self.restore_from_backup(self.state_path) 55 | return control 56 | 57 | if state.log_history: 58 | self.loss += state.log_history[-1]["loss"] 59 | self.steps += 1 60 | if self.collaborative_optimizer.local_epoch != self.last_reported_collaboration_step: 61 | self.last_reported_collaboration_step = self.collaborative_optimizer.local_epoch 62 | self.total_samples_processed += self.samples 63 | samples_per_second = self.collaborative_optimizer.tracker.performance_ema.samples_per_second 64 | statistics = LocalMetrics( 65 | step=self.collaborative_optimizer.local_epoch, 66 | samples_per_second=samples_per_second, 67 | samples_accumulated=self.samples, 68 | loss=self.loss, 69 | mini_steps=self.steps, 70 | ) 71 | logger.info(f"Current epoch: {self.collaborative_optimizer.local_epoch}") 72 | logger.info(f"Your current contribution: {self.total_samples_processed} samples") 73 | logger.info(f"Performance: {samples_per_second} samples/sec") 74 | if self.steps: 75 | logger.info(f"Local loss: {self.loss / self.steps}") 76 | 77 | self.loss = 0 78 | self.steps = 0 79 | if self.collaborative_optimizer.local_epoch == self.collaborative_optimizer.tracker.global_epoch: 80 | self.dht.store( 81 | key=self.collaborative_optimizer.run_id + "_metrics", 82 | subkey=self.task.local_public_key, 83 | value=statistics.dict(), 84 | expiration_time=hivemind.get_dht_time() + self.statistics_expiration, 85 | return_future=True, 86 | ) 87 | if self.backup_every_steps is not None and \ 88 | self.collaborative_optimizer.local_epoch % self.backup_every_steps == 0: 89 | self.backup_state() 90 | 91 | self.samples = self.collaborative_optimizer.grad_averager.local_samples_accumulated 92 | 93 | return control 94 | 95 | @torch.no_grad() 96 | def params_are_finite(self): 97 | for param in self.task.model.parameters(): 98 | if not torch.all(torch.isfinite(param)): 99 | return False 100 | return True 101 | 102 | @torch.no_grad() 103 | def backup_state(self) -> Any: 104 | logger.info("Saving backup") 105 | return torch.save( 106 | { 107 | "model": self.task.model.state_dict(), 108 | "training": self.collaborative_optimizer.state_dict(), 109 | "scheduler": self.collaborative_optimizer.state_averager.scheduler.state_dict(), 110 | "local_epoch": self.collaborative_optimizer.local_epoch, 111 | }, 112 | self.state_path, 113 | ) 114 | 115 | @torch.no_grad() 116 | def restore_from_backup(self, path, check_step=False): 117 | state = torch.load(path) 118 | current_step = self.collaborative_optimizer.local_epoch 119 | backup_step = state['local_epoch'] 120 | if not check_step or backup_step >= current_step: 121 | self.task.model.load_state_dict(state["model"], strict=False) 122 | self.collaborative_optimizer.load_state_dict(state["training"]) 123 | self.collaborative_optimizer.state_averager.scheduler.load_state_dict(state["scheduler"]) 124 | self.collaborative_optimizer.state_averager.local_epoch = backup_step 125 | logger.info("Restored from a backup") 126 | else: 127 | logger.info("Bypassed restoring state from local backup: backup state is too old.") 128 | -------------------------------------------------------------------------------- /run_aux_peer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | 4 | import torch 5 | import wandb 6 | import transformers 7 | from transformers import HfArgumentParser 8 | from huggingface_hub import HfFolder, Repository 9 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 10 | 11 | import utils 12 | from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments 13 | from task import TrainingTask 14 | 15 | 16 | transformers.utils.logging.set_verbosity_warning() 17 | use_hivemind_log_handler("in_root_logger") 18 | logger = get_logger(__name__) 19 | 20 | 21 | class CheckpointHandler: 22 | def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments): 23 | self.task, self.peer_args = task, peer_args 24 | self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval 25 | self.prefix = peer_args.experiment_prefix 26 | self.local_path = peer_args.local_path 27 | self.upload_interval = peer_args.upload_interval 28 | if self.upload_interval is not None: 29 | assert task.authorizer is not None, 'Model uploading needs Hugging Face auth to be enabled' 30 | self.repo = Repository( 31 | local_dir=self.local_path, 32 | clone_from=peer_args.repo_url, 33 | use_auth_token=task.authorizer.hf_user_access_token, 34 | ) 35 | self.last_upload_time = None 36 | self.previous_step = -1 37 | 38 | def should_save_state(self, cur_step): 39 | if self.save_checkpoint_step_interval is None: 40 | return False 41 | elif cur_step - self.previous_step >= self.save_checkpoint_step_interval: 42 | return True 43 | else: 44 | return False 45 | 46 | def save_state(self, cur_step): 47 | logger.info("Saving state from peers") 48 | self.task.collaborative_optimizer.load_state_from_peers() 49 | self.previous_step = cur_step 50 | 51 | def is_time_to_upload(self): 52 | if self.upload_interval is None: 53 | return False 54 | elif self.last_upload_time is None or time.time() - self.last_upload_time >= self.upload_interval: 55 | return True 56 | else: 57 | return False 58 | 59 | def upload_checkpoint(self, current_loss): 60 | self.last_upload_time = time.time() 61 | 62 | logger.info("Saving model") 63 | torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt") 64 | logger.info("Saving optimizer") 65 | torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt") 66 | logger.info("Started uploading to Model Hub") 67 | try: 68 | # We start by pulling the remote changes (for example a change in the readme file) 69 | self.repo.git_pull() 70 | 71 | # Then we add / commmit and push the changes 72 | self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}") 73 | logger.info("Finished uploading to Model Hub") 74 | except Exception: 75 | logger.exception("Uploading the checkpoint to HF Model Hub failed:") 76 | logger.warning("Ensure that your access token is valid and has WRITE permissions") 77 | 78 | 79 | def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments): 80 | while True: 81 | time.sleep(peer_args.assist_refresh) 82 | task.collaborative_optimizer.step() 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments)) 87 | peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses() 88 | 89 | task = TrainingTask(peer_args, trainer_args, collab_args) 90 | dht, collaborative_optimizer = task.dht, task.collaborative_optimizer 91 | 92 | if peer_args.wandb_project is not None: 93 | wandb.init(project=peer_args.wandb_project) 94 | 95 | current_step = 0 96 | if peer_args.store_checkpoints: 97 | checkpoint_handler = CheckpointHandler(task, peer_args) 98 | 99 | if peer_args.assist_in_averaging: 100 | # assert not peer_args.client_mode, "client-mode peers cannot assist in averaging" 101 | # averaging_thread = threading.Thread( 102 | # name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True) 103 | # averaging_thread.start() 104 | raise NotImplementedError('aux peers with hivemind.optim.experimental are not supported yet') 105 | 106 | while True: 107 | metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True) 108 | if metrics_entry is not None and len(metrics_entry.value) > 0: 109 | metrics_dict = metrics_entry.value 110 | metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict] 111 | latest_step = max(item.step for item in metrics) 112 | 113 | if latest_step != current_step: 114 | logger.debug(f"Got metrics from {len(metrics)} peers") 115 | 116 | for i, metrics_for_peer in enumerate(metrics): 117 | logger.debug(f"{i} peer {metrics_for_peer}") 118 | 119 | current_step = latest_step 120 | alive_peers = 0 121 | sum_loss = 0 122 | num_samples = 0 123 | sum_perf = 0 124 | sum_mini_steps = 0 125 | 126 | for item in metrics: 127 | sum_loss += item.loss 128 | alive_peers += 1 129 | sum_perf += item.samples_per_second 130 | num_samples += item.samples_accumulated 131 | sum_mini_steps += item.mini_steps 132 | current_loss = sum_loss / sum_mini_steps 133 | logger.info(f"Epoch #{current_step}\tloss = {current_loss:.5f}") 134 | 135 | if peer_args.wandb_project is not None: 136 | wandb.log( 137 | { 138 | "loss": current_loss, 139 | "alive peers": alive_peers, 140 | "samples": num_samples, 141 | "performance": sum_perf, 142 | "step": latest_step, 143 | } 144 | ) 145 | 146 | if peer_args.store_checkpoints: 147 | if checkpoint_handler.should_save_state(current_step): 148 | checkpoint_handler.save_state(current_step) 149 | if checkpoint_handler.is_time_to_upload(): 150 | checkpoint_handler.upload_checkpoint(current_loss) 151 | logger.debug("Peer is still alive...") 152 | time.sleep(peer_args.refresh_period) 153 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | import torch 5 | from transformers import TrainingArguments 6 | 7 | 8 | @dataclass 9 | class HFTrainerArguments(TrainingArguments): 10 | """Arguments for huggingface/transformers.Trainer""" 11 | dataloader_num_workers: int = 1 12 | per_device_train_batch_size: int = 2 13 | per_device_eval_batch_size: int = 2 14 | gradient_accumulation_steps: int = 1 15 | text_seq_length: int = 256 16 | 17 | # DALLE-specific params 18 | learning_rate: float = 0.0025 19 | adam_beta1: float = 0.9 20 | adam_beta2: float = 0.96 21 | max_grad_norm: float = 4.0 22 | weight_decay: float = 0.045 23 | 24 | total_steps: int = 31250 # total number of collaborative SGD updates, used for learning rate schedule 25 | warmup_steps: int = 3125 26 | adam_epsilon: float = 1e-6 27 | clamp_value: float = 10000.0 28 | 29 | fp16: bool = False 30 | do_train: bool = True 31 | 32 | logging_steps: int = 100 33 | max_steps: int = 10 ** 20 34 | save_steps: int = 10 ** 20 35 | save_total_limit: int = 2 36 | 37 | output_dir: str = "outputs" 38 | 39 | @property 40 | def batch_size_per_step(self): 41 | """Compute the number of training sequences contributed by each .step() from this peer""" 42 | total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps 43 | if torch.cuda.device_count() > 0: 44 | total_batch_size_per_step *= torch.cuda.device_count() 45 | return total_batch_size_per_step 46 | 47 | 48 | @dataclass 49 | class TPUTrainerArguments(HFTrainerArguments): 50 | num_tpus: int = 8 # the total number of TPU cores in use 51 | wandb_project: str = "huggingface" 52 | 53 | @property 54 | def batch_size_per_step(self): 55 | """Compute the number of training sequences contributed by each .step() from this peer""" 56 | return self.per_device_train_batch_size * self.gradient_accumulation_steps * self.num_tpus 57 | 58 | 59 | @dataclass 60 | class CollaborativeArguments: 61 | """Configuration for CollaborativeOptimzier and its internals""" 62 | target_batch_size: int = field( 63 | default=4096, 64 | metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}, 65 | ) 66 | matchmaking_time: float = field( 67 | default=15.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"} 68 | ) 69 | allreduce_timeout: float = field( 70 | default=60, metadata={"help": "Give up on a given all-reduce round after this many seconds"} 71 | ) 72 | averaging_timeout: float = field( 73 | default=180, metadata={"help": "Give up on averaging step after this many seconds"} 74 | ) 75 | reuse_grad_buffers: bool = field(default=True, metadata={ 76 | "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This " 77 | "optimization reduces GPU memory consumption but may result in incorrect gradients when using some " 78 | "advanced techniques (e.g. applying custom loss scaler)"}) 79 | 80 | 81 | @dataclass 82 | class BasePeerArguments: 83 | """Base arguments that are used for both trainers and for auxiliary peers such as training monitor""" 84 | experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"}) 85 | tokenizer_path: Optional[str] = field(default="t5-small", metadata={"help": "Path to the tokenizer"}) 86 | cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"}) 87 | 88 | authorize: bool = field(default=True, metadata={"help": "Whether or not to use HF authorizer"}) 89 | client_mode: bool = field( 90 | default=False, 91 | metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"}, 92 | ) 93 | initial_peers: List[str] = field( 94 | default_factory=list, 95 | metadata={ 96 | "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. " 97 | "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY" 98 | }, 99 | ) 100 | use_ipfs: bool = field( 101 | default=False, 102 | metadata={ 103 | "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs " 104 | "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)" 105 | }, 106 | ) 107 | host_maddrs: List[str] = field( 108 | default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"], 109 | metadata={ 110 | "help": "Multiaddrs to listen for external connections from other p2p instances. " 111 | "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0" 112 | }, 113 | ) 114 | announce_maddrs: List[str] = field( 115 | default_factory=list, 116 | metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"}, 117 | ) 118 | identity_path: Optional[str] = field( 119 | default=None, 120 | metadata={ 121 | "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. " 122 | "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``." 123 | }, 124 | ) 125 | 126 | 127 | @dataclass 128 | class TrainingPeerArguments(BasePeerArguments): 129 | statistics_expiration: float = field( 130 | default=600, metadata={"help": "Statistics will be removed if not updated in this many seconds"} 131 | ) 132 | backup_every_steps: Optional[int] = field( 133 | default=None, metadata={"help": "Update training state backup on disk once in this many global steps " 134 | "(default = do not update local state)"} 135 | ) 136 | state_path: str = field( 137 | default="state.zip", metadata={"help": "Load this state upon init and when recovering from NaN parameters"}) 138 | 139 | 140 | @dataclass 141 | class AuxiliaryPeerArguments(BasePeerArguments): 142 | """ 143 | Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking 144 | learning curves, assisting in all-reduce and uploading checkpoints to the hub 145 | """ 146 | refresh_period: float = field(default=10, metadata={"help": "Period (in seconds) for fetching the keys from DHT"}) 147 | wandb_project: Optional[str] = field( 148 | default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"} 149 | ) 150 | save_checkpoint_step_interval: int = field( 151 | default=2, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"} 152 | ) 153 | repo_url: Optional[str] = field( 154 | default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"} 155 | ) 156 | local_path: Optional[str] = field( 157 | default="Repo", metadata={"help": "Path to local repository to store the model and optimizer states"} 158 | ) 159 | upload_interval: Optional[float] = field( 160 | default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"} 161 | ) 162 | store_checkpoints: bool = field(default=True, metadata={"help": "If True, enables CheckpointHandler"}) 163 | assist_in_averaging: bool = field( 164 | default=False, metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"}) 165 | assist_refresh: float = field(default=1.0, metadata={"help": "Period (in seconds) for tryin to assist averaging"}) 166 | -------------------------------------------------------------------------------- /task.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | from itertools import cycle, islice 4 | from pathlib import Path 5 | 6 | import hivemind 7 | import torch 8 | import torch.nn as nn 9 | import transformers 10 | from dalle_pytorch import DALLE 11 | from dalle_pytorch.vae import VQGanVAE 12 | from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization 13 | from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup 14 | 15 | import utils 16 | from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments 17 | from data import make_dataset 18 | from huggingface_auth import authorize_with_huggingface 19 | from lib.training.lamb_8bit import CPULAMB8Bit 20 | 21 | 22 | logger = hivemind.get_logger(__name__) 23 | 24 | 25 | class VQGanParams(VQGanVAE): 26 | def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True): 27 | nn.Module.__init__(self) 28 | 29 | self.num_layers = num_layers 30 | self.image_size = image_size 31 | self.num_tokens = num_tokens 32 | self.is_gumbel = is_gumbel 33 | 34 | 35 | class ModelWrapper(nn.Module): 36 | def __init__(self, model): 37 | super().__init__() 38 | self.model = model 39 | 40 | def forward(self, input_ids, attention_mask, image): 41 | loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True) 42 | return {'loss': loss} 43 | 44 | 45 | class TrainingTask: 46 | """A container that defines the training config, model, tokenizer, optimizer and other local training utilities""" 47 | _authorizer = _dht = _collaborative_optimizer = _training_dataset = None 48 | 49 | 50 | def __init__( 51 | self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments): 52 | self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args 53 | self.trainer_args.run_name = self.authorizer.username # For wandb 54 | 55 | self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix) 56 | transformers.set_seed(trainer_args.seed) # seed used for initialization 57 | 58 | self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path) 59 | self.tokenizer.pad_token = self.tokenizer.eos_token 60 | 61 | logger.info(f"Creating model") 62 | depth = 64 63 | attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1)) 64 | attn_types.append('conv_like') 65 | shared_layer_ids = list(islice(cycle(range(4)), depth - 1)) 66 | shared_layer_ids.append('w_conv') 67 | dalle = DALLE( 68 | vae=VQGanParams(), 69 | num_text_tokens=self.tokenizer.vocab_size, 70 | text_seq_len=trainer_args.text_seq_length, 71 | dim=1024, 72 | depth=depth, 73 | heads=16, 74 | dim_head=64, 75 | attn_types=attn_types, 76 | ff_dropout=0, 77 | attn_dropout=0, 78 | shared_attn_ids=shared_layer_ids, 79 | shared_ff_ids=shared_layer_ids, 80 | rotary_emb=True, 81 | reversible=True, 82 | share_input_output_emb=True, 83 | ) 84 | logger.info(f"Trainable parameters: " 85 | f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}") 86 | self.model = ModelWrapper(dalle) 87 | 88 | output_dir = Path(trainer_args.output_dir) 89 | logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}') 90 | latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime) 91 | if latest_checkpoint_dir is not None: 92 | logger.info(f"Loading model from {latest_checkpoint_dir}") 93 | self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt")) 94 | 95 | @property 96 | def authorizer(self): 97 | if self._authorizer is None and self.peer_args.authorize: 98 | self._authorizer = authorize_with_huggingface() 99 | return self._authorizer 100 | 101 | @property 102 | def dht(self): 103 | if self._dht is None: 104 | self._dht = hivemind.DHT( 105 | start=True, 106 | initial_peers=self.peer_args.initial_peers, 107 | client_mode=self.peer_args.client_mode, 108 | host_maddrs=self.peer_args.host_maddrs, 109 | announce_maddrs=self.peer_args.announce_maddrs, 110 | use_ipfs=self.peer_args.use_ipfs, 111 | record_validators=self.validators, 112 | identity_path=self.peer_args.identity_path, 113 | authorizer=self.authorizer, 114 | ) 115 | if self.peer_args.client_mode: 116 | logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}") 117 | else: 118 | utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs) 119 | return self._dht 120 | 121 | @property 122 | def collaborative_optimizer(self): 123 | if self._collaborative_optimizer is None: 124 | params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args) 125 | averaging_compression = SizeAdaptiveCompression( 126 | threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization()) 127 | self._collaborative_optimizer = hivemind.Optimizer( 128 | dht=self.dht, run_id=self.peer_args.experiment_prefix, 129 | params=params, optimizer=opt, scheduler=scheduler, 130 | offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True, 131 | batch_size_per_step=self.trainer_args.batch_size_per_step, 132 | grad_compression=averaging_compression, state_averaging_compression=averaging_compression, 133 | client_mode=self.peer_args.client_mode, verbose=True, 134 | **asdict(self.collab_args)) 135 | return self._collaborative_optimizer 136 | 137 | def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments): 138 | no_decay = ["bias", "LayerNorm.weight"] 139 | params = [ 140 | { 141 | "params": [p for n, p in self.model.named_parameters() 142 | if not any(nd in n for nd in no_decay) and p.requires_grad], 143 | "weight_decay": training_args.weight_decay, 144 | }, 145 | { 146 | "params": [p for n, p in self.model.named_parameters() 147 | if any(nd in n for nd in no_decay) and p.requires_grad], 148 | "weight_decay": 0.0, 149 | }, 150 | ] 151 | 152 | opt = lambda params: CPULAMB8Bit( 153 | params, 154 | lr=training_args.learning_rate, 155 | betas=(training_args.adam_beta1, training_args.adam_beta2), 156 | eps=training_args.adam_epsilon, 157 | weight_decay=training_args.weight_decay, 158 | max_grad_norm=training_args.max_grad_norm, 159 | clamp_value=training_args.clamp_value, 160 | reuse_grad_buffers=True, 161 | ) 162 | 163 | scheduler = lambda opt: get_linear_schedule_with_warmup( 164 | opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps 165 | ) 166 | 167 | return params, opt, scheduler 168 | 169 | @property 170 | def training_dataset(self): 171 | if self._training_dataset is None: 172 | self._training_dataset = make_dataset( 173 | self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31, 174 | max_sequence_length=self.trainer_args.text_seq_length 175 | ) 176 | return self._training_dataset 177 | 178 | @property 179 | def data_collator(self): 180 | return DataCollatorWithPadding(tokenizer=self.tokenizer, 181 | padding='max_length', max_length=self.trainer_args.text_seq_length) 182 | -------------------------------------------------------------------------------- /huggingface_auth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from datetime import datetime, timedelta 4 | from getpass import getpass 5 | 6 | import requests 7 | from huggingface_hub import HfApi 8 | from termcolor import colored 9 | 10 | from hivemind.proto.auth_pb2 import AccessToken 11 | from hivemind.utils.auth import TokenAuthorizerBase 12 | from hivemind.utils.crypto import RSAPublicKey 13 | from hivemind.utils.logging import get_logger 14 | 15 | 16 | logger = get_logger("root." + __name__) 17 | 18 | 19 | class NonRetriableError(Exception): 20 | pass 21 | 22 | 23 | def call_with_retries(func, n_retries=10, initial_delay=1.0): 24 | for i in range(n_retries): 25 | try: 26 | return func() 27 | except NonRetriableError: 28 | raise 29 | except Exception as e: 30 | if i == n_retries - 1: 31 | raise 32 | 33 | delay = initial_delay * (2 ** i) 34 | logger.warning(f'Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec') 35 | time.sleep(delay) 36 | 37 | 38 | class InvalidCredentialsError(NonRetriableError): 39 | pass 40 | 41 | 42 | class NotInAllowlistError(NonRetriableError): 43 | pass 44 | 45 | 46 | class HuggingFaceAuthorizer(TokenAuthorizerBase): 47 | _AUTH_SERVER_URL = 'https://collaborative-training-auth.huggingface.co' 48 | 49 | def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str): 50 | super().__init__() 51 | 52 | self.organization_name = organization_name 53 | self.model_name = model_name 54 | self.hf_user_access_token = hf_user_access_token 55 | 56 | self._authority_public_key = None 57 | self.coordinator_ip = None 58 | self.coordinator_port = None 59 | 60 | self._hf_api = HfApi() 61 | 62 | async def get_token(self) -> AccessToken: 63 | """ 64 | Hivemind calls this method to refresh the token when necessary. 65 | """ 66 | 67 | self.join_experiment() 68 | return self._local_access_token 69 | 70 | @property 71 | def username(self): 72 | return self._local_access_token.username 73 | 74 | def join_experiment(self) -> None: 75 | call_with_retries(self._join_experiment) 76 | 77 | def _join_experiment(self) -> None: 78 | try: 79 | url = f'{self._AUTH_SERVER_URL}/api/experiments/join' 80 | headers = {'Authorization': f'Bearer {self.hf_user_access_token}'} 81 | response = requests.put( 82 | url, 83 | headers=headers, 84 | params={ 85 | 'organization_name': self.organization_name, 86 | 'model_name': self.model_name, 87 | }, 88 | json={ 89 | 'experiment_join_input': { 90 | 'peer_public_key': self.local_public_key.to_bytes().decode(), 91 | }, 92 | }, 93 | ) 94 | 95 | response.raise_for_status() 96 | response = response.json() 97 | 98 | self._authority_public_key = RSAPublicKey.from_bytes(response['auth_server_public_key'].encode()) 99 | self.coordinator_ip = response['coordinator_ip'] 100 | self.coordinator_port = response['coordinator_port'] 101 | 102 | token_dict = response['hivemind_access'] 103 | access_token = AccessToken() 104 | access_token.username = token_dict['username'] 105 | access_token.public_key = token_dict['peer_public_key'].encode() 106 | access_token.expiration_time = str(datetime.fromisoformat(token_dict['expiration_time'])) 107 | access_token.signature = token_dict['signature'].encode() 108 | self._local_access_token = access_token 109 | logger.info(f'Access for user {access_token.username} ' 110 | f'has been granted until {access_token.expiration_time} UTC') 111 | except requests.exceptions.HTTPError as e: 112 | if e.response.status_code == 401: # Unauthorized 113 | raise NotInAllowlistError() 114 | raise 115 | 116 | def is_token_valid(self, access_token: AccessToken) -> bool: 117 | data = self._token_to_bytes(access_token) 118 | if not self._authority_public_key.verify(data, access_token.signature): 119 | logger.exception('Access token has invalid signature') 120 | return False 121 | 122 | try: 123 | expiration_time = datetime.fromisoformat(access_token.expiration_time) 124 | except ValueError: 125 | logger.exception( 126 | f'datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}') 127 | return False 128 | if expiration_time.tzinfo is not None: 129 | logger.exception(f'Expected to have no timezone for expiration time: {access_token.expiration_time}') 130 | return False 131 | if expiration_time < datetime.utcnow(): 132 | logger.exception('Access token has expired') 133 | return False 134 | 135 | return True 136 | 137 | _MAX_LATENCY = timedelta(minutes=1) 138 | 139 | def does_token_need_refreshing(self, access_token: AccessToken) -> bool: 140 | expiration_time = datetime.fromisoformat(access_token.expiration_time) 141 | return expiration_time < datetime.utcnow() + self._MAX_LATENCY 142 | 143 | @staticmethod 144 | def _token_to_bytes(access_token: AccessToken) -> bytes: 145 | return f'{access_token.username} {access_token.public_key} {access_token.expiration_time}'.encode() 146 | 147 | 148 | def authorize_with_huggingface() -> HuggingFaceAuthorizer: 149 | while True: 150 | organization_name = os.getenv('HF_ORGANIZATION_NAME') 151 | if organization_name is None: 152 | organization_name = input('HuggingFace organization name: ') 153 | 154 | model_name = os.getenv('HF_MODEL_NAME') 155 | if model_name is None: 156 | model_name = input('HuggingFace model name: ') 157 | 158 | hf_user_access_token = os.getenv('HF_USER_ACCESS_TOKEN') 159 | if hf_user_access_token is None: 160 | print( 161 | "\nCopy a token from 🤗 Hugging Face settings page at " 162 | f"{colored('https://huggingface.co/settings/token', attrs=['bold'])} " 163 | "and paste it here.\n\n" 164 | f"💡 {colored('Tip:', attrs=['bold'])} " 165 | "If you don't already have one, you can create a dedicated user access token.\n" 166 | f"Go to {colored('https://huggingface.co/settings/token', attrs=['bold'])}, " 167 | f"click the {colored('New token', attrs=['bold'])} button, " 168 | f"and choose the {colored('read', attrs=['bold'])} role.\n" 169 | ) 170 | hf_user_access_token = getpass('🤗 Hugging Face user access token (characters will be hidden): ') 171 | 172 | authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token) 173 | 174 | try: 175 | authorizer.join_experiment() 176 | print(f"🚀 You will contribute to the collaborative training under the username {authorizer.username}") 177 | return authorizer 178 | except InvalidCredentialsError: 179 | print('Invalid user access token, please try again') 180 | except NotInAllowlistError: 181 | print( 182 | '\n😥 Authentication has failed.\n\n' 183 | 'This error may be due to the fact:\n' 184 | " 1. Your user access token is not valid. You can try to delete the previous token and " 185 | "recreate one. Be careful, organization tokens can't be used to join a collaborative " 186 | "training.\n" 187 | f" 2. You have not yet joined the {organization_name} organization. You can request to " 188 | "join this organization by clicking on the 'request to join this org' button at " 189 | f"https://huggingface.co/{organization_name}.\n" 190 | f" 3. The {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n" 191 | f" 4. No {organization_name}'s admin has created a collaborative training for the {organization_name} " 192 | f"organization and the {model_name} model." 193 | ) 194 | -------------------------------------------------------------------------------- /manage_scaleset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from base64 import b64encode 4 | 5 | from azure.identity import DefaultAzureCredential 6 | from azure.mgmt.compute import ComputeManagementClient 7 | from azure.mgmt.network import NetworkManagementClient 8 | from azure.mgmt.resource import ResourceManagementClient 9 | 10 | 11 | print("=======================WARNING=======================") 12 | print("= The code may fail to import 'gi' but that is okay =") 13 | print("===================END OF WARNING====================") 14 | SUBSCRIPTION_ID = os.environ["SUBSCRIPTION_ID"] 15 | GROUP_NAME = "dalle_west2" 16 | NETWORK_NAME = "vnet" 17 | SUBNET_NAME = "subnet" 18 | LOCATION = "westus2" 19 | ADMIN_PASS = os.environ['AZURE_PASS'] 20 | 21 | SCALE_SETS = ('worker',) 22 | SWARM_SIZE = 4 23 | 24 | WORKER_CLOUD_INIT = """#cloud-config 25 | package_update: true 26 | packages: 27 | - build-essential 28 | - wget 29 | - git 30 | - vim 31 | write_files: 32 | - path: /home/hivemind/init_worker.sh 33 | permissions: '0766' 34 | owner: root:root 35 | content: | 36 | #!/usr/bin/env bash 37 | set -e 38 | wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O install_miniconda.sh 39 | bash install_miniconda.sh -b -p /opt/conda 40 | export PATH="/opt/conda/bin:${PATH}" 41 | conda install python~=3.8.0 pip 42 | conda install pytorch cudatoolkit=11.1 -c pytorch -c nvidia 43 | conda clean --all 44 | pip install https://github.com/learning-at-home/hivemind/archive/scaling_tweaks.zip 45 | systemctl enable testserv 46 | systemctl start testserv 47 | - path: /etc/systemd/system/testserv.service 48 | permissions: '0777' 49 | owner: root:root 50 | content: | 51 | [Unit] 52 | Description=One Shot 53 | 54 | [Service] 55 | ExecStart=/etc/createfile 56 | Type=oneshot 57 | RemainAfterExit=yes 58 | 59 | [Install] 60 | WantedBy=multi-user.target 61 | - path: /etc/createfile 62 | permissions: '0777' 63 | owner: root:root 64 | content: | 65 | #!/bin/bash 66 | export PATH="/opt/conda/bin:${PATH}" 67 | cd /home/hivemind 68 | ulimit -n 8192 69 | 70 | git clone https://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@github.com/learning-at-home/dalle-hivemind.git -b azure 71 | cd dalle-hivemind 72 | pip install -r requirements.txt 73 | pip install -U transformers==4.10.2 datasets==1.11.0 74 | 75 | WANDB_API_KEY=7cc938e45e63ef7d2f88f811be240ba0395c02dd python run_trainer.py --run_name $(hostname) \ 76 | --experiment_prefix dalle_large_5groups \ 77 | --initial_peers /ip4/52.232.13.142/tcp/31334/p2p/QmZLrSPKAcP4puJ8gUGvQ155thk5Q6J7oE5exMUSq1oD5i \ 78 | --per_device_train_batch_size 1 --gradient_accumulation_steps 1 79 | runcmd: 80 | - bash /home/hivemind/init_worker.sh 81 | """ 82 | 83 | 84 | def main(): 85 | parser = ArgumentParser() 86 | parser.add_argument('command', choices=('create', 'delete')) 87 | args = parser.parse_args() 88 | 89 | resource_client = ResourceManagementClient( 90 | credential=DefaultAzureCredential(), 91 | subscription_id=SUBSCRIPTION_ID 92 | ) 93 | network_client = NetworkManagementClient( 94 | credential=DefaultAzureCredential(), 95 | subscription_id=SUBSCRIPTION_ID 96 | ) 97 | compute_client = ComputeManagementClient( 98 | credential=DefaultAzureCredential(), 99 | subscription_id=SUBSCRIPTION_ID 100 | ) 101 | 102 | # Create resource group 103 | resource_client.resource_groups.create_or_update( 104 | GROUP_NAME, 105 | {"location": LOCATION} 106 | ) 107 | 108 | # Create virtual network 109 | network_client.virtual_networks.begin_create_or_update( 110 | GROUP_NAME, 111 | NETWORK_NAME, 112 | { 113 | 'location': LOCATION, 114 | 'address_space': { 115 | 'address_prefixes': ['10.0.0.0/16'] 116 | } 117 | } 118 | ).result() 119 | 120 | subnet = network_client.subnets.begin_create_or_update( 121 | GROUP_NAME, 122 | NETWORK_NAME, 123 | SUBNET_NAME, 124 | {'address_prefix': '10.0.0.0/16'} 125 | ).result() 126 | 127 | if args.command == 'create': 128 | 129 | scalesets = [] 130 | 131 | for scaleset_name in SCALE_SETS: 132 | cloud_init_cmd = WORKER_CLOUD_INIT 133 | vm_image = { 134 | "exactVersion": "21.06.0", 135 | "offer": "ngc_base_image_version_b", 136 | "publisher": "nvidia", 137 | "sku": "gen2_21-06-0", 138 | "version": "latest", 139 | } 140 | 141 | vm_config = { 142 | "sku": { 143 | "tier": "Standard", 144 | "capacity": SWARM_SIZE, 145 | "name": "Standard_NC4as_T4_v3" 146 | }, 147 | "plan": { 148 | "name": "gen2_21-06-0", 149 | "publisher": "nvidia", 150 | "product": "ngc_base_image_version_b" 151 | }, 152 | "location": LOCATION, 153 | "virtual_machine_profile": { 154 | "storage_profile": { 155 | "image_reference": vm_image, 156 | "os_disk": { 157 | "caching": "ReadWrite", 158 | "managed_disk": {"storage_account_type": "Standard_LRS"}, 159 | "create_option": "FromImage", 160 | "disk_size_gb": "32", 161 | }, 162 | }, 163 | "os_profile": { 164 | "computer_name_prefix": scaleset_name, 165 | "admin_username": "hivemind", 166 | "admin_password": ADMIN_PASS, 167 | "linux_configuration": { 168 | "disable_password_authentication": True, 169 | "ssh": { 170 | "public_keys": [ 171 | { 172 | "key_data": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDPFugAsrqEsqxj+hKDTfgrtkY26jqCjRubT5vhnJLhtkDAqe5vJ1donWfUVhtBfnqGr92LPmJewPUd9hRa1i33FLVVdkFAs5/Cg8/YbzR8B8e1Y+Nl5HeT7Dq1i+cPEbA1EZAm9tqK4VWYeCMd3CDkoJVuweTwyja08mxtnVNwKCeY4oBKQCE5QlliAKaQnGpJE6MRnbudWM9Ly1wM6OaJVdGwsfPfEG/sSDip4q/8x/KGAzKbhE6ax15Yu/Bu12ahcIdScQsYK9Y6Sm57MHQQLWQO1G+3h3oCTXQ0BGaSMWKXsjmHsB7f9kLZ1j8yMoGlgbpWbjB0ZVsK/4Zh8Ho3h9gDXADzt1j69qT1aERWCt7fxp9+WOLsCTw1W/W9FY2Ia4niVh2/wEwT9AcOBcAqBl7kXQAoUpP8b2Xb+KNXyTEtVB562EdFn+LmG1gZAy8J3piy2/zoo16QJP5PjpKW5GFxL6BRYLtG+uxgx1Glya617T0dtJF/X2vxjT45QK3FaFH1Zd+vhpcLg94fOPNPEhNU7EeBVp8CGYNd+aXVIPsb0I7EIVu9wWi3/a7y86cUedal61fEigfmAQkC7AHYiAiiT94eARj0N+KgjEy2UOITSCJJTHuamYWO8jZc/n7yAqr6mxOKn5ZjBTfAR9bNB/D+HpL6yepI1UDGBVk4DQ== justHeuristic@gmail.com\n", 173 | "path": "/home/hivemind/.ssh/authorized_keys" 174 | } 175 | ] 176 | } 177 | }, 178 | "custom_data": b64encode(cloud_init_cmd.encode('utf-8')).decode('latin-1'), 179 | }, 180 | "network_profile": { 181 | "network_interface_configurations": [ 182 | { 183 | "name": "test", 184 | "primary": True, 185 | "enable_accelerated_networking": True, 186 | "ip_configurations": [ 187 | { 188 | "name": "test", 189 | "subnet": { 190 | "id": f"/subscriptions/{SUBSCRIPTION_ID}/resourceGroups/{GROUP_NAME}/providers/Microsoft.Network/virtualNetworks/{NETWORK_NAME}/subnets/{SUBNET_NAME}" 191 | }, 192 | "public_ip_address_configuration": { 193 | "name": "pub1", 194 | "idle_timeout_in_minutes": 15 195 | } 196 | 197 | } 198 | ] 199 | } 200 | ] 201 | }, 202 | "diagnostics_profile": {"boot_diagnostics": {"enabled": True}}, 203 | "priority": "spot", 204 | "eviction_policy": "deallocate", 205 | }, 206 | "upgrade_policy": { 207 | "mode": "Manual" 208 | }, 209 | "upgrade_mode": "Manual", 210 | "spot_restore_policy": {"enabled": True} 211 | } 212 | 213 | # Create virtual machine scale set 214 | vmss = compute_client.virtual_machine_scale_sets.begin_create_or_update( 215 | GROUP_NAME, 216 | scaleset_name, 217 | vm_config, 218 | ) 219 | print(f"{scaleset_name} {vmss.status()}") 220 | scalesets.append(vmss) 221 | 222 | for scaleset_name, vmss in zip(SCALE_SETS, scalesets): 223 | print(f"Created scale set {scaleset_name}:\n{vmss.result()}") 224 | 225 | else: 226 | delete_results = [] 227 | for scaleset_name in SCALE_SETS: 228 | delete_results.append(compute_client.virtual_machine_scale_sets.begin_delete(GROUP_NAME, scaleset_name)) 229 | 230 | for scaleset_name, delete_result in zip(SCALE_SETS, delete_results): 231 | delete_result.result() 232 | print(f"Deleted scale set {scaleset_name}") 233 | 234 | 235 | if __name__ == "__main__": 236 | main() 237 | -------------------------------------------------------------------------------- /lib/training/tpu.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import threading 3 | from functools import partial 4 | from contextlib import nullcontext 5 | from copy import deepcopy 6 | import multiprocessing as mp 7 | from itertools import zip_longest 8 | from typing import Iterable 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.data 13 | import torch_xla.core.xla_model as xm 14 | import torch_xla.distributed.xla_multiprocessing as xmp 15 | import torch_xla.distributed.parallel_loader as pl 16 | 17 | from hivemind.utils.logging import get_logger 18 | 19 | 20 | logger = get_logger(__name__) 21 | 22 | 23 | class TPUManager(mp.Process): 24 | """Auxiliary class that manages model training over an array of TPU cores""" 25 | 26 | def __init__(self, 27 | model, 28 | dataset, 29 | *, 30 | collate_fn: callable = None, 31 | nprocs: int = 8, 32 | prefetch: int = 16, 33 | batch_size_per_device: int = 1, 34 | grad_accumulation_steps: int = 1, 35 | seed_base: int = 42, 36 | start: bool): 37 | super().__init__() 38 | self.lock = mp.Lock() 39 | self.nprocs, self.prefetch, self.seed_base = nprocs, prefetch, seed_base 40 | self.batch_size_per_device, self.grad_accumulation_steps = batch_size_per_device, grad_accumulation_steps 41 | self.collate_fn = collate_fn 42 | self.step_triggered, self.step_finished = mp.Event(), mp.Event() 43 | self._synchronizer = TPUSynchronizer(model) 44 | self._data_manager = TPUDataManager(dataset, nprocs, prefetch) 45 | 46 | # shared fields for communicating statistics after each step 47 | self.should_load_parameters = mp.Value(ctypes.c_bool, False) 48 | self.gradients_accumulated = mp.Value(ctypes.c_long, 0) 49 | self.loss_accumulated = mp.Value(ctypes.c_double, 0) 50 | if start: 51 | self.start() 52 | 53 | def run(self): 54 | thread = threading.Thread( 55 | target=partial(xmp.spawn, self.runner, nprocs=self.nprocs, start_method='fork'), 56 | daemon=True) 57 | thread.start() 58 | thread.join() 59 | 60 | def update_model_parameters(self, new_host_parameters): 61 | """Schedule TPUs to update model parameters during at the beginning of the next step""" 62 | with self.lock, torch.no_grad(): 63 | self._synchronizer.set_host_parameters(new_host_parameters) 64 | self.should_load_parameters.value = True 65 | 66 | def get_aggregated_gradients(self): 67 | """Get current accumulated gradients from the master model""" 68 | with self.lock, torch.no_grad(): 69 | return self._synchronizer.get_aggregated_gradients() 70 | 71 | def zero_grad(self): 72 | """Reset master accumulated gradients to zeros""" 73 | with self.lock, torch.no_grad(): 74 | for param in self._synchronizer.master_model.parameters(): 75 | param.grad.zero_() 76 | 77 | def step(self): 78 | """run forward/backward step with all TPUs, collect gradients""" 79 | self.loss_accumulated.value = self.gradients_accumulated.value = 0 80 | self.step_finished.clear() 81 | self.step_triggered.set() 82 | self.step_finished.wait() 83 | return self.loss_accumulated.value, self.gradients_accumulated.value 84 | 85 | def runner(self, tpu_index): 86 | """Run training steps from the perspective of a single TPU core""" 87 | # acquire the (unique) Cloud TPU core corresponding to this process's index 88 | device = xm.xla_device() 89 | logger.info(f"Process {tpu_index} is using {xm.xla_real_devices([str(device)])[0]}") 90 | 91 | # set random seed for 92 | torch.manual_seed(self.seed_base + tpu_index) 93 | 94 | # use staged init to minimize peak RAM usage 95 | for init_index in range(xm.xrt_world_size()): 96 | xm.rendezvous(f'init_{init_index}') 97 | if tpu_index == init_index: 98 | model = self._synchronizer.get_device_model_replica(device) 99 | data_loader = self._data_manager.get_device_dataloader( 100 | batch_size=self.batch_size_per_device, num_workers=0, collate_fn=self.collate_fn, pin_memory=False) 101 | data_loader_iter = iter(data_loader) 102 | logger.info(f"Process {tpu_index} initialized.") 103 | 104 | xm.rendezvous('init_finished') 105 | 106 | while True: 107 | self.step_triggered.wait() 108 | xm.rendezvous('before_step') 109 | if xm.is_master_ordinal(): 110 | self.step_triggered.clear() 111 | 112 | if bool(self.should_load_parameters.value): 113 | with self.lock if xm.is_master_ordinal() else nullcontext(): 114 | self._synchronizer.send_params_to_device(model) 115 | self.should_load_parameters.value = False 116 | 117 | ### compute loss and gradients 118 | loss = 0.0 119 | for i in range(self.grad_accumulation_steps): 120 | inputs = next(data_loader_iter) 121 | outputs = model(**inputs) 122 | loss_i = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 123 | loss_i = loss_i / (self.grad_accumulation_steps * self.nprocs) 124 | loss_i.backward() 125 | loss += loss_i 126 | del inputs, outputs, loss_i 127 | 128 | ### aggregate gradients from TPUs 129 | with self.lock if xm.is_master_ordinal() else nullcontext(): 130 | self._synchronizer.aggregate_grads_on_host(model, add=True) 131 | # clear aggregated gradients from all devices 132 | model.zero_grad() 133 | 134 | ### accumulate statistics to host 135 | loss = xm.all_reduce(xm.REDUCE_SUM, loss, scale=1.0) 136 | xm.do_on_ordinals(self._mark_step_finished, data=(loss,), ordinals=(0,)) 137 | 138 | def _mark_step_finished(self, loss): 139 | self.gradients_accumulated.value = self.batch_size_per_device * self.nprocs * self.grad_accumulation_steps 140 | self.loss_accumulated.value = float(loss) 141 | self.step_finished.set() 142 | 143 | 144 | class TPUSynchronizer: 145 | """An auxiliary class for manipulating parameters and gradients without producing a ton of XLA graphs""" 146 | 147 | def __init__(self, model: nn.Module): 148 | self.master_model = model.share_memory() 149 | for param in self.master_model.parameters(): 150 | if param.grad is None: 151 | param.grad = torch.zeros_like(param) 152 | param.grad = param.grad.share_memory_() 153 | 154 | def get_device_model_replica(self, device: torch.device, tie_weights: bool = True): 155 | replica = deepcopy(self.master_model).to(device) 156 | if tie_weights: 157 | replica.tie_weights() 158 | for param in replica.parameters(): 159 | param.grad = torch.zeros_like(param, device=device) 160 | return replica 161 | 162 | def set_host_parameters(self, new_host_parameters): 163 | return self._assign(source=self.master_model.parameters(), target=new_host_parameters, add=False, strict=True) 164 | 165 | def get_aggregated_gradients(self): 166 | return [param.grad for param in self.master_model.parameters()] 167 | 168 | def send_params_to_device(self, replica: nn.Module): 169 | """Copy params from master_model to this device_model replica""" 170 | with torch.no_grad(): 171 | replica_params = list(replica.parameters()) 172 | master_params = list(self.master_model.parameters()) 173 | master_params = xm.send_cpu_data_to_device(master_params, xm.xla_device()) 174 | self._assign(source=master_params, target=replica_params, add=False) 175 | xm.rendezvous("params_replicated") 176 | 177 | def aggregate_grads_on_host(self, replica: nn.Module, *, add: bool): 178 | """Aggregate grads from all tpu devices and move them to host""" 179 | with torch.no_grad(): 180 | replica_grads = [param.grad for param in replica.parameters()] 181 | replica_grads = xm.all_reduce(xm.REDUCE_SUM, replica_grads, scale=1.0) 182 | master_grads = [hp.grad for hp in self.master_model.parameters()] 183 | xm.do_on_ordinals(lambda *replica_grads: self._assign(source=replica_grads, target=master_grads, add=add), 184 | data=tuple(replica_grads), ordinals=(0,)) 185 | # ^-- do_on_ordinals already runs rendezvous at the end 186 | 187 | def _assign(self, source: Iterable[torch.Tensor], target: Iterable[torch.Tensor], add: bool, strict: bool = False): 188 | for source_tensor, target_tensor in zip_longest(source, target): 189 | assert source_tensor is not None or target_tensor is not None, "Source and target length must match exactly" 190 | if strict: 191 | assert source_tensor.shape == target_tensor.shape 192 | assert source_tensor.device == target_tensor.device 193 | assert source_tensor.dtype == target_tensor.dtype 194 | if add: 195 | target_tensor.add_(source_tensor) 196 | else: 197 | target_tensor.copy_(source_tensor) 198 | 199 | 200 | class TPUDataManager: 201 | """An auxiliary class that loads centralized dataset from master into multiple TPU devices""" 202 | def __init__(self, dataset: torch.utils.data.Dataset, nprocs: int, master_prefetch: int = 16): 203 | self.dataset, self.nprocs = dataset, nprocs 204 | self.device_queues = [mp.Queue(master_prefetch) for _ in range(nprocs)] 205 | self._loader_thread = threading.Thread(target=self._load_data_into_queues) 206 | self._loader_thread.start() 207 | 208 | def _load_data_into_queues(self): 209 | try: 210 | for i, batch in enumerate(self.dataset): 211 | self.device_queues[i % self.nprocs].put(batch) 212 | finally: 213 | logger.warning("Minibatch generator finished.") 214 | 215 | def get_device_dataloader(self, **kwargs): 216 | data_loader = torch.utils.data.DataLoader(QueueDataset(self.device_queues[xm.get_ordinal()]), **kwargs) 217 | return pl.ParallelLoader(data_loader, [xm.xla_device()]).per_device_loader(xm.xla_device()) 218 | 219 | 220 | class QueueDataset(torch.utils.data.IterableDataset): 221 | """A dataset that ceaselessly iterates over a queue""" 222 | def __init__(self, queue: mp.Queue): 223 | super().__init__() 224 | self.queue = queue 225 | 226 | def __iter__(self): 227 | while True: 228 | yield self.queue.get() 229 | 230 | def __len__(self): 231 | return 10 ** 12 # TODO deprecate this when the issue is resolved: https://github.com/googlecolab/colabtools/issues/2237 232 | -------------------------------------------------------------------------------- /lib/training/lamb_8bit.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Any, Optional 3 | 4 | import torch 5 | 6 | from torch_optimizer.types import Betas2, Params 7 | from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise 8 | from bitsandbytes.optim.optimizer import Optimizer2State 9 | 10 | __all__ = ('CPULAMB8Bit',) 11 | 12 | 13 | class CPULAMB8Bit(Optimizer2State): 14 | r""" 15 | Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form. 16 | The LAMB optimizer and block-wise quantization are described in the following papers: 17 | - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962 18 | - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861 19 | This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb 20 | - bias correction defaults to False because paper v3 does not use debiasing 21 | - it has baked in clipping by global max_grad_norm 22 | Arguments: 23 | params: iterable of parameters to optimize or dicts defining 24 | parameter groups 25 | lr: learning rate (default: 1e-3) 26 | betas: coefficients used for computing 27 | running averages of gradient and its square (default: (0.9, 0.999)) 28 | eps: term added to the denominator to improve 29 | numerical stability (default: 1e-8) 30 | weight_decay: weight decay (L2 penalty) (default: 0) 31 | clamp_value: clamp weight_norm in (0,clamp_value) (default: 10) 32 | set to a high value to avoid it (e.g 10e3) 33 | bias_correction: debias statistics by (1 - beta**step) (default: False) 34 | min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized 35 | reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory. 36 | If enabled, one must ensure that .zero_grad() is called after each optimizer step. 37 | update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | params: Params, 43 | lr: float = 1e-3, 44 | betas: Betas2 = (0.9, 0.999), 45 | eps: float = 1e-6, 46 | weight_decay: float = 0, 47 | clamp_value: float = 10, 48 | bias_correction: bool = False, 49 | min_8bit_size: int = 65536, 50 | reuse_grad_buffers: bool = False, 51 | update_chunk_size: int = 2 ** 24, 52 | max_grad_norm: Optional[float] = None, 53 | ) -> None: 54 | if lr <= 0.0: 55 | raise ValueError('Invalid learning rate: {}'.format(lr)) 56 | if eps < 0.0: 57 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 58 | if not 0.0 <= betas[0] < 1.0: 59 | raise ValueError( 60 | 'Invalid beta parameter at index 0: {}'.format(betas[0]) 61 | ) 62 | if not 0.0 <= betas[1] < 1.0: 63 | raise ValueError( 64 | 'Invalid beta parameter at index 1: {}'.format(betas[1]) 65 | ) 66 | if weight_decay < 0: 67 | raise ValueError( 68 | 'Invalid weight_decay value: {}'.format(weight_decay) 69 | ) 70 | if clamp_value < 0.0: 71 | raise ValueError('Invalid clamp value: {}'.format(clamp_value)) 72 | 73 | self.clamp_value = clamp_value 74 | self.bias_correction = bias_correction 75 | self.reuse_grad_buffers = reuse_grad_buffers 76 | self.update_chunk_size = update_chunk_size 77 | self.max_grad_norm = max_grad_norm 78 | 79 | super(CPULAMB8Bit, self).__init__( 80 | 'cpu-lamb', params, lr, betas, eps, weight_decay, optim_bits=8, min_8bit_size=min_8bit_size, args=None, 81 | percentile_clipping=100, block_wise=4096, max_unorm=0) 82 | 83 | @torch.no_grad() 84 | def step(self, closure=None): 85 | if self.max_grad_norm is not None: 86 | iter_params = (param for group in self.param_groups for param in group['params']) 87 | torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm) 88 | return super().step(closure=closure) 89 | 90 | @torch.no_grad() 91 | def init_state(self, group, p, gindex, pindex): 92 | config = self.get_config(gindex, pindex, group) 93 | assert config['percentile_clipping'] == 100, "percentile clipping is not implemented on CPU" 94 | assert config['max_unorm'] == 0 95 | 96 | if config['optim_bits'] == 32: 97 | dtype = torch.float32 98 | elif config['optim_bits'] == 8: 99 | dtype = torch.uint8 100 | else: 101 | raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') 102 | 103 | if p.numel() < config['min_8bit_size']: dtype = torch.float32 104 | 105 | state = self.state[p] 106 | state['step'] = 0 107 | 108 | if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): 109 | state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, 110 | device=p.device) 111 | state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, 112 | device=p.device) 113 | elif dtype == torch.uint8: 114 | if state['step'] == 0: 115 | if 'dynamic' not in self.name2qmap: self.fill_qmap() 116 | self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) 117 | self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) 118 | 119 | n = p.numel() 120 | blocks = (n - 1) // config['block_wise'] + 1 121 | 122 | state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, 123 | device=p.device) 124 | state['qmap1'] = self.name2qmap['dynamic'] 125 | 126 | state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, 127 | device=p.device) 128 | state['qmap2'] = self.name2qmap['udynamic'] 129 | 130 | state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) 131 | state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) 132 | 133 | @torch.no_grad() 134 | def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int): 135 | state = self.state[p] 136 | config = self.get_config(gindex, pindex, group) 137 | 138 | p_cpu, grad_cpu = p.cpu(), p.grad.cpu() 139 | # this is a no-op if parameters are already on CPU 140 | 141 | step = state['step'] = state['step'] + 1 142 | beta1, beta2 = group['betas'] 143 | 144 | param_delta = self._update_moments_and_compute_delta( 145 | state, config, p_cpu, grad_cpu, beta1, beta2, group['eps'], group['weight_decay'] 146 | ) 147 | del grad_cpu # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers 148 | 149 | step_norm = torch.norm(param_delta) 150 | weight_norm = p_cpu.norm().clamp(0, self.clamp_value) 151 | 152 | trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0 153 | state['weight_norm'], state['step_norm'], state['trust_ratio'] = weight_norm, step_norm, trust_ratio 154 | 155 | # Apply bias to lr to avoid broadcast. 156 | bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1 157 | step_size = group['lr'] * bias_correction 158 | p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio) 159 | 160 | def _update_moments_and_compute_delta( 161 | self, state: Dict, config: Dict, 162 | p_cpu: torch.Tensor, grad_cpu: torch.Tensor, 163 | beta1: float, beta2: float, eps: float, weight_decay: float 164 | ) -> torch.Tensor: 165 | step, block_size, chunk_size = state['step'], config['block_wise'], self.update_chunk_size 166 | 167 | if state['state1'].dtype != torch.uint8: 168 | # not quantized: update normally 169 | exp_avg, exp_avg_sq = state['state1'], state['state2'] 170 | exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1) 171 | exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2) 172 | 173 | sqrt_out = grad_cpu if self.reuse_grad_buffers else None 174 | _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps) 175 | param_delta = torch.div(exp_avg, _denominator, out=_denominator) 176 | if weight_decay != 0: 177 | param_delta.add_(p_cpu, alpha=weight_decay) 178 | return param_delta 179 | elif p_cpu.numel() <= chunk_size: 180 | # quantized tensor within chunk size 181 | exp_avg = dequantize_blockwise( 182 | state['state1'], (state['absmax1'], state['qmap1']), blocksize=block_size 183 | ) 184 | exp_avg_sq = dequantize_blockwise( 185 | state['state2'], (state['absmax2'], state['qmap2']), blocksize=block_size 186 | ) 187 | 188 | exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1) 189 | exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2) 190 | 191 | quantize_blockwise(exp_avg, state['qmap1'], state['absmax1'], out=state['state1']) 192 | quantize_blockwise(exp_avg_sq, state['qmap2'], state['absmax2'], out=state['state2']) 193 | # note: quantize_blockwise also modifies qmap and absmax in-place 194 | 195 | param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps)) 196 | # note: this changes statistics in-place, but it's okay b/c we saved quantized version 197 | 198 | if weight_decay != 0: 199 | param_delta.add_(p_cpu, alpha=weight_decay) 200 | return param_delta 201 | 202 | else: 203 | # very large quantized tensor, compute updates in chunks to save RAM 204 | flat_p, flat_grad, flat_state1, flat_state2 = ( 205 | tensor.view(-1) for tensor in (p_cpu, grad_cpu, state['state1'], state['state2']) 206 | ) 207 | output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad) 208 | 209 | for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)): 210 | chunk = slice(chunk_start, chunk_start + chunk_size) 211 | chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size) 212 | 213 | chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk] 214 | chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk] 215 | chunk_absmax1, chunk_absmax2 = state['absmax1'][chunk_blocks], state['absmax2'][chunk_blocks] 216 | if chunk_state1.storage_offset() != 0: 217 | chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map( 218 | torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2) 219 | ) # clone chunks to ensure that tensors do not have offsets 220 | 221 | exp_avg_chunk = dequantize_blockwise( 222 | chunk_state1, (chunk_absmax1, state['qmap1']), blocksize=block_size 223 | ) 224 | exp_avg_sq_chunk = dequantize_blockwise( 225 | chunk_state2, (chunk_absmax2, state['qmap2']), blocksize=block_size 226 | ) 227 | 228 | exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1) 229 | exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2) 230 | 231 | # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu 232 | del chunk_grad 233 | 234 | flat_state1[chunk], (state['absmax1'][chunk_blocks], state['qmap1']) = quantize_blockwise( 235 | exp_avg_chunk, state['qmap1'], chunk_absmax1, out=chunk_state1 236 | ) 237 | flat_state2[chunk], (state['absmax2'][chunk_blocks], state['qmap2']) = quantize_blockwise( 238 | exp_avg_sq_chunk, state['qmap2'], chunk_absmax2, out=chunk_state2 239 | ) 240 | # note: we need to explicitly assign new quantized tensors because of cloning earlier 241 | 242 | torch.div(exp_avg_chunk, exp_avg_sq_chunk.sqrt_().add_(eps), out=output_buffer[chunk]) 243 | # note: this changes statistics in-place, but it's okay b/c we saved quantized version 244 | 245 | if weight_decay != 0: 246 | output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay) 247 | 248 | param_delta = output_buffer.view_as(grad_cpu) 249 | 250 | return param_delta --------------------------------------------------------------------------------