├── tasks ├── __init__.py ├── mlm │ ├── __init__.py │ ├── data.py │ ├── whole_word_mask.py │ └── task.py └── base.py ├── src ├── training │ ├── __init__.py │ ├── sync.py │ ├── hf_trainer.py │ ├── callback.py │ └── lamb_8bit.py ├── models │ ├── __init__.py │ ├── config.py │ ├── lean_albert.py │ └── albert.py ├── __init__.py ├── modules │ ├── __init__.py │ ├── functional.py │ ├── sequence.py │ ├── rotary.py │ ├── attn.py │ ├── pixelfly.py │ ├── linear.py │ └── ffn.py ├── utils.py └── huggingface_auth.py ├── requirements.txt ├── .gitignore ├── run_trainer.py ├── field_collator.py ├── predict_on_lidirus.py ├── models.py ├── run_aux_peer.py ├── arguments.py ├── finetune_mlm.py ├── LICENSE └── README.md /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tasks/mlm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .albert import * 2 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attn import * 2 | from .ffn import * 3 | from .linear import * 4 | from .pixelfly import * 5 | from .rotary import * 6 | from .sequence import * 7 | -------------------------------------------------------------------------------- /src/models/config.py: -------------------------------------------------------------------------------- 1 | from transformers import AlbertConfig 2 | 3 | 4 | class LeanAlbertConfig(AlbertConfig): 5 | rotary_embedding_base: int = 10_000 6 | hidden_act_gated: bool = False 7 | 8 | def __hash__(self): 9 | return hash("\t".join(f"{k}={v}" for k, v in self.__dict__.items() if not k.startswith("_"))) 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes==0.32.3 2 | datasets==2.4.0 3 | einops==0.3.2 4 | fuzzysearch==0.7.3 5 | https://github.com/learning-at-home/hivemind/archive/calm.zip 6 | numpy==1.22.3 7 | pandas==1.4.3 8 | pymorphy2==0.9.1 9 | razdel==0.5.0 10 | requests==2.31.0 11 | scikit-learn==1.0.2 12 | sentencepiece==0.1.96 13 | termcolor==1.1.0 14 | tokenizers==0.12.1 15 | torch>=1.8.0 16 | torch_optimizer==0.1.0 17 | transformers==4.30.0 18 | wandb==0.12.1 19 | -------------------------------------------------------------------------------- /src/modules/functional.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | import torch 5 | from transformers.activations import ACT2FN as HF_ACT2FN 6 | 7 | 8 | @functools.lru_cache() 9 | def maybe_script(fn: callable) -> callable: 10 | """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script.""" 11 | using_tpu = bool(os.environ.get("TPU_NAME")) 12 | # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function 13 | should_script = int(os.environ.get("LEAN_USE_JIT", not using_tpu)) 14 | return torch.jit.script(fn) if should_script else fn 15 | 16 | 17 | @maybe_script 18 | def gelu_fused(x): 19 | """ 20 | Approximate GELU activation, same as in Google BERT and OpenAI GPT (as of Dec 2021) 21 | the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 22 | """ 23 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) # note: 0.7988.. = sqrt(2/pi) 24 | 25 | 26 | @maybe_script 27 | def gelu_fused_grad(grad_output, x): 28 | """Gradients of gelu_fwd w.r.t. input""" 29 | tanh = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 30 | jac = 0.5 * x * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh) 31 | return jac * grad_output 32 | 33 | 34 | class GELU(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, input: torch.Tensor): 37 | ctx.save_for_backward(input) 38 | return gelu_fused(input) 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output: torch.Tensor): 42 | input, = ctx.saved_tensors 43 | return gelu_fused_grad(grad_output, input) 44 | 45 | 46 | ACT2FN = dict(HF_ACT2FN, gelu_fused=GELU.apply) 47 | -------------------------------------------------------------------------------- /src/modules/sequence.py: -------------------------------------------------------------------------------- 1 | """ 2 | A module that implements sequential model type with optional keyword arguments. 3 | When using gradient checkpoints, keyword arguments should NOT require grad. 4 | """ 5 | from typing import Callable, Sequence 6 | 7 | import torch 8 | from logging import getLogger 9 | from torch import nn as nn 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | class ActiveKwargs(nn.Module): 16 | """ 17 | A module with selective kwargs, compatible with sequential, gradient checkpoints and 18 | Usage: ony use this as a part of SequentialWithKwargs 19 | """ 20 | 21 | def __init__(self, module: nn.Module, active_keys: Sequence[str], use_first_output: bool = False): 22 | super().__init__() 23 | self.module, self.active_keys, self.use_first_output = module, set(active_keys), use_first_output 24 | 25 | def forward(self, input: torch.Tensor, *args, **kwargs): 26 | kwargs = {key: value for key, value in kwargs.items() if key in self.active_keys} 27 | output = self.module(input, *args, **kwargs) 28 | if self.use_first_output and not isinstance(output, torch.Tensor): 29 | output = output[0] 30 | return output 31 | 32 | 33 | class SequentialWithKwargs(nn.Sequential): 34 | def __init__(self, *modules: ActiveKwargs): 35 | for module in modules: 36 | assert isinstance(module, ActiveKwargs) 37 | super().__init__(*modules) 38 | self.gradient_checkpointing = False 39 | 40 | def forward(self, input: torch.Tensor, *args, **kwargs): 41 | kwarg_keys, kwarg_values = zip(*kwargs.items()) if (self.gradient_checkpointing and kwargs) else ([], []) 42 | for module in self: 43 | if self.gradient_checkpointing and torch.is_grad_enabled(): 44 | # pack kwargs with args since gradient checkpoint does not support kwargs 45 | input = checkpoint(self._checkpoint_forward, module, input, kwarg_keys, *kwarg_values, *args) 46 | else: 47 | input = module(input, *args, **kwargs) 48 | return input 49 | 50 | def _checkpoint_forward(self, module: Callable, input: torch.Tensor, kwarg_keys: Sequence[str], *etc): 51 | kwargs = {key: etc[i] for i, key in enumerate(kwarg_keys)} 52 | args = etc[len(kwarg_keys) :] 53 | return module(input, *args, **kwargs) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /run_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import scipy.stats # compatibility for internal testing environment 7 | import torch.distributed 8 | import transformers 9 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 10 | from transformers import HfArgumentParser 11 | 12 | from arguments import CollaborativeArguments, HFTrainerArguments, TrainingPeerArguments 13 | from src import utils 14 | from src.training.callback import CollaborativeCallback 15 | from src.training.hf_trainer import CollaborativeHFTrainer, NOPtimizer 16 | from src.training.sync import SynchronizationCallback, is_main_process 17 | from tasks.mlm.task import MLMTrainingTask 18 | 19 | use_hivemind_log_handler("in_root_logger") 20 | logger = get_logger() 21 | torch.set_num_threads(1) # avoid quadratic number of threads 22 | 23 | 24 | def main(): 25 | parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments)) 26 | training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses() 27 | if torch.distributed.is_initialized(): 28 | assert not collab_args.reuse_grad_buffers, "Reuse_grad_buffers are not supported in distributed mode" 29 | 30 | logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}") 31 | if len(training_peer_args.initial_peers) == 0: 32 | logger.warning("Please specify initial peers or let others join your peer") 33 | 34 | utils.setup_logging(trainer_args) 35 | task = MLMTrainingTask(training_peer_args, trainer_args, collab_args) 36 | model = task.model.to(trainer_args.device) 37 | for param in model.parameters(): 38 | if param.grad is None: 39 | param.grad = torch.zeros_like(param) 40 | 41 | callbacks = [(CollaborativeCallback if is_main_process() else SynchronizationCallback)(task, training_peer_args)] 42 | assert trainer_args.do_train and not trainer_args.do_eval 43 | 44 | # Note: the code below creates the trainer with dummy scheduler and removes some callbacks. 45 | # This is done because collaborative training has its own callbacks that take other peers into account. 46 | trainer = CollaborativeHFTrainer( 47 | model=model, 48 | args=trainer_args, 49 | tokenizer=task.tokenizer, 50 | data_collator=task.data_collator, 51 | data_seed=hash(task.local_public_key), 52 | train_dataset=task.training_dataset, 53 | reuse_grad_buffers=collab_args.reuse_grad_buffers, 54 | eval_dataset=None, 55 | optimizer=task.optimizer if is_main_process() else NOPtimizer(task._make_param_groups()), 56 | callbacks=callbacks, 57 | ) 58 | trainer.remove_callback(transformers.trainer_callback.PrinterCallback) 59 | trainer.remove_callback(transformers.trainer_callback.ProgressCallback) 60 | 61 | latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None) 62 | trainer.train(model_path=latest_checkpoint_dir) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import transformers.utils.logging 4 | from hivemind import choose_ip_address 5 | from hivemind.dht.crypto import RSASignatureValidator 6 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator 7 | from hivemind.dht.validation import RecordValidatorBase 8 | from hivemind.utils.logging import get_logger 9 | from multiaddr import Multiaddr 10 | from pydantic import BaseModel, StrictFloat, confloat, conint 11 | from transformers.trainer_utils import is_main_process 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class LocalMetrics(BaseModel): 17 | epoch: conint(ge=0, strict=True) 18 | samples_per_second: confloat(ge=0.0, strict=True) 19 | samples_accumulated: conint(ge=0, strict=True) 20 | loss: StrictFloat 21 | mini_steps: conint(ge=0, strict=True) # queries 22 | 23 | 24 | class MetricSchema(BaseModel): 25 | metrics: Dict[BytesWithPublicKey, LocalMetrics] 26 | 27 | 28 | def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]: 29 | signature_validator = RSASignatureValidator() 30 | validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator] 31 | return validators, signature_validator.local_public_key 32 | 33 | 34 | class TextStyle: 35 | BOLD = "\033[1m" 36 | BLUE = "\033[34m" 37 | RESET = "\033[0m" 38 | 39 | 40 | def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None: 41 | if only_p2p: 42 | unique_addrs = {addr["p2p"] for addr in visible_maddrs} 43 | initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs) 44 | else: 45 | available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr] 46 | if available_ips: 47 | preferred_ip = choose_ip_address(available_ips) 48 | selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)] 49 | else: 50 | selected_maddrs = visible_maddrs 51 | initial_peers_str = " ".join(str(addr) for addr in selected_maddrs) 52 | 53 | logger.info( 54 | f"Running a DHT peer. To connect other peers to this one over the Internet, use " 55 | f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}" 56 | ) 57 | logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}") 58 | 59 | 60 | def setup_logging(training_args): 61 | # Log on each process the small summary: 62 | logger.warning( 63 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 64 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 65 | ) 66 | # Set the verbosity to info of the Transformers logger (on main process only): 67 | if is_main_process(training_args.local_rank): 68 | transformers.utils.logging.set_verbosity_info() 69 | transformers.utils.logging.enable_default_handler() 70 | transformers.utils.logging.enable_explicit_format() 71 | logger.info("Training/evaluation parameters %s", training_args) 72 | -------------------------------------------------------------------------------- /field_collator.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from dataclasses import dataclass 3 | from typing import Iterable, Tuple, Optional, Union, List, Dict, Any 4 | 5 | import torch 6 | from transformers import PreTrainedTokenizerBase 7 | from transformers.file_utils import PaddingStrategy 8 | 9 | 10 | @dataclass 11 | class FieldDataCollatorWithPadding: 12 | """ 13 | A general-purpose data collator that can handle batches with arbitrary fields. 14 | Supports only PyTorch tensors as outputs. 15 | """ 16 | tokenizer: PreTrainedTokenizerBase 17 | # (field name, pad token index, index of the sequence axis or None) 18 | fields_to_pad: Iterable[Tuple[str, int, Optional[int]]] = () 19 | padding: Union[bool, str, PaddingStrategy] = True 20 | max_length: Optional[int] = None 21 | pad_to_multiple_of: Optional[int] = None 22 | 23 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 24 | field_values = defaultdict(list) 25 | for field_name, field_pad_idx, field_seq_idx in self.fields_to_pad: 26 | for example in features: 27 | if field_name in example: 28 | field_values[field_name].append(example.pop(field_name)) 29 | 30 | batch = self.tokenizer.pad( 31 | features, 32 | padding=self.padding, 33 | max_length=self.max_length, 34 | pad_to_multiple_of=self.pad_to_multiple_of, 35 | return_tensors="pt", 36 | ) 37 | if "label" in batch: 38 | batch["labels"] = batch.pop("label") 39 | if "label_ids" in batch: 40 | batch["labels"] = batch.pop("label_ids") 41 | 42 | sequence_length = batch["input_ids"].size(1) 43 | batch_size = batch["input_ids"].size(0) 44 | 45 | for field_name, field_pad_idx, field_seq_idx in self.fields_to_pad: 46 | 47 | if field_values[field_name]: 48 | field_value_arrays = [torch.as_tensor(values) for values in field_values[field_name]] 49 | assert len(set(arr.ndim for arr in field_value_arrays)) == 1 50 | 51 | max_dim_lengths = [ 52 | max(arr.size(dim) for arr in field_value_arrays) 53 | for dim in range(field_value_arrays[0].ndim) 54 | ] 55 | 56 | if field_seq_idx is not None: 57 | max_dim_lengths[field_seq_idx] = sequence_length 58 | 59 | padded_tensor = torch.full([batch_size] + max_dim_lengths, field_pad_idx, dtype=torch.long) 60 | 61 | for i, ar in enumerate(field_value_arrays): 62 | paste_inds = [i] 63 | 64 | for dim in range(field_value_arrays[0].ndim): 65 | if dim == field_seq_idx and self.tokenizer.padding_side == "left": 66 | inds = slice(-ar.size(dim), -1) 67 | else: 68 | inds = slice(ar.size(dim)) 69 | paste_inds.append(inds) 70 | 71 | padded_tensor[paste_inds] = ar 72 | 73 | batch[field_name] = padded_tensor 74 | 75 | return batch 76 | -------------------------------------------------------------------------------- /src/modules/rotary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Auxiliary modules for implementing Rotary Position Embeddings 3 | Original paper: https://arxiv.org/abs/2104.09864 4 | Based on reference implementation from https://blog.eleuther.ai/rotary-embeddings 5 | """ 6 | 7 | import torch 8 | import torch.distributed 9 | import torch.nn as nn 10 | 11 | from logging import getLogger 12 | from src.modules.functional import maybe_script 13 | 14 | logger = getLogger(__file__) 15 | 16 | 17 | class RotaryEmbeddings(nn.Module): 18 | """Applies rotary position embeddings to a tensor, uses caching to improve performance""" 19 | 20 | def __init__(self, dim: int, base: int = 10_000): 21 | super().__init__() 22 | self.dim, self.base = dim, base 23 | self._rotate = maybe_script(rotate) 24 | self._get_auxiliary_tensors = maybe_script(get_auxiliary_tensors) 25 | self.register_buffer("cos", torch.empty(1, dim), persistent=False) 26 | self.register_buffer("sin", torch.empty(1, dim), persistent=False) 27 | 28 | def forward(self, x: torch.Tensor, offset: int = 0): 29 | """ 30 | :param x: tensor of shape [batch_size, seq_len, nhead, hid_size] 31 | :param offset: add this value to all position indices 32 | """ 33 | seq_len = x.shape[1] 34 | if seq_len + offset > self.cos.shape[0] or x.device != self.cos.device: 35 | if torch.distributed.is_initialized(): 36 | logger.warning("Rebuilding auxiliary tensors for rotary embeddings, this may cause DDP to freeze. To " 37 | "avoid this, call _set_auxiliary_tensors for maximum length before the training starts") 38 | self._set_auxiliary_buffers(max_len=seq_len + offset, device=x.device) 39 | cosines_for_position = self.cos[None, offset : seq_len + offset, None, :] 40 | sines_for_position = self.sin[None, offset : seq_len + offset, None, :] 41 | return self._rotate(x, cosines_for_position, sines_for_position) 42 | 43 | def _set_auxiliary_buffers(self, max_len: int, device: torch.device): 44 | _cos, _sin = self._get_auxiliary_tensors(max_len, self.dim, torch.float32, device, self.base) 45 | self.register_buffer("cos", _cos, persistent=False) 46 | self.register_buffer("sin", _sin, persistent=False) 47 | 48 | 49 | @torch.no_grad() 50 | def get_auxiliary_tensors(seq_len: int, dim: int, dtype: torch.dtype, device: torch.device, base: int): 51 | """ 52 | Compute auxiliary sine and cosine tensors for rotary position embedding 53 | :returns: a tuple of (cos, sin) tensors of shape [seq_len, hid_size] 54 | """ 55 | _buf = torch.linspace(0, -1 + 2 / dim, dim // 2, dtype=torch.float32, device=device) 56 | inv_freq = torch.pow(base, _buf, out=_buf).repeat(2) 57 | time_ix = torch.arange(seq_len, dtype=inv_freq.dtype, device=device) 58 | 59 | freqs = time_ix[:, None] * inv_freq[None, :] 60 | cos = torch.cos(freqs) 61 | sin = freqs.sin_() 62 | return cos.to(dtype), sin.to(dtype) 63 | 64 | 65 | def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 66 | """rotate pairwise coordinate using precomputed cos & sin tensors""" 67 | dim = x.shape[-1] 68 | x_left, x_right = x.split(split_size=dim // 2, dim=x.ndim - 1) 69 | x_rotated = torch.cat([x_right.neg(), x_left], dim=x.ndim - 1) 70 | return x * cos + x_rotated * sin 71 | 72 | 73 | -------------------------------------------------------------------------------- /predict_on_lidirus.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | from functools import partial 5 | from pathlib import Path 6 | 7 | from datasets import load_dataset 8 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, \ 9 | TrainingArguments, AlbertTokenizerFast 10 | 11 | from data import TASK_TO_CONFIG, TASK_TO_NAME 12 | from src import LeanAlbertForSequenceClassification 13 | 14 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 15 | 16 | BATCH_SIZE = 8 17 | MAX_LENGTH = 512 18 | 19 | TASK = "lidirus" 20 | 21 | 22 | def main(model_path: Path, arch): 23 | run_dirname = model_path.parts[-1] 24 | assert 'terra' in run_dirname 25 | 26 | if arch == 'lean_albert': 27 | tokenizer = AlbertTokenizerFast.from_pretrained('tokenizer') 28 | else: 29 | tokenizer = AutoTokenizer.from_pretrained(arch) 30 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 31 | 32 | dataset = load_dataset("russian_super_glue", TASK) 33 | config = TASK_TO_CONFIG[TASK](dataset) 34 | 35 | processed_dataset = dataset.map(partial(config.process_data, tokenizer=tokenizer, max_length=512), 36 | batched=True, remove_columns=["label"]) 37 | test_without_labels = processed_dataset['test'].remove_columns(['labels']) 38 | 39 | last_path = max(Path(model_path).glob("checkpoint*"), default=None, key=os.path.getctime) 40 | with open(last_path / 'trainer_state.json') as f: 41 | trainer_state = json.load(f) 42 | best_path = Path(trainer_state['best_model_checkpoint']).parts[-1] 43 | 44 | if arch == 'lean_albert': 45 | model = LeanAlbertForSequenceClassification.from_pretrained(model_path / best_path) 46 | else: 47 | model = AutoModelForSequenceClassification.from_pretrained(model_path / best_path) 48 | 49 | training_args = TrainingArguments( 50 | output_dir=model_path, overwrite_output_dir=True, 51 | evaluation_strategy='epoch', logging_strategy='epoch', per_device_train_batch_size=BATCH_SIZE, 52 | per_device_eval_batch_size=BATCH_SIZE, save_strategy='epoch', save_total_limit=1, 53 | fp16=True, dataloader_num_workers=4, group_by_length=True, 54 | report_to='none', load_best_model_at_end=True, metric_for_best_model=config.best_metric 55 | ) 56 | 57 | trainer = Trainer( 58 | model=model, 59 | args=training_args, 60 | train_dataset=None, 61 | eval_dataset=None, 62 | compute_metrics=config.compute_metrics, 63 | tokenizer=tokenizer, 64 | data_collator=data_collator, 65 | ) 66 | 67 | predictions = trainer.predict(test_dataset=test_without_labels) 68 | processed_predictions = config.process_predictions(predictions.predictions, split="test") 69 | 70 | preds_dir = f"preds/{run_dirname.replace('_terra', '')}" 71 | 72 | os.makedirs(preds_dir, exist_ok=True) 73 | 74 | with open(f"{preds_dir}/{TASK_TO_NAME[TASK]}.jsonl", 'w+') as outf: 75 | for prediction in processed_predictions: 76 | print(json.dumps(prediction, ensure_ascii=True), file=outf) 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = ArgumentParser() 81 | parser.add_argument("-m", '--model', type=Path, required=True) 82 | parser.add_argument("-a", '--arch') 83 | args = parser.parse_args() 84 | main(args.model, args.arch) 85 | -------------------------------------------------------------------------------- /src/training/sync.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Sequence 3 | 4 | import torch 5 | import transformers 6 | from hivemind.utils import get_logger 7 | from torch.distributed.distributed_c10d import _get_default_group, _get_default_store 8 | from transformers import TrainerControl, TrainerState, TrainingArguments 9 | 10 | import arguments 11 | import tasks 12 | 13 | AUTHORITATIVE_RANK = 0 14 | BROADCAST_BUFFER_SIZE: int = 250 * 1024 * 1024 15 | logger = get_logger(__name__) 16 | 17 | 18 | def is_main_process() -> bool: 19 | """Whether this is the main process on **this peer's** distributed run. Non-distributed process is always main.""" 20 | return (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == AUTHORITATIVE_RANK 21 | 22 | 23 | class SynchronizationCallback(transformers.TrainerCallback): 24 | """Minimalistic callback for non-master DDP workers""" 25 | 26 | def __init__(self, task: "tasks.TrainingTaskBase", args: "arguments.TrainingPeerArguments"): 27 | self.task = task 28 | self.is_master = is_main_process() 29 | self._checksum_counter = 0 30 | self._state_tensors = None 31 | self._prev_version = self._prev_epoch = -1 32 | 33 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 34 | if torch.distributed.is_initialized(): 35 | self._maybe_sync_model_state() 36 | 37 | def on_step_end( 38 | self, 39 | args: TrainingArguments, 40 | state: transformers.TrainerState, 41 | control: transformers.TrainerControl, 42 | **kwargs, 43 | ): 44 | control.should_log = True 45 | model = self.task.model 46 | if torch.distributed.is_initialized(): 47 | self._maybe_sync_model_state() 48 | 49 | self._checksum_counter += 1 50 | if self._checksum_counter % 100 == 0: 51 | rank = torch.distributed.get_rank() 52 | print(end=f"CHECKSUM({rank})={float(sum(p.sum().item() for p in model.parameters()))}\n") 53 | self.task.on_step_end() 54 | 55 | @property 56 | def state_tensors(self) -> Sequence[torch.Tensor]: 57 | if self._state_tensors is None: 58 | self._state_tensors = list(self.task.model.state_dict().values()) 59 | return self._state_tensors 60 | 61 | def _compute_state_version(self) -> int: 62 | """return a non-decreasing integer that goes up whenever model params and/or buffers were updated""" 63 | assert self.is_master 64 | return sum(state["step"] for state in self.task.optimizer.opt.state.values()) 65 | 66 | def _should_broadcast_state(self): 67 | store = _get_default_store() 68 | if self.is_master: 69 | current_version = self._compute_state_version() 70 | if current_version == self._prev_version and self.task.optimizer.local_epoch > self._prev_epoch + 1: 71 | logger.warning("Model state version has not changed during a full epoch; " 72 | "broadcasting parameters between torch.distributed synchronization may be broken") 73 | 74 | if current_version != self._prev_version or self.task.optimizer.local_epoch > self._prev_epoch + 1: 75 | should_broadcast = True 76 | else: 77 | should_broadcast = False 78 | 79 | store.set(f"_hivemind_should_broadcast_state", str(int(should_broadcast))) 80 | torch.distributed.barrier() 81 | return should_broadcast 82 | else: 83 | torch.distributed.barrier() 84 | raw_should_broadcast = store.get(f"_hivemind_should_broadcast_state") 85 | return bool(int(raw_should_broadcast)) 86 | 87 | def _maybe_sync_model_state(self): 88 | """Synchronize model params and buffers from master""" 89 | if self.state_tensors and self._should_broadcast_state(): 90 | t_start = time.perf_counter() 91 | with torch.no_grad(): 92 | torch.distributed._broadcast_coalesced( 93 | _get_default_group(), self.state_tensors, BROADCAST_BUFFER_SIZE, AUTHORITATIVE_RANK 94 | ) 95 | if self.is_master: 96 | self._prev_version = self._compute_state_version() 97 | self._prev_epoch = self.task.optimizer.local_epoch 98 | logger.info(f"Broadcasting master params took {time.perf_counter() - t_start} seconds") 99 | else: 100 | logger.debug("Not broadcasting") 101 | -------------------------------------------------------------------------------- /tasks/mlm/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from functools import partial 4 | from typing import Optional 5 | 6 | import torch.utils.data 7 | from datasets import interleave_datasets, load_dataset 8 | from hivemind.utils.logging import get_logger 9 | from razdel import sentenize 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def make_training_dataset( 15 | tokenizer, 16 | shuffle_buffer_size: int = 10 ** 4, 17 | shuffle_seed: Optional[int] = None, 18 | preprocessing_batch_size: int = 256, 19 | max_sequence_length: int = 512, 20 | ): 21 | dataset1 = load_dataset(YOUR_DATASET_HERE) 22 | dataset2 = load_dataset(YOUR_DATASET_HERE) 23 | dataset3 = load_dataset(YOUR_DATASET_HERE) 24 | 25 | datasets = dict(dataset1=dataset1, dataset2=dataset2, dataset3=dataset3) 26 | weights = dict(dataset1=0.6, dataset2=0.2, dataset3=0.2) 27 | 28 | datasets = { 29 | key: dataset.map( 30 | partial(tokenize_function, tokenizer, max_sequence_length=max_sequence_length), 31 | batched=True, 32 | batch_size=preprocessing_batch_size, 33 | ) 34 | for key, dataset in datasets.items() 35 | } 36 | 37 | dataset = interleave_datasets( 38 | [datasets[k] for k in sorted(datasets.keys())], 39 | probabilities=[weights[k] for k in sorted(datasets.keys())], 40 | ) 41 | 42 | dataset = dataset.shuffle(seed=shuffle_seed, buffer_size=shuffle_buffer_size) 43 | dataset = dataset.with_format("torch") 44 | return WrappedIterableDataset(dataset) 45 | 46 | 47 | def create_instances_from_document(tokenizer, document, max_sequence_length): 48 | """Creates `TrainingInstance`s for a single document.""" 49 | # We DON'T just concatenate all of the tokens from a document into a long 50 | # sequence and choose an arbitrary split point because this would make the 51 | # next sentence prediction task too easy. Instead, we split the input into 52 | # segments "A" and "B" based on the actual "sentences" provided by the user 53 | # input. 54 | instances = [] 55 | current_chunk = [] 56 | current_length = 0 57 | max_sequence_length = int(max_sequence_length.value) 58 | segmented_sents = [s.text for s in sentenize(document)] 59 | 60 | for i, sent in enumerate(segmented_sents): 61 | current_chunk.append(sent) 62 | current_length += len(tokenizer.tokenize(sent)) 63 | if i == len(segmented_sents) - 1 or current_length >= max_sequence_length: 64 | if len(current_chunk) > 1: 65 | # `a_end` is how many segments from `current_chunk` go into the `A` 66 | # (first) sentence. 67 | a_end = random.randint(1, len(current_chunk) - 1) 68 | 69 | tokens_a = [] 70 | for j in range(a_end): 71 | tokens_a.append(current_chunk[j]) 72 | 73 | tokens_b = [] 74 | 75 | for j in range(a_end, len(current_chunk)): 76 | tokens_b.append(current_chunk[j]) 77 | 78 | if random.random() < 0.5: 79 | # Random next 80 | is_random_next = True 81 | # in this case, we just swap tokens_a and tokens_b 82 | tokens_a, tokens_b = tokens_b, tokens_a 83 | else: 84 | # Actual next 85 | is_random_next = False 86 | 87 | assert len(tokens_a) >= 1 88 | assert len(tokens_b) >= 1 89 | 90 | instance = tokenizer( 91 | " ".join(tokens_a), 92 | " ".join(tokens_b), 93 | padding="max_length", 94 | truncation="longest_first", 95 | max_length=max_sequence_length, 96 | # We use this option because DataCollatorForLanguageModeling 97 | # is more efficient when it receives the `special_tokens_mask`. 98 | return_special_tokens_mask=True, 99 | ) 100 | assert len(instance["input_ids"]) <= max_sequence_length 101 | instance["sentence_order_label"] = 1 if is_random_next else 0 102 | instances.append(instance) 103 | 104 | current_chunk = [] 105 | current_length = 0 106 | return instances 107 | 108 | 109 | def tokenize_function(tokenizer, examples, max_sequence_length): 110 | # Remove empty texts 111 | texts = [text for text in examples["text"] if len(text) > 0 and not text.isspace()] 112 | new_examples = defaultdict(list) 113 | 114 | for text in texts: 115 | instances = create_instances_from_document(tokenizer, text, max_sequence_length) 116 | for instance in instances: 117 | for key, value in instance.items(): 118 | new_examples[key].append(value) 119 | 120 | return new_examples 121 | 122 | 123 | class WrappedIterableDataset(torch.utils.data.IterableDataset): 124 | """Wraps huggingface IterableDataset as pytorch IterableDataset, implement default methods for DataLoader""" 125 | 126 | def __init__(self, hf_iterable, verbose: bool = True): 127 | self.hf_iterable = hf_iterable 128 | self.verbose = verbose 129 | 130 | def __iter__(self): 131 | started = False 132 | logger.info("Pre-fetching training samples...") 133 | while True: 134 | for sample in self.hf_iterable: 135 | if not started: 136 | logger.info("Began iterating minibatches!") 137 | started = True 138 | yield sample 139 | -------------------------------------------------------------------------------- /src/training/hf_trainer.py: -------------------------------------------------------------------------------- 1 | """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training""" 2 | 3 | import hivemind 4 | import torch 5 | import torch.distributed 6 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from transformers.trainer import Trainer 10 | 11 | from arguments import HFTrainerArguments 12 | from src.modules.rotary import RotaryEmbeddings 13 | from src.training.sync import is_main_process 14 | 15 | use_hivemind_log_handler("in_root_logger") 16 | logger = get_logger(__name__) 17 | LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None) 18 | 19 | 20 | class CollaborativeHFTrainer(Trainer): 21 | """ 22 | A version of HuggingFace trainer that shuffles the dataset using a separate random seed. 23 | Used to ensure that peers don't process batches in the same order. 24 | """ 25 | 26 | def __init__(self, *, data_seed: int, optimizer: hivemind.Optimizer, reuse_grad_buffers: bool, **kwargs): 27 | self.data_seed = data_seed 28 | self.optimizer = optimizer 29 | self.reuse_grad_buffers = reuse_grad_buffers 30 | assert not torch.distributed.is_initialized() or not reuse_grad_buffers, "DDP with reuse_grad_buffers is not implemented (yet)" 31 | super().__init__(optimizers=(optimizer, NoOpScheduler(optimizer)), **kwargs) 32 | 33 | if self.args.fp16 and self.reuse_grad_buffers: 34 | self.scaler = hivemind.GradScaler() 35 | 36 | def get_train_dataloader(self) -> DataLoader: 37 | """Shuffle data independently for each peer to avoid duplicating batches [important for quality]""" 38 | seed = self.data_seed 39 | if torch.distributed.is_initialized(): 40 | seed += torch.distributed.get_rank() 41 | torch.manual_seed(seed) 42 | return super().get_train_dataloader() 43 | 44 | def _wrap_model(self, model: nn.Module, training=True): 45 | assert training, "Evaluation (training=False) should be run on a separate dedicated worker." 46 | model = super()._wrap_model(model, training) 47 | if torch.distributed.is_initialized(): 48 | assert isinstance(model, nn.parallel.DistributedDataParallel) 49 | assert model.require_forward_param_sync 50 | logger.info("Pre-populating rotary embedding cache up to maximum length to enforce static graph") 51 | assert isinstance(self.args, HFTrainerArguments) 52 | device = f"{model.device_type}:{model.output_device}" 53 | for module in model.modules(): 54 | if isinstance(module, RotaryEmbeddings): 55 | module._set_auxiliary_buffers(max_len=self.args.max_sequence_length, device=device) 56 | 57 | logger.warning("DistributedDataParallel: triggering _set_static_graph() to allow checkpointing") 58 | model._set_static_graph() 59 | 60 | # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step 61 | should_override_zero_grad = self.reuse_grad_buffers if is_main_process() else False # replicas can reset grad 62 | return IgnoreGradManipulations(model, override_zero_grad=should_override_zero_grad) 63 | 64 | 65 | class NOPtimizer(torch.optim.SGD): 66 | def __init__(self, params): 67 | super().__init__(params, lr=0) 68 | 69 | def step(self, *args, **kwargs): 70 | pass 71 | 72 | 73 | class NoOpScheduler(LRSchedulerBase): 74 | """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler""" 75 | 76 | def get_lr(self): 77 | return [group["lr"] for group in self.optimizer.param_groups] 78 | 79 | def print_lr(self, *args, **kwargs): 80 | if self.optimizer.scheduler: 81 | return self.optimizer.scheduler.print_lr(*args, **kwargs) 82 | 83 | def step(self): 84 | logger.debug("Called NoOpScheduler.step") 85 | self._last_lr = self.get_lr() 86 | 87 | def state_dict(self): 88 | return {} 89 | 90 | def load_state_dict(self, *args, **kwargs): 91 | logger.debug("Called NoOpScheduler.load_state_dict") 92 | 93 | 94 | class IgnoreGradManipulations(nn.Module): 95 | """Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad)""" 96 | 97 | def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True): 98 | super().__init__() 99 | self.module = module 100 | self.override_clipping = override_clipping 101 | self.override_zero_grad = override_zero_grad 102 | 103 | def forward(self, *args, **kwargs): 104 | return self.module.forward(*args, **kwargs) 105 | 106 | def zero_grad(self, set_to_none: bool = False) -> None: 107 | if self.override_zero_grad: 108 | grad_is_nan = all(param.grad.isfinite().all() for param in self.parameters()) 109 | if grad_is_nan: 110 | logger.debug("Successfully bypassed zero_grad") 111 | else: 112 | logger.debug("Encountered non-finite value in gradients!") 113 | self.module.zero_grad(set_to_none=set_to_none) 114 | else: 115 | 116 | self.module.zero_grad(set_to_none=set_to_none) 117 | 118 | def clip_grad_norm_(self, max_norm: float, norm_type: int = 2): 119 | """ignore clip_grad_norm on each step, clip in optimizer instead""" 120 | if self.override_clipping: 121 | logger.debug("Successfully bypassed clip_grad_norm_") 122 | else: 123 | return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type) 124 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import PreTrainedModel 5 | 6 | 7 | class FCLayer(nn.Module): 8 | def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True): 9 | super().__init__() 10 | 11 | layers = [nn.Dropout(dropout_rate)] 12 | if use_activation: 13 | layers.append(nn.Tanh()) 14 | layers.append(nn.Linear(input_dim, output_dim)) 15 | 16 | self.layers = nn.Sequential(*layers) 17 | 18 | def forward(self, x): 19 | return self.layers(x) 20 | 21 | 22 | class SpanClassificationModel(nn.Module): 23 | def __init__(self, backbone: PreTrainedModel, num_labels: int): 24 | super().__init__() 25 | self.backbone = backbone 26 | self.num_labels = num_labels 27 | self.config = backbone.config 28 | 29 | hidden_size = self.config.hidden_size 30 | 31 | if hasattr(self.config, "classifier_dropout"): 32 | dropout_rate = ( 33 | self.config.classifier_dropout if self.config.classifier_dropout is not None 34 | else self.config.hidden_dropout_prob 35 | ) 36 | else: 37 | dropout_rate = self.config.classifier_dropout_prob 38 | 39 | self.cls_fc_layer = FCLayer(hidden_size, hidden_size, dropout_rate) 40 | self.e1_fc_layer = FCLayer(hidden_size, hidden_size, dropout_rate) 41 | self.e2_fc_layer = FCLayer(hidden_size, hidden_size, dropout_rate) 42 | self.label_classifier = FCLayer(hidden_size * 3, num_labels, dropout_rate, use_activation=False) 43 | 44 | @staticmethod 45 | def _entity_average(hidden_output, e_mask): 46 | """ 47 | Average the entity hidden state vectors (H_i ~ H_j) 48 | :param hidden_output: [batch_size, j-i+1, dim] 49 | :param e_mask: [batch_size, seq_len] 50 | e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0] 51 | :return: [batch_size, dim] 52 | """ 53 | e_mask_unsqueeze = e_mask.unsqueeze(2) # [batch_size, seq_len, 1] 54 | length_tensor = (e_mask_unsqueeze != 0).sum(dim=1) # [batch_size, 1] 55 | 56 | sum_vector = (hidden_output * e_mask_unsqueeze).sum(1) 57 | avg_vector = sum_vector / length_tensor 58 | return avg_vector 59 | 60 | def forward(self, input_ids, attention_mask, e1_mask, e2_mask, token_type_ids, labels=None): 61 | # sequence_output, pooled_output, (hidden_states), (attentions) 62 | sequence_output, pooled_output, *outputs = self.backbone( 63 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False, 64 | ) 65 | 66 | # Average 67 | hidden_first = self._entity_average(sequence_output, e1_mask) 68 | hidden_second = self._entity_average(sequence_output, e2_mask) 69 | 70 | # Dropout -> tanh -> fc_layer 71 | pooled_output = self.cls_fc_layer(pooled_output) 72 | hidden_first = self.e1_fc_layer(hidden_first) 73 | hidden_second = self.e2_fc_layer(hidden_second) 74 | 75 | # Concat -> fc_layer 76 | concat_h = torch.cat([pooled_output, hidden_first, hidden_second], dim=-1) 77 | logits = self.label_classifier(concat_h) 78 | 79 | outputs = (logits,) + tuple(outputs) # add hidden states and attention if they are here 80 | 81 | # Softmax 82 | if labels is not None: 83 | loss_fct = nn.CrossEntropyLoss() 84 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 85 | 86 | outputs = (loss,) + outputs 87 | 88 | return outputs # (loss), logits, (hidden_states), (attentions) 89 | 90 | 91 | class EntityChoiceModel(nn.Module): 92 | def __init__(self, backbone: PreTrainedModel): 93 | super().__init__() 94 | self.backbone = backbone 95 | 96 | hidden_size = backbone.config.hidden_size 97 | 98 | if hasattr(backbone.config, "classifier_dropout"): 99 | dropout_rate = ( 100 | backbone.config.classifier_dropout if backbone.config.classifier_dropout is not None 101 | else backbone.config.hidden_dropout_prob 102 | ) 103 | else: 104 | dropout_rate = backbone.config.classifier_dropout_prob 105 | 106 | self.clf = nn.Sequential( 107 | nn.Linear(hidden_size, hidden_size), 108 | nn.GELU(), 109 | nn.Dropout(dropout_rate), 110 | nn.Linear(hidden_size, 1) 111 | ) 112 | 113 | self.loss = nn.BCEWithLogitsLoss(reduction="none") 114 | 115 | def forward(self, input_ids, attention_mask, entity_mask, token_type_ids, labels=None): 116 | sequence_output, pooled_output, *outputs = self.backbone( 117 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False 118 | ) 119 | 120 | entity_mask = entity_mask.unsqueeze(3) 121 | entity_lengths = entity_mask.sum(2) 122 | 123 | # [batch_size, num_entities, seq_len, hid_dim] 124 | embeds_for_entities = sequence_output.unsqueeze(1) * entity_mask 125 | aggregated_embeds = embeds_for_entities.sum(2) / (entity_lengths + 1e-8) 126 | 127 | # [batch_size, num_entities] 128 | logits = self.clf(aggregated_embeds).squeeze(2) 129 | 130 | # due to padding, we might have rows without entity indices at all 131 | # thus, we need to mask the logits/loss for them 132 | 133 | # [batch_size, num_entities] 134 | present_entities = (entity_lengths.squeeze(2) != 0).to(logits.dtype) 135 | logits = logits * present_entities - 10000.0 * (1 - present_entities) 136 | 137 | outputs = (logits,) + tuple(outputs) # add hidden states and attention if they are here 138 | 139 | if labels is not None: 140 | # do not penalize predictions for padded entities 141 | label_mask = (labels != -1) 142 | 143 | # BCEWithLogitsLoss requires targets to be from 0 to 1 144 | labels[~label_mask] = 0 145 | 146 | loss = self.loss(logits, labels.to(logits.dtype)) 147 | 148 | reduced_loss = ((loss * label_mask).sum(1) / (label_mask.sum(1) + 1e-8)).mean() 149 | outputs = (reduced_loss,) + outputs 150 | 151 | return outputs 152 | -------------------------------------------------------------------------------- /tasks/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Optional, Type 3 | 4 | import hivemind 5 | import torch 6 | import torch.nn as nn 7 | from hivemind import Float16Compression, SizeAdaptiveCompression, Uniform8BitQuantization 8 | from hivemind.utils.logging import get_logger 9 | from torch.distributed.distributed_c10d import _get_default_store 10 | from transformers.data.data_collator import DataCollatorMixin 11 | 12 | from arguments import BasePeerArguments, CollaborativeArguments, HFTrainerArguments 13 | from src.huggingface_auth import authorize_with_huggingface 14 | from src.training.sync import AUTHORITATIVE_RANK, is_main_process 15 | 16 | try: 17 | from hivemind.optim.experimental.state_averager import LRSchedulerBase, ParamGroups 18 | except ImportError: 19 | from hivemind.optim.state_averager import LRSchedulerBase, ParamGroups 20 | 21 | from src import utils 22 | 23 | logger = get_logger(__name__) 24 | TASKS = {} 25 | 26 | 27 | def register_task(name: str): 28 | def _register(cls: Type[TrainingTaskBase]): 29 | if name in TASKS: 30 | logger.warning(f"Registering task {name} a second time, previous entry will be overwritten") 31 | TASKS[name] = cls 32 | return cls 33 | 34 | return _register 35 | 36 | 37 | class TrainingTaskBase: 38 | """A container that defines the training config, model, tokenizer, optimizer and other local training utilities""" 39 | 40 | _dht = _optimizer = _authorizer = None # for caching 41 | 42 | def __init__( 43 | self, 44 | model: nn.Module, 45 | peer_args: BasePeerArguments, 46 | trainer_args: HFTrainerArguments, 47 | collab_args: CollaborativeArguments, 48 | ): 49 | self.model, self.peer_args, self.trainer_args, self.collab_args = model, peer_args, trainer_args, collab_args 50 | self.validators, self.local_public_key = utils.make_validators(self.peer_args.run_id) 51 | 52 | if self.authorizer: 53 | self.trainer_args.run_name = self.authorizer.username # For wandb 54 | 55 | @property 56 | def authorizer(self): 57 | if self._authorizer is None and self.peer_args.authorize: 58 | self._authorizer = authorize_with_huggingface() 59 | return self._authorizer 60 | 61 | @property 62 | def dht(self): 63 | if self._dht is None: 64 | assert is_main_process() 65 | self._dht = hivemind.DHT( 66 | start=True, 67 | initial_peers=self.peer_args.initial_peers, 68 | client_mode=self.peer_args.client_mode, 69 | host_maddrs=self.peer_args.host_maddrs, 70 | announce_maddrs=self.peer_args.announce_maddrs, 71 | use_ipfs=self.peer_args.use_ipfs, 72 | record_validators=self.validators, 73 | identity_path=self.peer_args.identity_path, 74 | authorizer=self.authorizer, 75 | ) 76 | if self.peer_args.client_mode: 77 | logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}") 78 | else: 79 | utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs) 80 | return self._dht 81 | 82 | @property 83 | def optimizer(self) -> hivemind.Optimizer: 84 | if self._optimizer is None: 85 | assert is_main_process() 86 | averaging_compression = SizeAdaptiveCompression( 87 | threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization() 88 | ) 89 | 90 | self._optimizer = hivemind.Optimizer( 91 | dht=self.dht, 92 | params=self._make_param_groups(), 93 | run_id=self.peer_args.run_id, 94 | optimizer=self._make_base_optimizer, 95 | scheduler=self._make_scheduler, 96 | grad_compression=averaging_compression, 97 | state_averaging_compression=averaging_compression, 98 | batch_size_per_step=self.trainer_args.batch_size_per_step, 99 | client_mode=self.peer_args.client_mode, 100 | verbose=True, 101 | averager_opts=dict(min_vector_size=self.peer_args.min_vector_size, bandwidth=self.peer_args.bandwidth), 102 | **asdict(self.collab_args), 103 | ) 104 | return self._optimizer 105 | 106 | def _make_param_groups(self) -> ParamGroups: 107 | """Return optimizer param groups: either list of parameters or a list of dictionaries in torch.optim format""" 108 | raise NotImplementedError() 109 | 110 | def _make_base_optimizer(self, param_groups: ParamGroups) -> torch.optim.Optimizer: 111 | """Return PyTorch optimizer to be wrapped with hivemind.Optimizer. Use only the specified param groups.""" 112 | raise NotImplementedError() 113 | 114 | def _make_scheduler(self, optimizer: torch.optim.Optimizer) -> Optional[LRSchedulerBase]: 115 | """Return PyTorch scheduler that will be ran on each synchronization epoch (each target_batch_size samples)""" 116 | return None # default = no scheduler 117 | 118 | @property 119 | def training_dataset(self): 120 | raise NotImplementedError() 121 | 122 | @property 123 | def data_collator(self) -> DataCollatorMixin: 124 | raise NotImplementedError() 125 | 126 | def on_step_end(self): 127 | """This will be called after each local step (in callback.py)""" 128 | pass 129 | 130 | def get_current_epoch(self): 131 | """ 132 | Return current epoch in a ddp-friendly way. When implementing your own task, please use this instead of 133 | directly accessing self.optimizer.global_epoch because non-main DDP workers will not run hivemind.Optimizer 134 | """ 135 | if not torch.distributed.is_initialized(): 136 | return self.optimizer.tracker.global_epoch 137 | else: 138 | store = _get_default_store() 139 | if torch.distributed.get_rank() == AUTHORITATIVE_RANK: 140 | current_epoch = self.optimizer.tracker.global_epoch 141 | store.set("_hivemind_current_epoch", str(current_epoch)) 142 | torch.distributed.barrier() 143 | return current_epoch 144 | else: 145 | torch.distributed.barrier() 146 | return int(store.get("_hivemind_current_epoch")) 147 | -------------------------------------------------------------------------------- /tasks/mlm/whole_word_mask.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from transformers import DataCollatorForLanguageModeling 7 | 8 | try: 9 | from transformers.data.data_collator import _torch_collate_batch as collate_batch 10 | from transformers.data.data_collator import tolist 11 | except ImportError: 12 | from transformers.data.data_collator import _collate_batch as collate_batch, tolist 13 | 14 | from transformers.tokenization_utils_base import BatchEncoding 15 | 16 | 17 | def _is_start_piece_sp(piece): 18 | """Check if the current word piece is the starting piece (sentence piece).""" 19 | special_pieces = set(list('!"#$%&"()*+,-./:;?@[\\]^_`{|}~')) 20 | special_pieces.add(u"€".encode("utf-8")) 21 | special_pieces.add(u"£".encode("utf-8")) 22 | if piece.startswith("▁") or piece.startswith("<") or piece in special_pieces: 23 | return True 24 | else: 25 | return False 26 | 27 | 28 | @dataclass 29 | class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): 30 | """ 31 | Data collator used for language modeling that masks entire words. 32 | 33 | - collates batches of tensors, honoring their tokenizer's pad_token 34 | - preprocesses batches for masked language modeling 35 | 36 | .. note:: 37 | 38 | This collator relies on details of the implementation of subword tokenization by 39 | :class:`~transformers.AlbertTokenizer`, specifically that start-of-word tokens are prefixed with `▁`. 40 | For tokenizers that do not adhere to this scheme, this collator will produce an output that is roughly 41 | equivalent to :class:`.DataCollatorForLanguageModeling`. 42 | """ 43 | 44 | def __call__( 45 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] 46 | ) -> Dict[str, torch.Tensor]: 47 | if isinstance(examples[0], (dict, BatchEncoding)): 48 | batch = self.tokenizer.pad( 49 | examples, 50 | return_tensors="pt", 51 | pad_to_multiple_of=self.pad_to_multiple_of, 52 | ) 53 | else: 54 | batch = {"input_ids": collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)} 55 | 56 | # If special token mask has been preprocessed, pop it from the dict. 57 | special_tokens_mask = batch.pop("special_tokens_mask", None) 58 | 59 | mask_labels = [] 60 | for example in batch["input_ids"]: 61 | ref_tokens = self.tokenizer.convert_ids_to_tokens(tolist(example)) 62 | mask_labels.append(self._whole_word_mask(ref_tokens)) 63 | batch_mask = collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of) 64 | 65 | batch["input_ids"], batch["labels"] = self.mask_tokens( 66 | batch["input_ids"], batch_mask, special_tokens_mask=special_tokens_mask 67 | ) 68 | return batch 69 | 70 | def _whole_word_mask(self, input_tokens: List[str]): 71 | """ 72 | Get 0/1 labels for masked tokens with whole word mask proxy 73 | """ 74 | 75 | cand_indexes = [] 76 | num_tokens_exc_pad = 0 77 | for i, token in enumerate(input_tokens): 78 | if token in ( 79 | self.tokenizer.cls_token, 80 | self.tokenizer.sep_token, 81 | self.tokenizer.pad_token, 82 | ): 83 | continue 84 | num_tokens_exc_pad += 1 85 | if len(cand_indexes) >= 1 and not _is_start_piece_sp(token): 86 | cand_indexes[-1].append(i) 87 | else: 88 | cand_indexes.append([i]) 89 | 90 | random.shuffle(cand_indexes) 91 | 92 | mask_labels = torch.zeros((len(input_tokens),), dtype=torch.long) 93 | num_tokens_to_mask = min(num_tokens_exc_pad, max(1, int(round(num_tokens_exc_pad * self.mlm_probability)))) 94 | covered_indexes = set() 95 | 96 | for index_set in cand_indexes: 97 | if len(covered_indexes) >= num_tokens_to_mask: 98 | break 99 | 100 | # If adding a whole-word mask would exceed the maximum number of 101 | # predictions, then just skip this candidate. 102 | if len(covered_indexes) + len(index_set) > num_tokens_to_mask: 103 | continue 104 | 105 | is_any_index_covered = any(index in covered_indexes for index in index_set) 106 | if is_any_index_covered: 107 | continue 108 | 109 | for index in index_set: 110 | covered_indexes.add(index) 111 | mask_labels[index] = 1 112 | 113 | return mask_labels 114 | 115 | def mask_tokens( 116 | self, 117 | inputs: torch.Tensor, 118 | mask_labels: torch.Tensor, 119 | special_tokens_mask: Optional[torch.Tensor] = None, 120 | ) -> Tuple[torch.Tensor, torch.Tensor]: 121 | """ 122 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 123 | 'mask_labels' means we use whole word mask (WMM), we directly mask idxs according to it's ref. 124 | """ 125 | assert self.mlm 126 | labels = inputs.clone() 127 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 128 | 129 | probability_matrix = mask_labels 130 | 131 | if special_tokens_mask is None: 132 | special_tokens_mask = [ 133 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 134 | ] 135 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 136 | else: 137 | special_tokens_mask = special_tokens_mask.bool() 138 | 139 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 140 | 141 | masked_indices = probability_matrix.bool() 142 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 143 | 144 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 145 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 146 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 147 | 148 | # 10% of the time, we replace masked input tokens with random word 149 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 150 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) 151 | inputs[indices_random] = random_words[indices_random] 152 | 153 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 154 | return inputs, labels 155 | -------------------------------------------------------------------------------- /src/training/callback.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Any 3 | 4 | import hivemind 5 | import torch 6 | import torch.distributed 7 | import transformers 8 | from transformers import TrainingArguments 9 | 10 | from arguments import TrainingPeerArguments 11 | from src.training.sync import SynchronizationCallback 12 | from src.utils import LocalMetrics, logger 13 | from tasks.base import TrainingTaskBase 14 | 15 | 16 | class CollaborativeCallback(SynchronizationCallback): 17 | """ 18 | This callback monitors and reports collaborative training progress, 19 | In case of a catastrophic failure, it can also revert training to a backup 20 | """ 21 | 22 | def __init__(self, task: TrainingTaskBase, args: TrainingPeerArguments): 23 | super().__init__(task, args) 24 | self.dht, self.optimizer = task.dht, task.optimizer 25 | self.statistics_expiration = args.statistics_expiration 26 | self.last_reported_epoch = -1 27 | self.samples = 0 28 | self.mini_steps = 0 # number of calls to optimizer.step, NOT equal to the number of global steps 29 | self.loss = 0 30 | self.total_samples_processed = 0 31 | self.backup_every_epochs = args.backup_every_epochs 32 | self.state_path = args.state_path 33 | 34 | def on_train_begin( 35 | self, 36 | args: TrainingArguments, 37 | state: transformers.TrainerState, 38 | control: transformers.TrainerControl, 39 | **kwargs, 40 | ): 41 | if os.path.isfile(self.state_path): 42 | self.restore_from_backup(self.state_path) 43 | logger.info("Loaded state") 44 | else: 45 | logger.info("Loading state from peers") 46 | self.optimizer.load_state_from_peers() 47 | super().on_train_begin(args, state, control, **kwargs) 48 | 49 | def on_step_end( 50 | self, 51 | args: TrainingArguments, 52 | state: transformers.TrainerState, 53 | control: transformers.TrainerControl, 54 | **kwargs, 55 | ): 56 | super().on_step_end(args, state, control, **kwargs) 57 | if not self.params_are_finite(): 58 | if not os.path.exists(self.state_path): 59 | raise RuntimeError("Encountered broken parameters, but there is no backup to fall back to.") 60 | logger.warning("Parameters are invalid, reloading model from earlier state") 61 | self.restore_from_backup(self.state_path) 62 | return control 63 | 64 | if state.log_history: 65 | self.loss += state.log_history[-1]["loss"] 66 | self.mini_steps += 1 67 | if self.optimizer.local_epoch != self.last_reported_epoch: 68 | self.last_reported_epoch = self.optimizer.local_epoch 69 | self.total_samples_processed += self.samples 70 | samples_per_second = self.optimizer.tracker.performance_ema.samples_per_second 71 | statistics = LocalMetrics( 72 | epoch=self.optimizer.local_epoch, 73 | samples_per_second=samples_per_second, 74 | samples_accumulated=self.samples, 75 | loss=self.loss, 76 | mini_steps=self.mini_steps, 77 | ) 78 | logger.info(f"Current epoch: {self.optimizer.local_epoch}") 79 | logger.info(f"Your current contribution: {self.total_samples_processed} samples") 80 | logger.info(f"Performance: {samples_per_second} samples/sec") 81 | if self.mini_steps: 82 | logger.info(f"Local loss: {self.loss / self.mini_steps}") 83 | 84 | self.loss = 0 85 | self.mini_steps = 0 86 | if self.optimizer.local_epoch == self.optimizer.tracker.global_epoch: 87 | self.dht.store( 88 | key=self.optimizer.run_id + "_metrics", 89 | subkey=self.task.local_public_key, 90 | value=statistics.dict(), 91 | expiration_time=hivemind.get_dht_time() + self.statistics_expiration, 92 | return_future=True, 93 | ) 94 | if self.backup_every_epochs is not None and self.optimizer.local_epoch % self.backup_every_epochs == 0: 95 | self.backup_state() 96 | 97 | self.samples = self.optimizer.grad_averager.local_samples_accumulated 98 | 99 | return control 100 | 101 | @torch.no_grad() 102 | def params_are_finite(self): 103 | for param in self.task.model.parameters(): 104 | if not torch.all(torch.isfinite(param)): 105 | return False 106 | return True 107 | 108 | @torch.no_grad() 109 | def backup_state(self) -> Any: 110 | logger.info("Saving backup") 111 | return torch.save( 112 | { 113 | "model": self.task.model.state_dict(), 114 | "optimizer": self.optimizer.state_dict(), 115 | "scheduler": self.optimizer.state_averager.scheduler.state_dict(), 116 | "local_epoch": self.optimizer.local_epoch, 117 | }, 118 | self.state_path, 119 | ) 120 | 121 | @torch.no_grad() 122 | def restore_from_backup(self, path, check_epoch=False): 123 | state = torch.load(path) 124 | current_epoch = self.optimizer.local_epoch 125 | if "model" not in state: 126 | logger.warning("Found weights-only checkpoint") 127 | self.task.model.load_state_dict(state, strict=True) 128 | return 129 | backup_epoch = state["local_epoch"] 130 | 131 | if not check_epoch or backup_epoch >= current_epoch: 132 | self.task.model.load_state_dict(state["model"], strict=False) 133 | self.optimizer.load_state_dict(state["optimizer"]) 134 | self.optimizer.state_averager.scheduler.load_state_dict(state["scheduler"]) 135 | 136 | if self.optimizer.offload_optimizer: 137 | state_averager = self.optimizer.state_averager 138 | offloaded_parameters = [ 139 | param for group in state_averager.optimizer.param_groups for param in group["params"] 140 | ] 141 | assert len(offloaded_parameters) == len(state_averager.main_parameters) 142 | for main_param, offloaded_param in zip(state_averager.main_parameters, offloaded_parameters): 143 | offloaded_param.copy_(main_param, non_blocking=True) 144 | 145 | self.optimizer.state_averager.local_epoch = backup_epoch 146 | 147 | if not self.optimizer.client_mode: 148 | self.optimizer.state_averager.state_sharing_priority = self.optimizer.local_epoch 149 | 150 | if self.optimizer.use_gradient_averaging: 151 | self.optimizer.grad_averager.reset_accumulated_grads_() 152 | if not self.optimizer.client_mode: 153 | self.optimizer.grad_averager.state_sharing_priority = self.optimizer.local_epoch 154 | 155 | logger.info("Restored from a backup") 156 | else: 157 | logger.info("Bypassed restoring state from local backup: backup state is too old.") 158 | -------------------------------------------------------------------------------- /tasks/mlm/task.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import multiprocessing as mp 3 | import os 4 | from pathlib import Path 5 | 6 | import hivemind 7 | import torch.distributed 8 | import transformers 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from transformers import AutoTokenizer 11 | import torch 12 | 13 | from arguments import BasePeerArguments, CollaborativeArguments, HFTrainerArguments 14 | from src.models.albert import LeanAlbertConfig, LeanAlbertForPreTraining 15 | from src.training.lamb_8bit import CPULAMB8Bit 16 | from tasks.base import LRSchedulerBase, ParamGroups, TrainingTaskBase, register_task 17 | 18 | from .data import make_training_dataset 19 | from .whole_word_mask import DataCollatorForWholeWordMask 20 | 21 | hivemind.use_hivemind_log_handler("in_root_logger") 22 | logger = hivemind.get_logger() 23 | 24 | 25 | @register_task("mlm") 26 | class MLMTrainingTask(TrainingTaskBase): 27 | """A container that defines the training config, model, tokenizer, optimizer and other local training utilities""" 28 | 29 | _dht = _optimizer = _training_dataset = _authorizer = None 30 | 31 | def __init__( 32 | self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments 33 | ): 34 | transformers.set_seed(trainer_args.seed) # seed used for initialization 35 | 36 | self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path) 37 | self.tokenizer = AutoTokenizer.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir) 38 | 39 | output_dir = Path(trainer_args.output_dir) 40 | logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}') 41 | latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime) 42 | 43 | if latest_checkpoint_dir is None: 44 | logger.info(f"Creating model") 45 | model = LeanAlbertForPreTraining(self.config) 46 | model.resize_token_embeddings(len(self.tokenizer)) 47 | else: 48 | logger.info(f"Loading model from {latest_checkpoint_dir}") 49 | model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir) 50 | if trainer_args.gradient_checkpointing: 51 | model.gradient_checkpointing_enable() 52 | 53 | super().__init__(model, peer_args, trainer_args, collab_args) 54 | self.current_sequence_length = mp.Value(ctypes.c_int64, self.trainer_args.max_sequence_length) 55 | self.update_sequence_length() # updated by callback 56 | 57 | def _make_param_groups(self) -> ParamGroups: 58 | no_decay = ["bias", "norm.weight"] 59 | return [ 60 | { 61 | "params": [p for n, p in self.model.named_parameters() if not any(n.endswith(nd) for nd in no_decay)], 62 | "weight_decay": self.trainer_args.weight_decay, 63 | }, 64 | { 65 | "params": [p for n, p in self.model.named_parameters() if any(n.endswith(nd) for nd in no_decay)], 66 | "weight_decay": 0.0, 67 | }, 68 | ] 69 | 70 | def _make_base_optimizer(self, param_groups: ParamGroups) -> torch.optim.Optimizer: 71 | return CPULAMB8Bit( 72 | param_groups, 73 | lr=self.trainer_args.learning_rate, 74 | betas=(self.trainer_args.adam_beta1, self.trainer_args.adam_beta2), 75 | min_8bit_size=self.trainer_args.min_8bit_size, 76 | max_grad_norm=self.trainer_args.max_grad_norm, 77 | clamp_value=self.trainer_args.clamp_value, 78 | eps=self.trainer_args.adam_epsilon, 79 | weight_decay=self.trainer_args.weight_decay, 80 | reuse_grad_buffers=True, 81 | bias_correction=True, 82 | ) 83 | 84 | def _make_scheduler(self, optimizer: torch.optim.Optimizer) -> LRSchedulerBase: 85 | num_warmup_steps = self.trainer_args.warmup_steps 86 | num_training_steps = self.trainer_args.total_steps 87 | min_learning_rate = self.trainer_args.min_learning_rate 88 | 89 | 90 | 91 | def lr_lambda(current_step: int): 92 | if current_step < 50: 93 | return 0 94 | if current_step < num_warmup_steps: 95 | return float(current_step-50) / float(max(1, num_warmup_steps-50)) 96 | decaying = float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 97 | return max(0.02, decaying) 98 | 99 | 100 | return LambdaLR(optimizer, lr_lambda) 101 | 102 | def lr_lambda(current_step: int): 103 | if current_step < num_warmup_steps: 104 | return float(current_step) / float(max(1, num_warmup_steps)) 105 | decaying = float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 106 | return max(min_learning_rate, decaying) 107 | @property 108 | def training_dataset(self): 109 | if self._training_dataset is None: 110 | self._training_dataset = make_training_dataset( 111 | self.tokenizer, 112 | shuffle_seed=hash(self.local_public_key) % 2 ** 31, 113 | max_sequence_length=self.current_sequence_length, # this is a mp.Value that will be changed later 114 | ) 115 | return self._training_dataset 116 | 117 | def on_step_end(self): 118 | return self.update_sequence_length() 119 | 120 | def update_sequence_length(self): 121 | """ 122 | If ramp-up is enabled, start with smaller sequences of initial_sequence_length tokens, then increase linearly 123 | to the max_sequence_length over the period of first 124 | """ 125 | current_epoch = self.get_current_epoch() 126 | if ( 127 | self.trainer_args.sequence_length_warmup_steps == 0 128 | or current_epoch > self.trainer_args.sequence_length_warmup_steps 129 | ): 130 | current_sequence_length = self.trainer_args.max_sequence_length 131 | else: 132 | increment_size = self.trainer_args.pad_to_multiple_of 133 | max_sequence_length = self.trainer_args.max_sequence_length 134 | initial_sequence_length = self.trainer_args.initial_sequence_length or increment_size 135 | sequence_length_warmup_steps = self.trainer_args.sequence_length_warmup_steps 136 | assert sequence_length_warmup_steps > 0 and max_sequence_length >= initial_sequence_length 137 | length_range = max_sequence_length - initial_sequence_length 138 | warmup_relative = min(1, current_epoch / sequence_length_warmup_steps) 139 | current_sequence_length = initial_sequence_length + warmup_relative * length_range 140 | current_sequence_length = (current_sequence_length // increment_size) * increment_size 141 | current_sequence_length = min(max(current_sequence_length, initial_sequence_length), max_sequence_length) 142 | 143 | current_sequence_length = int(current_sequence_length) 144 | if current_sequence_length != self.current_sequence_length.value: 145 | logger.info(f"Transitioning to sequence length {current_sequence_length}") 146 | self.current_sequence_length.value = current_sequence_length 147 | # note: it may take time for new sequence length to take effect due to buffering 148 | 149 | @property 150 | def data_collator(self): 151 | return DataCollatorForWholeWordMask( 152 | tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of 153 | ) 154 | -------------------------------------------------------------------------------- /run_aux_peer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import threading 3 | import time 4 | 5 | import scipy.stats # compatibility for internal testing environment 6 | import torch 7 | import transformers 8 | import wandb 9 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 10 | from huggingface_hub import Repository 11 | from transformers import HfArgumentParser 12 | 13 | from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments 14 | from src import utils 15 | from tasks.mlm.task import MLMTrainingTask 16 | 17 | transformers.utils.logging.set_verbosity_warning() 18 | use_hivemind_log_handler("in_root_logger") 19 | logger = get_logger(__name__) 20 | torch.set_num_threads(1) # avoid quadratic number of threads 21 | 22 | 23 | class CheckpointHandler: 24 | def __init__(self, task: MLMTrainingTask, peer_args: AuxiliaryPeerArguments): 25 | self.task, self.peer_args = task, peer_args 26 | self.save_checkpoint_epoch_interval = peer_args.save_checkpoint_epoch_interval 27 | self.prefix = peer_args.run_id 28 | self.local_path = peer_args.local_path 29 | self.upload_interval = peer_args.upload_interval 30 | if self.upload_interval is not None: 31 | assert task.authorizer is not None, "Model uploading needs Hugging Face auth to be enabled" 32 | self.repo = Repository( 33 | local_dir=self.local_path, 34 | clone_from=peer_args.repo_url, 35 | use_auth_token=task.authorizer.hf_user_access_token, 36 | ) 37 | self.last_upload_time = None 38 | self.previous_epoch = -1 39 | 40 | def should_save_state(self, current_epoch: int): 41 | if self.save_checkpoint_epoch_interval is None: 42 | return False 43 | elif current_epoch - self.previous_epoch >= self.save_checkpoint_epoch_interval: 44 | return True 45 | else: 46 | return False 47 | 48 | def save_state(self, current_epoch: int): 49 | logger.info("Saving state from peers") 50 | self.task.optimizer.load_state_from_peers() 51 | self.previous_epoch = current_epoch 52 | 53 | def is_time_to_upload(self): 54 | if self.upload_interval is None: 55 | return False 56 | elif self.last_upload_time is None or time.time() - self.last_upload_time >= self.upload_interval: 57 | return True 58 | else: 59 | return False 60 | 61 | def upload_checkpoint(self, current_loss: float): 62 | self.last_upload_time = time.time() 63 | 64 | logger.info("Saving model") 65 | torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt") 66 | logger.info("Saving optimizer") 67 | torch.save(self.task.optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt") 68 | self.previous_timestamp = time.time() 69 | logger.info("Started uploading to Model Hub") 70 | try: 71 | # We start by pulling the remote changes (for example a change in the readme file) 72 | self.repo.git_pull() 73 | 74 | # Then we add / commmit and push the changes 75 | self.repo.push_to_hub(commit_message=f"Epoch {self.task.optimizer.local_epoch}, loss {current_loss:.3f}") 76 | logger.info("Finished uploading to Model Hub") 77 | except Exception: 78 | logger.exception("Uploading the checkpoint to HF Model Hub failed:") 79 | logger.warning("Ensure that your access token is valid and has WRITE permissions") 80 | 81 | 82 | def assist_averaging_in_background( 83 | lock: threading.Lock, task: MLMTrainingTask, peer_args: AuxiliaryPeerArguments, finished: threading.Event 84 | ): 85 | while not finished.is_set(): 86 | try: 87 | time.sleep(peer_args.assist_refresh) 88 | with lock: 89 | task.optimizer.step() 90 | except Exception as e: 91 | logger.exception(e, exc_info=True) 92 | 93 | 94 | if __name__ == "__main__": 95 | parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments)) 96 | peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses() 97 | finished, lock = threading.Event(), threading.Lock() 98 | 99 | task = MLMTrainingTask(peer_args, trainer_args, collab_args) 100 | dht, optimizer = task.dht, task.optimizer 101 | 102 | if peer_args.wandb_project is not None: 103 | wandb.init(project=peer_args.wandb_project) 104 | 105 | current_epoch = 0 106 | if peer_args.store_checkpoints: 107 | checkpoint_handler = CheckpointHandler(task, peer_args) 108 | 109 | if peer_args.assist_in_averaging: 110 | assert not peer_args.client_mode, "client-mode peers cannot assist in averaging" 111 | averaging_thread = threading.Thread( 112 | name="AveragingAuxThread", 113 | target=assist_averaging_in_background, 114 | args=[lock, task, peer_args, finished], 115 | daemon=True, 116 | ) 117 | averaging_thread.start() 118 | 119 | try: 120 | while True: 121 | metrics_entry = dht.get(peer_args.run_id + "_metrics", latest=True) 122 | if metrics_entry is not None and len(metrics_entry.value) > 0: 123 | metrics_dict = metrics_entry.value 124 | metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict] 125 | latest_epoch = max(item.epoch for item in metrics) 126 | 127 | if latest_epoch != current_epoch: 128 | logger.debug(f"Got metrics from {len(metrics)} peers") 129 | 130 | for i, metrics_for_peer in enumerate(metrics): 131 | logger.debug(f"{i} peer {metrics_for_peer}") 132 | 133 | current_epoch = latest_epoch 134 | alive_peers = 0 135 | sum_loss = 0 136 | num_samples = 0 137 | sum_perf = 0 138 | sum_mini_steps = 0 139 | 140 | for item in metrics: 141 | sum_loss += item.loss 142 | alive_peers += 1 143 | sum_perf += item.samples_per_second 144 | num_samples += item.samples_accumulated 145 | sum_mini_steps += item.mini_steps 146 | current_loss = sum_loss / sum_mini_steps 147 | logger.info(f"Epoch #{current_epoch}\tloss = {current_loss:.5f}") 148 | 149 | if peer_args.wandb_project is not None: 150 | wandb.log( 151 | { 152 | "loss": current_loss, 153 | "alive peers": alive_peers, 154 | "samples": num_samples, 155 | "performance": sum_perf, 156 | "optimizer_step": latest_epoch, 157 | }, 158 | step=latest_epoch, 159 | ) 160 | 161 | if peer_args.store_checkpoints: 162 | if checkpoint_handler.should_save_state(current_epoch): 163 | with lock: 164 | checkpoint_handler.save_state(current_epoch) 165 | if checkpoint_handler.is_time_to_upload(): 166 | checkpoint_handler.upload_checkpoint(current_loss) 167 | logger.debug("Peer is still alive...") 168 | time.sleep(peer_args.refresh_period) 169 | finally: 170 | finished.set() 171 | -------------------------------------------------------------------------------- /src/huggingface_auth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from datetime import datetime, timedelta 4 | from getpass import getpass 5 | 6 | import requests 7 | from hivemind.proto.auth_pb2 import AccessToken 8 | from hivemind.utils.auth import TokenAuthorizerBase 9 | from hivemind.utils.crypto import RSAPublicKey 10 | from hivemind.utils.logging import get_logger 11 | from huggingface_hub import HfApi 12 | from termcolor import colored 13 | 14 | logger = get_logger("root." + __name__) 15 | 16 | 17 | class NonRetriableError(Exception): 18 | pass 19 | 20 | 21 | def call_with_retries(func, n_retries=10, initial_delay=1.0): 22 | for i in range(n_retries): 23 | try: 24 | return func() 25 | except NonRetriableError: 26 | raise 27 | except Exception as e: 28 | if i == n_retries - 1: 29 | raise 30 | 31 | delay = initial_delay * (2 ** i) 32 | logger.warning(f"Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec") 33 | time.sleep(delay) 34 | 35 | 36 | class InvalidCredentialsError(NonRetriableError): 37 | pass 38 | 39 | 40 | class NotInAllowlistError(NonRetriableError): 41 | pass 42 | 43 | 44 | class HuggingFaceAuthorizer(TokenAuthorizerBase): 45 | _AUTH_SERVER_URL = "https://collaborative-training-auth.huggingface.co" 46 | 47 | def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str): 48 | super().__init__() 49 | 50 | self.organization_name = organization_name 51 | self.model_name = model_name 52 | self.hf_user_access_token = hf_user_access_token 53 | 54 | self._authority_public_key = None 55 | self.coordinator_ip = None 56 | self.coordinator_port = None 57 | 58 | self._hf_api = HfApi() 59 | 60 | async def get_token(self) -> AccessToken: 61 | """ 62 | Hivemind calls this method to refresh the token when necessary. 63 | """ 64 | 65 | self.join_experiment() 66 | return self._local_access_token 67 | 68 | @property 69 | def username(self): 70 | return self._local_access_token.username 71 | 72 | def join_experiment(self) -> None: 73 | call_with_retries(self._join_experiment) 74 | 75 | def _join_experiment(self) -> None: 76 | try: 77 | url = f"{self._AUTH_SERVER_URL}/api/experiments/join" 78 | headers = {"Authorization": f"Bearer {self.hf_user_access_token}"} 79 | response = requests.put( 80 | url, 81 | headers=headers, 82 | params={ 83 | "organization_name": self.organization_name, 84 | "model_name": self.model_name, 85 | }, 86 | json={ 87 | "experiment_join_input": { 88 | "peer_public_key": self.local_public_key.to_bytes().decode(), 89 | }, 90 | }, 91 | ) 92 | 93 | response.raise_for_status() 94 | response = response.json() 95 | 96 | self._authority_public_key = RSAPublicKey.from_bytes(response["auth_server_public_key"].encode()) 97 | self.coordinator_ip = response["coordinator_ip"] 98 | self.coordinator_port = response["coordinator_port"] 99 | 100 | token_dict = response["hivemind_access"] 101 | access_token = AccessToken() 102 | access_token.username = token_dict["username"] 103 | access_token.public_key = token_dict["peer_public_key"].encode() 104 | access_token.expiration_time = str(datetime.fromisoformat(token_dict["expiration_time"])) 105 | access_token.signature = token_dict["signature"].encode() 106 | self._local_access_token = access_token 107 | logger.info( 108 | f"Access for user {access_token.username} " f"has been granted until {access_token.expiration_time} UTC" 109 | ) 110 | except requests.exceptions.HTTPError as e: 111 | if e.response.status_code == 401: # Unauthorized 112 | raise NotInAllowlistError() 113 | raise 114 | 115 | def is_token_valid(self, access_token: AccessToken) -> bool: 116 | data = self._token_to_bytes(access_token) 117 | if not self._authority_public_key.verify(data, access_token.signature): 118 | logger.exception("Access token has invalid signature") 119 | return False 120 | 121 | try: 122 | expiration_time = datetime.fromisoformat(access_token.expiration_time) 123 | except ValueError: 124 | logger.exception(f"datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}") 125 | return False 126 | if expiration_time.tzinfo is not None: 127 | logger.exception(f"Expected to have no timezone for expiration time: {access_token.expiration_time}") 128 | return False 129 | if expiration_time < datetime.utcnow(): 130 | logger.exception("Access token has expired") 131 | return False 132 | 133 | return True 134 | 135 | _MAX_LATENCY = timedelta(minutes=1) 136 | 137 | def does_token_need_refreshing(self, access_token: AccessToken) -> bool: 138 | expiration_time = datetime.fromisoformat(access_token.expiration_time) 139 | return expiration_time < datetime.utcnow() + self._MAX_LATENCY 140 | 141 | @staticmethod 142 | def _token_to_bytes(access_token: AccessToken) -> bytes: 143 | return f"{access_token.username} {access_token.public_key} {access_token.expiration_time}".encode() 144 | 145 | 146 | def authorize_with_huggingface() -> HuggingFaceAuthorizer: 147 | while True: 148 | organization_name = os.getenv("HF_ORGANIZATION_NAME") 149 | if organization_name is None: 150 | organization_name = input("HuggingFace organization name: ") 151 | 152 | model_name = os.getenv("HF_MODEL_NAME") 153 | if model_name is None: 154 | model_name = input("HuggingFace model name: ") 155 | 156 | hf_user_access_token = os.getenv("HF_USER_ACCESS_TOKEN") 157 | if hf_user_access_token is None: 158 | print( 159 | "\nCopy a token from 🤗 Hugging Face settings page at " 160 | f"{colored('https://huggingface.co/settings/token', attrs=['bold'])} " 161 | "and paste it here.\n\n" 162 | f"💡 {colored('Tip:', attrs=['bold'])} " 163 | "If you don't already have one, you can create a dedicated user access token.\n" 164 | f"Go to {colored('https://huggingface.co/settings/token', attrs=['bold'])}, " 165 | f"click the {colored('New token', attrs=['bold'])} button, " 166 | f"and choose the {colored('read', attrs=['bold'])} role.\n" 167 | ) 168 | os.environ["HF_USER_ACCESS_TOKEN"] = hf_user_access_token = getpass( 169 | "🤗 Hugging Face user access token (characters will be hidden): " 170 | ) 171 | 172 | authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token) 173 | 174 | try: 175 | authorizer.join_experiment() 176 | print(f"🚀 You will contribute to the collaborative training under the username {authorizer.username}") 177 | return authorizer 178 | except InvalidCredentialsError: 179 | print("Invalid user access token, please try again") 180 | except NotInAllowlistError: 181 | print( 182 | "\n😥 Authentication has failed.\n\n" 183 | "This error may be due to the fact:\n" 184 | " 1. Your user access token is not valid. You can try to delete the previous token and " 185 | "recreate one. Be careful, organization tokens can't be used to join a collaborative " 186 | "training.\n" 187 | f" 2. You have not yet joined the {organization_name} organization. You can request to " 188 | "join this organization by clicking on the 'request to join this org' button at " 189 | f"https://huggingface.co/{organization_name}.\n" 190 | f" 3. The {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n" 191 | f" 4. No {organization_name}'s admin has created a collaborative training for the {organization_name} " 192 | f"organization and the {model_name} model." 193 | ) 194 | -------------------------------------------------------------------------------- /src/modules/attn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | from src.modules.rotary import RotaryEmbeddings 9 | 10 | 11 | class LeanSelfAttention(nn.Module): 12 | def __init__( 13 | self, 14 | hidden_size: int, 15 | num_attention_heads: int, 16 | dropout: float = 0, 17 | layer_norm_eps: float = 1e-12, 18 | sandwich_norm: bool = False, 19 | dense_qkv: Optional[nn.Linear] = None, 20 | dense_out: Optional[nn.Linear] = None, 21 | residual: bool = True, 22 | attention_core: Optional[nn.Module] = None, 23 | checkpoint_attention_core: bool = True, 24 | **kwargs, 25 | ): 26 | """ 27 | Attention layer that does not hog GPU memory. Re-computes pairwise attention weights instead of storing them. 28 | 29 | :note: this code is relatively less optimized than FFN because attention is usually not a memory bottleneck 30 | for typical sequence lengths (e.g. 2048 in language or 1024 in vision). If training with longer sequences, 31 | one can use query chunking: running one chunk of queries at a time, without storing the full QxK matrix. 32 | This technique runs in O(length) memory instead of O(length^2), making it not-a-bottleneck compared to FFN 33 | 34 | :param hidden_size: base hidden size of the transformer, before q/k/v projections 35 | :param num_attention_heads: number of heads, as defined in the original transformer 36 | :param dropout: hidden dropout probability, applied to the output projection (before adding residual) 37 | :param layer_norm_eps: see torch.nn.functional.layer_norm 38 | :param sandwich_norm: if set, applies an additional layer norm to projected attention outputs before residuals, 39 | as proposed in the CogView paper ( arXiv:2105.13290 ). This is meant to make fp16 training 40 | more stable for deep transformers. This technique is also a part of NormFormer ( arXiv:2110.09456 ) 41 | :param residual: if True, adds the original layer input to the final layer output 42 | :param attention_core: optionally provide custom attention function. See SimpleAttentionCore for inspiration. 43 | :param checkpoint_attention_core: re-compute attention weights during backward pass instead of storing them 44 | :param dense_qkv: custom QKV projection layer (hidden_size -> 3 * hidden_size) 45 | :param dense_out: custom output projection layer (hidden_size -> hidden_size) 46 | :param kwargs: additional kwargs are passed to the chosen attention core 47 | """ 48 | super().__init__() 49 | if attention_core is None: 50 | attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, **kwargs) 51 | else: 52 | assert len(kwargs) == 0, f"Unexpected parameters: {kwargs}" 53 | 54 | self.hidden_size = hidden_size 55 | self.attention_core = attention_core 56 | self.dense_qkv = nn.Linear(hidden_size, hidden_size * 3) if dense_qkv is None else dense_qkv 57 | self.dense_out = nn.Linear(hidden_size, hidden_size) if dense_out is None else dense_out 58 | assert self.dense_qkv.in_features == self.dense_out.in_features == self.dense_out.out_features == hidden_size 59 | assert self.dense_qkv.out_features == hidden_size * 3 60 | 61 | self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 62 | self.sandwich_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if sandwich_norm else None 63 | self.output_dropout = nn.Dropout(dropout, inplace=False) 64 | self.residual, self.checkpoint_attention_core = residual, checkpoint_attention_core 65 | 66 | def forward(self, hidden_states, attention_mask=None, output_attentions=False): 67 | hidden_states_ln = self.layer_norm(hidden_states) 68 | qkv_output = self.dense_qkv(hidden_states_ln) 69 | query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1) 70 | attention_output, attention_probs = self._maybe_checkpoint( 71 | self.attention_core, query, key, value, attention_mask 72 | ) 73 | outputs = self.dense_out(attention_output) 74 | if self.sandwich_norm: 75 | outputs = self.sandwich_norm(outputs) 76 | outputs = self.output_dropout(outputs) 77 | if self.residual: 78 | outputs = outputs + hidden_states.to(torch.float32, copy=False) 79 | return (outputs, attention_probs) if output_attentions else (outputs,) 80 | 81 | def _maybe_checkpoint(self, func, *args): 82 | return checkpoint(func, *args) if torch.is_grad_enabled() and self.checkpoint_attention_core else func(*args) 83 | 84 | 85 | class SimpleAttentionCore(nn.Module): 86 | def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0): 87 | super().__init__() 88 | assert hidden_size % num_attention_heads == 0 89 | self.attention_dropout = nn.Dropout(attention_probs_dropout, inplace=False) 90 | self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads 91 | self.attention_head_size = hidden_size // num_attention_heads 92 | 93 | def forward(self, query, key, value, attention_mask): 94 | """ 95 | :param query: [batch_size, query_seq_len, hidden_size] 96 | :param key: [batch_size, kv_seq_len, hidden_size] 97 | :param value: [batch_size, kv_seq_len, hidden_size] 98 | :param attention_mask: float [(optional heads), batch, query_seq_len, kv_seq_length] 99 | :note: attention_mask should be equal to zero for non-masked tokens and a large negative value for masked ones 100 | :return: (outputs, probs) 101 | - outputs shape: [batch_size, query_seq_len, hidden_size] 102 | - probs shape: [batch_size, num_heads, query_seq_len, kv_seq_len] 103 | """ 104 | if attention_mask is not None: 105 | assert torch.is_floating_point(attention_mask), "expected float mask with negative values for masked items" 106 | return self._attention_core_forward( 107 | query, key, value, attention_mask, self.num_attention_heads, self.attention_dropout.p, self.training 108 | ) 109 | 110 | @staticmethod 111 | def _attention_core_forward( 112 | query: torch.Tensor, 113 | key: torch.Tensor, 114 | value: torch.Tensor, 115 | attention_mask: Optional[torch.Tensor], 116 | num_attention_heads: int, attention_dropout: float, training: bool 117 | ) -> Tuple[torch.Tensor, torch.Tensor]: 118 | # transpose from [batch, seq_length, full_hid_size] to [batch, num_heads, seq_length, head_size] 119 | new_query_shape = query.shape[:-1] + (num_attention_heads, -1) 120 | new_kv_shape = key.shape[:-1] + (num_attention_heads, -1) 121 | 122 | query = query.view(new_query_shape).permute(0, 2, 1, 3) 123 | key_transposed = key.view(new_kv_shape).permute(0, 2, 3, 1) # swap to [..., head_size, seq_length] 124 | value = value.view(new_kv_shape).permute(0, 2, 1, 3) 125 | del key # not to confuse with key_transposed 126 | 127 | # Take the dot product between "query" and "key" to get the raw attention scores. 128 | attention_scores = torch.matmul(query, key_transposed / math.sqrt(query.shape[-1])) 129 | 130 | if attention_mask is not None: 131 | attention_scores += attention_mask 132 | 133 | # Normalize the attention scores to probabilities. 134 | attention_probs = torch.softmax(attention_scores, dim=-1) 135 | 136 | # This is actually dropping out entire tokens to attend to, which might 137 | # seem a bit unusual, but is taken from the original Transformer paper. 138 | attention_probs = torch.dropout(attention_probs, attention_dropout, training) 139 | attention_output = torch.matmul(attention_probs, value) 140 | attention_output = attention_output.transpose(2, 1).flatten(2) 141 | 142 | return attention_output, attention_probs 143 | 144 | 145 | class RotaryAttentionCore(SimpleAttentionCore): 146 | """Attention core that applies rotary embeddings to queries and keys before computing dot products""" 147 | 148 | def __init__( 149 | self, hidden_size: int, num_attention_heads: int, rotary_emb: Optional[RotaryEmbeddings] = None, **kwargs 150 | ): 151 | super().__init__(hidden_size, num_attention_heads, **kwargs) 152 | if rotary_emb is None: 153 | rotary_emb = RotaryEmbeddings(self.attention_head_size) 154 | self.rotary_emb = rotary_emb 155 | 156 | def rotate(self, tensor: torch.Tensor): 157 | """:param tensor: query or key, shape: [batch_size, query_seq_len, hidden_size]""" 158 | tensor_split_heads = tensor.view(*(tensor.shape[:-1] + (self.num_attention_heads, self.attention_head_size))) 159 | return self.rotary_emb(tensor_split_heads).view(*tensor.shape) 160 | 161 | def forward(self, query, key, value, attention_mask): 162 | return super().forward(self.rotate(query), self.rotate(key), value, attention_mask) 163 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | import torch 5 | from transformers import TrainingArguments 6 | 7 | 8 | @dataclass 9 | class CollaborativeArguments: 10 | """Configuration for CollaborativeOptimizer and its internals""" 11 | 12 | target_batch_size: int = field( 13 | default=16384, 14 | metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"}, 15 | ) 16 | matchmaking_time: float = field( 17 | default=60.0, 18 | metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}, 19 | ) 20 | next_chunk_timeout: float = field( 21 | default=60.0, 22 | metadata={"help": "Consider allreduce peer failed if it does not respond in this many seconds"}, 23 | ) 24 | averaging_timeout: float = field( 25 | default=600.0, 26 | metadata={"help": "Give up on averaging step after this many seconds"}, 27 | ) 28 | offload_optimizer: bool = field(default=True, metadata={"help": "Whether or not to offload optimizer into RAM"}) 29 | delay_optimizer_step: bool = field( 30 | default=True, 31 | metadata={"help": "Whether or not to run optimizer step in background"}, 32 | ) 33 | delay_grad_averaging: bool = field( 34 | default=True, 35 | metadata={"help": "Whether or not to run gradient averaging in background"}, 36 | ) 37 | average_state_every: int = field(default=5, metadata={"help": "Average parameters every this many epochs"}) 38 | reuse_grad_buffers: bool = field( 39 | default=True, 40 | metadata={ 41 | "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This " 42 | "optimization reduces GPU memory consumption but may result in incorrect gradients when using some " 43 | "advanced techniques (e.g. changing loss scaler to a custom one)." 44 | }, 45 | ) 46 | 47 | 48 | @dataclass 49 | class HFTrainerArguments(TrainingArguments): 50 | """Arguments for huggingface/transformers.Trainer""" 51 | 52 | per_device_train_batch_size: int = 1 53 | per_device_eval_batch_size: int = 1 54 | gradient_accumulation_steps: int = 1 55 | 56 | learning_rate: float = 2.5e-3 # based on https://arxiv.org/abs/1904.00962 57 | total_steps: int = 15625 # total number of collaborative optimizer updates, used for learning rate schedule 58 | warmup_steps: int = 3125 # based on https://arxiv.org/abs/1904.00962 59 | min_learning_rate: float = 1e-5 # learning rate after total_steps have passed 60 | adam_beta1: float = 0.9 61 | adam_beta2: float = 0.95 62 | adam_epsilon: float = 1e-6 63 | weight_decay: float = 0.01 64 | max_grad_norm: float = 1.0 # clipping performed by the optimizer; trainer is modified to disable builtin clipping 65 | clamp_value: float = 1e9 # no clipping by value 66 | min_8bit_size: int = 2 ** 20 67 | 68 | gradient_checkpointing: bool = False # can be enabled to save memory at the cost of ~30% slower training 69 | fp16: bool = False # can be enabled depending on the device 70 | 71 | max_sequence_length: int = 2048 72 | initial_sequence_length: Optional[int] = 256 # used only if warmup > 0, default = pad_to_multiple_of 73 | sequence_length_warmup_steps: int = 7_000 74 | pad_to_multiple_of: int = 128 # sequence length will be divisible by this value 75 | 76 | output_dir: str = "outputs" 77 | logging_steps: int = 100 78 | 79 | # params that should *not* be changed* 80 | do_train: bool = True 81 | do_eval: bool = False 82 | logging_first_step = True 83 | dataloader_num_workers: int = 0 # temporary fix for https://github.com/huggingface/datasets/issues/3148 84 | max_steps: int = 10 ** 30 85 | save_steps: int = 10 ** 30 86 | save_total_limit: int = 2 87 | ddp_find_unused_parameters: bool = False 88 | 89 | @property 90 | def batch_size_per_step(self): 91 | """Compute the number of training sequences contributed by each .step() from this peer""" 92 | total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps 93 | if torch.cuda.device_count() > 0: 94 | total_batch_size_per_step *= torch.cuda.device_count() 95 | return total_batch_size_per_step 96 | 97 | 98 | @dataclass 99 | class BasePeerArguments: 100 | """Base arguments that are used for both trainers and for auxiliary peers such as training monitor""" 101 | 102 | run_id: str = field(metadata={"help": "A unique experiment name, used as prefix for all DHT keys"}) 103 | model_config_path: Optional[str] = field(default="./config.json", metadata={"help": "Path to the model config"}) 104 | tokenizer_path: Optional[str] = field(default="./tokenizer", metadata={"help": "Path to the tokenizer"}) 105 | cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"}) 106 | authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"}) 107 | client_mode: bool = field( 108 | default=False, 109 | metadata={"help": "If True, runs training without incoming connections, in a firewall-compatible mode"}, 110 | ) 111 | bandwidth: Optional[float] = field( 112 | default=None, 113 | metadata={"help": "Min(upload & download speed) in megabits/s, used to assign averaging tasks between peers"}, 114 | ) 115 | min_vector_size: int = 4_000_000 # minimum slice of gradients assigned to one reducer, should be same across peers 116 | initial_peers: List[str] = field( 117 | default_factory=list, 118 | metadata={ 119 | "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. " 120 | "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY" 121 | }, 122 | ) 123 | use_ipfs: bool = field( 124 | default=False, 125 | metadata={ 126 | "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs " 127 | "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)" 128 | }, 129 | ) 130 | host_maddrs: List[str] = field( 131 | default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"], 132 | metadata={ 133 | "help": "Multiaddrs to listen for external connections from other p2p instances. " 134 | "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0" 135 | }, 136 | ) 137 | announce_maddrs: List[str] = field( 138 | default_factory=list, 139 | metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"}, 140 | ) 141 | identity_path: Optional[str] = field( 142 | default=None, 143 | metadata={ 144 | "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. " 145 | "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``." 146 | }, 147 | ) 148 | 149 | 150 | @dataclass 151 | class TrainingPeerArguments(BasePeerArguments): 152 | statistics_expiration: float = field( 153 | default=600, 154 | metadata={"help": "Statistics will be removed if not updated in this many seconds"}, 155 | ) 156 | backup_every_epochs: Optional[int] = field( 157 | default=None, 158 | metadata={ 159 | "help": "Update training state backup on disk once in this many global steps " 160 | "(default = do not update local state)" 161 | }, 162 | ) 163 | state_path: str = field( 164 | default="state.zip", 165 | metadata={"help": "Load this state upon init and when recovering from NaN parameters"}, 166 | ) 167 | 168 | 169 | @dataclass 170 | class AuxiliaryPeerArguments(BasePeerArguments): 171 | """ 172 | Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking 173 | learning curves, assisting in all-reduce and uploading checkpoints to the hub 174 | """ 175 | 176 | refresh_period: float = field( 177 | default=10, 178 | metadata={"help": "Period (in seconds) for fetching the keys from DHT"}, 179 | ) 180 | wandb_project: Optional[str] = field( 181 | default=None, 182 | metadata={"help": "Name of Weights & Biases project to report the training progress to"}, 183 | ) 184 | save_checkpoint_epoch_interval: int = field( 185 | default=5, 186 | metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}, 187 | ) 188 | repo_url: Optional[str] = field( 189 | default=None, 190 | metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}, 191 | ) 192 | local_path: Optional[str] = field( 193 | default="Repo", 194 | metadata={"help": "Path to local repository to store the model and optimizer states"}, 195 | ) 196 | upload_interval: Optional[float] = field( 197 | default=None, 198 | metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}, 199 | ) 200 | store_checkpoints: bool = field(default=True, metadata={"help": "If True, enables CheckpointHandler"}) 201 | assist_in_averaging: bool = field( 202 | default=False, 203 | metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"}, 204 | ) 205 | assist_refresh: float = field( 206 | default=1.0, 207 | metadata={"help": "Period (in seconds) for tryin to assist averaging"}, 208 | ) 209 | -------------------------------------------------------------------------------- /src/modules/pixelfly.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from src.modules.functional import maybe_script 9 | 10 | 11 | @functools.lru_cache() 12 | def get_butterfly_indices( 13 | out_features: int, 14 | in_features: int, 15 | block_size: int = 256, 16 | butterfly_size: Optional[int] = None, 17 | n_factors: Optional[int] = None, 18 | stretch: bool = False, 19 | ) -> Tuple[torch.IntTensor, torch.IntTensor]: 20 | """ 21 | Get a matrix [num_output_blocks, num_active_input_blocks] with int32 indices for additive butterfly. 22 | The values in the matrix represent 23 | Based on the original implementation from https://arxiv.org/abs/2112.00029 . 24 | 25 | :param stretch: by default, non-square matrices will have stretched butterfly patterns, 26 | otherwise the square pattern will be repeated a given number of times 27 | 28 | :returns: tuple (forward_indices, backward_indices), where 29 | - (forward) indices of non-zero blocks that contribute to each output -- assuming all input blocks are flattened 30 | - (backward) indices of output blocks to which a given input block contributes 31 | """ 32 | if butterfly_size is None: 33 | butterfly_size = 2 ** int(math.ceil(math.log2(min(in_features, out_features) / block_size))) 34 | assert out_features % in_features == 0 or in_features % out_features == 0, \ 35 | "if matrix is not square, the longer dimension must be a multiple of the shorter dimension" 36 | assert out_features % block_size == 0 and in_features % block_size == 0 37 | log_n = int(math.log2(butterfly_size)) 38 | n_factors = log_n if n_factors is None else n_factors 39 | if butterfly_size != 2 ** log_n or butterfly_size < 2: 40 | raise NotImplementedError("butterfly_size must be a power of 2") 41 | if not (1 <= n_factors <= log_n): 42 | raise NotImplementedError("n_factors must be a between 1 and log_2(butterfly_size)") 43 | 44 | twiddle = torch.ones(butterfly_size // 2, 2, 2) 45 | layout = sum(butterfly_factor_to_matrix(twiddle, index) for index in range(n_factors)).bool().int() 46 | # Convert from (butterfly_size, butterfly_size) mask to (out_features, in_features) mask 47 | layout = einops.repeat( 48 | layout, 49 | "b b1 -> (b f) (b1 f1)", 50 | f=out_features // butterfly_size, 51 | f1=in_features // butterfly_size, 52 | ) 53 | # Convert from (out_features, in_features) mask to 54 | # (out_features // block_size, in_features // block_size) mask 55 | layout = einops.rearrange( 56 | layout, 57 | "(p blksz) (r blksz1) -> p r (blksz blksz1)", 58 | blksz=block_size, 59 | blksz1=block_size, 60 | ) 61 | 62 | layout = (layout > 0).any(dim=-1) # [out_features // block_size, in_features // block_size] 63 | if not stretch: 64 | out_blocks, in_blocks = layout.shape 65 | if out_blocks > in_blocks: 66 | ratio = out_blocks // in_blocks 67 | layout = layout.view(out_blocks // ratio, ratio, in_blocks).permute(1, 0, 2).reshape_as(layout) 68 | elif out_blocks < in_blocks: 69 | ratio = in_blocks // out_blocks 70 | layout = layout.view(out_blocks, in_blocks // ratio, ratio).permute(0, 2, 1).reshape_as(layout) 71 | 72 | # convert boolean layout to indices for F.embedding_bag 73 | num_output_blocks = out_features // block_size 74 | num_input_blocks = in_features // block_size 75 | active_blocks_per_output = layout.sum(1).unique() 76 | assert len(active_blocks_per_output) == 1, "butterfly layout must have the same number of blocks per row" 77 | active_blocks_per_output = active_blocks_per_output.item() 78 | 79 | active_blocks_per_input = layout.sum(0).unique() 80 | assert len(active_blocks_per_input) == 1, "butterfly layout must have the same number of blocks per row" 81 | active_blocks_per_input = active_blocks_per_input.item() 82 | 83 | # which input blocks should be added for i-th output 84 | input_block_index = layout.nonzero()[:, 1].view(num_output_blocks, active_blocks_per_output) 85 | # which output blocks does j-th input contribute to 86 | output_block_index = layout.t().nonzero()[:, 1].view(num_input_blocks, active_blocks_per_input) 87 | 88 | # which of the active blocks from the corresponding input_block should be used for i-th output 89 | active_block_index = torch.where( 90 | torch.eq( 91 | output_block_index[input_block_index], 92 | torch.arange(len(input_block_index))[:, None, None], 93 | ) 94 | )[-1].view(input_block_index.shape) 95 | 96 | forward_indices = input_block_index * active_blocks_per_input + active_block_index 97 | backward_indices = output_block_index 98 | return forward_indices.to(torch.int32), backward_indices.to(torch.int64) # dtypes tuned for max throughput 99 | 100 | 101 | def butterfly_factor_to_matrix(twiddle: torch.Tensor, factor_index: int) -> torch.Tensor: 102 | """ 103 | Let b be the base (most commonly 2). 104 | Parameters: 105 | twiddle: (n // b, b, b) 106 | factor_index: an int from 0 to log_b(n) - 1 107 | """ 108 | n_div_b, b, _ = twiddle.shape 109 | n = b * n_div_b 110 | log_b_n = int(math.log(n) / math.log(b)) 111 | assert n == b ** log_b_n, f"n must be a power of {b}" 112 | assert twiddle.shape == (n // b, b, b) 113 | assert 0 <= factor_index <= log_b_n 114 | stride = b ** factor_index 115 | x = einops.rearrange(torch.eye(n), "bs (diagblk j stride) -> bs diagblk j stride", stride=stride, j=b) 116 | t = einops.rearrange(twiddle, "(diagblk stride) i j -> diagblk stride i j", stride=stride) 117 | out = torch.einsum("d s i j, b d j s -> b d i s", t, x) 118 | out = einops.rearrange(out, "b diagblk i stride -> b (diagblk i stride)") 119 | return out.t() # Transpose because we assume the 1st dimension of x is the batch dimension 120 | 121 | 122 | @maybe_script 123 | def butterfly_matmul(input: torch.Tensor, weight: torch.Tensor, forward_indices: torch.Tensor) -> torch.Tensor: 124 | """ 125 | :param input: tensor [*batch_dims, in_features] 126 | :param weight: tensor [in_features, active_blocks_per_input, block_size] 127 | :param forward_indices: the first output of get_butterfly_indices(...) 128 | :returns: tensor [*batch_dims, out_features] 129 | """ 130 | assert input.shape[-1] == weight.shape[0] 131 | in_features, active_blocks_per_input, block_size = weight.shape 132 | num_input_blocks = in_features // block_size 133 | batch_dims = input.shape[:-1] 134 | input = input.flatten(0, -2) 135 | 136 | input_permuted = input.t().view(input.shape[1] // block_size, block_size, input.shape[0]) 137 | output_blocks = torch.matmul(weight.view(num_input_blocks, -1, block_size), input_permuted) 138 | # ^-- shape: [num_input_blocks, (active_blocks_per_input * block_size), flat_batch_dims] 139 | 140 | blocks_for_indexing = output_blocks.view(num_input_blocks * active_blocks_per_input, block_size * input.shape[0]) 141 | # ^-- shape: [(num_input_blocks * active_blocks_per_input), (block_size, flat_batch_dims)] 142 | 143 | aggregated_blocks = F.embedding_bag(forward_indices, blocks_for_indexing, mode="sum") 144 | # ^-- shape: [num_ouput_blocks, (block_size, flat_batch_dims)] 145 | 146 | outputs = aggregated_blocks.view(-1, input.shape[0]).t() 147 | # ^-- shape: [flat_batch_dims, (num_output_blocks * block_size)] aka [flat_batch_dims, out_features] 148 | return outputs.view(batch_dims + outputs.shape[-1:]) 149 | 150 | 151 | @maybe_script 152 | def butterfly_matmul_backward( 153 | grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, backward_indices: torch.Tensor, 154 | input_requires_grad: bool = True, weight_requires_grad: bool = True): 155 | """Compute gradients of butterfly_matmul w.r.t. input and/or weight without relying on pytorch autograd""" 156 | assert input_requires_grad or weight_requires_grad, "computing backward but none of the inputs requires grad" 157 | grad_input = grad_weight = torch.empty(0) 158 | out_features = grad_output.shape[-1] 159 | in_features, active_blocks_per_input, block_size = weight.shape 160 | num_input_blocks = input.shape[-1] // block_size 161 | num_output_blocks = out_features // block_size 162 | grad_output_flat = grad_output.flatten(0, -2) 163 | input_flat = input.flatten(0, -2) 164 | 165 | flat_batch_dims = grad_output_flat.shape[0] 166 | 167 | grad_aggregated_blocks = grad_output_flat.t().reshape(num_output_blocks, (block_size * flat_batch_dims)) 168 | # [num_output_blocks, (block_size, flat_batch_dims)] 169 | 170 | grad_blocks_for_indexing = F.embedding(backward_indices, grad_aggregated_blocks).flatten(0, -2) 171 | # ^-- shape: [(num_input_blocks * active_blocks_per_input), (block_size, flat_batch_dims)] 172 | 173 | grad_output_blocks = grad_blocks_for_indexing.view( 174 | num_input_blocks, active_blocks_per_input * block_size, flat_batch_dims 175 | ) 176 | # ^-- shape: [num_input_blocks, (active_blocks_per_input * block_size), flat_batch_dims] 177 | 178 | if input_requires_grad: 179 | grad_input_permuted = torch.matmul( 180 | weight.view(num_input_blocks, -1, block_size).permute(0, 2, 1), grad_output_blocks 181 | ) 182 | grad_input = grad_input_permuted.flatten(0, -2).t().view(grad_output.shape[:-1] + input.shape[-1:]) 183 | 184 | if weight_requires_grad: 185 | grad_weight = torch.matmul( 186 | grad_output_blocks, input_flat.t().view(num_input_blocks, block_size, flat_batch_dims).permute(0, 2, 1) 187 | ).view_as(weight) 188 | 189 | return grad_input, grad_weight 190 | -------------------------------------------------------------------------------- /finetune_mlm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | from functools import partial 5 | from pathlib import Path 6 | from shutil import rmtree 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import transformers.utils.logging 12 | import wandb 13 | from datasets import load_dataset, Dataset, DatasetDict 14 | from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, \ 15 | TrainingArguments, AlbertTokenizerFast, AutoModel 16 | from transformers.trainer_utils import set_seed 17 | 18 | from data import TASK_TO_CONFIG, TASK_TO_NAME 19 | from field_collator import FieldDataCollatorWithPadding 20 | from models import SpanClassificationModel, EntityChoiceModel 21 | from src import LeanAlbertConfig, LeanAlbertForSequenceClassification, LeanAlbertForPreTraining 22 | 23 | 24 | class NumpyEncoder(json.JSONEncoder): 25 | def default(self, obj): 26 | if isinstance(obj, np.integer): 27 | return int(obj) 28 | if isinstance(obj, np.floating): 29 | return float(obj) 30 | if isinstance(obj, np.ndarray): 31 | return obj.tolist() 32 | return super().default(obj) 33 | 34 | 35 | MODEL_TO_HUB_NAME = { 36 | 'ruroberta-large': "sberbank-ai/ruRoberta-large", 37 | } 38 | 39 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 40 | 41 | N_EPOCHS = 40 42 | LEARNING_RATE = 1e-5 43 | MAX_LENGTH = 512 44 | 45 | 46 | def main(task, model_name, checkpoint_path, data_dir, batch_size, grad_acc_steps, dropout, weight_decay, num_seeds): 47 | if checkpoint_path is not None: 48 | tokenizer = AlbertTokenizerFast.from_pretrained('tokenizer') 49 | assert model_name is None 50 | model_name = 'lean_albert' 51 | else: 52 | tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_TO_HUB_NAME[model_name], 53 | cache_dir=data_dir / 'transformers_cache') 54 | 55 | if task == "russe": 56 | data_collator = FieldDataCollatorWithPadding(tokenizer, fields_to_pad=(("e1_mask", 0, 0), ("e2_mask", 0, 0)), 57 | pad_to_multiple_of=8) 58 | elif task == "rucos": 59 | data_collator = FieldDataCollatorWithPadding(tokenizer=tokenizer, 60 | fields_to_pad=( 61 | ("entity_mask", 0, 1), 62 | ("labels", -1, None) 63 | )) 64 | else: 65 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 66 | 67 | if task == "rucos": 68 | # we're using a custom dataset, because the HF hub version has no information about entity indices 69 | train = Dataset.from_json(["RuCoS/train.jsonl"]) 70 | val = Dataset.from_json(["RuCoS/val.jsonl"]) 71 | test = Dataset.from_json(["RuCoS/test.jsonl"]) 72 | dataset = DatasetDict(train=train, validation=val, test=test) 73 | elif task == "rucola": 74 | train_df, in_domain_dev_df, out_of_domain_dev_df, test_df = map( 75 | pd.read_csv, ("RuCoLA/in_domain_train.csv", "RuCoLA/in_domain_dev.csv", "RuCoLA/out_of_domain_dev.csv", 76 | "RuCoLA/test.csv") 77 | ) 78 | 79 | # concatenate datasets to get aggregate metrics 80 | dev_df = pd.concat((in_domain_dev_df, out_of_domain_dev_df)) 81 | train, dev, test = map(Dataset.from_pandas, (train_df, dev_df, test_df)) 82 | dataset = DatasetDict(train=train, validation=dev, test=test) 83 | else: 84 | dataset = load_dataset("russian_super_glue", task) 85 | 86 | config = TASK_TO_CONFIG[task](dataset) 87 | 88 | processed_dataset = dataset.map(partial(config.process_data, tokenizer=tokenizer, max_length=MAX_LENGTH), 89 | num_proc=32, keep_in_memory=True, batched=True) 90 | 91 | if "labels" in processed_dataset["test"].column_names: 92 | test_without_labels = processed_dataset['test'].remove_columns(['labels']) 93 | else: 94 | test_without_labels = processed_dataset["test"] 95 | 96 | transformers.utils.logging.enable_progress_bar() 97 | 98 | model_prefix = f"{model_name}_" \ 99 | f"{task}_" \ 100 | f"dr{dropout}_" \ 101 | f"wd{weight_decay}_" \ 102 | f"bs{batch_size * grad_acc_steps}" 103 | 104 | dev_metrics_per_run = [] 105 | predictions_per_run = [] 106 | 107 | for seed in range(num_seeds): 108 | set_seed(seed) 109 | 110 | if checkpoint_path is not None: 111 | model_config = LeanAlbertConfig.from_pretrained('config.json') 112 | model_config.num_labels = config.num_classes 113 | model_config.classifier_dropout_prob = dropout 114 | 115 | if task in ("russe", "rucos"): 116 | model = LeanAlbertForPreTraining(model_config) 117 | else: 118 | model = LeanAlbertForSequenceClassification(model_config) 119 | 120 | model.resize_token_embeddings(len(tokenizer)) 121 | checkpoint = torch.load(checkpoint_path, map_location='cpu')['model'] 122 | incompat_keys = model.load_state_dict(checkpoint, strict=False) 123 | print("missing", incompat_keys.missing_keys) 124 | print("unexpected", incompat_keys.unexpected_keys) 125 | 126 | if task in ("russe", "rucos"): 127 | model = model.albert 128 | else: 129 | if task in ("russe", "rucos"): 130 | model = AutoModel.from_pretrained(MODEL_TO_HUB_NAME[model_name], 131 | attention_probs_dropout_prob=dropout, 132 | hidden_dropout_prob=dropout, 133 | cache_dir=data_dir / 'transformers_cache') 134 | else: 135 | model = AutoModelForSequenceClassification.from_pretrained(MODEL_TO_HUB_NAME[model_name], 136 | num_labels=config.num_classes, 137 | attention_probs_dropout_prob=dropout, 138 | hidden_dropout_prob=dropout, 139 | cache_dir=data_dir / 'transformers_cache') 140 | 141 | if task == "russe": 142 | model = SpanClassificationModel(model, num_labels=config.num_classes) 143 | elif task == "rucos": 144 | model = EntityChoiceModel(model) 145 | 146 | run_base_dir = f"{model_prefix}_{seed}" 147 | 148 | run = wandb.init(project='brbert', name=run_base_dir) 149 | run.config.update({"task": task, "model": model_name, "checkpoint": str(checkpoint_path)}) 150 | 151 | training_args = TrainingArguments( 152 | output_dir=data_dir / 'checkpoints' / run_base_dir, overwrite_output_dir=True, 153 | evaluation_strategy='epoch', logging_strategy='epoch', logging_first_step=True, 154 | per_device_train_batch_size=batch_size, 155 | per_device_eval_batch_size=batch_size, gradient_accumulation_steps=grad_acc_steps, 156 | optim="adamw_torch", learning_rate=LEARNING_RATE, weight_decay=weight_decay, 157 | num_train_epochs=N_EPOCHS, warmup_ratio=0.1, save_strategy='epoch', 158 | seed=seed, fp16=True, dataloader_num_workers=4, group_by_length=True, 159 | report_to='wandb', run_name=run_base_dir, save_total_limit=1, 160 | load_best_model_at_end=True, metric_for_best_model=config.best_metric 161 | ) 162 | 163 | trainer = Trainer( 164 | model=model, 165 | args=training_args, 166 | train_dataset=processed_dataset['train'], 167 | eval_dataset=processed_dataset['validation'], 168 | compute_metrics=partial(config.compute_metrics, split="validation", 169 | processed_dataset=processed_dataset["validation"]), 170 | tokenizer=tokenizer, 171 | data_collator=data_collator, 172 | ) 173 | 174 | train_result = trainer.train() 175 | print(run_base_dir) 176 | print('train', train_result.metrics) 177 | 178 | dev_predictions = trainer.predict(test_dataset=processed_dataset['validation']) 179 | print('dev', dev_predictions.metrics) 180 | 181 | run.summary.update(dev_predictions.metrics) 182 | wandb.finish() 183 | 184 | dev_metrics_per_run.append(dev_predictions.metrics[f"test_{config.best_metric}"]) 185 | 186 | predictions = trainer.predict(test_dataset=test_without_labels) 187 | predictions_per_run.append(predictions.predictions) 188 | 189 | if task != "terra": 190 | rmtree(data_dir / 'checkpoints' / run_base_dir) 191 | 192 | best_run = np.argmax(dev_metrics_per_run) 193 | best_predictions = predictions_per_run[best_run] 194 | processed_predictions = config.process_predictions(best_predictions, split="test", 195 | processed_dataset=processed_dataset["test"]) 196 | 197 | prefix_without_task = model_prefix.replace(f"{task}_", "") 198 | 199 | os.makedirs(f"preds/{prefix_without_task}", exist_ok=True) 200 | 201 | if task == "rucola": 202 | result_df = pd.DataFrame.from_records(processed_predictions, index="id") 203 | result_df.to_csv(f"preds/{prefix_without_task}/{TASK_TO_NAME[task]}.csv") 204 | else: 205 | with open(f"preds/{prefix_without_task}/{TASK_TO_NAME[task]}.jsonl", 'w+') as outf: 206 | for prediction in processed_predictions: 207 | print(json.dumps(prediction, ensure_ascii=True, cls=NumpyEncoder), file=outf) 208 | 209 | 210 | if __name__ == '__main__': 211 | parser = ArgumentParser() 212 | parser.add_argument("-t", '--task', choices=TASK_TO_CONFIG.keys()) 213 | parser.add_argument("-m", '--model-name', choices=MODEL_TO_HUB_NAME.keys()) 214 | parser.add_argument("-c", '--checkpoint', type=Path) 215 | parser.add_argument("-d", "--data-dir", type=Path) 216 | parser.add_argument("--batch-size", required=True, type=int) 217 | parser.add_argument("--grad-acc-steps", required=True, type=int) 218 | parser.add_argument("--dropout", required=True, type=float) 219 | parser.add_argument("--weight-decay", required=True, type=float) 220 | parser.add_argument("--num-seeds", required=True, type=int) 221 | args = parser.parse_args() 222 | main(args.task, args.model_name, args.checkpoint, args.data_dir, args.batch_size, args.grad_acc_steps, 223 | args.dropout, args.weight_decay, args.num_seeds) 224 | -------------------------------------------------------------------------------- /src/models/lean_albert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch ALBERT modules that do not hog your GPU memory """ 16 | from functools import lru_cache 17 | 18 | import torch 19 | import torch.nn as nn 20 | from transformers.file_utils import add_start_docstrings 21 | from transformers.modeling_outputs import BaseModelOutput 22 | from transformers.modeling_utils import PreTrainedModel 23 | from transformers.models.albert import AlbertConfig 24 | from transformers.models.albert.modeling_albert import ( 25 | ACT2FN, 26 | ALBERT_START_DOCSTRING, 27 | AlbertForPreTraining, 28 | AlbertLayerGroup, 29 | AlbertMLMHead, 30 | AlbertModel, 31 | AlbertSOPHead, 32 | AlbertTransformer 33 | ) 34 | from transformers.utils import logging 35 | 36 | from src.models.config import LeanAlbertConfig 37 | from src.modules.attn import LeanSelfAttention, RotaryAttentionCore 38 | from src.modules.ffn import LeanFFN 39 | from src.modules.rotary import RotaryEmbeddings 40 | 41 | logger = logging.get_logger(__name__) 42 | 43 | _CONFIG_FOR_DOC = "LeanAlbertConfig" 44 | _TOKENIZER_FOR_DOC = "AlbertTokenizer" 45 | 46 | 47 | def get_input_embedding(config: LeanAlbertConfig): 48 | if config.position_embedding_type == "absolute": 49 | return nn.Embedding(config.max_position_embeddings, config.embedding_size) 50 | elif config.position_embedding_type == "rotary": 51 | return None 52 | else: 53 | raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding}") 54 | 55 | 56 | @lru_cache() 57 | def get_attention_core(config: LeanAlbertConfig): 58 | if config.position_embedding_type == "absolute": 59 | return None 60 | elif config.position_embedding_type == "rotary": 61 | rotary_emb = RotaryEmbeddings(config.hidden_size // config.num_attention_heads, config.rotary_embedding_base) 62 | return RotaryAttentionCore( 63 | config.hidden_size, 64 | config.num_attention_heads, 65 | rotary_emb, 66 | attention_probs_dropout=config.attention_probs_dropout_prob, 67 | ) 68 | else: 69 | raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding_type}") 70 | 71 | 72 | class LeanAlbertEmbeddings(nn.Module): 73 | """ 74 | Construct the embeddings from word, position and token_type embeddings. 75 | """ 76 | 77 | def __init__(self, config): 78 | super().__init__() 79 | self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) 80 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) 81 | self.position_embeddings = get_input_embedding(config) 82 | 83 | self.layernorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) 84 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 85 | 86 | if self.position_embeddings is not None: 87 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 88 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 89 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 90 | 91 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward 92 | def forward( 93 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 94 | ): 95 | if input_ids is not None: 96 | input_shape = input_ids.size() 97 | else: 98 | input_shape = inputs_embeds.size()[:-1] 99 | 100 | seq_length = input_shape[1] 101 | 102 | if token_type_ids is None: 103 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 104 | 105 | if inputs_embeds is None: 106 | inputs_embeds = self.word_embeddings(input_ids) 107 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 108 | 109 | embeddings = inputs_embeds + token_type_embeddings 110 | 111 | if self.position_embeddings is not None: 112 | if position_ids is None: 113 | position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length] 114 | position_embeddings = self.position_embeddings(position_ids) 115 | embeddings += position_embeddings 116 | 117 | embeddings = self.layernorm(embeddings) 118 | embeddings = self.dropout(embeddings) 119 | return embeddings 120 | 121 | 122 | class LeanAlbertLayer(nn.Module): 123 | def __init__(self, config: LeanAlbertConfig): 124 | super().__init__() 125 | 126 | self.config = config 127 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 128 | self.seq_len_dim = 1 129 | self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 130 | 131 | self.attention = LeanSelfAttention( 132 | config.hidden_size, 133 | config.num_attention_heads, 134 | attention_core=get_attention_core(config), 135 | hidden_dropout_prob=config.hidden_dropout_prob, 136 | layer_norm_eps=config.layer_norm_eps, 137 | ) 138 | 139 | self.ffn = LeanFFN( 140 | config.hidden_size, 141 | config.intermediate_size, 142 | activation=ACT2FN[config.hidden_act], 143 | gated=config.hidden_act_gated, 144 | layer_norm_eps=config.layer_norm_eps, 145 | dropout=config.hidden_dropout_prob, 146 | ) 147 | 148 | def forward(self, hidden_states, attention_mask=None, output_attentions=False): 149 | attention_output, *extras = self.attention(hidden_states, attention_mask, output_attentions) 150 | ffn_output = self.ffn(attention_output) 151 | return (ffn_output, attention_output, *extras) 152 | 153 | 154 | class LeanAlbertLayerGroup(AlbertLayerGroup): 155 | def __init__(self, config): 156 | nn.Module.__init__(self) 157 | self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)]) 158 | 159 | def forward( 160 | self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False 161 | ): 162 | if any(head_mask): 163 | raise NotImplementedError(f"head mask was provided, but it is not supported") 164 | 165 | layer_hidden_states = () 166 | layer_attentions = () 167 | 168 | for layer_index, albert_layer in enumerate(self.albert_layers): 169 | layer_output = albert_layer(hidden_states, attention_mask, output_attentions) 170 | hidden_states = layer_output[0] 171 | 172 | if output_attentions: 173 | layer_attentions = layer_attentions + (layer_output[1],) 174 | 175 | if output_hidden_states: 176 | layer_hidden_states = layer_hidden_states + (hidden_states,) 177 | 178 | outputs = (hidden_states,) 179 | if output_hidden_states: 180 | outputs = outputs + (layer_hidden_states,) 181 | if output_attentions: 182 | outputs = outputs + (layer_attentions,) 183 | return outputs # last-layer hidden state, (layer hidden states), (layer attentions) 184 | 185 | 186 | class LeanAlbertTransformer(AlbertTransformer): 187 | def __init__(self, config): 188 | nn.Module.__init__(self) 189 | self.config = config 190 | self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size) 191 | self.albert_layer_groups = nn.ModuleList( 192 | [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)] 193 | ) 194 | self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps) 195 | 196 | def forward( 197 | self, 198 | hidden_states, 199 | attention_mask=None, 200 | head_mask=None, 201 | output_attentions=False, 202 | output_hidden_states=False, 203 | return_dict=True, 204 | ): 205 | hidden_states = self.embedding_hidden_mapping_in(hidden_states) 206 | 207 | all_hidden_states = (hidden_states,) if output_hidden_states else None 208 | all_attentions = () if output_attentions else None 209 | 210 | for i in range(self.config.num_hidden_layers): 211 | # Number of layers in a hidden group 212 | layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) 213 | 214 | # Index of the hidden group 215 | group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) 216 | 217 | layer_group_output = self.albert_layer_groups[group_idx]( 218 | hidden_states, 219 | attention_mask, 220 | head_mask[group_idx * layers_per_group: (group_idx + 1) * layers_per_group], 221 | output_attentions, 222 | output_hidden_states, 223 | ) 224 | hidden_states = layer_group_output[0] 225 | 226 | if output_attentions: 227 | all_attentions = all_attentions + layer_group_output[-1] 228 | 229 | if output_hidden_states: 230 | all_hidden_states = all_hidden_states + (hidden_states,) 231 | 232 | if not return_dict: 233 | return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) 234 | 235 | return BaseModelOutput( 236 | last_hidden_state=self.post_layer_norm(hidden_states), 237 | hidden_states=all_hidden_states, 238 | attentions=all_attentions, 239 | ) 240 | 241 | 242 | @add_start_docstrings( 243 | "The bare LeanALBERT Model transformer outputting raw hidden-states without any specific head on top.", 244 | ALBERT_START_DOCSTRING, 245 | ) 246 | class LeanAlbertModel(AlbertModel): 247 | config_class = LeanAlbertConfig 248 | 249 | def __init__(self, config: AlbertConfig, add_pooling_layer=True): 250 | PreTrainedModel.__init__(self, config) 251 | 252 | self.config = config 253 | self.embeddings = LeanAlbertEmbeddings(config) 254 | self.encoder = LeanAlbertTransformer(config) 255 | 256 | if add_pooling_layer: 257 | self.pooler = nn.Linear(config.hidden_size, config.hidden_size) 258 | self.pooler_activation = nn.Tanh() 259 | else: 260 | self.pooler = None 261 | self.pooler_activation = None 262 | 263 | self.init_weights() 264 | 265 | 266 | class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel): 267 | config_class = LeanAlbertConfig 268 | base_model_prefix = "albert" 269 | 270 | def __init__(self, config: AlbertConfig): 271 | PreTrainedModel.__init__(self, config) 272 | 273 | self.albert = LeanAlbertModel(config) 274 | self.predictions = AlbertMLMHead(config) 275 | self.sop_classifier = AlbertSOPHead(config) 276 | 277 | self.init_weights() 278 | -------------------------------------------------------------------------------- /src/modules/linear.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements weight matrix sharing for linear layer: full sharing and sharing with adapters 3 | """ 4 | import math 5 | from itertools import zip_longest 6 | from typing import Optional, Tuple, List 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.cuda.amp import custom_bwd, custom_fwd 12 | 13 | from src.modules.pixelfly import butterfly_matmul, butterfly_matmul_backward, get_butterfly_indices 14 | from src.modules.functional import maybe_script 15 | 16 | 17 | class GeneralizedMatrix(nn.Module): 18 | """A module that stores a shared pytorch tensor for use in GeneralizedLinear layers""" 19 | 20 | def __init__(self, in_features: int, out_features: int, block_size: int = 0, lowrank_dim: int = 0): 21 | super().__init__() 22 | self.out_features, self.in_features = out_features, in_features 23 | 24 | if block_size == 0: 25 | # fully-connected weight matrix 26 | self.weight = nn.Parameter(torch.empty(out_features, in_features)) 27 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 28 | # note: this is usually overwritten by the model-wide initialization 29 | self.forward_indices = self.backward_indices = None 30 | else: 31 | # block-sparse weights with additive butterfly pattern 32 | forward_indices, backward_indices = get_butterfly_indices( 33 | out_features, in_features, block_size, stretch=False 34 | ) 35 | self.register_buffer("forward_indices", forward_indices) 36 | self.register_buffer("backward_indices", backward_indices) 37 | active_blocks_per_input = self.forward_indices.numel() // (in_features // block_size) 38 | self.weight = nn.Parameter(torch.empty(in_features, active_blocks_per_input, block_size)) 39 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 40 | # note: this is usually overwritten by the model-wide init 41 | 42 | if lowrank_dim: 43 | self.lowrank_first = nn.Parameter(torch.zeros(lowrank_dim, self.in_features)) 44 | self.lowrank_second = nn.Parameter(torch.zeros(self.out_features, lowrank_dim)) 45 | nn.init.normal_(self.lowrank_first, std=math.sqrt(2.0 / (5 * min(out_features, in_features)))) 46 | nn.init.normal_(self.lowrank_second, std=math.sqrt(2.0 / (5 * min(out_features, in_features)))) 47 | else: 48 | self.lowrank_first = self.lowrank_second = None 49 | 50 | @property 51 | def shape(self): 52 | return (self.out_features, self.in_features) 53 | 54 | def __repr__(self): 55 | return f"{self.__class__.__name__}{tuple(self.shape)}" 56 | 57 | def forward(self, input: torch.Tensor, *, ignore_lowrank: bool = False): 58 | """ 59 | Multiply input tensor by this matrix with the same semantics as in torch.nn.Linear(..., bias=False) 60 | 61 | :param ignore_lowrank: if True, the low-rank components (lowrank_dim) will not be used in matrix multiplication 62 | """ 63 | if self.forward_indices is not None: 64 | output = butterfly_matmul(input, self.weight, self.forward_indices) 65 | else: 66 | output = F.linear(input, self.weight) 67 | if self.lowrank_first is not None and not ignore_lowrank: 68 | output = F.linear(F.linear(input, self.lowrank_first), self.lowrank_second, output) 69 | return output 70 | 71 | 72 | class GeneralizedLinear(nn.Linear): 73 | """A linear layer with a shared full-rank matrix and an individual low-rank adapter""" 74 | 75 | def __init__(self, shared_matrix: GeneralizedMatrix, adapter_dim: int = 0, bias: bool = True): 76 | nn.Module.__init__(self) 77 | self.shared_matrix = shared_matrix 78 | self.out_features, self.in_features = self.shared_matrix.shape 79 | self.bias = nn.Parameter(torch.zeros(self.out_features)) if bias else None 80 | 81 | if adapter_dim != 0: 82 | self.adapter_first = nn.Parameter(torch.zeros(adapter_dim, self.in_features)) 83 | self.adapter_second = nn.Parameter(torch.zeros(self.out_features, adapter_dim)) 84 | 85 | # initialize in accordance with https://arxiv.org/pdf/2106.09685.pdf 86 | nn.init.xavier_normal_(self.adapter_first) 87 | nn.init.zeros_(self.adapter_second) 88 | else: 89 | self.adapter_first = self.adapter_second = None 90 | 91 | def get_combined_lowrank_components(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: 92 | """Group together low-rank matrix components from this layer's adapter and GeneralizedMatrix for faster matmul""" 93 | if self.adapter_first is not None and self.shared_matrix.lowrank_first is None: 94 | return self.adapter_first, self.adapter_second 95 | elif self.adapter_first is None and self.shared_matrix.lowrank_first is not None: 96 | return self.shared_matrix.lowrank_first, self.shared_matrix.lowrank_second 97 | elif self.adapter_first is not None and self.shared_matrix.lowrank_first is not None: 98 | combined_first = torch.cat([self.shared_matrix.lowrank_first, self.adapter_first], dim=0) 99 | # ^-- cat0[(lowrank_dim x input_dim), (adapter_dim, input_dim)] -> (combined_dim, input_dim) 100 | combined_second = torch.cat([self.shared_matrix.lowrank_second, self.adapter_second], dim=1) 101 | # ^-- cat1[(output_dim x lowrank_dim), (output_dim, adapter_dim)] -> (combined_dim, input_dim) 102 | return combined_first, combined_second 103 | else: 104 | assert self.adapter_first is None and self.adapter_second is None 105 | assert self.shared_matrix.lowrank_first is None and self.shared_matrix.lowrank_second is None 106 | return None, None 107 | 108 | @property 109 | def weight(self): 110 | return self.shared_matrix.weight 111 | 112 | def forward(self, input: torch.Tensor) -> torch.Tensor: 113 | return _GeneralizedLinear.apply( 114 | input, self.weight, self.bias, *self.get_combined_lowrank_components(), 115 | self.shared_matrix.forward_indices, self.shared_matrix.backward_indices) 116 | 117 | 118 | class _GeneralizedLinear(torch.autograd.Function): 119 | @staticmethod 120 | @custom_fwd 121 | def forward(ctx, *args): 122 | output, *tensors_to_save = _GeneralizedLinear._forward_impl(*args) 123 | ctx.save_for_backward(*tensors_to_save) 124 | return output 125 | 126 | @staticmethod 127 | @maybe_script 128 | def _forward_impl( 129 | input: torch.Tensor, 130 | main_weight: torch.Tensor, 131 | bias: Optional[torch.Tensor], 132 | lowrank_first: Optional[torch.Tensor], 133 | lowrank_second: Optional[torch.Tensor], 134 | forward_indices: Optional[torch.Tensor], 135 | backward_indices: Optional[torch.Tensor], 136 | ): 137 | input_flat = input.view(-1, input.shape[-1]) 138 | if forward_indices is not None: 139 | output = butterfly_matmul(input_flat, main_weight, forward_indices) 140 | if bias is not None: 141 | output.add_(bias.to(output.dtype)) 142 | else: 143 | output = F.linear(input_flat, main_weight, bias) 144 | 145 | if lowrank_first is not None and lowrank_second is not None: 146 | lowrank_hid = F.linear(input_flat, lowrank_first) 147 | if "xla" in output.device.type: # xla does not support in-place ops 148 | output = torch.addmm(output, lowrank_hid, lowrank_second.t().to(output.dtype)) 149 | else: 150 | output = torch.addmm(output, lowrank_hid, lowrank_second.t().to(output.dtype), out=output) 151 | else: 152 | lowrank_hid = None 153 | output = output.view(input.shape[:-1] + output.shape[-1:]) 154 | return output, input, lowrank_hid, main_weight, lowrank_first, lowrank_second, backward_indices 155 | 156 | @staticmethod 157 | @custom_bwd 158 | def backward(ctx, grad_output: torch.Tensor): 159 | grads = _GeneralizedLinear._backward_impl( 160 | grad_output, *ctx.saved_tensors, needs_input_grad=ctx.needs_input_grad 161 | ) 162 | return tuple(grad if needed else None for grad, needed in zip_longest(grads, ctx.needs_input_grad)) 163 | 164 | @staticmethod 165 | @maybe_script 166 | def _backward_impl(grad_output: torch.Tensor, 167 | input: torch.Tensor, 168 | lowrank_hid: Optional[torch.Tensor], 169 | main_weight: torch.Tensor, 170 | lowrank_first: Optional[torch.Tensor], 171 | lowrank_second: Optional[torch.Tensor], 172 | backward_indices: Optional[torch.Tensor], 173 | needs_input_grad: List[bool]): 174 | grad_input = grad_input_flat = grad_main_weight = grad_lowrank_first = grad_lowrank_second = grad_bias \ 175 | = grad_output_flat_transposed = grad_lowrank_hid_flat = lowrank_hid_flat = torch.empty(0) 176 | input_flat = input.flatten(0, -2) # [etc, in_features] 177 | grad_output_flat = grad_output.flatten(0, -2) # [etc, out_features] 178 | 179 | if lowrank_hid is not None: 180 | lowrank_hid_flat = lowrank_hid.flatten(0, -2) # [etc, lowrank_dim] 181 | if lowrank_first is not None and (needs_input_grad[0] or needs_input_grad[3]): 182 | assert lowrank_second is not None 183 | grad_lowrank_hid_flat = torch.matmul(grad_output_flat, lowrank_second) # [etc, lowrank_dim] 184 | if needs_input_grad[1] or needs_input_grad[4]: 185 | grad_output_flat_transposed = grad_output_flat.t() # [out_features, etc] 186 | 187 | if needs_input_grad[4]: 188 | assert lowrank_second is not None 189 | grad_lowrank_second = torch.matmul(grad_output_flat_transposed, lowrank_hid_flat) 190 | # ^-- [out_features, lowrank_dim] 191 | if needs_input_grad[3]: 192 | grad_lowrank_hid_flat_transposed = grad_lowrank_hid_flat.t() # [lowrank_dim, etc] 193 | grad_lowrank_first = torch.matmul(grad_lowrank_hid_flat_transposed, input_flat) 194 | # ^-- [lowrank_dim, in_features] 195 | if needs_input_grad[2]: 196 | grad_bias = grad_output_flat.sum(dim=0) # [out_features] 197 | if backward_indices is None: 198 | # dense shared matrix 199 | if needs_input_grad[1]: 200 | grad_main_weight = torch.matmul(grad_output_flat_transposed, input_flat) 201 | # ^-- [out_features, in_features] 202 | if needs_input_grad[0]: 203 | grad_input_flat = torch.matmul(grad_output_flat, main_weight) 204 | else: 205 | # block-sparse shared matrix 206 | grad_input_flat, grad_main_weight = butterfly_matmul_backward( 207 | grad_output_flat, input_flat, main_weight, backward_indices, 208 | input_requires_grad=needs_input_grad[0], weight_requires_grad=needs_input_grad[1]) 209 | 210 | if needs_input_grad[0] and lowrank_first is not None: 211 | # grad w.r.t. input through low-rank components 212 | if 'xla' not in grad_output.device.type: 213 | grad_input_flat = grad_input_flat.addmm_( 214 | grad_lowrank_hid_flat.to(grad_output_flat.dtype), 215 | lowrank_first.to(grad_output_flat.dtype) 216 | ) 217 | else: 218 | grad_input_flat = torch.addmm(grad_input_flat, grad_lowrank_hid_flat, lowrank_first) 219 | if needs_input_grad[0]: 220 | grad_input = grad_input_flat.view_as(input) 221 | return grad_input, grad_main_weight, grad_bias, grad_lowrank_first, grad_lowrank_second, None, None 222 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Yandex Research 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /src/training/lamb_8bit.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is a joint work with Tim Dettmers, based on his library https://github.com/facebookresearch/bitsandbytes ; 3 | Unlike the rest of bnb optimizers, CPULAMB8Bit is tuned to further reduce memory footprint at the cost of performance. 4 | The intended use-case of CPULamb8Bit is to run in background on CPU while training with large batches. 5 | """ 6 | import math 7 | from typing import Any, Dict, Optional 8 | 9 | import torch 10 | from bitsandbytes.functional import dequantize_blockwise, quantize_blockwise 11 | from bitsandbytes.optim.optimizer import Optimizer2State 12 | from torch_optimizer.types import Betas2, Params 13 | 14 | __all__ = ("CPULAMB8Bit",) 15 | 16 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler 17 | 18 | use_hivemind_log_handler("in_root_logger") 19 | logger = get_logger(__name__) 20 | 21 | class CPULAMB8Bit(Optimizer2State): 22 | r""" 23 | Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form. 24 | The LAMB optimizer and block-wise quantization are described in the following papers: 25 | - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962 26 | - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861 27 | This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb 28 | - bias correction defaults to False because paper v3 does not use debiasing 29 | - it has baked in clipping by global max_grad_norm 30 | Arguments: 31 | params: iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr: learning rate (default: 1e-3) 34 | betas: coefficients used for computing 35 | running averages of gradient and its square (default: (0.9, 0.999)) 36 | eps: term added to the denominator to improve 37 | numerical stability (default: 1e-8) 38 | weight_decay: weight decay (L2 penalty) (default: 0) 39 | clamp_value: clamp weight_norm in (0,clamp_value) (default: 10) 40 | set to a high value to avoid it (e.g 10e3) 41 | bias_correction: debias statistics by (1 - beta**step) (default: True) 42 | min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized 43 | reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory. 44 | If enabled, one must ensure that .zero_grad() is called after each optimizer step. 45 | update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | params: Params, 51 | lr: float = 1e-3, 52 | betas: Betas2 = (0.9, 0.999), 53 | eps: float = 1e-6, 54 | weight_decay: float = 0, 55 | clamp_value: float = 10, 56 | bias_correction: bool = True, 57 | min_8bit_size: int = 65536, 58 | reuse_grad_buffers: bool = False, 59 | update_chunk_size: int = 2 ** 24, 60 | max_grad_norm: Optional[float] = None, 61 | ) -> None: 62 | if eps < 0.0: 63 | raise ValueError("Invalid epsilon value: {}".format(eps)) 64 | if not 0.0 <= betas[0] < 1.0: 65 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 66 | if not 0.0 <= betas[1] < 1.0: 67 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 68 | if weight_decay < 0: 69 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 70 | if clamp_value < 0.0: 71 | raise ValueError("Invalid clamp value: {}".format(clamp_value)) 72 | 73 | self.clamp_value = clamp_value 74 | self.bias_correction = bias_correction 75 | self.reuse_grad_buffers = reuse_grad_buffers 76 | self.update_chunk_size = update_chunk_size 77 | self.max_grad_norm = max_grad_norm 78 | 79 | super(CPULAMB8Bit, self).__init__( 80 | "cpu-lamb", 81 | params, 82 | lr, 83 | betas, 84 | eps, 85 | weight_decay, 86 | optim_bits=8, 87 | min_8bit_size=min_8bit_size, 88 | args=None, 89 | percentile_clipping=100, 90 | block_wise=4096, 91 | max_unorm=0, 92 | ) 93 | 94 | @torch.no_grad() 95 | def step(self, closure=None): 96 | if self.max_grad_norm is not None: 97 | iter_params = (param for group in self.param_groups for param in group["params"]) 98 | torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm) 99 | return super().step(closure=closure) 100 | 101 | @torch.no_grad() 102 | def init_state(self, group, p, gindex, pindex): 103 | config = self.get_config(gindex, pindex, group) 104 | assert config["percentile_clipping"] == 100, "percentile clipping is not implemented on CPU" 105 | assert config["max_unorm"] == 0 106 | 107 | if config["optim_bits"] == 32: 108 | dtype = torch.float32 109 | elif config["optim_bits"] == 8: 110 | dtype = torch.uint8 111 | else: 112 | raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') 113 | 114 | if p.numel() < config["min_8bit_size"]: 115 | dtype = torch.float32 116 | 117 | state = self.state[p] 118 | state["step"] = 0 119 | 120 | if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): 121 | state["state1"] = torch.zeros_like( 122 | p, 123 | memory_format=torch.preserve_format, 124 | dtype=torch.float32, 125 | device=p.device, 126 | ) 127 | state["state2"] = torch.zeros_like( 128 | p, 129 | memory_format=torch.preserve_format, 130 | dtype=torch.float32, 131 | device=p.device, 132 | ) 133 | elif dtype == torch.uint8: 134 | if state["step"] == 0: 135 | if "dynamic" not in self.name2qmap: 136 | self.fill_qmap() 137 | self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) 138 | self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) 139 | 140 | n = p.numel() 141 | blocks = (n - 1) // config["block_wise"] + 1 142 | 143 | state["state1"] = torch.zeros_like( 144 | p, 145 | memory_format=torch.preserve_format, 146 | dtype=torch.uint8, 147 | device=p.device, 148 | ) 149 | state["qmap1"] = self.name2qmap["dynamic"] 150 | 151 | state["state2"] = torch.zeros_like( 152 | p, 153 | memory_format=torch.preserve_format, 154 | dtype=torch.uint8, 155 | device=p.device, 156 | ) 157 | state["qmap2"] = self.name2qmap["udynamic"] 158 | 159 | state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) 160 | state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) 161 | 162 | @torch.no_grad() 163 | def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int): 164 | state = self.state[p] 165 | config = self.get_config(gindex, pindex, group) 166 | 167 | p_cpu, grad_cpu = p.cpu(), p.grad.cpu() 168 | # this is a no-op if parameters are already on CPU 169 | 170 | step = state["step"] = state["step"] + 1 171 | beta1, beta2 = group["betas"] 172 | 173 | param_delta = self._update_moments_and_compute_delta( 174 | state, config, p_cpu, grad_cpu, beta1, beta2, group["eps"], group["weight_decay"] 175 | ) 176 | del grad_cpu # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers 177 | 178 | step_norm = torch.norm(param_delta) 179 | weight_norm = p_cpu.norm().clamp(0, self.clamp_value) 180 | 181 | trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0 182 | state["weight_norm"], state["step_norm"], state["trust_ratio"] = (weight_norm, step_norm, trust_ratio) 183 | 184 | # Apply bias to lr to avoid broadcast. 185 | bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1 186 | step_size = group["lr"] * bias_correction 187 | p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio) 188 | 189 | def _update_moments_and_compute_delta( 190 | self, 191 | state: Dict, 192 | config: Dict, 193 | p_cpu: torch.Tensor, 194 | grad_cpu: torch.Tensor, 195 | beta1: float, 196 | beta2: float, 197 | eps: float, 198 | weight_decay: float, 199 | ) -> torch.Tensor: 200 | step, block_size, chunk_size = (state["step"], config["block_wise"], self.update_chunk_size) 201 | 202 | if state["state1"].dtype != torch.uint8: 203 | # not quantized: update normally 204 | exp_avg, exp_avg_sq = state["state1"], state["state2"] 205 | exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1) 206 | exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2) 207 | 208 | sqrt_out = grad_cpu if self.reuse_grad_buffers else None 209 | _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps) 210 | param_delta = torch.div(exp_avg, _denominator, out=_denominator) 211 | if weight_decay != 0: 212 | param_delta.add_(p_cpu, alpha=weight_decay) 213 | return param_delta 214 | elif p_cpu.numel() <= chunk_size: 215 | # quantized tensor within chunk size 216 | exp_avg = dequantize_blockwise(state["state1"], (state["absmax1"], state["qmap1"]), blocksize=block_size) 217 | exp_avg_sq = dequantize_blockwise(state["state2"], (state["absmax2"], state["qmap2"]), blocksize=block_size) 218 | 219 | exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1) 220 | exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2) 221 | 222 | quantize_blockwise(exp_avg, state["qmap1"], state["absmax1"], out=state["state1"]) 223 | quantize_blockwise(exp_avg_sq, state["qmap2"], state["absmax2"], out=state["state2"]) 224 | # note: quantize_blockwise also modifies qmap and absmax in-place 225 | 226 | param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps)) 227 | # note: this changes statistics in-place, but it's okay b/c we saved quantized version 228 | 229 | if weight_decay != 0: 230 | param_delta.add_(p_cpu, alpha=weight_decay) 231 | return param_delta 232 | 233 | else: 234 | # very large quantized tensor, compute updates in chunks to save RAM 235 | flat_p, flat_grad, flat_state1, flat_state2 = ( 236 | tensor.view(-1) for tensor in (p_cpu, grad_cpu, state["state1"], state["state2"]) 237 | ) 238 | output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad) 239 | 240 | for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)): 241 | chunk = slice(chunk_start, chunk_start + chunk_size) 242 | chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size) 243 | 244 | chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk] 245 | chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk] 246 | chunk_absmax1, chunk_absmax2 = ( 247 | state["absmax1"][chunk_blocks], 248 | state["absmax2"][chunk_blocks], 249 | ) 250 | if chunk_state1.storage_offset() != 0: 251 | # clone chunks to ensure that tensors do not have offsets (bnb hack, possibly no longer needed) 252 | chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map( 253 | torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2), 254 | ) 255 | 256 | exp_avg_chunk = dequantize_blockwise( 257 | chunk_state1, (chunk_absmax1, state["qmap1"]), blocksize=block_size 258 | ) 259 | exp_avg_sq_chunk = dequantize_blockwise( 260 | chunk_state2, (chunk_absmax2, state["qmap2"]), blocksize=block_size 261 | ) 262 | 263 | exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1) 264 | exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2) 265 | 266 | # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu 267 | del chunk_grad 268 | 269 | flat_state1[chunk], ( 270 | state["absmax1"][chunk_blocks], 271 | state["qmap1"], 272 | ) = quantize_blockwise(exp_avg_chunk, state["qmap1"], chunk_absmax1, out=chunk_state1) 273 | flat_state2[chunk], ( 274 | state["absmax2"][chunk_blocks], 275 | state["qmap2"], 276 | ) = quantize_blockwise(exp_avg_sq_chunk, state["qmap2"], chunk_absmax2, out=chunk_state2) 277 | # note: we need to explicitly assign new quantized tensors because of cloning earlier 278 | 279 | torch.div( 280 | exp_avg_chunk, 281 | exp_avg_sq_chunk.sqrt_().add_(eps), 282 | out=output_buffer[chunk], 283 | ) 284 | # note: this changes statistics in-place, but it's okay b/c we saved quantized version 285 | 286 | if weight_decay != 0: 287 | output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay) 288 | 289 | param_delta = output_buffer.view_as(grad_cpu) 290 | 291 | return param_delta 292 | -------------------------------------------------------------------------------- /src/modules/ffn.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | from src.modules.functional import ACT2FN 10 | from src.modules.linear import GeneralizedLinear, _GeneralizedLinear 11 | 12 | 13 | class LeanFFN(nn.Module): 14 | """ 15 | A transformer FFN module that doesn't hog your GPU memory. Uses a manually optimized differentiation algorithm. 16 | 17 | :param hidden_size: base hidden size of the transformer 18 | :param intermediate_size: a (typically larger) hidden dimension where activation is applied 19 | :param activation: a pytorch nonlinearity to use in the intermediate layer 20 | :param gated: use gated activations based on https://arxiv.org/abs/2002.05202 and https://arxiv.org/abs/2102.11972 21 | note: gated activations require 1.5x more parameters compared to their non-gated variants. 22 | :param layer_norm_eps: see torch.nn.functional.layer_norm 23 | :param sandwich_norm: if set, applies an additional layer norm to projected attention outputs before residuals, 24 | as proposed in the CogView paper ( arXiv:2105.13290 ). This is meant to make fp16 training 25 | more stable for deep transformers. This technique is also a part of NormFormer ( arXiv:2110.09456 ) 26 | :param dropout: hidden dropout probability, applied to the output projection (before adding residual) 27 | :param residual: if True, adds the original layer input to the final layer output 28 | 29 | :param dense_i2h: custom *first* linear layer (hidden_size -> intermediate_size or 2x indermediate_size) 30 | :param dense_h2o: custom *second* linear layer (intermediate_size -> hidden_size) 31 | """ 32 | 33 | def __init__( 34 | self, 35 | hidden_size: int, 36 | intermediate_size: int, 37 | activation=ACT2FN["gelu_fused"], 38 | gated: bool = False, 39 | layer_norm_eps: float = 1e-12, 40 | dropout: float = 0.0, 41 | sandwich_norm: bool = False, 42 | dense_i2h: Optional[nn.Linear] = None, 43 | dense_h2o: Optional[nn.Linear] = None, 44 | residual: bool = True, 45 | ): 46 | super().__init__() 47 | i2h_out_features = intermediate_size * 2 if gated else intermediate_size 48 | self.dense_i2h = nn.Linear(hidden_size, i2h_out_features) if dense_i2h is None else dense_i2h 49 | self.dense_h2o = nn.Linear(intermediate_size, hidden_size) if dense_h2o is None else dense_h2o 50 | assert type(self.dense_i2h) in ( 51 | nn.Linear, GeneralizedLinear), "only Linear and GeneralizedLinear are supported" 52 | assert type(self.dense_h2o) in ( 53 | nn.Linear, GeneralizedLinear), "only Linear and GeneralizedLinear are supported" 54 | assert self.dense_i2h.in_features == self.dense_h2o.out_features == hidden_size 55 | assert self.dense_i2h.out_features == i2h_out_features and self.dense_h2o.in_features == intermediate_size 56 | self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) 57 | self.sandwich_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if sandwich_norm else None 58 | self.activation = activation 59 | self.gated = gated 60 | self.dropout = dropout 61 | self.residual = residual 62 | 63 | def forward(self, input): 64 | sandwich_ln_weight = sandwich_ln_bias = None 65 | if self.sandwich_norm is not None: 66 | sandwich_ln_weight, sandwich_ln_bias = self.sandwich_norm.weight, self.sandwich_norm.bias 67 | i2h_lowrank_first = i2h_lowrank_second = h2o_lowrank_first = h2o_lowrank_second = None 68 | i2h_forward_indices = i2h_backward_indices = h2o_forward_indices = h2o_backward_indices = None 69 | if isinstance(self.dense_i2h, GeneralizedLinear): 70 | i2h_lowrank_first, i2h_lowrank_second = self.dense_i2h.get_combined_lowrank_components() 71 | i2h_forward_indices = self.dense_i2h.shared_matrix.forward_indices 72 | i2h_backward_indices = self.dense_i2h.shared_matrix.backward_indices 73 | if isinstance(self.dense_h2o, GeneralizedLinear): 74 | h2o_lowrank_first, h2o_lowrank_second = self.dense_h2o.get_combined_lowrank_components() 75 | h2o_forward_indices = self.dense_h2o.shared_matrix.forward_indices 76 | h2o_backward_indices = self.dense_h2o.shared_matrix.backward_indices 77 | 78 | output = _LeanFFN.apply( 79 | input, 80 | self.layer_norm.weight, 81 | self.layer_norm.bias, 82 | self.dense_i2h.weight, 83 | self.dense_i2h.bias, 84 | i2h_lowrank_first, 85 | i2h_lowrank_second, 86 | i2h_forward_indices, 87 | i2h_backward_indices, 88 | self.dense_h2o.weight, 89 | self.dense_h2o.bias, 90 | h2o_lowrank_first, 91 | h2o_lowrank_second, 92 | h2o_forward_indices, 93 | h2o_backward_indices, 94 | sandwich_ln_weight, 95 | sandwich_ln_bias, 96 | self.activation, 97 | self.gated, 98 | self.dropout, 99 | self.training, 100 | self.layer_norm.eps, 101 | self.residual, 102 | ) 103 | return output 104 | 105 | 106 | class _LeanFFN(torch.autograd.Function): 107 | """Autograd function for transformer FFN, manually optimized to reduce memory without affecting performance""" 108 | 109 | @staticmethod 110 | def _apply_activation(pre_activation: torch.Tensor, activation: callable, gated: bool): 111 | if not gated: 112 | return activation(pre_activation) 113 | else: 114 | pre_gate, lin = pre_activation.split(pre_activation.shape[-1] // 2, dim=-1) 115 | return activation(pre_gate).mul_(lin) 116 | 117 | @staticmethod 118 | @custom_fwd 119 | def forward( 120 | ctx, 121 | input: torch.Tensor, 122 | ln_weight: torch.Tensor, 123 | ln_bias: torch.Tensor, 124 | i2h_weight: torch.Tensor, 125 | i2h_bias: Optional[torch.Tensor], 126 | i2h_lowrank_first: Optional[torch.Tensor], 127 | i2h_lowrank_second: Optional[torch.Tensor], 128 | i2h_forward_indices: Optional[torch.IntTensor], 129 | i2h_backward_indices: Optional[torch.IntTensor], 130 | h2o_weight: torch.Tensor, 131 | h2o_bias: Optional[torch.Tensor], 132 | h2o_lowrank_first: Optional[torch.Tensor], 133 | h2o_lowrank_second: Optional[torch.Tensor], 134 | h2o_forward_indices: Optional[torch.IntTensor], 135 | h2o_backward_indices: Optional[torch.IntTensor], 136 | sandwich_ln_weight: Optional[torch.Tensor], 137 | sandwich_ln_bias: Optional[torch.Tensor], 138 | activation: callable, 139 | gated: bool, 140 | dropout: float, 141 | training: bool, 142 | ln_eps: float, 143 | residual: bool, 144 | ): 145 | ctx._dropout, ctx._training, ctx._ln_eps = dropout, training, ln_eps 146 | ctx._activation, ctx._gated, ctx._residual = activation, gated, residual 147 | ctx._use_sandwich = sandwich_ln_weight is not None 148 | 149 | dropout_mask, pre_sandwich = None, None # optional tensors to save 150 | input_2d = input.view(-1, input.shape[-1]) 151 | 152 | input_ln = F.layer_norm(input_2d, input.shape[-1:], ln_weight, ln_bias, ln_eps) 153 | 154 | pre_activation, *i2h_tensors = _GeneralizedLinear._forward_impl( 155 | input_ln, i2h_weight, i2h_bias, i2h_lowrank_first, i2h_lowrank_second, i2h_forward_indices, 156 | i2h_backward_indices 157 | ) 158 | 159 | hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, ctx._gated) 160 | 161 | out, *h2o_tensors = _GeneralizedLinear._forward_impl( 162 | hid_act, h2o_weight, h2o_bias, h2o_lowrank_first, h2o_lowrank_second, h2o_forward_indices, 163 | h2o_backward_indices 164 | ) 165 | 166 | if ctx._use_sandwich: 167 | pre_sandwich = out 168 | out = F.layer_norm(pre_sandwich, pre_sandwich.shape[-1:], sandwich_ln_weight, sandwich_ln_bias, eps=ln_eps) 169 | 170 | out = F.dropout(out, dropout, training, inplace=True) 171 | if training and dropout: 172 | dropout_mask = (out == 0.0).to(torch.int8) 173 | 174 | if residual: 175 | out = torch.add(out, input_2d, out=out if 'xla' not in out.device.type else None) 176 | 177 | assert i2h_tensors[0] is input_ln and h2o_tensors[0] is hid_act # we can rematerialize these tensors 178 | tensors_to_save = [ 179 | input, pre_activation, ln_weight, ln_bias, pre_sandwich, sandwich_ln_weight, sandwich_ln_bias, dropout_mask 180 | ] 181 | tensors_to_save.extend((*i2h_tensors[1:], *h2o_tensors[1:])) 182 | ctx.save_for_backward(*tensors_to_save) 183 | ctx._num_i2h_tensors = len(i2h_tensors) 184 | ctx._num_h2o_tensors = len(h2o_tensors) 185 | return out.view(*input.shape) 186 | 187 | @staticmethod 188 | def _h2o_backward(ctx, grad_output: torch.Tensor, hid_act: torch.Tensor): 189 | h2o_tensors = ctx.saved_tensors[-ctx._num_h2o_tensors + 1:] 190 | needs_input_grad = [hid_act.requires_grad, *ctx.needs_input_grad[9:15]] 191 | grads = _GeneralizedLinear._backward_impl(grad_output, hid_act, *h2o_tensors, 192 | needs_input_grad=needs_input_grad) 193 | return tuple(grad if needed else None for grad, needed in zip_longest(grads, needs_input_grad)) 194 | 195 | @staticmethod 196 | def _i2h_backward(ctx, grad_output: torch.Tensor, input_ln: torch.Tensor): 197 | i2h_tensors = ctx.saved_tensors[-ctx._num_i2h_tensors - ctx._num_h2o_tensors + 2: -ctx._num_h2o_tensors + 1] 198 | needs_input_grad = [input_ln.requires_grad, *ctx.needs_input_grad[3:9]] 199 | grads = _GeneralizedLinear._backward_impl(grad_output, input_ln, *i2h_tensors, 200 | needs_input_grad=needs_input_grad) 201 | return tuple(grad if needed else None for grad, needed in zip_longest(grads, needs_input_grad)) 202 | 203 | @staticmethod 204 | @custom_bwd 205 | def backward(ctx, grad_output): 206 | grad_input = grad_ln_weight = grad_ln_bias = grad_sandwich_ln_weight = grad_sandwich_ln_bias = None 207 | input, pre_activation, ln_weight, ln_bias, = ctx.saved_tensors[:4] 208 | pre_sandwich, sandwich_ln_weight, sandwich_ln_bias, dropout_mask = ctx.saved_tensors[4: 8] 209 | grad_output_2d = grad_output.view(-1, grad_output.shape[-1]) 210 | 211 | # backward(... -> sandwich_norm -> dropout -> residual) 212 | grad_residual_2d = grad_output_2d if ctx._residual else None 213 | if dropout_mask is not None: 214 | grad_output_2d = grad_output_2d.mul(dropout_mask.to(grad_output_2d.dtype)) 215 | if ctx._use_sandwich: 216 | assert pre_sandwich is not None 217 | with torch.enable_grad(): 218 | required_grad = pre_sandwich.requires_grad 219 | pre_sandwich.requires_grad_(True) 220 | sandwich_out = F.layer_norm( 221 | pre_sandwich, pre_sandwich.shape[-1:], sandwich_ln_weight, sandwich_ln_bias, eps=ctx._ln_eps 222 | ) 223 | grad_output, grad_sandwich_ln_weight, grad_sandwich_ln_bias = torch.autograd.grad( 224 | sandwich_out, [pre_sandwich, sandwich_ln_weight, sandwich_ln_bias], grad_outputs=grad_output_2d 225 | ) 226 | pre_sandwich.requires_grad_(required_grad) 227 | del pre_sandwich, sandwich_out 228 | 229 | # backward(... -> nonlinearity -> intermediate_layernorm -> linear_h2o -> ...) 230 | input_2d = input.view(-1, input.shape[-1]) 231 | grad_h2o_output_2d = grad_output.view(-1, grad_output.shape[-1]) 232 | 233 | with torch.enable_grad(): 234 | # rematerialize activation 235 | pre_activation.requires_grad_(True) 236 | hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, ctx._gated) 237 | 238 | with torch.no_grad(): 239 | (grad_hid_act, grad_h2o_weight, grad_h2o_bias, grad_h2o_lowrank_first, grad_h2o_lowrank_second, 240 | unused_grad_forward_indices, unused_grad_backward_indices) = \ 241 | _LeanFFN._h2o_backward(ctx, grad_h2o_output_2d, hid_act) 242 | 243 | (grad_hid,) = torch.autograd.grad(hid_act, pre_activation, grad_outputs=grad_hid_act) 244 | pre_activation.requires_grad_(False) 245 | del hid_act 246 | 247 | # backward(... -> input_layernorm -> linear_i2h -> ...) 248 | with torch.enable_grad(): 249 | # rematerialize input_ln 250 | input_2d.requires_grad_(True) 251 | input_ln_2d = F.layer_norm(input_2d, input.shape[-1:], ln_weight, ln_bias, ctx._ln_eps) 252 | 253 | with torch.no_grad(): 254 | (grad_input_ln_2d, grad_i2h_weight, grad_i2h_bias, grad_i2h_lowrank_first, grad_i2h_lowrank_second, 255 | unused_grad_forward_indices, unused_grad_backward_indices) = \ 256 | _LeanFFN._i2h_backward(ctx, grad_hid, input_ln_2d) 257 | 258 | if any(ctx.needs_input_grad[0:3]): 259 | partial_grad_input_2d, grad_ln_weight, grad_ln_bias = torch.autograd.grad( 260 | outputs=input_ln_2d, inputs=[input_2d, ln_weight, ln_bias], grad_outputs=grad_input_ln_2d 261 | ) 262 | del input_2d, input_ln_2d, grad_input_ln_2d 263 | 264 | # add up residual grads 265 | if ctx.needs_input_grad[0]: 266 | grad_input = partial_grad_input_2d 267 | if ctx._residual: 268 | grad_input = grad_input.add_(grad_residual_2d) 269 | grad_input = grad_input.view(*input.shape) 270 | 271 | return (grad_input, grad_ln_weight, grad_ln_bias, 272 | grad_i2h_weight, grad_i2h_bias, grad_i2h_lowrank_first, grad_i2h_lowrank_second, None, None, 273 | grad_h2o_weight, grad_h2o_bias, grad_h2o_lowrank_first, grad_h2o_lowrank_second, None, None, 274 | grad_sandwich_ln_weight, grad_sandwich_ln_bias, None, None, None, None, None, None) 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RuLeanALBERT 2 | 3 | RuLeanALBERT is a pretrained masked language model for the Russian language using a memory-efficient architecture. 4 | 5 | ## Using the model 6 | You can download the pretrained weights, the tokenizer and the config file used for pretraining from the Hugging Face Hub: [huggingface.co/yandex/RuLeanALBERT](https://huggingface.co/yandex/RuLeanALBERT). 7 | Download them directly by running the following code: 8 | ``` 9 | wget https://huggingface.co/yandex/RuLeanALBERT/resolve/main/state.pth state.pth 10 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/config.json config.json 11 | mkdir tokenizer 12 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/tokenizer/config.json tokenizer/config.json 13 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/tokenizer/special_tokens_map.json tokenizer/special_tokens_map.json 14 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/tokenizer/tokenizer.json tokenizer/tokenizer.json 15 | ``` 16 | 17 | As the model itself is using custom code (see [`src/models`](./src/models)), right now the simplest solution is to clone the repository and to use the relevant classes (`LeanAlbertForPreTraining` and its dependencies) in your code. 18 | 19 | Loading the model is as simple as 20 | ``` 21 | tokenizer = AlbertTokenizerFast.from_pretrained('tokenizer') 22 | config = LeanAlbertConfig.from_pretrained('config.json') 23 | model = LeanAlbertForPreTraining(config) 24 | model.resize_token_embeddings(len(tokenizer)) 25 | checkpoint = torch.load(checkpoint_path, map_location='cpu')['model'] 26 | model.load_state_dict(checkpoint) 27 | ``` 28 | 29 | ## Fine-tuning guide 30 | 31 | Once you have downloaded the model, you can use [`finetune_mlm.py`](./finetune_mlm.py) to evaluate it on [RussianSuperGLUE](https://russiansuperglue.com/) and [RuCoLA](https://rucola-benchmark.com/). 32 | 33 | To do this, you can run the command as follows: 34 | ``` 35 | python finetune_mlm.py -t TASK_NAME \ 36 | --checkpoint state.pth \ 37 | --data-dir . \ 38 | --batch-size 32 \ 39 | --grad-acc-steps 1 \ 40 | --dropout 0.1 \ 41 | --weight-decay 0.01 \ 42 | --num-seeds 1 43 | ``` 44 | 45 | Most datasets will be loaded from the Hugging Face Hub. However, you need to download [RuCoLA](https://github.com/RussianNLP/RuCoLA) and [RuCoS](https://russiansuperglue.com/tasks/task_info/RuCoS) from their respective sources and place them in `RuCoLA` and `RuCoS` directories respectively. 46 | 47 | For reference, finetuning with a batch size of 32 on RuCoLA should take approximately 10 GB of GPU memory. If you exceed the GPU memory limits for a specific task, you can reduce the batch size by reducing `--batch-size` and increasing `--grad-acc-steps` accordingly. 48 | 49 | If you want to finetune an existing masked language model from the Hugging Face Hub, you can do it with the same code. The script directly supports fine-tuning RuRoBERTa-large: simply change `--checkpoint` to `--model-name ruroberta-large`. 50 | 51 | For LiDiRus, you should use the model trained on the TERRa dataset with `predict_on_lidirus.py`. 52 | For RWSD, all ML-based solutions perform worse than the most-frequent class baseline: in our experiments, we found that the majority of runs with RuLeanALBERT or RuRoBERTa converge to a similar solution. 53 | 54 | ## Pretraining a new model 55 | 56 | Here you can find the best practices that we learned from running the experiment. These may help you set up your own collaborative experiment. 57 | 58 | If your training run is not confidential, feel free to ask for help on the [Hivemind discord server](https://discord.gg/vRNN9ua2). 59 | 60 |
61 | 1. Choose and verify your training configuration 62 | 63 | Depending on you use case, you may want to change 64 | - Dataset and preprocessing ([`data.py`](tasks/mlm/data.py)); 65 | - Tokenizer ([`arguments.py`](arguments.py)); 66 | - Model config 67 | 68 | In particular, you need to specify the datasets in the `make_training_dataset` function from [`data.py`](tasks/mlm/data.py). 69 | One solution is to use the [`datasets`](https://github.com/huggingface/datasets) library and stream one of the existing large datasets for your target domain. **A working example** of `data.py` can be found in the [NCAI-research/CALM](https://github.com/NCAI-Research/CALM/blob/main/tasks/mlm/data.py) project. 70 | 71 | 72 | When transitioning to a new language or new dataset, it is important to check that the tokenizer/collator works as intended **before** you begin training. 73 | The best way to do that is to manually look at training minibatches: 74 | ```python 75 | from tasks.mlm.data import make_training_dataset 76 | from tasks.mlm.whole_word_mask import DataCollatorForWholeWordMask 77 | 78 | tokenizer = create_tokenizer_here(...) 79 | dataset = make_training_dataset(tokenizer, max_sequence_length=...) # see arguments.py 80 | collator = DataCollatorForWholeWordMask(tokenizer, pad_to_multiple_of=...) # see arguments.py 81 | data_loader = torch.utils.data.DataLoader(dataset, collate_fn=collator, batch_size=4) 82 | 83 | # generate a few batches 84 | rows = [] 85 | with tqdm(enumerate(data_loader)) as progress: 86 | for i, row in progress: 87 | rows.append(row) 88 | if i > 10: 89 | break 90 | 91 | # look into the training data 92 | row_ix, sample_ix = 0, 1 93 | sources = [tokenizer.decode([i]) for i in rows[row_ix]['input_ids'][sample_ix].data.numpy()] 94 | print("MASK RATE:", (rows[row_ix]['input_ids'][sample_ix] == 4).data.numpy().sum() / (rows[row_ix]['input_ids'][sample_ix] != 0).data.numpy().sum()) 95 | 96 | for i in range(len(sources)): 97 | if sources[i] == '[MASK]': 98 | pass#sources[i] = '[[' + tokenizer.decode(rows[row_ix]['labels'][sample_ix][i].item()) + ']]' 99 | 100 | print(' '.join(sources)) 101 | ``` 102 | 103 | If you make many changes, it also helps to train a very model using your own device to check if everything works as intended. A good initial configuration is 6 layers, 512 hidden, 2048 intermediate). 104 | 105 | If you're training with volunteers, the most convenient way is to set up a Hugging Face organization. For instructions on that, see "make your own" section of https://training-transformers-together.github.io . We use WANDB for tracking logs and training progress: we've set up a [WandB team](https://docs.wandb.ai/ref/app/features/teams) for this experiment. Alternatively, you can use hivemind standalone (and even without internet access) by setting --authorize False and WANDB_DISABLED=true -- or manually removing the corresponding options from the code. 106 | 107 |
108 | 109 |
110 | 2. Setting up auxiliary peers 111 | 112 | Auxiliary peers are low-end servers without GPU that will keep track of the latest model checkpoint and report metrics and assist in communication. 113 | You will need 1-3 workers that track metrics, upload statistics, etc. These peers do not use GPU. 114 | If you have many participants are behind firewall (in --client_mode), it helps to add more auxiliary servers, as they can serve as relays and help with all-reduce. 115 | 116 | __Minimum requirements:__ 15+ GB RAM, at least 100Mbit/s download/upload speed, at least one port opened to incoming connections; 117 | 118 | __Where to get:__ cloud providers that have cheap ingress/egress pricing. Good examples: [pebblehost](https://pebblehost.com/dedicated/) "Essential-01" and [hetzner](https://www.hetzner.com/cloud) CX41. Path of the true jedi: use your homelab or university server -- but that may require networking experience. AWS/GCP/Azure has similar offers, but they cost more due to [egress pricing](https://cloud.google.com/vpc/network-pricing). 119 | 120 | 121 | __Setup env:__ 122 | 123 | ``` 124 | sudo apt install -y git tmux 125 | curl https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh > Anaconda3-2021.11-Linux-x86_64.sh 126 | bash Anaconda3-2021.11-Linux-x86_64.sh -b -p ~/anaconda3 127 | source ~/anaconda3/bin/activate 128 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 129 | 130 | git clone https://github.com/yandex-research/RuLeanALBERT 131 | cd RuLeanALBERT && pip install -q -r requirements.txt &> log 132 | 133 | # re-install bitsandbytes for the actual CUDA version 134 | pip uninstall -y bitsandbytes-cuda111 135 | pip install bitsandbytes-cuda113==0.26.0 136 | 137 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 138 | sudo apt-get install git-lfs 139 | git lfs install 140 | ``` 141 | 142 | 143 | __Run auxiliary worker:__ 144 | 145 | 1. Open a tmux (or screen) session that will stay up after you logout. (`tmux new` , [about tmux](https://tmuxcheatsheet.com/)) 146 | 147 | 2. Measure internet bandwidth and set `$BANDWIDTH` variable 148 | ```bash 149 | 150 | # You can measure bandwidth automatically: 151 | curl -s https://gist.githubusercontent.com/justheuristic/5467799d8f2ad59b36fa75f642cc9b87/raw/c5a4b9b66987c2115e6c54a07d97e0104dfbcd97/speedtest.py | python - --json > speedtest.json 152 | export BANDWIDTH=`python -c "import json; speedtest = json.load(open('speedtest.json')); print(int(max(1, min(speedtest['upload'], speedtest['download']) / 1e6)))"` 153 | echo "Internet Bandwidth (Mb/s) = $BANDWIDTH" 154 | 155 | # If that doesn't work, you can simply `export BANDWIDTH=TODOyour_bandwidth_mbits_here` using the minimum of download and upload speed. 156 | ``` 157 | 158 | 159 | 3. Run the auxiliary peer 160 | ```bash 161 | export MY_IP=`curl --ipv4 -s http://whatismyip.akamai.com/` 162 | echo "MY IP (check not empty):" $MY_IP 163 | # If empty, please set this manually: export MY_IP=... 164 | # When training on internal infrastructure, feel free to use internal IP. 165 | # If using IPv6, please replace /ip4/ with /ip6/ in subsequent lines 166 | 167 | export PORT=12345 # please choose a port where you can accept incoming tcp connections (or open that port if you're on a cloud) 168 | export LISTEN_ON=/ip4/0.0.0.0/tcp/$PORT 169 | export ANNOUNCE_ON=/ip4/$MY_IP/tcp/$PORT 170 | export WANDB_START_METHOD=thread 171 | export CUDA_VISIBLE_DEVICES= # do not use GPUs even if they are avilable 172 | 173 | # organizations 174 | export WANDB_ENTITY=YOUR_USERNAME_HERE 175 | export HF_ORGANIZATION_NAME=YOUR_ORG_HERE 176 | 177 | # experiment name 178 | export EXP_NAME=my-exp 179 | export WANDB_PROJECT=$EXP_NAME 180 | export HF_MODEL_NAME=$EXP_NAME 181 | 182 | export WANDB_API_KEY=TODO_get_your_wandb_key_here_wandb.ai/authorize 183 | export HF_USER_ACCESS_TOKEN=TODO_create_user_access_token_here_with_WRITE_permissions_https://huggingface.co/settings/token 184 | # note: you can avoid setting the two tokens above: in that case, the script will ask you to login to wandb and huggingface 185 | 186 | # activate your anaconda environment 187 | source ~/anaconda3/bin/activate 188 | 189 | 190 | ulimit -n 16384 # this line is important, ignoring it may cause a "Too Many Open Files" error 191 | 192 | python run_aux_peer.py --run_id $EXP_NAME --host_maddrs $LISTEN_ON --announce_maddrs $ANNOUNCE_ON --wandb_project $WANDB_PROJECT --store_checkpoints --upload_interval 43200 --repo_url $HF_ORGANIZATION_NAME/$HF_MODEL_NAME --assist_in_averaging --bandwidth $BANDWIDTH 193 | # Optionally, add more peers to the training via `--initial_peers ONE_OR_MORE PEERS_HERE` 194 | ``` 195 | 196 | If everything went right, it will print its address as such: 197 | ![image](https://user-images.githubusercontent.com/3491902/146950956-0ea06e77-15b4-423f-aeaa-02eb6aec06db.png) 198 | 199 | Please copy this address and use it as ``--initial_peers`` with GPU trainers and other auxiliary peers. 200 |
201 | 202 | 203 |
204 | 3. Setting up a trainer 205 | Trainers are peers with GPUs (or other compute accelerators) that compute gradients, average them via all-reduce and perform optimizer steps. 206 | There are two broad types of trainers: normal (full) peers and client mode peers. Client peers rely on others to average their gradients, but otherwise behave same as full peers. You can designate your trainer as a client-only using the `--client_mode` flag. 207 | 208 | __When do I need client mode?__ if a peer is unreliable (e.g. will likely be gone in 1 hour) OR sits behind a firewall that blocks incoming connections OR has very unstable internet connection, it should be a client. For instance, it is recommended to set colab / kaggle peers as clients. In turn, cloud GPUs (even spot instances!) are generally more reliable and should be full peers. 209 | 210 | Participating as a client is easy, you can find the code for that in **this colab notebook(TODO)**. Setting up a full peer is more difficult, 211 | ### Set up environment: 212 | 213 | This part is the same as in auxiliary peer, except we don't need LFS (that was needed to upload checkpoints). 214 | ```bash 215 | sudo apt install -y git tmux 216 | curl https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh > Anaconda3-2021.11-Linux-x86_64.sh 217 | bash Anaconda3-2021.11-Linux-x86_64.sh -b -p ~/anaconda3 218 | source ~/anaconda3/bin/activate 219 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 220 | 221 | git clone https://github.com/yandex-research/RuLeanALBERT 222 | cd RuLeanALBERT && pip install -q -r requirements.txt &> log 223 | 224 | # re-install bitsandbytes for the actual CUDA version 225 | pip uninstall -y bitsandbytes-cuda111 226 | pip install -y bitsandbytes-cuda113==0.26.0 227 | 228 | # note: we use bitsandbytes for 8-bit LAMB, and in turn, bitsandbytes needs cuda -- even if you run on a non-CUDA device. 229 | ``` 230 | 231 | ```bash 232 | export MY_IP=`curl --ipv4 -s http://whatismyip.akamai.com/` 233 | echo "MY IP (check not empty):" $MY_IP 234 | # If empty, please set this manually: export MY_IP=... 235 | # When training on internal infrastructure, feel free to use internal IP. 236 | # If using IPv6, please replace /ip4/ with /ip6/ in subsequent lines 237 | 238 | export PORT=31337 # same requirements as for aux peer 239 | export LISTEN_ON=/ip4/0.0.0.0/tcp/$PORT 240 | 241 | export CUDA_VISIBLE_DEVICES=0 # supports multiple cuda devices! 242 | 243 | # organization & experiment name 244 | export WANDB_ENTITY=YOUR_USERNAME_HERE 245 | export HF_ORGANIZATION_NAME=YOUR_ORG_HERE 246 | export EXP_NAME=my-exp 247 | export WANDB_PROJECT=$EXP_NAME-hivemind-trainers 248 | export HF_MODEL_NAME=$EXP_NAME 249 | 250 | export WANDB_API_KEY=get_your_wandb_key_here_https://wandb.ai/authorize_OR_just_login_on_wandb 251 | export HF_USER_ACCESS_TOKEN=create_user_access_token_here_with_WRITE_permissions_https://huggingface.co/settings/token 252 | # note: you can avoid setting the two tokens above: in that case, the script will ask you to login to wandb and huggingface 253 | 254 | export INITIAL_PEERS="/ip4/IP_ADDR/tcp/12345/p2p/PEER_ID" 255 | # ^-- If you're runnnng an independent experiment, this must be your own initial peers. Can be either auxiliary peers or full gpu peers. 256 | 257 | 258 | curl -s https://raw.githubusercontent.com/sivel/speedtest-cli/master/speedtest.py | python - --json > speedtest.json 259 | export BANDWIDTH=`python -c "import json; speedtest = json.load(open('speedtest.json')); print(int(max(1, min(speedtest['upload'], speedtest['download']) / 1e6)))"` 260 | echo "Internet Bandwidth (Mb/s) = $BANDWIDTH" 261 | 262 | ulimit -n 16384 # this line is important, ignoring it may cause a "Too Many Open Files" error 263 | 264 | python run_trainer.py --run_id $EXP_NAME --host_maddrs $LISTEN_ON --announce_maddrs $ANNOUNCE_ON --initial_peers $INITIAL_PEERS --bandwidth $BANDWIDTH \ 265 | --per_device_train_batch_size 1 --gradient_accumulation_steps 1 266 | # you can tune per_device_train_batch_size, gradient_accumulation steps, --fp16, --gradient_checkpoints based on the device. A good rule of thumb is that the device should compute (batch size x num accumulations) gradients over 1-10 seconds. Setting very large gradient_accumulation_steps can cause your peer to miss an averaging round. 267 | 268 | ``` 269 | 270 | 271 |
272 | 273 |
274 | Best (and worst) practices 275 | 276 | - __Hardware requirements:__ The code is meant to run with the following specs: 2-core CPU, 12gb RAM (more if you train a bigger model). Peers used as `--initial_peers` must be accessible by others, so you may need to open a network port for incoming connections. The rest depends on what role you're playing: 277 | 278 | - __Auxiliary peers:__ If you use `--upload_interval X --repo_url Y` must have enough disk space to store all the checkpoints. For instance, assuming that training takes 1 month and the model+optimizer state takes 1GB, you will need 30GB with `--upload_interval 86400`, 60GB if `--upload_interval 28800`, etc. If `assist_in_averaging`, ensure you have at least 100Mbit/s bandwidth, more is better. 279 | 280 | - __Trainers__ need *some* means for compute: a GPU with at least 6GB memory or a TPU - as long as you can run pytorch on that. You will need to tune `--per_device_train_batch_size X` to fit into memory. Also, you can use `--fp16` even on old GPUs to save memory. Finally, `--gradient_checkpointing` can reduce memory usage at the cost of 30-40% slower training. Non-client-mode peers must have at least 100Mbit/s network bandwidth, mode is better. 281 | 282 | 283 | - __Swarm composition:__ you will need 2-3 peers with public IP as `--initial_peers` for redundancy. If some participants are behind firewalls, we recommend finding at least one non-firewalled participant per 5 peers behind firewall. 284 | 285 |
286 | 287 | ## Acknowledgements 288 | 289 | Many of the best practices from this guide were found in the [CALM](https://github.com/NCAI-Research/CALM) project for collaborative training by NCAI. 290 | -------------------------------------------------------------------------------- /src/models/albert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch ALBERT modules that do not hog your GPU memory """ 16 | 17 | import torch 18 | from torch import nn as nn 19 | from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss 20 | from transformers import AlbertPreTrainedModel 21 | from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput 22 | from transformers.modeling_utils import PreTrainedModel 23 | from transformers.models.albert.modeling_albert import AlbertForPreTrainingOutput 24 | from transformers.utils import logging 25 | 26 | from src.models.transformer import GradientCheckpointingMixin, LeanTransformer, LeanTransformerConfig 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | _CONFIG_FOR_DOC = "LeanAlbertConfig" 31 | _TOKENIZER_FOR_DOC = "AlbertTokenizer" 32 | 33 | 34 | class LeanAlbertConfig(LeanTransformerConfig): 35 | def __init__( 36 | self, 37 | *args, 38 | vocab_size: int = 30000, 39 | embedding_size: int = 128, 40 | classifier_dropout_prob: float = 0.1, 41 | type_vocab_size: int = 2, 42 | pad_token_id: int = 0, 43 | bos_token_id: int = 2, 44 | eos_token_id: int = 3, 45 | **kwargs 46 | ): 47 | super().__init__( 48 | *args, 49 | pad_token_id=pad_token_id, 50 | bos_token_id=bos_token_id, 51 | eos_token_id=eos_token_id, 52 | type_vocab_size=type_vocab_size, 53 | **kwargs 54 | ) 55 | self.vocab_size = vocab_size 56 | self.embedding_size = embedding_size 57 | self.classifier_dropout_prob = classifier_dropout_prob 58 | self.type_vocab_size = type_vocab_size 59 | 60 | 61 | class LeanAlbertEmbeddings(nn.Module): 62 | """ 63 | Construct the embeddings from word, position and token_type embeddings. 64 | """ 65 | 66 | def __init__(self, config: LeanTransformerConfig): 67 | super().__init__() 68 | self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) 69 | 70 | self.token_type_embeddings = config.get_token_type_embeddings() 71 | self.position_embeddings = config.get_input_position_embeddings() 72 | 73 | self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) 74 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 75 | if config.embedding_size != config.hidden_size: 76 | self.embedding_hidden_mapping = nn.Linear(config.embedding_size, config.hidden_size) 77 | 78 | if self.position_embeddings is not None: 79 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 80 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 81 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 82 | 83 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward 84 | def forward( 85 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 86 | ): 87 | if input_ids is not None: 88 | input_shape = input_ids.size() 89 | else: 90 | input_shape = inputs_embeds.size()[:-1] 91 | 92 | seq_length = input_shape[1] 93 | 94 | if token_type_ids is None: 95 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 96 | 97 | if inputs_embeds is None: 98 | inputs_embeds = self.word_embeddings(input_ids) 99 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 100 | 101 | embeddings = inputs_embeds + token_type_embeddings 102 | 103 | if self.position_embeddings is not None: 104 | if position_ids is None: 105 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] 106 | position_embeddings = self.position_embeddings(position_ids) 107 | embeddings += position_embeddings 108 | 109 | embeddings = self.layer_norm(embeddings) 110 | embeddings = self.dropout(embeddings) 111 | if hasattr(self, "embedding_hidden_mapping"): 112 | embeddings = self.embedding_hidden_mapping(embeddings) 113 | return embeddings 114 | 115 | 116 | class LeanAlbertModel(GradientCheckpointingMixin, PreTrainedModel): 117 | config_class = LeanAlbertConfig 118 | base_model_prefix = "lean_albert" 119 | _keys_to_ignore_on_load_missing = [r"position_ids"] 120 | 121 | def __init__(self, config: config_class, add_pooling_layer=True): 122 | super().__init__(config) 123 | 124 | self.config = config 125 | self.embeddings = LeanAlbertEmbeddings(config) 126 | self.transformer = LeanTransformer(config) 127 | 128 | if add_pooling_layer: 129 | self.pooler = nn.Linear(config.hidden_size, config.hidden_size) 130 | self.pooler_activation = nn.Tanh() 131 | else: 132 | self.pooler = None 133 | self.pooler_activation = None 134 | 135 | self.init_weights() 136 | 137 | def get_input_embeddings(self): 138 | return self.embeddings.word_embeddings 139 | 140 | def set_input_embeddings(self, value): 141 | self.embeddings.word_embeddings = value 142 | 143 | def _init_weights(self, module: nn.Module): 144 | return self.config.init_weights(module) 145 | 146 | def forward( 147 | self, 148 | input_ids=None, 149 | attention_mask=None, 150 | token_type_ids=None, 151 | position_ids=None, 152 | head_mask=None, 153 | inputs_embeds=None, 154 | output_attentions=None, 155 | output_hidden_states=None, 156 | return_dict=None, 157 | ): 158 | assert head_mask is None and output_attentions is None and output_hidden_states is None, "not implemented" 159 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 160 | 161 | if input_ids is not None and inputs_embeds is not None: 162 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 163 | elif input_ids is not None: 164 | input_shape = input_ids.size() 165 | elif inputs_embeds is not None: 166 | input_shape = inputs_embeds.size()[:-1] 167 | else: 168 | raise ValueError("You have to specify either input_ids or inputs_embeds") 169 | 170 | batch_size, seq_length = input_shape 171 | device = input_ids.device if input_ids is not None else inputs_embeds.device 172 | 173 | if attention_mask is None: 174 | attention_mask = torch.ones(input_shape, device=device, dtype=int) 175 | else: 176 | assert not torch.is_floating_point(attention_mask), "The model requires boolean or int mask with 0/1 entries" 177 | 178 | if token_type_ids is None: 179 | if hasattr(self.embeddings, "token_type_ids"): 180 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 181 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 182 | token_type_ids = buffered_token_type_ids_expanded 183 | else: 184 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 185 | 186 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 187 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility 188 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 189 | 190 | embedding_output = self.embeddings( 191 | input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds 192 | ) 193 | transformer_outputs = self.transformer(embedding_output, extended_attention_mask) 194 | 195 | sequence_output = transformer_outputs[0] 196 | 197 | pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None 198 | 199 | if not return_dict: 200 | return (sequence_output, pooled_output) + transformer_outputs[1:] 201 | 202 | return BaseModelOutputWithPooling( 203 | last_hidden_state=sequence_output, 204 | pooler_output=pooled_output, 205 | hidden_states=transformer_outputs.hidden_states, 206 | attentions=transformer_outputs.attentions, 207 | ) 208 | 209 | 210 | class AlbertMLMHead(nn.Module): 211 | def __init__(self, config): 212 | super().__init__() 213 | 214 | self.layer_norm = nn.LayerNorm(config.embedding_size) 215 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 216 | self.dense = nn.Linear(config.hidden_size, config.embedding_size) 217 | self.decoder = nn.Linear(config.embedding_size, config.vocab_size) 218 | self.activation = config.get_activation_callable() 219 | self.decoder.bias = self.bias 220 | 221 | def forward(self, hidden_states): 222 | hidden_states = self.dense(hidden_states) 223 | hidden_states = self.activation(hidden_states) 224 | hidden_states = self.layer_norm(hidden_states) 225 | hidden_states = self.decoder(hidden_states) 226 | 227 | prediction_scores = hidden_states 228 | 229 | return prediction_scores 230 | 231 | def _tie_weights(self): 232 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 233 | self.bias = self.decoder.bias 234 | 235 | 236 | class AlbertSOPHead(nn.Module): 237 | def __init__(self, config): 238 | super().__init__() 239 | 240 | self.dropout = nn.Dropout(config.classifier_dropout_prob) 241 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 242 | 243 | def forward(self, pooled_output): 244 | dropout_pooled_output = self.dropout(pooled_output) 245 | logits = self.classifier(dropout_pooled_output) 246 | return logits 247 | 248 | 249 | class LeanAlbertForPreTraining(GradientCheckpointingMixin, PreTrainedModel): 250 | config_class = LeanAlbertConfig 251 | base_model_prefix = "lean_albert" 252 | 253 | def __init__(self, config: config_class): 254 | super().__init__(config) 255 | 256 | self.albert = LeanAlbertModel(config) 257 | self.predictions = AlbertMLMHead(config) 258 | self.sop_classifier = AlbertSOPHead(config) 259 | 260 | def get_input_embeddings(self): 261 | return self.albert.embeddings.word_embeddings 262 | 263 | def set_input_embeddings(self, new_embeddings: nn.Module): 264 | self.albert.embeddings.word_embeddings = new_embeddings 265 | 266 | def get_output_embeddings(self): 267 | return self.predictions.decoder 268 | 269 | def set_output_embeddings(self, new_embeddings): 270 | self.predictions.decoder = new_embeddings 271 | 272 | def _init_weights(self, module: nn.Module): 273 | return self.config.init_weights(module) 274 | 275 | def forward( 276 | self, 277 | input_ids=None, 278 | attention_mask=None, 279 | token_type_ids=None, 280 | position_ids=None, 281 | head_mask=None, 282 | inputs_embeds=None, 283 | labels=None, 284 | sentence_order_label=None, 285 | output_attentions=None, 286 | output_hidden_states=None, 287 | return_dict=None, 288 | ): 289 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 290 | 291 | outputs = self.albert( 292 | input_ids, 293 | attention_mask=attention_mask, 294 | token_type_ids=token_type_ids, 295 | position_ids=position_ids, 296 | head_mask=head_mask, 297 | inputs_embeds=inputs_embeds, 298 | output_attentions=output_attentions, 299 | output_hidden_states=output_hidden_states, 300 | return_dict=return_dict, 301 | ) 302 | 303 | sequence_output, pooled_output = outputs[:2] 304 | 305 | prediction_scores = self.predictions(sequence_output) 306 | sop_scores = self.sop_classifier(pooled_output) 307 | 308 | total_loss = None 309 | if labels is not None: 310 | loss_fct = nn.CrossEntropyLoss() 311 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 312 | if sentence_order_label is not None: 313 | sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1)) 314 | total_loss = masked_lm_loss + sentence_order_loss 315 | else: 316 | total_loss = masked_lm_loss 317 | 318 | if not return_dict: 319 | output = (prediction_scores, sop_scores) + outputs[2:] 320 | return ((total_loss,) + output) if total_loss is not None else output 321 | 322 | return AlbertForPreTrainingOutput( 323 | loss=total_loss, 324 | prediction_logits=prediction_scores, 325 | sop_logits=sop_scores, 326 | hidden_states=outputs.hidden_states, 327 | attentions=outputs.attentions, 328 | ) 329 | 330 | 331 | class LeanAlbertForSequenceClassification(AlbertPreTrainedModel): 332 | config_class = LeanAlbertConfig 333 | base_model_prefix = "albert" 334 | 335 | def __init__(self, config: LeanAlbertConfig): 336 | super().__init__(config) 337 | self.num_labels = config.num_labels 338 | self.config = config 339 | 340 | self.albert = LeanAlbertModel(config, add_pooling_layer=False) 341 | 342 | self.classifier = nn.Sequential( 343 | nn.Dropout(config.classifier_dropout_prob), 344 | nn.Linear(config.hidden_size, config.hidden_size), 345 | nn.Tanh(), 346 | nn.Dropout(config.classifier_dropout_prob), 347 | nn.Linear(config.hidden_size, self.config.num_labels) 348 | ) 349 | 350 | # Initialize weights and apply final processing 351 | self.post_init() 352 | 353 | def forward( 354 | self, 355 | input_ids=None, 356 | attention_mask=None, 357 | token_type_ids=None, 358 | position_ids=None, 359 | head_mask=None, 360 | inputs_embeds=None, 361 | labels=None, 362 | output_attentions=None, 363 | output_hidden_states=None, 364 | return_dict=None, 365 | ): 366 | r""" 367 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 368 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 369 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 370 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 371 | """ 372 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 373 | 374 | outputs = self.albert( 375 | input_ids=input_ids, 376 | attention_mask=attention_mask, 377 | token_type_ids=token_type_ids, 378 | position_ids=position_ids, 379 | head_mask=head_mask, 380 | inputs_embeds=inputs_embeds, 381 | output_attentions=output_attentions, 382 | output_hidden_states=output_hidden_states, 383 | return_dict=return_dict, 384 | ) 385 | sequence_output = outputs[0] 386 | logits = self.classifier(sequence_output[:, 0, :]) 387 | 388 | loss = None 389 | if labels is not None: 390 | if self.config.problem_type is None: 391 | if self.num_labels == 1: 392 | self.config.problem_type = "regression" 393 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 394 | self.config.problem_type = "single_label_classification" 395 | else: 396 | self.config.problem_type = "multi_label_classification" 397 | 398 | if self.config.problem_type == "regression": 399 | loss_fct = MSELoss() 400 | if self.num_labels == 1: 401 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 402 | else: 403 | loss = loss_fct(logits, labels) 404 | elif self.config.problem_type == "single_label_classification": 405 | loss_fct = CrossEntropyLoss() 406 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 407 | elif self.config.problem_type == "multi_label_classification": 408 | loss_fct = BCEWithLogitsLoss() 409 | loss = loss_fct(logits, labels) 410 | 411 | if not return_dict: 412 | output = (logits,) + outputs[2:] 413 | return ((loss,) + output) if loss is not None else output 414 | 415 | return SequenceClassifierOutput( 416 | loss=loss, 417 | logits=logits, 418 | hidden_states=outputs.hidden_states, 419 | attentions=outputs.attentions, 420 | ) 421 | --------------------------------------------------------------------------------