├── tasks
├── __init__.py
├── mlm
│ ├── __init__.py
│ ├── data.py
│ ├── whole_word_mask.py
│ └── task.py
└── base.py
├── src
├── training
│ ├── __init__.py
│ ├── sync.py
│ ├── hf_trainer.py
│ ├── callback.py
│ └── lamb_8bit.py
├── models
│ ├── __init__.py
│ ├── config.py
│ ├── lean_albert.py
│ └── albert.py
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── functional.py
│ ├── sequence.py
│ ├── rotary.py
│ ├── attn.py
│ ├── pixelfly.py
│ ├── linear.py
│ └── ffn.py
├── utils.py
└── huggingface_auth.py
├── requirements.txt
├── .gitignore
├── run_trainer.py
├── field_collator.py
├── predict_on_lidirus.py
├── models.py
├── run_aux_peer.py
├── arguments.py
├── finetune_mlm.py
├── LICENSE
└── README.md
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tasks/mlm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .albert import *
2 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .attn import *
2 | from .ffn import *
3 | from .linear import *
4 | from .pixelfly import *
5 | from .rotary import *
6 | from .sequence import *
7 |
--------------------------------------------------------------------------------
/src/models/config.py:
--------------------------------------------------------------------------------
1 | from transformers import AlbertConfig
2 |
3 |
4 | class LeanAlbertConfig(AlbertConfig):
5 | rotary_embedding_base: int = 10_000
6 | hidden_act_gated: bool = False
7 |
8 | def __hash__(self):
9 | return hash("\t".join(f"{k}={v}" for k, v in self.__dict__.items() if not k.startswith("_")))
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bitsandbytes==0.32.3
2 | datasets==2.4.0
3 | einops==0.3.2
4 | fuzzysearch==0.7.3
5 | https://github.com/learning-at-home/hivemind/archive/calm.zip
6 | numpy==1.22.3
7 | pandas==1.4.3
8 | pymorphy2==0.9.1
9 | razdel==0.5.0
10 | requests==2.31.0
11 | scikit-learn==1.0.2
12 | sentencepiece==0.1.96
13 | termcolor==1.1.0
14 | tokenizers==0.12.1
15 | torch>=1.8.0
16 | torch_optimizer==0.1.0
17 | transformers==4.30.0
18 | wandb==0.12.1
19 |
--------------------------------------------------------------------------------
/src/modules/functional.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 |
4 | import torch
5 | from transformers.activations import ACT2FN as HF_ACT2FN
6 |
7 |
8 | @functools.lru_cache()
9 | def maybe_script(fn: callable) -> callable:
10 | """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
11 | using_tpu = bool(os.environ.get("TPU_NAME"))
12 | # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
13 | should_script = int(os.environ.get("LEAN_USE_JIT", not using_tpu))
14 | return torch.jit.script(fn) if should_script else fn
15 |
16 |
17 | @maybe_script
18 | def gelu_fused(x):
19 | """
20 | Approximate GELU activation, same as in Google BERT and OpenAI GPT (as of Dec 2021)
21 | the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
22 | """
23 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) # note: 0.7988.. = sqrt(2/pi)
24 |
25 |
26 | @maybe_script
27 | def gelu_fused_grad(grad_output, x):
28 | """Gradients of gelu_fwd w.r.t. input"""
29 | tanh = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
30 | jac = 0.5 * x * ((1 - tanh * tanh) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh)
31 | return jac * grad_output
32 |
33 |
34 | class GELU(torch.autograd.Function):
35 | @staticmethod
36 | def forward(ctx, input: torch.Tensor):
37 | ctx.save_for_backward(input)
38 | return gelu_fused(input)
39 |
40 | @staticmethod
41 | def backward(ctx, grad_output: torch.Tensor):
42 | input, = ctx.saved_tensors
43 | return gelu_fused_grad(grad_output, input)
44 |
45 |
46 | ACT2FN = dict(HF_ACT2FN, gelu_fused=GELU.apply)
47 |
--------------------------------------------------------------------------------
/src/modules/sequence.py:
--------------------------------------------------------------------------------
1 | """
2 | A module that implements sequential model type with optional keyword arguments.
3 | When using gradient checkpoints, keyword arguments should NOT require grad.
4 | """
5 | from typing import Callable, Sequence
6 |
7 | import torch
8 | from logging import getLogger
9 | from torch import nn as nn
10 | from torch.utils.checkpoint import checkpoint
11 |
12 | logger = getLogger(__name__)
13 |
14 |
15 | class ActiveKwargs(nn.Module):
16 | """
17 | A module with selective kwargs, compatible with sequential, gradient checkpoints and
18 | Usage: ony use this as a part of SequentialWithKwargs
19 | """
20 |
21 | def __init__(self, module: nn.Module, active_keys: Sequence[str], use_first_output: bool = False):
22 | super().__init__()
23 | self.module, self.active_keys, self.use_first_output = module, set(active_keys), use_first_output
24 |
25 | def forward(self, input: torch.Tensor, *args, **kwargs):
26 | kwargs = {key: value for key, value in kwargs.items() if key in self.active_keys}
27 | output = self.module(input, *args, **kwargs)
28 | if self.use_first_output and not isinstance(output, torch.Tensor):
29 | output = output[0]
30 | return output
31 |
32 |
33 | class SequentialWithKwargs(nn.Sequential):
34 | def __init__(self, *modules: ActiveKwargs):
35 | for module in modules:
36 | assert isinstance(module, ActiveKwargs)
37 | super().__init__(*modules)
38 | self.gradient_checkpointing = False
39 |
40 | def forward(self, input: torch.Tensor, *args, **kwargs):
41 | kwarg_keys, kwarg_values = zip(*kwargs.items()) if (self.gradient_checkpointing and kwargs) else ([], [])
42 | for module in self:
43 | if self.gradient_checkpointing and torch.is_grad_enabled():
44 | # pack kwargs with args since gradient checkpoint does not support kwargs
45 | input = checkpoint(self._checkpoint_forward, module, input, kwarg_keys, *kwarg_values, *args)
46 | else:
47 | input = module(input, *args, **kwargs)
48 | return input
49 |
50 | def _checkpoint_forward(self, module: Callable, input: torch.Tensor, kwarg_keys: Sequence[str], *etc):
51 | kwargs = {key: etc[i] for i, key in enumerate(kwarg_keys)}
52 | args = etc[len(kwarg_keys) :]
53 | return module(input, *args, **kwargs)
54 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/run_trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | from pathlib import Path
5 |
6 | import scipy.stats # compatibility for internal testing environment
7 | import torch.distributed
8 | import transformers
9 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler
10 | from transformers import HfArgumentParser
11 |
12 | from arguments import CollaborativeArguments, HFTrainerArguments, TrainingPeerArguments
13 | from src import utils
14 | from src.training.callback import CollaborativeCallback
15 | from src.training.hf_trainer import CollaborativeHFTrainer, NOPtimizer
16 | from src.training.sync import SynchronizationCallback, is_main_process
17 | from tasks.mlm.task import MLMTrainingTask
18 |
19 | use_hivemind_log_handler("in_root_logger")
20 | logger = get_logger()
21 | torch.set_num_threads(1) # avoid quadratic number of threads
22 |
23 |
24 | def main():
25 | parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
26 | training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
27 | if torch.distributed.is_initialized():
28 | assert not collab_args.reuse_grad_buffers, "Reuse_grad_buffers are not supported in distributed mode"
29 |
30 | logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
31 | if len(training_peer_args.initial_peers) == 0:
32 | logger.warning("Please specify initial peers or let others join your peer")
33 |
34 | utils.setup_logging(trainer_args)
35 | task = MLMTrainingTask(training_peer_args, trainer_args, collab_args)
36 | model = task.model.to(trainer_args.device)
37 | for param in model.parameters():
38 | if param.grad is None:
39 | param.grad = torch.zeros_like(param)
40 |
41 | callbacks = [(CollaborativeCallback if is_main_process() else SynchronizationCallback)(task, training_peer_args)]
42 | assert trainer_args.do_train and not trainer_args.do_eval
43 |
44 | # Note: the code below creates the trainer with dummy scheduler and removes some callbacks.
45 | # This is done because collaborative training has its own callbacks that take other peers into account.
46 | trainer = CollaborativeHFTrainer(
47 | model=model,
48 | args=trainer_args,
49 | tokenizer=task.tokenizer,
50 | data_collator=task.data_collator,
51 | data_seed=hash(task.local_public_key),
52 | train_dataset=task.training_dataset,
53 | reuse_grad_buffers=collab_args.reuse_grad_buffers,
54 | eval_dataset=None,
55 | optimizer=task.optimizer if is_main_process() else NOPtimizer(task._make_param_groups()),
56 | callbacks=callbacks,
57 | )
58 | trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
59 | trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
60 |
61 | latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None)
62 | trainer.train(model_path=latest_checkpoint_dir)
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | import transformers.utils.logging
4 | from hivemind import choose_ip_address
5 | from hivemind.dht.crypto import RSASignatureValidator
6 | from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
7 | from hivemind.dht.validation import RecordValidatorBase
8 | from hivemind.utils.logging import get_logger
9 | from multiaddr import Multiaddr
10 | from pydantic import BaseModel, StrictFloat, confloat, conint
11 | from transformers.trainer_utils import is_main_process
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | class LocalMetrics(BaseModel):
17 | epoch: conint(ge=0, strict=True)
18 | samples_per_second: confloat(ge=0.0, strict=True)
19 | samples_accumulated: conint(ge=0, strict=True)
20 | loss: StrictFloat
21 | mini_steps: conint(ge=0, strict=True) # queries
22 |
23 |
24 | class MetricSchema(BaseModel):
25 | metrics: Dict[BytesWithPublicKey, LocalMetrics]
26 |
27 |
28 | def make_validators(run_id: str) -> Tuple[List[RecordValidatorBase], bytes]:
29 | signature_validator = RSASignatureValidator()
30 | validators = [SchemaValidator(MetricSchema, prefix=run_id), signature_validator]
31 | return validators, signature_validator.local_public_key
32 |
33 |
34 | class TextStyle:
35 | BOLD = "\033[1m"
36 | BLUE = "\033[34m"
37 | RESET = "\033[0m"
38 |
39 |
40 | def log_visible_maddrs(visible_maddrs: List[Multiaddr], only_p2p: bool) -> None:
41 | if only_p2p:
42 | unique_addrs = {addr["p2p"] for addr in visible_maddrs}
43 | initial_peers_str = " ".join(f"/p2p/{addr}" for addr in unique_addrs)
44 | else:
45 | available_ips = [Multiaddr(addr) for addr in visible_maddrs if "ip4" in addr or "ip6" in addr]
46 | if available_ips:
47 | preferred_ip = choose_ip_address(available_ips)
48 | selected_maddrs = [addr for addr in visible_maddrs if preferred_ip in str(addr)]
49 | else:
50 | selected_maddrs = visible_maddrs
51 | initial_peers_str = " ".join(str(addr) for addr in selected_maddrs)
52 |
53 | logger.info(
54 | f"Running a DHT peer. To connect other peers to this one over the Internet, use "
55 | f"{TextStyle.BOLD}{TextStyle.BLUE}--initial_peers {initial_peers_str}{TextStyle.RESET}"
56 | )
57 | logger.info(f"Full list of visible multiaddresses: {' '.join(str(addr) for addr in visible_maddrs)}")
58 |
59 |
60 | def setup_logging(training_args):
61 | # Log on each process the small summary:
62 | logger.warning(
63 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
64 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
65 | )
66 | # Set the verbosity to info of the Transformers logger (on main process only):
67 | if is_main_process(training_args.local_rank):
68 | transformers.utils.logging.set_verbosity_info()
69 | transformers.utils.logging.enable_default_handler()
70 | transformers.utils.logging.enable_explicit_format()
71 | logger.info("Training/evaluation parameters %s", training_args)
72 |
--------------------------------------------------------------------------------
/field_collator.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from dataclasses import dataclass
3 | from typing import Iterable, Tuple, Optional, Union, List, Dict, Any
4 |
5 | import torch
6 | from transformers import PreTrainedTokenizerBase
7 | from transformers.file_utils import PaddingStrategy
8 |
9 |
10 | @dataclass
11 | class FieldDataCollatorWithPadding:
12 | """
13 | A general-purpose data collator that can handle batches with arbitrary fields.
14 | Supports only PyTorch tensors as outputs.
15 | """
16 | tokenizer: PreTrainedTokenizerBase
17 | # (field name, pad token index, index of the sequence axis or None)
18 | fields_to_pad: Iterable[Tuple[str, int, Optional[int]]] = ()
19 | padding: Union[bool, str, PaddingStrategy] = True
20 | max_length: Optional[int] = None
21 | pad_to_multiple_of: Optional[int] = None
22 |
23 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
24 | field_values = defaultdict(list)
25 | for field_name, field_pad_idx, field_seq_idx in self.fields_to_pad:
26 | for example in features:
27 | if field_name in example:
28 | field_values[field_name].append(example.pop(field_name))
29 |
30 | batch = self.tokenizer.pad(
31 | features,
32 | padding=self.padding,
33 | max_length=self.max_length,
34 | pad_to_multiple_of=self.pad_to_multiple_of,
35 | return_tensors="pt",
36 | )
37 | if "label" in batch:
38 | batch["labels"] = batch.pop("label")
39 | if "label_ids" in batch:
40 | batch["labels"] = batch.pop("label_ids")
41 |
42 | sequence_length = batch["input_ids"].size(1)
43 | batch_size = batch["input_ids"].size(0)
44 |
45 | for field_name, field_pad_idx, field_seq_idx in self.fields_to_pad:
46 |
47 | if field_values[field_name]:
48 | field_value_arrays = [torch.as_tensor(values) for values in field_values[field_name]]
49 | assert len(set(arr.ndim for arr in field_value_arrays)) == 1
50 |
51 | max_dim_lengths = [
52 | max(arr.size(dim) for arr in field_value_arrays)
53 | for dim in range(field_value_arrays[0].ndim)
54 | ]
55 |
56 | if field_seq_idx is not None:
57 | max_dim_lengths[field_seq_idx] = sequence_length
58 |
59 | padded_tensor = torch.full([batch_size] + max_dim_lengths, field_pad_idx, dtype=torch.long)
60 |
61 | for i, ar in enumerate(field_value_arrays):
62 | paste_inds = [i]
63 |
64 | for dim in range(field_value_arrays[0].ndim):
65 | if dim == field_seq_idx and self.tokenizer.padding_side == "left":
66 | inds = slice(-ar.size(dim), -1)
67 | else:
68 | inds = slice(ar.size(dim))
69 | paste_inds.append(inds)
70 |
71 | padded_tensor[paste_inds] = ar
72 |
73 | batch[field_name] = padded_tensor
74 |
75 | return batch
76 |
--------------------------------------------------------------------------------
/src/modules/rotary.py:
--------------------------------------------------------------------------------
1 | """
2 | Auxiliary modules for implementing Rotary Position Embeddings
3 | Original paper: https://arxiv.org/abs/2104.09864
4 | Based on reference implementation from https://blog.eleuther.ai/rotary-embeddings
5 | """
6 |
7 | import torch
8 | import torch.distributed
9 | import torch.nn as nn
10 |
11 | from logging import getLogger
12 | from src.modules.functional import maybe_script
13 |
14 | logger = getLogger(__file__)
15 |
16 |
17 | class RotaryEmbeddings(nn.Module):
18 | """Applies rotary position embeddings to a tensor, uses caching to improve performance"""
19 |
20 | def __init__(self, dim: int, base: int = 10_000):
21 | super().__init__()
22 | self.dim, self.base = dim, base
23 | self._rotate = maybe_script(rotate)
24 | self._get_auxiliary_tensors = maybe_script(get_auxiliary_tensors)
25 | self.register_buffer("cos", torch.empty(1, dim), persistent=False)
26 | self.register_buffer("sin", torch.empty(1, dim), persistent=False)
27 |
28 | def forward(self, x: torch.Tensor, offset: int = 0):
29 | """
30 | :param x: tensor of shape [batch_size, seq_len, nhead, hid_size]
31 | :param offset: add this value to all position indices
32 | """
33 | seq_len = x.shape[1]
34 | if seq_len + offset > self.cos.shape[0] or x.device != self.cos.device:
35 | if torch.distributed.is_initialized():
36 | logger.warning("Rebuilding auxiliary tensors for rotary embeddings, this may cause DDP to freeze. To "
37 | "avoid this, call _set_auxiliary_tensors for maximum length before the training starts")
38 | self._set_auxiliary_buffers(max_len=seq_len + offset, device=x.device)
39 | cosines_for_position = self.cos[None, offset : seq_len + offset, None, :]
40 | sines_for_position = self.sin[None, offset : seq_len + offset, None, :]
41 | return self._rotate(x, cosines_for_position, sines_for_position)
42 |
43 | def _set_auxiliary_buffers(self, max_len: int, device: torch.device):
44 | _cos, _sin = self._get_auxiliary_tensors(max_len, self.dim, torch.float32, device, self.base)
45 | self.register_buffer("cos", _cos, persistent=False)
46 | self.register_buffer("sin", _sin, persistent=False)
47 |
48 |
49 | @torch.no_grad()
50 | def get_auxiliary_tensors(seq_len: int, dim: int, dtype: torch.dtype, device: torch.device, base: int):
51 | """
52 | Compute auxiliary sine and cosine tensors for rotary position embedding
53 | :returns: a tuple of (cos, sin) tensors of shape [seq_len, hid_size]
54 | """
55 | _buf = torch.linspace(0, -1 + 2 / dim, dim // 2, dtype=torch.float32, device=device)
56 | inv_freq = torch.pow(base, _buf, out=_buf).repeat(2)
57 | time_ix = torch.arange(seq_len, dtype=inv_freq.dtype, device=device)
58 |
59 | freqs = time_ix[:, None] * inv_freq[None, :]
60 | cos = torch.cos(freqs)
61 | sin = freqs.sin_()
62 | return cos.to(dtype), sin.to(dtype)
63 |
64 |
65 | def rotate(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
66 | """rotate pairwise coordinate using precomputed cos & sin tensors"""
67 | dim = x.shape[-1]
68 | x_left, x_right = x.split(split_size=dim // 2, dim=x.ndim - 1)
69 | x_rotated = torch.cat([x_right.neg(), x_left], dim=x.ndim - 1)
70 | return x * cos + x_rotated * sin
71 |
72 |
73 |
--------------------------------------------------------------------------------
/predict_on_lidirus.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from argparse import ArgumentParser
4 | from functools import partial
5 | from pathlib import Path
6 |
7 | from datasets import load_dataset
8 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, \
9 | TrainingArguments, AlbertTokenizerFast
10 |
11 | from data import TASK_TO_CONFIG, TASK_TO_NAME
12 | from src import LeanAlbertForSequenceClassification
13 |
14 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
15 |
16 | BATCH_SIZE = 8
17 | MAX_LENGTH = 512
18 |
19 | TASK = "lidirus"
20 |
21 |
22 | def main(model_path: Path, arch):
23 | run_dirname = model_path.parts[-1]
24 | assert 'terra' in run_dirname
25 |
26 | if arch == 'lean_albert':
27 | tokenizer = AlbertTokenizerFast.from_pretrained('tokenizer')
28 | else:
29 | tokenizer = AutoTokenizer.from_pretrained(arch)
30 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
31 |
32 | dataset = load_dataset("russian_super_glue", TASK)
33 | config = TASK_TO_CONFIG[TASK](dataset)
34 |
35 | processed_dataset = dataset.map(partial(config.process_data, tokenizer=tokenizer, max_length=512),
36 | batched=True, remove_columns=["label"])
37 | test_without_labels = processed_dataset['test'].remove_columns(['labels'])
38 |
39 | last_path = max(Path(model_path).glob("checkpoint*"), default=None, key=os.path.getctime)
40 | with open(last_path / 'trainer_state.json') as f:
41 | trainer_state = json.load(f)
42 | best_path = Path(trainer_state['best_model_checkpoint']).parts[-1]
43 |
44 | if arch == 'lean_albert':
45 | model = LeanAlbertForSequenceClassification.from_pretrained(model_path / best_path)
46 | else:
47 | model = AutoModelForSequenceClassification.from_pretrained(model_path / best_path)
48 |
49 | training_args = TrainingArguments(
50 | output_dir=model_path, overwrite_output_dir=True,
51 | evaluation_strategy='epoch', logging_strategy='epoch', per_device_train_batch_size=BATCH_SIZE,
52 | per_device_eval_batch_size=BATCH_SIZE, save_strategy='epoch', save_total_limit=1,
53 | fp16=True, dataloader_num_workers=4, group_by_length=True,
54 | report_to='none', load_best_model_at_end=True, metric_for_best_model=config.best_metric
55 | )
56 |
57 | trainer = Trainer(
58 | model=model,
59 | args=training_args,
60 | train_dataset=None,
61 | eval_dataset=None,
62 | compute_metrics=config.compute_metrics,
63 | tokenizer=tokenizer,
64 | data_collator=data_collator,
65 | )
66 |
67 | predictions = trainer.predict(test_dataset=test_without_labels)
68 | processed_predictions = config.process_predictions(predictions.predictions, split="test")
69 |
70 | preds_dir = f"preds/{run_dirname.replace('_terra', '')}"
71 |
72 | os.makedirs(preds_dir, exist_ok=True)
73 |
74 | with open(f"{preds_dir}/{TASK_TO_NAME[TASK]}.jsonl", 'w+') as outf:
75 | for prediction in processed_predictions:
76 | print(json.dumps(prediction, ensure_ascii=True), file=outf)
77 |
78 |
79 | if __name__ == '__main__':
80 | parser = ArgumentParser()
81 | parser.add_argument("-m", '--model', type=Path, required=True)
82 | parser.add_argument("-a", '--arch')
83 | args = parser.parse_args()
84 | main(args.model, args.arch)
85 |
--------------------------------------------------------------------------------
/src/training/sync.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Sequence
3 |
4 | import torch
5 | import transformers
6 | from hivemind.utils import get_logger
7 | from torch.distributed.distributed_c10d import _get_default_group, _get_default_store
8 | from transformers import TrainerControl, TrainerState, TrainingArguments
9 |
10 | import arguments
11 | import tasks
12 |
13 | AUTHORITATIVE_RANK = 0
14 | BROADCAST_BUFFER_SIZE: int = 250 * 1024 * 1024
15 | logger = get_logger(__name__)
16 |
17 |
18 | def is_main_process() -> bool:
19 | """Whether this is the main process on **this peer's** distributed run. Non-distributed process is always main."""
20 | return (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == AUTHORITATIVE_RANK
21 |
22 |
23 | class SynchronizationCallback(transformers.TrainerCallback):
24 | """Minimalistic callback for non-master DDP workers"""
25 |
26 | def __init__(self, task: "tasks.TrainingTaskBase", args: "arguments.TrainingPeerArguments"):
27 | self.task = task
28 | self.is_master = is_main_process()
29 | self._checksum_counter = 0
30 | self._state_tensors = None
31 | self._prev_version = self._prev_epoch = -1
32 |
33 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
34 | if torch.distributed.is_initialized():
35 | self._maybe_sync_model_state()
36 |
37 | def on_step_end(
38 | self,
39 | args: TrainingArguments,
40 | state: transformers.TrainerState,
41 | control: transformers.TrainerControl,
42 | **kwargs,
43 | ):
44 | control.should_log = True
45 | model = self.task.model
46 | if torch.distributed.is_initialized():
47 | self._maybe_sync_model_state()
48 |
49 | self._checksum_counter += 1
50 | if self._checksum_counter % 100 == 0:
51 | rank = torch.distributed.get_rank()
52 | print(end=f"CHECKSUM({rank})={float(sum(p.sum().item() for p in model.parameters()))}\n")
53 | self.task.on_step_end()
54 |
55 | @property
56 | def state_tensors(self) -> Sequence[torch.Tensor]:
57 | if self._state_tensors is None:
58 | self._state_tensors = list(self.task.model.state_dict().values())
59 | return self._state_tensors
60 |
61 | def _compute_state_version(self) -> int:
62 | """return a non-decreasing integer that goes up whenever model params and/or buffers were updated"""
63 | assert self.is_master
64 | return sum(state["step"] for state in self.task.optimizer.opt.state.values())
65 |
66 | def _should_broadcast_state(self):
67 | store = _get_default_store()
68 | if self.is_master:
69 | current_version = self._compute_state_version()
70 | if current_version == self._prev_version and self.task.optimizer.local_epoch > self._prev_epoch + 1:
71 | logger.warning("Model state version has not changed during a full epoch; "
72 | "broadcasting parameters between torch.distributed synchronization may be broken")
73 |
74 | if current_version != self._prev_version or self.task.optimizer.local_epoch > self._prev_epoch + 1:
75 | should_broadcast = True
76 | else:
77 | should_broadcast = False
78 |
79 | store.set(f"_hivemind_should_broadcast_state", str(int(should_broadcast)))
80 | torch.distributed.barrier()
81 | return should_broadcast
82 | else:
83 | torch.distributed.barrier()
84 | raw_should_broadcast = store.get(f"_hivemind_should_broadcast_state")
85 | return bool(int(raw_should_broadcast))
86 |
87 | def _maybe_sync_model_state(self):
88 | """Synchronize model params and buffers from master"""
89 | if self.state_tensors and self._should_broadcast_state():
90 | t_start = time.perf_counter()
91 | with torch.no_grad():
92 | torch.distributed._broadcast_coalesced(
93 | _get_default_group(), self.state_tensors, BROADCAST_BUFFER_SIZE, AUTHORITATIVE_RANK
94 | )
95 | if self.is_master:
96 | self._prev_version = self._compute_state_version()
97 | self._prev_epoch = self.task.optimizer.local_epoch
98 | logger.info(f"Broadcasting master params took {time.perf_counter() - t_start} seconds")
99 | else:
100 | logger.debug("Not broadcasting")
101 |
--------------------------------------------------------------------------------
/tasks/mlm/data.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from functools import partial
4 | from typing import Optional
5 |
6 | import torch.utils.data
7 | from datasets import interleave_datasets, load_dataset
8 | from hivemind.utils.logging import get_logger
9 | from razdel import sentenize
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def make_training_dataset(
15 | tokenizer,
16 | shuffle_buffer_size: int = 10 ** 4,
17 | shuffle_seed: Optional[int] = None,
18 | preprocessing_batch_size: int = 256,
19 | max_sequence_length: int = 512,
20 | ):
21 | dataset1 = load_dataset(YOUR_DATASET_HERE)
22 | dataset2 = load_dataset(YOUR_DATASET_HERE)
23 | dataset3 = load_dataset(YOUR_DATASET_HERE)
24 |
25 | datasets = dict(dataset1=dataset1, dataset2=dataset2, dataset3=dataset3)
26 | weights = dict(dataset1=0.6, dataset2=0.2, dataset3=0.2)
27 |
28 | datasets = {
29 | key: dataset.map(
30 | partial(tokenize_function, tokenizer, max_sequence_length=max_sequence_length),
31 | batched=True,
32 | batch_size=preprocessing_batch_size,
33 | )
34 | for key, dataset in datasets.items()
35 | }
36 |
37 | dataset = interleave_datasets(
38 | [datasets[k] for k in sorted(datasets.keys())],
39 | probabilities=[weights[k] for k in sorted(datasets.keys())],
40 | )
41 |
42 | dataset = dataset.shuffle(seed=shuffle_seed, buffer_size=shuffle_buffer_size)
43 | dataset = dataset.with_format("torch")
44 | return WrappedIterableDataset(dataset)
45 |
46 |
47 | def create_instances_from_document(tokenizer, document, max_sequence_length):
48 | """Creates `TrainingInstance`s for a single document."""
49 | # We DON'T just concatenate all of the tokens from a document into a long
50 | # sequence and choose an arbitrary split point because this would make the
51 | # next sentence prediction task too easy. Instead, we split the input into
52 | # segments "A" and "B" based on the actual "sentences" provided by the user
53 | # input.
54 | instances = []
55 | current_chunk = []
56 | current_length = 0
57 | max_sequence_length = int(max_sequence_length.value)
58 | segmented_sents = [s.text for s in sentenize(document)]
59 |
60 | for i, sent in enumerate(segmented_sents):
61 | current_chunk.append(sent)
62 | current_length += len(tokenizer.tokenize(sent))
63 | if i == len(segmented_sents) - 1 or current_length >= max_sequence_length:
64 | if len(current_chunk) > 1:
65 | # `a_end` is how many segments from `current_chunk` go into the `A`
66 | # (first) sentence.
67 | a_end = random.randint(1, len(current_chunk) - 1)
68 |
69 | tokens_a = []
70 | for j in range(a_end):
71 | tokens_a.append(current_chunk[j])
72 |
73 | tokens_b = []
74 |
75 | for j in range(a_end, len(current_chunk)):
76 | tokens_b.append(current_chunk[j])
77 |
78 | if random.random() < 0.5:
79 | # Random next
80 | is_random_next = True
81 | # in this case, we just swap tokens_a and tokens_b
82 | tokens_a, tokens_b = tokens_b, tokens_a
83 | else:
84 | # Actual next
85 | is_random_next = False
86 |
87 | assert len(tokens_a) >= 1
88 | assert len(tokens_b) >= 1
89 |
90 | instance = tokenizer(
91 | " ".join(tokens_a),
92 | " ".join(tokens_b),
93 | padding="max_length",
94 | truncation="longest_first",
95 | max_length=max_sequence_length,
96 | # We use this option because DataCollatorForLanguageModeling
97 | # is more efficient when it receives the `special_tokens_mask`.
98 | return_special_tokens_mask=True,
99 | )
100 | assert len(instance["input_ids"]) <= max_sequence_length
101 | instance["sentence_order_label"] = 1 if is_random_next else 0
102 | instances.append(instance)
103 |
104 | current_chunk = []
105 | current_length = 0
106 | return instances
107 |
108 |
109 | def tokenize_function(tokenizer, examples, max_sequence_length):
110 | # Remove empty texts
111 | texts = [text for text in examples["text"] if len(text) > 0 and not text.isspace()]
112 | new_examples = defaultdict(list)
113 |
114 | for text in texts:
115 | instances = create_instances_from_document(tokenizer, text, max_sequence_length)
116 | for instance in instances:
117 | for key, value in instance.items():
118 | new_examples[key].append(value)
119 |
120 | return new_examples
121 |
122 |
123 | class WrappedIterableDataset(torch.utils.data.IterableDataset):
124 | """Wraps huggingface IterableDataset as pytorch IterableDataset, implement default methods for DataLoader"""
125 |
126 | def __init__(self, hf_iterable, verbose: bool = True):
127 | self.hf_iterable = hf_iterable
128 | self.verbose = verbose
129 |
130 | def __iter__(self):
131 | started = False
132 | logger.info("Pre-fetching training samples...")
133 | while True:
134 | for sample in self.hf_iterable:
135 | if not started:
136 | logger.info("Began iterating minibatches!")
137 | started = True
138 | yield sample
139 |
--------------------------------------------------------------------------------
/src/training/hf_trainer.py:
--------------------------------------------------------------------------------
1 | """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
2 |
3 | import hivemind
4 | import torch
5 | import torch.distributed
6 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler
7 | from torch import nn
8 | from torch.utils.data import DataLoader
9 | from transformers.trainer import Trainer
10 |
11 | from arguments import HFTrainerArguments
12 | from src.modules.rotary import RotaryEmbeddings
13 | from src.training.sync import is_main_process
14 |
15 | use_hivemind_log_handler("in_root_logger")
16 | logger = get_logger(__name__)
17 | LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
18 |
19 |
20 | class CollaborativeHFTrainer(Trainer):
21 | """
22 | A version of HuggingFace trainer that shuffles the dataset using a separate random seed.
23 | Used to ensure that peers don't process batches in the same order.
24 | """
25 |
26 | def __init__(self, *, data_seed: int, optimizer: hivemind.Optimizer, reuse_grad_buffers: bool, **kwargs):
27 | self.data_seed = data_seed
28 | self.optimizer = optimizer
29 | self.reuse_grad_buffers = reuse_grad_buffers
30 | assert not torch.distributed.is_initialized() or not reuse_grad_buffers, "DDP with reuse_grad_buffers is not implemented (yet)"
31 | super().__init__(optimizers=(optimizer, NoOpScheduler(optimizer)), **kwargs)
32 |
33 | if self.args.fp16 and self.reuse_grad_buffers:
34 | self.scaler = hivemind.GradScaler()
35 |
36 | def get_train_dataloader(self) -> DataLoader:
37 | """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
38 | seed = self.data_seed
39 | if torch.distributed.is_initialized():
40 | seed += torch.distributed.get_rank()
41 | torch.manual_seed(seed)
42 | return super().get_train_dataloader()
43 |
44 | def _wrap_model(self, model: nn.Module, training=True):
45 | assert training, "Evaluation (training=False) should be run on a separate dedicated worker."
46 | model = super()._wrap_model(model, training)
47 | if torch.distributed.is_initialized():
48 | assert isinstance(model, nn.parallel.DistributedDataParallel)
49 | assert model.require_forward_param_sync
50 | logger.info("Pre-populating rotary embedding cache up to maximum length to enforce static graph")
51 | assert isinstance(self.args, HFTrainerArguments)
52 | device = f"{model.device_type}:{model.output_device}"
53 | for module in model.modules():
54 | if isinstance(module, RotaryEmbeddings):
55 | module._set_auxiliary_buffers(max_len=self.args.max_sequence_length, device=device)
56 |
57 | logger.warning("DistributedDataParallel: triggering _set_static_graph() to allow checkpointing")
58 | model._set_static_graph()
59 |
60 | # if reuse_grad_buffers is True, we should accumulate gradients in .grad without zeroing them after each step
61 | should_override_zero_grad = self.reuse_grad_buffers if is_main_process() else False # replicas can reset grad
62 | return IgnoreGradManipulations(model, override_zero_grad=should_override_zero_grad)
63 |
64 |
65 | class NOPtimizer(torch.optim.SGD):
66 | def __init__(self, params):
67 | super().__init__(params, lr=0)
68 |
69 | def step(self, *args, **kwargs):
70 | pass
71 |
72 |
73 | class NoOpScheduler(LRSchedulerBase):
74 | """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
75 |
76 | def get_lr(self):
77 | return [group["lr"] for group in self.optimizer.param_groups]
78 |
79 | def print_lr(self, *args, **kwargs):
80 | if self.optimizer.scheduler:
81 | return self.optimizer.scheduler.print_lr(*args, **kwargs)
82 |
83 | def step(self):
84 | logger.debug("Called NoOpScheduler.step")
85 | self._last_lr = self.get_lr()
86 |
87 | def state_dict(self):
88 | return {}
89 |
90 | def load_state_dict(self, *args, **kwargs):
91 | logger.debug("Called NoOpScheduler.load_state_dict")
92 |
93 |
94 | class IgnoreGradManipulations(nn.Module):
95 | """Wrapper for model that blocks gradient manipulations in huggingface Trainer (e.g. zero_grad, clip_grad)"""
96 |
97 | def __init__(self, module, override_clipping: bool = True, override_zero_grad: bool = True):
98 | super().__init__()
99 | self.module = module
100 | self.override_clipping = override_clipping
101 | self.override_zero_grad = override_zero_grad
102 |
103 | def forward(self, *args, **kwargs):
104 | return self.module.forward(*args, **kwargs)
105 |
106 | def zero_grad(self, set_to_none: bool = False) -> None:
107 | if self.override_zero_grad:
108 | grad_is_nan = all(param.grad.isfinite().all() for param in self.parameters())
109 | if grad_is_nan:
110 | logger.debug("Successfully bypassed zero_grad")
111 | else:
112 | logger.debug("Encountered non-finite value in gradients!")
113 | self.module.zero_grad(set_to_none=set_to_none)
114 | else:
115 |
116 | self.module.zero_grad(set_to_none=set_to_none)
117 |
118 | def clip_grad_norm_(self, max_norm: float, norm_type: int = 2):
119 | """ignore clip_grad_norm on each step, clip in optimizer instead"""
120 | if self.override_clipping:
121 | logger.debug("Successfully bypassed clip_grad_norm_")
122 | else:
123 | return torch.nn.utils.clip_grad_norm_(self.module.parameters(), max_norm, norm_type=norm_type)
124 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import PreTrainedModel
5 |
6 |
7 | class FCLayer(nn.Module):
8 | def __init__(self, input_dim, output_dim, dropout_rate=0., use_activation=True):
9 | super().__init__()
10 |
11 | layers = [nn.Dropout(dropout_rate)]
12 | if use_activation:
13 | layers.append(nn.Tanh())
14 | layers.append(nn.Linear(input_dim, output_dim))
15 |
16 | self.layers = nn.Sequential(*layers)
17 |
18 | def forward(self, x):
19 | return self.layers(x)
20 |
21 |
22 | class SpanClassificationModel(nn.Module):
23 | def __init__(self, backbone: PreTrainedModel, num_labels: int):
24 | super().__init__()
25 | self.backbone = backbone
26 | self.num_labels = num_labels
27 | self.config = backbone.config
28 |
29 | hidden_size = self.config.hidden_size
30 |
31 | if hasattr(self.config, "classifier_dropout"):
32 | dropout_rate = (
33 | self.config.classifier_dropout if self.config.classifier_dropout is not None
34 | else self.config.hidden_dropout_prob
35 | )
36 | else:
37 | dropout_rate = self.config.classifier_dropout_prob
38 |
39 | self.cls_fc_layer = FCLayer(hidden_size, hidden_size, dropout_rate)
40 | self.e1_fc_layer = FCLayer(hidden_size, hidden_size, dropout_rate)
41 | self.e2_fc_layer = FCLayer(hidden_size, hidden_size, dropout_rate)
42 | self.label_classifier = FCLayer(hidden_size * 3, num_labels, dropout_rate, use_activation=False)
43 |
44 | @staticmethod
45 | def _entity_average(hidden_output, e_mask):
46 | """
47 | Average the entity hidden state vectors (H_i ~ H_j)
48 | :param hidden_output: [batch_size, j-i+1, dim]
49 | :param e_mask: [batch_size, seq_len]
50 | e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
51 | :return: [batch_size, dim]
52 | """
53 | e_mask_unsqueeze = e_mask.unsqueeze(2) # [batch_size, seq_len, 1]
54 | length_tensor = (e_mask_unsqueeze != 0).sum(dim=1) # [batch_size, 1]
55 |
56 | sum_vector = (hidden_output * e_mask_unsqueeze).sum(1)
57 | avg_vector = sum_vector / length_tensor
58 | return avg_vector
59 |
60 | def forward(self, input_ids, attention_mask, e1_mask, e2_mask, token_type_ids, labels=None):
61 | # sequence_output, pooled_output, (hidden_states), (attentions)
62 | sequence_output, pooled_output, *outputs = self.backbone(
63 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False,
64 | )
65 |
66 | # Average
67 | hidden_first = self._entity_average(sequence_output, e1_mask)
68 | hidden_second = self._entity_average(sequence_output, e2_mask)
69 |
70 | # Dropout -> tanh -> fc_layer
71 | pooled_output = self.cls_fc_layer(pooled_output)
72 | hidden_first = self.e1_fc_layer(hidden_first)
73 | hidden_second = self.e2_fc_layer(hidden_second)
74 |
75 | # Concat -> fc_layer
76 | concat_h = torch.cat([pooled_output, hidden_first, hidden_second], dim=-1)
77 | logits = self.label_classifier(concat_h)
78 |
79 | outputs = (logits,) + tuple(outputs) # add hidden states and attention if they are here
80 |
81 | # Softmax
82 | if labels is not None:
83 | loss_fct = nn.CrossEntropyLoss()
84 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
85 |
86 | outputs = (loss,) + outputs
87 |
88 | return outputs # (loss), logits, (hidden_states), (attentions)
89 |
90 |
91 | class EntityChoiceModel(nn.Module):
92 | def __init__(self, backbone: PreTrainedModel):
93 | super().__init__()
94 | self.backbone = backbone
95 |
96 | hidden_size = backbone.config.hidden_size
97 |
98 | if hasattr(backbone.config, "classifier_dropout"):
99 | dropout_rate = (
100 | backbone.config.classifier_dropout if backbone.config.classifier_dropout is not None
101 | else backbone.config.hidden_dropout_prob
102 | )
103 | else:
104 | dropout_rate = backbone.config.classifier_dropout_prob
105 |
106 | self.clf = nn.Sequential(
107 | nn.Linear(hidden_size, hidden_size),
108 | nn.GELU(),
109 | nn.Dropout(dropout_rate),
110 | nn.Linear(hidden_size, 1)
111 | )
112 |
113 | self.loss = nn.BCEWithLogitsLoss(reduction="none")
114 |
115 | def forward(self, input_ids, attention_mask, entity_mask, token_type_ids, labels=None):
116 | sequence_output, pooled_output, *outputs = self.backbone(
117 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=False
118 | )
119 |
120 | entity_mask = entity_mask.unsqueeze(3)
121 | entity_lengths = entity_mask.sum(2)
122 |
123 | # [batch_size, num_entities, seq_len, hid_dim]
124 | embeds_for_entities = sequence_output.unsqueeze(1) * entity_mask
125 | aggregated_embeds = embeds_for_entities.sum(2) / (entity_lengths + 1e-8)
126 |
127 | # [batch_size, num_entities]
128 | logits = self.clf(aggregated_embeds).squeeze(2)
129 |
130 | # due to padding, we might have rows without entity indices at all
131 | # thus, we need to mask the logits/loss for them
132 |
133 | # [batch_size, num_entities]
134 | present_entities = (entity_lengths.squeeze(2) != 0).to(logits.dtype)
135 | logits = logits * present_entities - 10000.0 * (1 - present_entities)
136 |
137 | outputs = (logits,) + tuple(outputs) # add hidden states and attention if they are here
138 |
139 | if labels is not None:
140 | # do not penalize predictions for padded entities
141 | label_mask = (labels != -1)
142 |
143 | # BCEWithLogitsLoss requires targets to be from 0 to 1
144 | labels[~label_mask] = 0
145 |
146 | loss = self.loss(logits, labels.to(logits.dtype))
147 |
148 | reduced_loss = ((loss * label_mask).sum(1) / (label_mask.sum(1) + 1e-8)).mean()
149 | outputs = (reduced_loss,) + outputs
150 |
151 | return outputs
152 |
--------------------------------------------------------------------------------
/tasks/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict
2 | from typing import Optional, Type
3 |
4 | import hivemind
5 | import torch
6 | import torch.nn as nn
7 | from hivemind import Float16Compression, SizeAdaptiveCompression, Uniform8BitQuantization
8 | from hivemind.utils.logging import get_logger
9 | from torch.distributed.distributed_c10d import _get_default_store
10 | from transformers.data.data_collator import DataCollatorMixin
11 |
12 | from arguments import BasePeerArguments, CollaborativeArguments, HFTrainerArguments
13 | from src.huggingface_auth import authorize_with_huggingface
14 | from src.training.sync import AUTHORITATIVE_RANK, is_main_process
15 |
16 | try:
17 | from hivemind.optim.experimental.state_averager import LRSchedulerBase, ParamGroups
18 | except ImportError:
19 | from hivemind.optim.state_averager import LRSchedulerBase, ParamGroups
20 |
21 | from src import utils
22 |
23 | logger = get_logger(__name__)
24 | TASKS = {}
25 |
26 |
27 | def register_task(name: str):
28 | def _register(cls: Type[TrainingTaskBase]):
29 | if name in TASKS:
30 | logger.warning(f"Registering task {name} a second time, previous entry will be overwritten")
31 | TASKS[name] = cls
32 | return cls
33 |
34 | return _register
35 |
36 |
37 | class TrainingTaskBase:
38 | """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
39 |
40 | _dht = _optimizer = _authorizer = None # for caching
41 |
42 | def __init__(
43 | self,
44 | model: nn.Module,
45 | peer_args: BasePeerArguments,
46 | trainer_args: HFTrainerArguments,
47 | collab_args: CollaborativeArguments,
48 | ):
49 | self.model, self.peer_args, self.trainer_args, self.collab_args = model, peer_args, trainer_args, collab_args
50 | self.validators, self.local_public_key = utils.make_validators(self.peer_args.run_id)
51 |
52 | if self.authorizer:
53 | self.trainer_args.run_name = self.authorizer.username # For wandb
54 |
55 | @property
56 | def authorizer(self):
57 | if self._authorizer is None and self.peer_args.authorize:
58 | self._authorizer = authorize_with_huggingface()
59 | return self._authorizer
60 |
61 | @property
62 | def dht(self):
63 | if self._dht is None:
64 | assert is_main_process()
65 | self._dht = hivemind.DHT(
66 | start=True,
67 | initial_peers=self.peer_args.initial_peers,
68 | client_mode=self.peer_args.client_mode,
69 | host_maddrs=self.peer_args.host_maddrs,
70 | announce_maddrs=self.peer_args.announce_maddrs,
71 | use_ipfs=self.peer_args.use_ipfs,
72 | record_validators=self.validators,
73 | identity_path=self.peer_args.identity_path,
74 | authorizer=self.authorizer,
75 | )
76 | if self.peer_args.client_mode:
77 | logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
78 | else:
79 | utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
80 | return self._dht
81 |
82 | @property
83 | def optimizer(self) -> hivemind.Optimizer:
84 | if self._optimizer is None:
85 | assert is_main_process()
86 | averaging_compression = SizeAdaptiveCompression(
87 | threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization()
88 | )
89 |
90 | self._optimizer = hivemind.Optimizer(
91 | dht=self.dht,
92 | params=self._make_param_groups(),
93 | run_id=self.peer_args.run_id,
94 | optimizer=self._make_base_optimizer,
95 | scheduler=self._make_scheduler,
96 | grad_compression=averaging_compression,
97 | state_averaging_compression=averaging_compression,
98 | batch_size_per_step=self.trainer_args.batch_size_per_step,
99 | client_mode=self.peer_args.client_mode,
100 | verbose=True,
101 | averager_opts=dict(min_vector_size=self.peer_args.min_vector_size, bandwidth=self.peer_args.bandwidth),
102 | **asdict(self.collab_args),
103 | )
104 | return self._optimizer
105 |
106 | def _make_param_groups(self) -> ParamGroups:
107 | """Return optimizer param groups: either list of parameters or a list of dictionaries in torch.optim format"""
108 | raise NotImplementedError()
109 |
110 | def _make_base_optimizer(self, param_groups: ParamGroups) -> torch.optim.Optimizer:
111 | """Return PyTorch optimizer to be wrapped with hivemind.Optimizer. Use only the specified param groups."""
112 | raise NotImplementedError()
113 |
114 | def _make_scheduler(self, optimizer: torch.optim.Optimizer) -> Optional[LRSchedulerBase]:
115 | """Return PyTorch scheduler that will be ran on each synchronization epoch (each target_batch_size samples)"""
116 | return None # default = no scheduler
117 |
118 | @property
119 | def training_dataset(self):
120 | raise NotImplementedError()
121 |
122 | @property
123 | def data_collator(self) -> DataCollatorMixin:
124 | raise NotImplementedError()
125 |
126 | def on_step_end(self):
127 | """This will be called after each local step (in callback.py)"""
128 | pass
129 |
130 | def get_current_epoch(self):
131 | """
132 | Return current epoch in a ddp-friendly way. When implementing your own task, please use this instead of
133 | directly accessing self.optimizer.global_epoch because non-main DDP workers will not run hivemind.Optimizer
134 | """
135 | if not torch.distributed.is_initialized():
136 | return self.optimizer.tracker.global_epoch
137 | else:
138 | store = _get_default_store()
139 | if torch.distributed.get_rank() == AUTHORITATIVE_RANK:
140 | current_epoch = self.optimizer.tracker.global_epoch
141 | store.set("_hivemind_current_epoch", str(current_epoch))
142 | torch.distributed.barrier()
143 | return current_epoch
144 | else:
145 | torch.distributed.barrier()
146 | return int(store.get("_hivemind_current_epoch"))
147 |
--------------------------------------------------------------------------------
/tasks/mlm/whole_word_mask.py:
--------------------------------------------------------------------------------
1 | import random
2 | from dataclasses import dataclass
3 | from typing import Dict, List, Optional, Tuple, Union
4 |
5 | import torch
6 | from transformers import DataCollatorForLanguageModeling
7 |
8 | try:
9 | from transformers.data.data_collator import _torch_collate_batch as collate_batch
10 | from transformers.data.data_collator import tolist
11 | except ImportError:
12 | from transformers.data.data_collator import _collate_batch as collate_batch, tolist
13 |
14 | from transformers.tokenization_utils_base import BatchEncoding
15 |
16 |
17 | def _is_start_piece_sp(piece):
18 | """Check if the current word piece is the starting piece (sentence piece)."""
19 | special_pieces = set(list('!"#$%&"()*+,-./:;?@[\\]^_`{|}~'))
20 | special_pieces.add(u"€".encode("utf-8"))
21 | special_pieces.add(u"£".encode("utf-8"))
22 | if piece.startswith("▁") or piece.startswith("<") or piece in special_pieces:
23 | return True
24 | else:
25 | return False
26 |
27 |
28 | @dataclass
29 | class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
30 | """
31 | Data collator used for language modeling that masks entire words.
32 |
33 | - collates batches of tensors, honoring their tokenizer's pad_token
34 | - preprocesses batches for masked language modeling
35 |
36 | .. note::
37 |
38 | This collator relies on details of the implementation of subword tokenization by
39 | :class:`~transformers.AlbertTokenizer`, specifically that start-of-word tokens are prefixed with `▁`.
40 | For tokenizers that do not adhere to this scheme, this collator will produce an output that is roughly
41 | equivalent to :class:`.DataCollatorForLanguageModeling`.
42 | """
43 |
44 | def __call__(
45 | self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
46 | ) -> Dict[str, torch.Tensor]:
47 | if isinstance(examples[0], (dict, BatchEncoding)):
48 | batch = self.tokenizer.pad(
49 | examples,
50 | return_tensors="pt",
51 | pad_to_multiple_of=self.pad_to_multiple_of,
52 | )
53 | else:
54 | batch = {"input_ids": collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
55 |
56 | # If special token mask has been preprocessed, pop it from the dict.
57 | special_tokens_mask = batch.pop("special_tokens_mask", None)
58 |
59 | mask_labels = []
60 | for example in batch["input_ids"]:
61 | ref_tokens = self.tokenizer.convert_ids_to_tokens(tolist(example))
62 | mask_labels.append(self._whole_word_mask(ref_tokens))
63 | batch_mask = collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
64 |
65 | batch["input_ids"], batch["labels"] = self.mask_tokens(
66 | batch["input_ids"], batch_mask, special_tokens_mask=special_tokens_mask
67 | )
68 | return batch
69 |
70 | def _whole_word_mask(self, input_tokens: List[str]):
71 | """
72 | Get 0/1 labels for masked tokens with whole word mask proxy
73 | """
74 |
75 | cand_indexes = []
76 | num_tokens_exc_pad = 0
77 | for i, token in enumerate(input_tokens):
78 | if token in (
79 | self.tokenizer.cls_token,
80 | self.tokenizer.sep_token,
81 | self.tokenizer.pad_token,
82 | ):
83 | continue
84 | num_tokens_exc_pad += 1
85 | if len(cand_indexes) >= 1 and not _is_start_piece_sp(token):
86 | cand_indexes[-1].append(i)
87 | else:
88 | cand_indexes.append([i])
89 |
90 | random.shuffle(cand_indexes)
91 |
92 | mask_labels = torch.zeros((len(input_tokens),), dtype=torch.long)
93 | num_tokens_to_mask = min(num_tokens_exc_pad, max(1, int(round(num_tokens_exc_pad * self.mlm_probability))))
94 | covered_indexes = set()
95 |
96 | for index_set in cand_indexes:
97 | if len(covered_indexes) >= num_tokens_to_mask:
98 | break
99 |
100 | # If adding a whole-word mask would exceed the maximum number of
101 | # predictions, then just skip this candidate.
102 | if len(covered_indexes) + len(index_set) > num_tokens_to_mask:
103 | continue
104 |
105 | is_any_index_covered = any(index in covered_indexes for index in index_set)
106 | if is_any_index_covered:
107 | continue
108 |
109 | for index in index_set:
110 | covered_indexes.add(index)
111 | mask_labels[index] = 1
112 |
113 | return mask_labels
114 |
115 | def mask_tokens(
116 | self,
117 | inputs: torch.Tensor,
118 | mask_labels: torch.Tensor,
119 | special_tokens_mask: Optional[torch.Tensor] = None,
120 | ) -> Tuple[torch.Tensor, torch.Tensor]:
121 | """
122 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
123 | 'mask_labels' means we use whole word mask (WMM), we directly mask idxs according to it's ref.
124 | """
125 | assert self.mlm
126 | labels = inputs.clone()
127 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
128 |
129 | probability_matrix = mask_labels
130 |
131 | if special_tokens_mask is None:
132 | special_tokens_mask = [
133 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
134 | ]
135 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
136 | else:
137 | special_tokens_mask = special_tokens_mask.bool()
138 |
139 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
140 |
141 | masked_indices = probability_matrix.bool()
142 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
143 |
144 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
145 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
146 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
147 |
148 | # 10% of the time, we replace masked input tokens with random word
149 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
150 | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
151 | inputs[indices_random] = random_words[indices_random]
152 |
153 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
154 | return inputs, labels
155 |
--------------------------------------------------------------------------------
/src/training/callback.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from typing import Any
3 |
4 | import hivemind
5 | import torch
6 | import torch.distributed
7 | import transformers
8 | from transformers import TrainingArguments
9 |
10 | from arguments import TrainingPeerArguments
11 | from src.training.sync import SynchronizationCallback
12 | from src.utils import LocalMetrics, logger
13 | from tasks.base import TrainingTaskBase
14 |
15 |
16 | class CollaborativeCallback(SynchronizationCallback):
17 | """
18 | This callback monitors and reports collaborative training progress,
19 | In case of a catastrophic failure, it can also revert training to a backup
20 | """
21 |
22 | def __init__(self, task: TrainingTaskBase, args: TrainingPeerArguments):
23 | super().__init__(task, args)
24 | self.dht, self.optimizer = task.dht, task.optimizer
25 | self.statistics_expiration = args.statistics_expiration
26 | self.last_reported_epoch = -1
27 | self.samples = 0
28 | self.mini_steps = 0 # number of calls to optimizer.step, NOT equal to the number of global steps
29 | self.loss = 0
30 | self.total_samples_processed = 0
31 | self.backup_every_epochs = args.backup_every_epochs
32 | self.state_path = args.state_path
33 |
34 | def on_train_begin(
35 | self,
36 | args: TrainingArguments,
37 | state: transformers.TrainerState,
38 | control: transformers.TrainerControl,
39 | **kwargs,
40 | ):
41 | if os.path.isfile(self.state_path):
42 | self.restore_from_backup(self.state_path)
43 | logger.info("Loaded state")
44 | else:
45 | logger.info("Loading state from peers")
46 | self.optimizer.load_state_from_peers()
47 | super().on_train_begin(args, state, control, **kwargs)
48 |
49 | def on_step_end(
50 | self,
51 | args: TrainingArguments,
52 | state: transformers.TrainerState,
53 | control: transformers.TrainerControl,
54 | **kwargs,
55 | ):
56 | super().on_step_end(args, state, control, **kwargs)
57 | if not self.params_are_finite():
58 | if not os.path.exists(self.state_path):
59 | raise RuntimeError("Encountered broken parameters, but there is no backup to fall back to.")
60 | logger.warning("Parameters are invalid, reloading model from earlier state")
61 | self.restore_from_backup(self.state_path)
62 | return control
63 |
64 | if state.log_history:
65 | self.loss += state.log_history[-1]["loss"]
66 | self.mini_steps += 1
67 | if self.optimizer.local_epoch != self.last_reported_epoch:
68 | self.last_reported_epoch = self.optimizer.local_epoch
69 | self.total_samples_processed += self.samples
70 | samples_per_second = self.optimizer.tracker.performance_ema.samples_per_second
71 | statistics = LocalMetrics(
72 | epoch=self.optimizer.local_epoch,
73 | samples_per_second=samples_per_second,
74 | samples_accumulated=self.samples,
75 | loss=self.loss,
76 | mini_steps=self.mini_steps,
77 | )
78 | logger.info(f"Current epoch: {self.optimizer.local_epoch}")
79 | logger.info(f"Your current contribution: {self.total_samples_processed} samples")
80 | logger.info(f"Performance: {samples_per_second} samples/sec")
81 | if self.mini_steps:
82 | logger.info(f"Local loss: {self.loss / self.mini_steps}")
83 |
84 | self.loss = 0
85 | self.mini_steps = 0
86 | if self.optimizer.local_epoch == self.optimizer.tracker.global_epoch:
87 | self.dht.store(
88 | key=self.optimizer.run_id + "_metrics",
89 | subkey=self.task.local_public_key,
90 | value=statistics.dict(),
91 | expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
92 | return_future=True,
93 | )
94 | if self.backup_every_epochs is not None and self.optimizer.local_epoch % self.backup_every_epochs == 0:
95 | self.backup_state()
96 |
97 | self.samples = self.optimizer.grad_averager.local_samples_accumulated
98 |
99 | return control
100 |
101 | @torch.no_grad()
102 | def params_are_finite(self):
103 | for param in self.task.model.parameters():
104 | if not torch.all(torch.isfinite(param)):
105 | return False
106 | return True
107 |
108 | @torch.no_grad()
109 | def backup_state(self) -> Any:
110 | logger.info("Saving backup")
111 | return torch.save(
112 | {
113 | "model": self.task.model.state_dict(),
114 | "optimizer": self.optimizer.state_dict(),
115 | "scheduler": self.optimizer.state_averager.scheduler.state_dict(),
116 | "local_epoch": self.optimizer.local_epoch,
117 | },
118 | self.state_path,
119 | )
120 |
121 | @torch.no_grad()
122 | def restore_from_backup(self, path, check_epoch=False):
123 | state = torch.load(path)
124 | current_epoch = self.optimizer.local_epoch
125 | if "model" not in state:
126 | logger.warning("Found weights-only checkpoint")
127 | self.task.model.load_state_dict(state, strict=True)
128 | return
129 | backup_epoch = state["local_epoch"]
130 |
131 | if not check_epoch or backup_epoch >= current_epoch:
132 | self.task.model.load_state_dict(state["model"], strict=False)
133 | self.optimizer.load_state_dict(state["optimizer"])
134 | self.optimizer.state_averager.scheduler.load_state_dict(state["scheduler"])
135 |
136 | if self.optimizer.offload_optimizer:
137 | state_averager = self.optimizer.state_averager
138 | offloaded_parameters = [
139 | param for group in state_averager.optimizer.param_groups for param in group["params"]
140 | ]
141 | assert len(offloaded_parameters) == len(state_averager.main_parameters)
142 | for main_param, offloaded_param in zip(state_averager.main_parameters, offloaded_parameters):
143 | offloaded_param.copy_(main_param, non_blocking=True)
144 |
145 | self.optimizer.state_averager.local_epoch = backup_epoch
146 |
147 | if not self.optimizer.client_mode:
148 | self.optimizer.state_averager.state_sharing_priority = self.optimizer.local_epoch
149 |
150 | if self.optimizer.use_gradient_averaging:
151 | self.optimizer.grad_averager.reset_accumulated_grads_()
152 | if not self.optimizer.client_mode:
153 | self.optimizer.grad_averager.state_sharing_priority = self.optimizer.local_epoch
154 |
155 | logger.info("Restored from a backup")
156 | else:
157 | logger.info("Bypassed restoring state from local backup: backup state is too old.")
158 |
--------------------------------------------------------------------------------
/tasks/mlm/task.py:
--------------------------------------------------------------------------------
1 | import ctypes
2 | import multiprocessing as mp
3 | import os
4 | from pathlib import Path
5 |
6 | import hivemind
7 | import torch.distributed
8 | import transformers
9 | from torch.optim.lr_scheduler import LambdaLR
10 | from transformers import AutoTokenizer
11 | import torch
12 |
13 | from arguments import BasePeerArguments, CollaborativeArguments, HFTrainerArguments
14 | from src.models.albert import LeanAlbertConfig, LeanAlbertForPreTraining
15 | from src.training.lamb_8bit import CPULAMB8Bit
16 | from tasks.base import LRSchedulerBase, ParamGroups, TrainingTaskBase, register_task
17 |
18 | from .data import make_training_dataset
19 | from .whole_word_mask import DataCollatorForWholeWordMask
20 |
21 | hivemind.use_hivemind_log_handler("in_root_logger")
22 | logger = hivemind.get_logger()
23 |
24 |
25 | @register_task("mlm")
26 | class MLMTrainingTask(TrainingTaskBase):
27 | """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
28 |
29 | _dht = _optimizer = _training_dataset = _authorizer = None
30 |
31 | def __init__(
32 | self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments
33 | ):
34 | transformers.set_seed(trainer_args.seed) # seed used for initialization
35 |
36 | self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path)
37 | self.tokenizer = AutoTokenizer.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir)
38 |
39 | output_dir = Path(trainer_args.output_dir)
40 | logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
41 | latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
42 |
43 | if latest_checkpoint_dir is None:
44 | logger.info(f"Creating model")
45 | model = LeanAlbertForPreTraining(self.config)
46 | model.resize_token_embeddings(len(self.tokenizer))
47 | else:
48 | logger.info(f"Loading model from {latest_checkpoint_dir}")
49 | model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
50 | if trainer_args.gradient_checkpointing:
51 | model.gradient_checkpointing_enable()
52 |
53 | super().__init__(model, peer_args, trainer_args, collab_args)
54 | self.current_sequence_length = mp.Value(ctypes.c_int64, self.trainer_args.max_sequence_length)
55 | self.update_sequence_length() # updated by callback
56 |
57 | def _make_param_groups(self) -> ParamGroups:
58 | no_decay = ["bias", "norm.weight"]
59 | return [
60 | {
61 | "params": [p for n, p in self.model.named_parameters() if not any(n.endswith(nd) for nd in no_decay)],
62 | "weight_decay": self.trainer_args.weight_decay,
63 | },
64 | {
65 | "params": [p for n, p in self.model.named_parameters() if any(n.endswith(nd) for nd in no_decay)],
66 | "weight_decay": 0.0,
67 | },
68 | ]
69 |
70 | def _make_base_optimizer(self, param_groups: ParamGroups) -> torch.optim.Optimizer:
71 | return CPULAMB8Bit(
72 | param_groups,
73 | lr=self.trainer_args.learning_rate,
74 | betas=(self.trainer_args.adam_beta1, self.trainer_args.adam_beta2),
75 | min_8bit_size=self.trainer_args.min_8bit_size,
76 | max_grad_norm=self.trainer_args.max_grad_norm,
77 | clamp_value=self.trainer_args.clamp_value,
78 | eps=self.trainer_args.adam_epsilon,
79 | weight_decay=self.trainer_args.weight_decay,
80 | reuse_grad_buffers=True,
81 | bias_correction=True,
82 | )
83 |
84 | def _make_scheduler(self, optimizer: torch.optim.Optimizer) -> LRSchedulerBase:
85 | num_warmup_steps = self.trainer_args.warmup_steps
86 | num_training_steps = self.trainer_args.total_steps
87 | min_learning_rate = self.trainer_args.min_learning_rate
88 |
89 |
90 |
91 | def lr_lambda(current_step: int):
92 | if current_step < 50:
93 | return 0
94 | if current_step < num_warmup_steps:
95 | return float(current_step-50) / float(max(1, num_warmup_steps-50))
96 | decaying = float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
97 | return max(0.02, decaying)
98 |
99 |
100 | return LambdaLR(optimizer, lr_lambda)
101 |
102 | def lr_lambda(current_step: int):
103 | if current_step < num_warmup_steps:
104 | return float(current_step) / float(max(1, num_warmup_steps))
105 | decaying = float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
106 | return max(min_learning_rate, decaying)
107 | @property
108 | def training_dataset(self):
109 | if self._training_dataset is None:
110 | self._training_dataset = make_training_dataset(
111 | self.tokenizer,
112 | shuffle_seed=hash(self.local_public_key) % 2 ** 31,
113 | max_sequence_length=self.current_sequence_length, # this is a mp.Value that will be changed later
114 | )
115 | return self._training_dataset
116 |
117 | def on_step_end(self):
118 | return self.update_sequence_length()
119 |
120 | def update_sequence_length(self):
121 | """
122 | If ramp-up is enabled, start with smaller sequences of initial_sequence_length tokens, then increase linearly
123 | to the max_sequence_length over the period of first
124 | """
125 | current_epoch = self.get_current_epoch()
126 | if (
127 | self.trainer_args.sequence_length_warmup_steps == 0
128 | or current_epoch > self.trainer_args.sequence_length_warmup_steps
129 | ):
130 | current_sequence_length = self.trainer_args.max_sequence_length
131 | else:
132 | increment_size = self.trainer_args.pad_to_multiple_of
133 | max_sequence_length = self.trainer_args.max_sequence_length
134 | initial_sequence_length = self.trainer_args.initial_sequence_length or increment_size
135 | sequence_length_warmup_steps = self.trainer_args.sequence_length_warmup_steps
136 | assert sequence_length_warmup_steps > 0 and max_sequence_length >= initial_sequence_length
137 | length_range = max_sequence_length - initial_sequence_length
138 | warmup_relative = min(1, current_epoch / sequence_length_warmup_steps)
139 | current_sequence_length = initial_sequence_length + warmup_relative * length_range
140 | current_sequence_length = (current_sequence_length // increment_size) * increment_size
141 | current_sequence_length = min(max(current_sequence_length, initial_sequence_length), max_sequence_length)
142 |
143 | current_sequence_length = int(current_sequence_length)
144 | if current_sequence_length != self.current_sequence_length.value:
145 | logger.info(f"Transitioning to sequence length {current_sequence_length}")
146 | self.current_sequence_length.value = current_sequence_length
147 | # note: it may take time for new sequence length to take effect due to buffering
148 |
149 | @property
150 | def data_collator(self):
151 | return DataCollatorForWholeWordMask(
152 | tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of
153 | )
154 |
--------------------------------------------------------------------------------
/run_aux_peer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import threading
3 | import time
4 |
5 | import scipy.stats # compatibility for internal testing environment
6 | import torch
7 | import transformers
8 | import wandb
9 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler
10 | from huggingface_hub import Repository
11 | from transformers import HfArgumentParser
12 |
13 | from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
14 | from src import utils
15 | from tasks.mlm.task import MLMTrainingTask
16 |
17 | transformers.utils.logging.set_verbosity_warning()
18 | use_hivemind_log_handler("in_root_logger")
19 | logger = get_logger(__name__)
20 | torch.set_num_threads(1) # avoid quadratic number of threads
21 |
22 |
23 | class CheckpointHandler:
24 | def __init__(self, task: MLMTrainingTask, peer_args: AuxiliaryPeerArguments):
25 | self.task, self.peer_args = task, peer_args
26 | self.save_checkpoint_epoch_interval = peer_args.save_checkpoint_epoch_interval
27 | self.prefix = peer_args.run_id
28 | self.local_path = peer_args.local_path
29 | self.upload_interval = peer_args.upload_interval
30 | if self.upload_interval is not None:
31 | assert task.authorizer is not None, "Model uploading needs Hugging Face auth to be enabled"
32 | self.repo = Repository(
33 | local_dir=self.local_path,
34 | clone_from=peer_args.repo_url,
35 | use_auth_token=task.authorizer.hf_user_access_token,
36 | )
37 | self.last_upload_time = None
38 | self.previous_epoch = -1
39 |
40 | def should_save_state(self, current_epoch: int):
41 | if self.save_checkpoint_epoch_interval is None:
42 | return False
43 | elif current_epoch - self.previous_epoch >= self.save_checkpoint_epoch_interval:
44 | return True
45 | else:
46 | return False
47 |
48 | def save_state(self, current_epoch: int):
49 | logger.info("Saving state from peers")
50 | self.task.optimizer.load_state_from_peers()
51 | self.previous_epoch = current_epoch
52 |
53 | def is_time_to_upload(self):
54 | if self.upload_interval is None:
55 | return False
56 | elif self.last_upload_time is None or time.time() - self.last_upload_time >= self.upload_interval:
57 | return True
58 | else:
59 | return False
60 |
61 | def upload_checkpoint(self, current_loss: float):
62 | self.last_upload_time = time.time()
63 |
64 | logger.info("Saving model")
65 | torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
66 | logger.info("Saving optimizer")
67 | torch.save(self.task.optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
68 | self.previous_timestamp = time.time()
69 | logger.info("Started uploading to Model Hub")
70 | try:
71 | # We start by pulling the remote changes (for example a change in the readme file)
72 | self.repo.git_pull()
73 |
74 | # Then we add / commmit and push the changes
75 | self.repo.push_to_hub(commit_message=f"Epoch {self.task.optimizer.local_epoch}, loss {current_loss:.3f}")
76 | logger.info("Finished uploading to Model Hub")
77 | except Exception:
78 | logger.exception("Uploading the checkpoint to HF Model Hub failed:")
79 | logger.warning("Ensure that your access token is valid and has WRITE permissions")
80 |
81 |
82 | def assist_averaging_in_background(
83 | lock: threading.Lock, task: MLMTrainingTask, peer_args: AuxiliaryPeerArguments, finished: threading.Event
84 | ):
85 | while not finished.is_set():
86 | try:
87 | time.sleep(peer_args.assist_refresh)
88 | with lock:
89 | task.optimizer.step()
90 | except Exception as e:
91 | logger.exception(e, exc_info=True)
92 |
93 |
94 | if __name__ == "__main__":
95 | parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
96 | peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
97 | finished, lock = threading.Event(), threading.Lock()
98 |
99 | task = MLMTrainingTask(peer_args, trainer_args, collab_args)
100 | dht, optimizer = task.dht, task.optimizer
101 |
102 | if peer_args.wandb_project is not None:
103 | wandb.init(project=peer_args.wandb_project)
104 |
105 | current_epoch = 0
106 | if peer_args.store_checkpoints:
107 | checkpoint_handler = CheckpointHandler(task, peer_args)
108 |
109 | if peer_args.assist_in_averaging:
110 | assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
111 | averaging_thread = threading.Thread(
112 | name="AveragingAuxThread",
113 | target=assist_averaging_in_background,
114 | args=[lock, task, peer_args, finished],
115 | daemon=True,
116 | )
117 | averaging_thread.start()
118 |
119 | try:
120 | while True:
121 | metrics_entry = dht.get(peer_args.run_id + "_metrics", latest=True)
122 | if metrics_entry is not None and len(metrics_entry.value) > 0:
123 | metrics_dict = metrics_entry.value
124 | metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
125 | latest_epoch = max(item.epoch for item in metrics)
126 |
127 | if latest_epoch != current_epoch:
128 | logger.debug(f"Got metrics from {len(metrics)} peers")
129 |
130 | for i, metrics_for_peer in enumerate(metrics):
131 | logger.debug(f"{i} peer {metrics_for_peer}")
132 |
133 | current_epoch = latest_epoch
134 | alive_peers = 0
135 | sum_loss = 0
136 | num_samples = 0
137 | sum_perf = 0
138 | sum_mini_steps = 0
139 |
140 | for item in metrics:
141 | sum_loss += item.loss
142 | alive_peers += 1
143 | sum_perf += item.samples_per_second
144 | num_samples += item.samples_accumulated
145 | sum_mini_steps += item.mini_steps
146 | current_loss = sum_loss / sum_mini_steps
147 | logger.info(f"Epoch #{current_epoch}\tloss = {current_loss:.5f}")
148 |
149 | if peer_args.wandb_project is not None:
150 | wandb.log(
151 | {
152 | "loss": current_loss,
153 | "alive peers": alive_peers,
154 | "samples": num_samples,
155 | "performance": sum_perf,
156 | "optimizer_step": latest_epoch,
157 | },
158 | step=latest_epoch,
159 | )
160 |
161 | if peer_args.store_checkpoints:
162 | if checkpoint_handler.should_save_state(current_epoch):
163 | with lock:
164 | checkpoint_handler.save_state(current_epoch)
165 | if checkpoint_handler.is_time_to_upload():
166 | checkpoint_handler.upload_checkpoint(current_loss)
167 | logger.debug("Peer is still alive...")
168 | time.sleep(peer_args.refresh_period)
169 | finally:
170 | finished.set()
171 |
--------------------------------------------------------------------------------
/src/huggingface_auth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from datetime import datetime, timedelta
4 | from getpass import getpass
5 |
6 | import requests
7 | from hivemind.proto.auth_pb2 import AccessToken
8 | from hivemind.utils.auth import TokenAuthorizerBase
9 | from hivemind.utils.crypto import RSAPublicKey
10 | from hivemind.utils.logging import get_logger
11 | from huggingface_hub import HfApi
12 | from termcolor import colored
13 |
14 | logger = get_logger("root." + __name__)
15 |
16 |
17 | class NonRetriableError(Exception):
18 | pass
19 |
20 |
21 | def call_with_retries(func, n_retries=10, initial_delay=1.0):
22 | for i in range(n_retries):
23 | try:
24 | return func()
25 | except NonRetriableError:
26 | raise
27 | except Exception as e:
28 | if i == n_retries - 1:
29 | raise
30 |
31 | delay = initial_delay * (2 ** i)
32 | logger.warning(f"Failed to call `{func.__name__}` with exception: {e}. Retrying in {delay:.1f} sec")
33 | time.sleep(delay)
34 |
35 |
36 | class InvalidCredentialsError(NonRetriableError):
37 | pass
38 |
39 |
40 | class NotInAllowlistError(NonRetriableError):
41 | pass
42 |
43 |
44 | class HuggingFaceAuthorizer(TokenAuthorizerBase):
45 | _AUTH_SERVER_URL = "https://collaborative-training-auth.huggingface.co"
46 |
47 | def __init__(self, organization_name: str, model_name: str, hf_user_access_token: str):
48 | super().__init__()
49 |
50 | self.organization_name = organization_name
51 | self.model_name = model_name
52 | self.hf_user_access_token = hf_user_access_token
53 |
54 | self._authority_public_key = None
55 | self.coordinator_ip = None
56 | self.coordinator_port = None
57 |
58 | self._hf_api = HfApi()
59 |
60 | async def get_token(self) -> AccessToken:
61 | """
62 | Hivemind calls this method to refresh the token when necessary.
63 | """
64 |
65 | self.join_experiment()
66 | return self._local_access_token
67 |
68 | @property
69 | def username(self):
70 | return self._local_access_token.username
71 |
72 | def join_experiment(self) -> None:
73 | call_with_retries(self._join_experiment)
74 |
75 | def _join_experiment(self) -> None:
76 | try:
77 | url = f"{self._AUTH_SERVER_URL}/api/experiments/join"
78 | headers = {"Authorization": f"Bearer {self.hf_user_access_token}"}
79 | response = requests.put(
80 | url,
81 | headers=headers,
82 | params={
83 | "organization_name": self.organization_name,
84 | "model_name": self.model_name,
85 | },
86 | json={
87 | "experiment_join_input": {
88 | "peer_public_key": self.local_public_key.to_bytes().decode(),
89 | },
90 | },
91 | )
92 |
93 | response.raise_for_status()
94 | response = response.json()
95 |
96 | self._authority_public_key = RSAPublicKey.from_bytes(response["auth_server_public_key"].encode())
97 | self.coordinator_ip = response["coordinator_ip"]
98 | self.coordinator_port = response["coordinator_port"]
99 |
100 | token_dict = response["hivemind_access"]
101 | access_token = AccessToken()
102 | access_token.username = token_dict["username"]
103 | access_token.public_key = token_dict["peer_public_key"].encode()
104 | access_token.expiration_time = str(datetime.fromisoformat(token_dict["expiration_time"]))
105 | access_token.signature = token_dict["signature"].encode()
106 | self._local_access_token = access_token
107 | logger.info(
108 | f"Access for user {access_token.username} " f"has been granted until {access_token.expiration_time} UTC"
109 | )
110 | except requests.exceptions.HTTPError as e:
111 | if e.response.status_code == 401: # Unauthorized
112 | raise NotInAllowlistError()
113 | raise
114 |
115 | def is_token_valid(self, access_token: AccessToken) -> bool:
116 | data = self._token_to_bytes(access_token)
117 | if not self._authority_public_key.verify(data, access_token.signature):
118 | logger.exception("Access token has invalid signature")
119 | return False
120 |
121 | try:
122 | expiration_time = datetime.fromisoformat(access_token.expiration_time)
123 | except ValueError:
124 | logger.exception(f"datetime.fromisoformat() failed to parse expiration time: {access_token.expiration_time}")
125 | return False
126 | if expiration_time.tzinfo is not None:
127 | logger.exception(f"Expected to have no timezone for expiration time: {access_token.expiration_time}")
128 | return False
129 | if expiration_time < datetime.utcnow():
130 | logger.exception("Access token has expired")
131 | return False
132 |
133 | return True
134 |
135 | _MAX_LATENCY = timedelta(minutes=1)
136 |
137 | def does_token_need_refreshing(self, access_token: AccessToken) -> bool:
138 | expiration_time = datetime.fromisoformat(access_token.expiration_time)
139 | return expiration_time < datetime.utcnow() + self._MAX_LATENCY
140 |
141 | @staticmethod
142 | def _token_to_bytes(access_token: AccessToken) -> bytes:
143 | return f"{access_token.username} {access_token.public_key} {access_token.expiration_time}".encode()
144 |
145 |
146 | def authorize_with_huggingface() -> HuggingFaceAuthorizer:
147 | while True:
148 | organization_name = os.getenv("HF_ORGANIZATION_NAME")
149 | if organization_name is None:
150 | organization_name = input("HuggingFace organization name: ")
151 |
152 | model_name = os.getenv("HF_MODEL_NAME")
153 | if model_name is None:
154 | model_name = input("HuggingFace model name: ")
155 |
156 | hf_user_access_token = os.getenv("HF_USER_ACCESS_TOKEN")
157 | if hf_user_access_token is None:
158 | print(
159 | "\nCopy a token from 🤗 Hugging Face settings page at "
160 | f"{colored('https://huggingface.co/settings/token', attrs=['bold'])} "
161 | "and paste it here.\n\n"
162 | f"💡 {colored('Tip:', attrs=['bold'])} "
163 | "If you don't already have one, you can create a dedicated user access token.\n"
164 | f"Go to {colored('https://huggingface.co/settings/token', attrs=['bold'])}, "
165 | f"click the {colored('New token', attrs=['bold'])} button, "
166 | f"and choose the {colored('read', attrs=['bold'])} role.\n"
167 | )
168 | os.environ["HF_USER_ACCESS_TOKEN"] = hf_user_access_token = getpass(
169 | "🤗 Hugging Face user access token (characters will be hidden): "
170 | )
171 |
172 | authorizer = HuggingFaceAuthorizer(organization_name, model_name, hf_user_access_token)
173 |
174 | try:
175 | authorizer.join_experiment()
176 | print(f"🚀 You will contribute to the collaborative training under the username {authorizer.username}")
177 | return authorizer
178 | except InvalidCredentialsError:
179 | print("Invalid user access token, please try again")
180 | except NotInAllowlistError:
181 | print(
182 | "\n😥 Authentication has failed.\n\n"
183 | "This error may be due to the fact:\n"
184 | " 1. Your user access token is not valid. You can try to delete the previous token and "
185 | "recreate one. Be careful, organization tokens can't be used to join a collaborative "
186 | "training.\n"
187 | f" 2. You have not yet joined the {organization_name} organization. You can request to "
188 | "join this organization by clicking on the 'request to join this org' button at "
189 | f"https://huggingface.co/{organization_name}.\n"
190 | f" 3. The {organization_name} organization doesn't exist at https://huggingface.co/{organization_name}.\n"
191 | f" 4. No {organization_name}'s admin has created a collaborative training for the {organization_name} "
192 | f"organization and the {model_name} model."
193 | )
194 |
--------------------------------------------------------------------------------
/src/modules/attn.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.checkpoint import checkpoint
7 |
8 | from src.modules.rotary import RotaryEmbeddings
9 |
10 |
11 | class LeanSelfAttention(nn.Module):
12 | def __init__(
13 | self,
14 | hidden_size: int,
15 | num_attention_heads: int,
16 | dropout: float = 0,
17 | layer_norm_eps: float = 1e-12,
18 | sandwich_norm: bool = False,
19 | dense_qkv: Optional[nn.Linear] = None,
20 | dense_out: Optional[nn.Linear] = None,
21 | residual: bool = True,
22 | attention_core: Optional[nn.Module] = None,
23 | checkpoint_attention_core: bool = True,
24 | **kwargs,
25 | ):
26 | """
27 | Attention layer that does not hog GPU memory. Re-computes pairwise attention weights instead of storing them.
28 |
29 | :note: this code is relatively less optimized than FFN because attention is usually not a memory bottleneck
30 | for typical sequence lengths (e.g. 2048 in language or 1024 in vision). If training with longer sequences,
31 | one can use query chunking: running one chunk of queries at a time, without storing the full QxK matrix.
32 | This technique runs in O(length) memory instead of O(length^2), making it not-a-bottleneck compared to FFN
33 |
34 | :param hidden_size: base hidden size of the transformer, before q/k/v projections
35 | :param num_attention_heads: number of heads, as defined in the original transformer
36 | :param dropout: hidden dropout probability, applied to the output projection (before adding residual)
37 | :param layer_norm_eps: see torch.nn.functional.layer_norm
38 | :param sandwich_norm: if set, applies an additional layer norm to projected attention outputs before residuals,
39 | as proposed in the CogView paper ( arXiv:2105.13290 ). This is meant to make fp16 training
40 | more stable for deep transformers. This technique is also a part of NormFormer ( arXiv:2110.09456 )
41 | :param residual: if True, adds the original layer input to the final layer output
42 | :param attention_core: optionally provide custom attention function. See SimpleAttentionCore for inspiration.
43 | :param checkpoint_attention_core: re-compute attention weights during backward pass instead of storing them
44 | :param dense_qkv: custom QKV projection layer (hidden_size -> 3 * hidden_size)
45 | :param dense_out: custom output projection layer (hidden_size -> hidden_size)
46 | :param kwargs: additional kwargs are passed to the chosen attention core
47 | """
48 | super().__init__()
49 | if attention_core is None:
50 | attention_core = SimpleAttentionCore(hidden_size, num_attention_heads, **kwargs)
51 | else:
52 | assert len(kwargs) == 0, f"Unexpected parameters: {kwargs}"
53 |
54 | self.hidden_size = hidden_size
55 | self.attention_core = attention_core
56 | self.dense_qkv = nn.Linear(hidden_size, hidden_size * 3) if dense_qkv is None else dense_qkv
57 | self.dense_out = nn.Linear(hidden_size, hidden_size) if dense_out is None else dense_out
58 | assert self.dense_qkv.in_features == self.dense_out.in_features == self.dense_out.out_features == hidden_size
59 | assert self.dense_qkv.out_features == hidden_size * 3
60 |
61 | self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
62 | self.sandwich_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if sandwich_norm else None
63 | self.output_dropout = nn.Dropout(dropout, inplace=False)
64 | self.residual, self.checkpoint_attention_core = residual, checkpoint_attention_core
65 |
66 | def forward(self, hidden_states, attention_mask=None, output_attentions=False):
67 | hidden_states_ln = self.layer_norm(hidden_states)
68 | qkv_output = self.dense_qkv(hidden_states_ln)
69 | query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1)
70 | attention_output, attention_probs = self._maybe_checkpoint(
71 | self.attention_core, query, key, value, attention_mask
72 | )
73 | outputs = self.dense_out(attention_output)
74 | if self.sandwich_norm:
75 | outputs = self.sandwich_norm(outputs)
76 | outputs = self.output_dropout(outputs)
77 | if self.residual:
78 | outputs = outputs + hidden_states.to(torch.float32, copy=False)
79 | return (outputs, attention_probs) if output_attentions else (outputs,)
80 |
81 | def _maybe_checkpoint(self, func, *args):
82 | return checkpoint(func, *args) if torch.is_grad_enabled() and self.checkpoint_attention_core else func(*args)
83 |
84 |
85 | class SimpleAttentionCore(nn.Module):
86 | def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0):
87 | super().__init__()
88 | assert hidden_size % num_attention_heads == 0
89 | self.attention_dropout = nn.Dropout(attention_probs_dropout, inplace=False)
90 | self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads
91 | self.attention_head_size = hidden_size // num_attention_heads
92 |
93 | def forward(self, query, key, value, attention_mask):
94 | """
95 | :param query: [batch_size, query_seq_len, hidden_size]
96 | :param key: [batch_size, kv_seq_len, hidden_size]
97 | :param value: [batch_size, kv_seq_len, hidden_size]
98 | :param attention_mask: float [(optional heads), batch, query_seq_len, kv_seq_length]
99 | :note: attention_mask should be equal to zero for non-masked tokens and a large negative value for masked ones
100 | :return: (outputs, probs)
101 | - outputs shape: [batch_size, query_seq_len, hidden_size]
102 | - probs shape: [batch_size, num_heads, query_seq_len, kv_seq_len]
103 | """
104 | if attention_mask is not None:
105 | assert torch.is_floating_point(attention_mask), "expected float mask with negative values for masked items"
106 | return self._attention_core_forward(
107 | query, key, value, attention_mask, self.num_attention_heads, self.attention_dropout.p, self.training
108 | )
109 |
110 | @staticmethod
111 | def _attention_core_forward(
112 | query: torch.Tensor,
113 | key: torch.Tensor,
114 | value: torch.Tensor,
115 | attention_mask: Optional[torch.Tensor],
116 | num_attention_heads: int, attention_dropout: float, training: bool
117 | ) -> Tuple[torch.Tensor, torch.Tensor]:
118 | # transpose from [batch, seq_length, full_hid_size] to [batch, num_heads, seq_length, head_size]
119 | new_query_shape = query.shape[:-1] + (num_attention_heads, -1)
120 | new_kv_shape = key.shape[:-1] + (num_attention_heads, -1)
121 |
122 | query = query.view(new_query_shape).permute(0, 2, 1, 3)
123 | key_transposed = key.view(new_kv_shape).permute(0, 2, 3, 1) # swap to [..., head_size, seq_length]
124 | value = value.view(new_kv_shape).permute(0, 2, 1, 3)
125 | del key # not to confuse with key_transposed
126 |
127 | # Take the dot product between "query" and "key" to get the raw attention scores.
128 | attention_scores = torch.matmul(query, key_transposed / math.sqrt(query.shape[-1]))
129 |
130 | if attention_mask is not None:
131 | attention_scores += attention_mask
132 |
133 | # Normalize the attention scores to probabilities.
134 | attention_probs = torch.softmax(attention_scores, dim=-1)
135 |
136 | # This is actually dropping out entire tokens to attend to, which might
137 | # seem a bit unusual, but is taken from the original Transformer paper.
138 | attention_probs = torch.dropout(attention_probs, attention_dropout, training)
139 | attention_output = torch.matmul(attention_probs, value)
140 | attention_output = attention_output.transpose(2, 1).flatten(2)
141 |
142 | return attention_output, attention_probs
143 |
144 |
145 | class RotaryAttentionCore(SimpleAttentionCore):
146 | """Attention core that applies rotary embeddings to queries and keys before computing dot products"""
147 |
148 | def __init__(
149 | self, hidden_size: int, num_attention_heads: int, rotary_emb: Optional[RotaryEmbeddings] = None, **kwargs
150 | ):
151 | super().__init__(hidden_size, num_attention_heads, **kwargs)
152 | if rotary_emb is None:
153 | rotary_emb = RotaryEmbeddings(self.attention_head_size)
154 | self.rotary_emb = rotary_emb
155 |
156 | def rotate(self, tensor: torch.Tensor):
157 | """:param tensor: query or key, shape: [batch_size, query_seq_len, hidden_size]"""
158 | tensor_split_heads = tensor.view(*(tensor.shape[:-1] + (self.num_attention_heads, self.attention_head_size)))
159 | return self.rotary_emb(tensor_split_heads).view(*tensor.shape)
160 |
161 | def forward(self, query, key, value, attention_mask):
162 | return super().forward(self.rotate(query), self.rotate(key), value, attention_mask)
163 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Optional
3 |
4 | import torch
5 | from transformers import TrainingArguments
6 |
7 |
8 | @dataclass
9 | class CollaborativeArguments:
10 | """Configuration for CollaborativeOptimizer and its internals"""
11 |
12 | target_batch_size: int = field(
13 | default=16384,
14 | metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
15 | )
16 | matchmaking_time: float = field(
17 | default=60.0,
18 | metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"},
19 | )
20 | next_chunk_timeout: float = field(
21 | default=60.0,
22 | metadata={"help": "Consider allreduce peer failed if it does not respond in this many seconds"},
23 | )
24 | averaging_timeout: float = field(
25 | default=600.0,
26 | metadata={"help": "Give up on averaging step after this many seconds"},
27 | )
28 | offload_optimizer: bool = field(default=True, metadata={"help": "Whether or not to offload optimizer into RAM"})
29 | delay_optimizer_step: bool = field(
30 | default=True,
31 | metadata={"help": "Whether or not to run optimizer step in background"},
32 | )
33 | delay_grad_averaging: bool = field(
34 | default=True,
35 | metadata={"help": "Whether or not to run gradient averaging in background"},
36 | )
37 | average_state_every: int = field(default=5, metadata={"help": "Average parameters every this many epochs"})
38 | reuse_grad_buffers: bool = field(
39 | default=True,
40 | metadata={
41 | "help": "Whether or not to use model's .grad buffers for accumulating gradients across local steps. This "
42 | "optimization reduces GPU memory consumption but may result in incorrect gradients when using some "
43 | "advanced techniques (e.g. changing loss scaler to a custom one)."
44 | },
45 | )
46 |
47 |
48 | @dataclass
49 | class HFTrainerArguments(TrainingArguments):
50 | """Arguments for huggingface/transformers.Trainer"""
51 |
52 | per_device_train_batch_size: int = 1
53 | per_device_eval_batch_size: int = 1
54 | gradient_accumulation_steps: int = 1
55 |
56 | learning_rate: float = 2.5e-3 # based on https://arxiv.org/abs/1904.00962
57 | total_steps: int = 15625 # total number of collaborative optimizer updates, used for learning rate schedule
58 | warmup_steps: int = 3125 # based on https://arxiv.org/abs/1904.00962
59 | min_learning_rate: float = 1e-5 # learning rate after total_steps have passed
60 | adam_beta1: float = 0.9
61 | adam_beta2: float = 0.95
62 | adam_epsilon: float = 1e-6
63 | weight_decay: float = 0.01
64 | max_grad_norm: float = 1.0 # clipping performed by the optimizer; trainer is modified to disable builtin clipping
65 | clamp_value: float = 1e9 # no clipping by value
66 | min_8bit_size: int = 2 ** 20
67 |
68 | gradient_checkpointing: bool = False # can be enabled to save memory at the cost of ~30% slower training
69 | fp16: bool = False # can be enabled depending on the device
70 |
71 | max_sequence_length: int = 2048
72 | initial_sequence_length: Optional[int] = 256 # used only if warmup > 0, default = pad_to_multiple_of
73 | sequence_length_warmup_steps: int = 7_000
74 | pad_to_multiple_of: int = 128 # sequence length will be divisible by this value
75 |
76 | output_dir: str = "outputs"
77 | logging_steps: int = 100
78 |
79 | # params that should *not* be changed*
80 | do_train: bool = True
81 | do_eval: bool = False
82 | logging_first_step = True
83 | dataloader_num_workers: int = 0 # temporary fix for https://github.com/huggingface/datasets/issues/3148
84 | max_steps: int = 10 ** 30
85 | save_steps: int = 10 ** 30
86 | save_total_limit: int = 2
87 | ddp_find_unused_parameters: bool = False
88 |
89 | @property
90 | def batch_size_per_step(self):
91 | """Compute the number of training sequences contributed by each .step() from this peer"""
92 | total_batch_size_per_step = self.per_device_train_batch_size * self.gradient_accumulation_steps
93 | if torch.cuda.device_count() > 0:
94 | total_batch_size_per_step *= torch.cuda.device_count()
95 | return total_batch_size_per_step
96 |
97 |
98 | @dataclass
99 | class BasePeerArguments:
100 | """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
101 |
102 | run_id: str = field(metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
103 | model_config_path: Optional[str] = field(default="./config.json", metadata={"help": "Path to the model config"})
104 | tokenizer_path: Optional[str] = field(default="./tokenizer", metadata={"help": "Path to the tokenizer"})
105 | cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
106 | authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
107 | client_mode: bool = field(
108 | default=False,
109 | metadata={"help": "If True, runs training without incoming connections, in a firewall-compatible mode"},
110 | )
111 | bandwidth: Optional[float] = field(
112 | default=None,
113 | metadata={"help": "Min(upload & download speed) in megabits/s, used to assign averaging tasks between peers"},
114 | )
115 | min_vector_size: int = 4_000_000 # minimum slice of gradients assigned to one reducer, should be same across peers
116 | initial_peers: List[str] = field(
117 | default_factory=list,
118 | metadata={
119 | "help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
120 | "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
121 | },
122 | )
123 | use_ipfs: bool = field(
124 | default=False,
125 | metadata={
126 | "help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs "
127 | "for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)"
128 | },
129 | )
130 | host_maddrs: List[str] = field(
131 | default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
132 | metadata={
133 | "help": "Multiaddrs to listen for external connections from other p2p instances. "
134 | "Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0"
135 | },
136 | )
137 | announce_maddrs: List[str] = field(
138 | default_factory=list,
139 | metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
140 | )
141 | identity_path: Optional[str] = field(
142 | default=None,
143 | metadata={
144 | "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
145 | "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
146 | },
147 | )
148 |
149 |
150 | @dataclass
151 | class TrainingPeerArguments(BasePeerArguments):
152 | statistics_expiration: float = field(
153 | default=600,
154 | metadata={"help": "Statistics will be removed if not updated in this many seconds"},
155 | )
156 | backup_every_epochs: Optional[int] = field(
157 | default=None,
158 | metadata={
159 | "help": "Update training state backup on disk once in this many global steps "
160 | "(default = do not update local state)"
161 | },
162 | )
163 | state_path: str = field(
164 | default="state.zip",
165 | metadata={"help": "Load this state upon init and when recovering from NaN parameters"},
166 | )
167 |
168 |
169 | @dataclass
170 | class AuxiliaryPeerArguments(BasePeerArguments):
171 | """
172 | Arguments for run_aux_peer.py that is responsible for connecting peers to one another, tracking
173 | learning curves, assisting in all-reduce and uploading checkpoints to the hub
174 | """
175 |
176 | refresh_period: float = field(
177 | default=10,
178 | metadata={"help": "Period (in seconds) for fetching the keys from DHT"},
179 | )
180 | wandb_project: Optional[str] = field(
181 | default=None,
182 | metadata={"help": "Name of Weights & Biases project to report the training progress to"},
183 | )
184 | save_checkpoint_epoch_interval: int = field(
185 | default=5,
186 | metadata={"help": "Frequency (in steps) of fetching and saving state from peers"},
187 | )
188 | repo_url: Optional[str] = field(
189 | default=None,
190 | metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"},
191 | )
192 | local_path: Optional[str] = field(
193 | default="Repo",
194 | metadata={"help": "Path to local repository to store the model and optimizer states"},
195 | )
196 | upload_interval: Optional[float] = field(
197 | default=None,
198 | metadata={"help": "Frequency (in seconds) of uploading the model to Hub"},
199 | )
200 | store_checkpoints: bool = field(default=True, metadata={"help": "If True, enables CheckpointHandler"})
201 | assist_in_averaging: bool = field(
202 | default=False,
203 | metadata={"help": "If True, this peer will facilitate averaging for other (training) peers"},
204 | )
205 | assist_refresh: float = field(
206 | default=1.0,
207 | metadata={"help": "Period (in seconds) for tryin to assist averaging"},
208 | )
209 |
--------------------------------------------------------------------------------
/src/modules/pixelfly.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import math
3 | from typing import Optional, Tuple
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from src.modules.functional import maybe_script
9 |
10 |
11 | @functools.lru_cache()
12 | def get_butterfly_indices(
13 | out_features: int,
14 | in_features: int,
15 | block_size: int = 256,
16 | butterfly_size: Optional[int] = None,
17 | n_factors: Optional[int] = None,
18 | stretch: bool = False,
19 | ) -> Tuple[torch.IntTensor, torch.IntTensor]:
20 | """
21 | Get a matrix [num_output_blocks, num_active_input_blocks] with int32 indices for additive butterfly.
22 | The values in the matrix represent
23 | Based on the original implementation from https://arxiv.org/abs/2112.00029 .
24 |
25 | :param stretch: by default, non-square matrices will have stretched butterfly patterns,
26 | otherwise the square pattern will be repeated a given number of times
27 |
28 | :returns: tuple (forward_indices, backward_indices), where
29 | - (forward) indices of non-zero blocks that contribute to each output -- assuming all input blocks are flattened
30 | - (backward) indices of output blocks to which a given input block contributes
31 | """
32 | if butterfly_size is None:
33 | butterfly_size = 2 ** int(math.ceil(math.log2(min(in_features, out_features) / block_size)))
34 | assert out_features % in_features == 0 or in_features % out_features == 0, \
35 | "if matrix is not square, the longer dimension must be a multiple of the shorter dimension"
36 | assert out_features % block_size == 0 and in_features % block_size == 0
37 | log_n = int(math.log2(butterfly_size))
38 | n_factors = log_n if n_factors is None else n_factors
39 | if butterfly_size != 2 ** log_n or butterfly_size < 2:
40 | raise NotImplementedError("butterfly_size must be a power of 2")
41 | if not (1 <= n_factors <= log_n):
42 | raise NotImplementedError("n_factors must be a between 1 and log_2(butterfly_size)")
43 |
44 | twiddle = torch.ones(butterfly_size // 2, 2, 2)
45 | layout = sum(butterfly_factor_to_matrix(twiddle, index) for index in range(n_factors)).bool().int()
46 | # Convert from (butterfly_size, butterfly_size) mask to (out_features, in_features) mask
47 | layout = einops.repeat(
48 | layout,
49 | "b b1 -> (b f) (b1 f1)",
50 | f=out_features // butterfly_size,
51 | f1=in_features // butterfly_size,
52 | )
53 | # Convert from (out_features, in_features) mask to
54 | # (out_features // block_size, in_features // block_size) mask
55 | layout = einops.rearrange(
56 | layout,
57 | "(p blksz) (r blksz1) -> p r (blksz blksz1)",
58 | blksz=block_size,
59 | blksz1=block_size,
60 | )
61 |
62 | layout = (layout > 0).any(dim=-1) # [out_features // block_size, in_features // block_size]
63 | if not stretch:
64 | out_blocks, in_blocks = layout.shape
65 | if out_blocks > in_blocks:
66 | ratio = out_blocks // in_blocks
67 | layout = layout.view(out_blocks // ratio, ratio, in_blocks).permute(1, 0, 2).reshape_as(layout)
68 | elif out_blocks < in_blocks:
69 | ratio = in_blocks // out_blocks
70 | layout = layout.view(out_blocks, in_blocks // ratio, ratio).permute(0, 2, 1).reshape_as(layout)
71 |
72 | # convert boolean layout to indices for F.embedding_bag
73 | num_output_blocks = out_features // block_size
74 | num_input_blocks = in_features // block_size
75 | active_blocks_per_output = layout.sum(1).unique()
76 | assert len(active_blocks_per_output) == 1, "butterfly layout must have the same number of blocks per row"
77 | active_blocks_per_output = active_blocks_per_output.item()
78 |
79 | active_blocks_per_input = layout.sum(0).unique()
80 | assert len(active_blocks_per_input) == 1, "butterfly layout must have the same number of blocks per row"
81 | active_blocks_per_input = active_blocks_per_input.item()
82 |
83 | # which input blocks should be added for i-th output
84 | input_block_index = layout.nonzero()[:, 1].view(num_output_blocks, active_blocks_per_output)
85 | # which output blocks does j-th input contribute to
86 | output_block_index = layout.t().nonzero()[:, 1].view(num_input_blocks, active_blocks_per_input)
87 |
88 | # which of the active blocks from the corresponding input_block should be used for i-th output
89 | active_block_index = torch.where(
90 | torch.eq(
91 | output_block_index[input_block_index],
92 | torch.arange(len(input_block_index))[:, None, None],
93 | )
94 | )[-1].view(input_block_index.shape)
95 |
96 | forward_indices = input_block_index * active_blocks_per_input + active_block_index
97 | backward_indices = output_block_index
98 | return forward_indices.to(torch.int32), backward_indices.to(torch.int64) # dtypes tuned for max throughput
99 |
100 |
101 | def butterfly_factor_to_matrix(twiddle: torch.Tensor, factor_index: int) -> torch.Tensor:
102 | """
103 | Let b be the base (most commonly 2).
104 | Parameters:
105 | twiddle: (n // b, b, b)
106 | factor_index: an int from 0 to log_b(n) - 1
107 | """
108 | n_div_b, b, _ = twiddle.shape
109 | n = b * n_div_b
110 | log_b_n = int(math.log(n) / math.log(b))
111 | assert n == b ** log_b_n, f"n must be a power of {b}"
112 | assert twiddle.shape == (n // b, b, b)
113 | assert 0 <= factor_index <= log_b_n
114 | stride = b ** factor_index
115 | x = einops.rearrange(torch.eye(n), "bs (diagblk j stride) -> bs diagblk j stride", stride=stride, j=b)
116 | t = einops.rearrange(twiddle, "(diagblk stride) i j -> diagblk stride i j", stride=stride)
117 | out = torch.einsum("d s i j, b d j s -> b d i s", t, x)
118 | out = einops.rearrange(out, "b diagblk i stride -> b (diagblk i stride)")
119 | return out.t() # Transpose because we assume the 1st dimension of x is the batch dimension
120 |
121 |
122 | @maybe_script
123 | def butterfly_matmul(input: torch.Tensor, weight: torch.Tensor, forward_indices: torch.Tensor) -> torch.Tensor:
124 | """
125 | :param input: tensor [*batch_dims, in_features]
126 | :param weight: tensor [in_features, active_blocks_per_input, block_size]
127 | :param forward_indices: the first output of get_butterfly_indices(...)
128 | :returns: tensor [*batch_dims, out_features]
129 | """
130 | assert input.shape[-1] == weight.shape[0]
131 | in_features, active_blocks_per_input, block_size = weight.shape
132 | num_input_blocks = in_features // block_size
133 | batch_dims = input.shape[:-1]
134 | input = input.flatten(0, -2)
135 |
136 | input_permuted = input.t().view(input.shape[1] // block_size, block_size, input.shape[0])
137 | output_blocks = torch.matmul(weight.view(num_input_blocks, -1, block_size), input_permuted)
138 | # ^-- shape: [num_input_blocks, (active_blocks_per_input * block_size), flat_batch_dims]
139 |
140 | blocks_for_indexing = output_blocks.view(num_input_blocks * active_blocks_per_input, block_size * input.shape[0])
141 | # ^-- shape: [(num_input_blocks * active_blocks_per_input), (block_size, flat_batch_dims)]
142 |
143 | aggregated_blocks = F.embedding_bag(forward_indices, blocks_for_indexing, mode="sum")
144 | # ^-- shape: [num_ouput_blocks, (block_size, flat_batch_dims)]
145 |
146 | outputs = aggregated_blocks.view(-1, input.shape[0]).t()
147 | # ^-- shape: [flat_batch_dims, (num_output_blocks * block_size)] aka [flat_batch_dims, out_features]
148 | return outputs.view(batch_dims + outputs.shape[-1:])
149 |
150 |
151 | @maybe_script
152 | def butterfly_matmul_backward(
153 | grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, backward_indices: torch.Tensor,
154 | input_requires_grad: bool = True, weight_requires_grad: bool = True):
155 | """Compute gradients of butterfly_matmul w.r.t. input and/or weight without relying on pytorch autograd"""
156 | assert input_requires_grad or weight_requires_grad, "computing backward but none of the inputs requires grad"
157 | grad_input = grad_weight = torch.empty(0)
158 | out_features = grad_output.shape[-1]
159 | in_features, active_blocks_per_input, block_size = weight.shape
160 | num_input_blocks = input.shape[-1] // block_size
161 | num_output_blocks = out_features // block_size
162 | grad_output_flat = grad_output.flatten(0, -2)
163 | input_flat = input.flatten(0, -2)
164 |
165 | flat_batch_dims = grad_output_flat.shape[0]
166 |
167 | grad_aggregated_blocks = grad_output_flat.t().reshape(num_output_blocks, (block_size * flat_batch_dims))
168 | # [num_output_blocks, (block_size, flat_batch_dims)]
169 |
170 | grad_blocks_for_indexing = F.embedding(backward_indices, grad_aggregated_blocks).flatten(0, -2)
171 | # ^-- shape: [(num_input_blocks * active_blocks_per_input), (block_size, flat_batch_dims)]
172 |
173 | grad_output_blocks = grad_blocks_for_indexing.view(
174 | num_input_blocks, active_blocks_per_input * block_size, flat_batch_dims
175 | )
176 | # ^-- shape: [num_input_blocks, (active_blocks_per_input * block_size), flat_batch_dims]
177 |
178 | if input_requires_grad:
179 | grad_input_permuted = torch.matmul(
180 | weight.view(num_input_blocks, -1, block_size).permute(0, 2, 1), grad_output_blocks
181 | )
182 | grad_input = grad_input_permuted.flatten(0, -2).t().view(grad_output.shape[:-1] + input.shape[-1:])
183 |
184 | if weight_requires_grad:
185 | grad_weight = torch.matmul(
186 | grad_output_blocks, input_flat.t().view(num_input_blocks, block_size, flat_batch_dims).permute(0, 2, 1)
187 | ).view_as(weight)
188 |
189 | return grad_input, grad_weight
190 |
--------------------------------------------------------------------------------
/finetune_mlm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from argparse import ArgumentParser
4 | from functools import partial
5 | from pathlib import Path
6 | from shutil import rmtree
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | import transformers.utils.logging
12 | import wandb
13 | from datasets import load_dataset, Dataset, DatasetDict
14 | from transformers import RobertaTokenizerFast, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, \
15 | TrainingArguments, AlbertTokenizerFast, AutoModel
16 | from transformers.trainer_utils import set_seed
17 |
18 | from data import TASK_TO_CONFIG, TASK_TO_NAME
19 | from field_collator import FieldDataCollatorWithPadding
20 | from models import SpanClassificationModel, EntityChoiceModel
21 | from src import LeanAlbertConfig, LeanAlbertForSequenceClassification, LeanAlbertForPreTraining
22 |
23 |
24 | class NumpyEncoder(json.JSONEncoder):
25 | def default(self, obj):
26 | if isinstance(obj, np.integer):
27 | return int(obj)
28 | if isinstance(obj, np.floating):
29 | return float(obj)
30 | if isinstance(obj, np.ndarray):
31 | return obj.tolist()
32 | return super().default(obj)
33 |
34 |
35 | MODEL_TO_HUB_NAME = {
36 | 'ruroberta-large': "sberbank-ai/ruRoberta-large",
37 | }
38 |
39 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
40 |
41 | N_EPOCHS = 40
42 | LEARNING_RATE = 1e-5
43 | MAX_LENGTH = 512
44 |
45 |
46 | def main(task, model_name, checkpoint_path, data_dir, batch_size, grad_acc_steps, dropout, weight_decay, num_seeds):
47 | if checkpoint_path is not None:
48 | tokenizer = AlbertTokenizerFast.from_pretrained('tokenizer')
49 | assert model_name is None
50 | model_name = 'lean_albert'
51 | else:
52 | tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_TO_HUB_NAME[model_name],
53 | cache_dir=data_dir / 'transformers_cache')
54 |
55 | if task == "russe":
56 | data_collator = FieldDataCollatorWithPadding(tokenizer, fields_to_pad=(("e1_mask", 0, 0), ("e2_mask", 0, 0)),
57 | pad_to_multiple_of=8)
58 | elif task == "rucos":
59 | data_collator = FieldDataCollatorWithPadding(tokenizer=tokenizer,
60 | fields_to_pad=(
61 | ("entity_mask", 0, 1),
62 | ("labels", -1, None)
63 | ))
64 | else:
65 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
66 |
67 | if task == "rucos":
68 | # we're using a custom dataset, because the HF hub version has no information about entity indices
69 | train = Dataset.from_json(["RuCoS/train.jsonl"])
70 | val = Dataset.from_json(["RuCoS/val.jsonl"])
71 | test = Dataset.from_json(["RuCoS/test.jsonl"])
72 | dataset = DatasetDict(train=train, validation=val, test=test)
73 | elif task == "rucola":
74 | train_df, in_domain_dev_df, out_of_domain_dev_df, test_df = map(
75 | pd.read_csv, ("RuCoLA/in_domain_train.csv", "RuCoLA/in_domain_dev.csv", "RuCoLA/out_of_domain_dev.csv",
76 | "RuCoLA/test.csv")
77 | )
78 |
79 | # concatenate datasets to get aggregate metrics
80 | dev_df = pd.concat((in_domain_dev_df, out_of_domain_dev_df))
81 | train, dev, test = map(Dataset.from_pandas, (train_df, dev_df, test_df))
82 | dataset = DatasetDict(train=train, validation=dev, test=test)
83 | else:
84 | dataset = load_dataset("russian_super_glue", task)
85 |
86 | config = TASK_TO_CONFIG[task](dataset)
87 |
88 | processed_dataset = dataset.map(partial(config.process_data, tokenizer=tokenizer, max_length=MAX_LENGTH),
89 | num_proc=32, keep_in_memory=True, batched=True)
90 |
91 | if "labels" in processed_dataset["test"].column_names:
92 | test_without_labels = processed_dataset['test'].remove_columns(['labels'])
93 | else:
94 | test_without_labels = processed_dataset["test"]
95 |
96 | transformers.utils.logging.enable_progress_bar()
97 |
98 | model_prefix = f"{model_name}_" \
99 | f"{task}_" \
100 | f"dr{dropout}_" \
101 | f"wd{weight_decay}_" \
102 | f"bs{batch_size * grad_acc_steps}"
103 |
104 | dev_metrics_per_run = []
105 | predictions_per_run = []
106 |
107 | for seed in range(num_seeds):
108 | set_seed(seed)
109 |
110 | if checkpoint_path is not None:
111 | model_config = LeanAlbertConfig.from_pretrained('config.json')
112 | model_config.num_labels = config.num_classes
113 | model_config.classifier_dropout_prob = dropout
114 |
115 | if task in ("russe", "rucos"):
116 | model = LeanAlbertForPreTraining(model_config)
117 | else:
118 | model = LeanAlbertForSequenceClassification(model_config)
119 |
120 | model.resize_token_embeddings(len(tokenizer))
121 | checkpoint = torch.load(checkpoint_path, map_location='cpu')['model']
122 | incompat_keys = model.load_state_dict(checkpoint, strict=False)
123 | print("missing", incompat_keys.missing_keys)
124 | print("unexpected", incompat_keys.unexpected_keys)
125 |
126 | if task in ("russe", "rucos"):
127 | model = model.albert
128 | else:
129 | if task in ("russe", "rucos"):
130 | model = AutoModel.from_pretrained(MODEL_TO_HUB_NAME[model_name],
131 | attention_probs_dropout_prob=dropout,
132 | hidden_dropout_prob=dropout,
133 | cache_dir=data_dir / 'transformers_cache')
134 | else:
135 | model = AutoModelForSequenceClassification.from_pretrained(MODEL_TO_HUB_NAME[model_name],
136 | num_labels=config.num_classes,
137 | attention_probs_dropout_prob=dropout,
138 | hidden_dropout_prob=dropout,
139 | cache_dir=data_dir / 'transformers_cache')
140 |
141 | if task == "russe":
142 | model = SpanClassificationModel(model, num_labels=config.num_classes)
143 | elif task == "rucos":
144 | model = EntityChoiceModel(model)
145 |
146 | run_base_dir = f"{model_prefix}_{seed}"
147 |
148 | run = wandb.init(project='brbert', name=run_base_dir)
149 | run.config.update({"task": task, "model": model_name, "checkpoint": str(checkpoint_path)})
150 |
151 | training_args = TrainingArguments(
152 | output_dir=data_dir / 'checkpoints' / run_base_dir, overwrite_output_dir=True,
153 | evaluation_strategy='epoch', logging_strategy='epoch', logging_first_step=True,
154 | per_device_train_batch_size=batch_size,
155 | per_device_eval_batch_size=batch_size, gradient_accumulation_steps=grad_acc_steps,
156 | optim="adamw_torch", learning_rate=LEARNING_RATE, weight_decay=weight_decay,
157 | num_train_epochs=N_EPOCHS, warmup_ratio=0.1, save_strategy='epoch',
158 | seed=seed, fp16=True, dataloader_num_workers=4, group_by_length=True,
159 | report_to='wandb', run_name=run_base_dir, save_total_limit=1,
160 | load_best_model_at_end=True, metric_for_best_model=config.best_metric
161 | )
162 |
163 | trainer = Trainer(
164 | model=model,
165 | args=training_args,
166 | train_dataset=processed_dataset['train'],
167 | eval_dataset=processed_dataset['validation'],
168 | compute_metrics=partial(config.compute_metrics, split="validation",
169 | processed_dataset=processed_dataset["validation"]),
170 | tokenizer=tokenizer,
171 | data_collator=data_collator,
172 | )
173 |
174 | train_result = trainer.train()
175 | print(run_base_dir)
176 | print('train', train_result.metrics)
177 |
178 | dev_predictions = trainer.predict(test_dataset=processed_dataset['validation'])
179 | print('dev', dev_predictions.metrics)
180 |
181 | run.summary.update(dev_predictions.metrics)
182 | wandb.finish()
183 |
184 | dev_metrics_per_run.append(dev_predictions.metrics[f"test_{config.best_metric}"])
185 |
186 | predictions = trainer.predict(test_dataset=test_without_labels)
187 | predictions_per_run.append(predictions.predictions)
188 |
189 | if task != "terra":
190 | rmtree(data_dir / 'checkpoints' / run_base_dir)
191 |
192 | best_run = np.argmax(dev_metrics_per_run)
193 | best_predictions = predictions_per_run[best_run]
194 | processed_predictions = config.process_predictions(best_predictions, split="test",
195 | processed_dataset=processed_dataset["test"])
196 |
197 | prefix_without_task = model_prefix.replace(f"{task}_", "")
198 |
199 | os.makedirs(f"preds/{prefix_without_task}", exist_ok=True)
200 |
201 | if task == "rucola":
202 | result_df = pd.DataFrame.from_records(processed_predictions, index="id")
203 | result_df.to_csv(f"preds/{prefix_without_task}/{TASK_TO_NAME[task]}.csv")
204 | else:
205 | with open(f"preds/{prefix_without_task}/{TASK_TO_NAME[task]}.jsonl", 'w+') as outf:
206 | for prediction in processed_predictions:
207 | print(json.dumps(prediction, ensure_ascii=True, cls=NumpyEncoder), file=outf)
208 |
209 |
210 | if __name__ == '__main__':
211 | parser = ArgumentParser()
212 | parser.add_argument("-t", '--task', choices=TASK_TO_CONFIG.keys())
213 | parser.add_argument("-m", '--model-name', choices=MODEL_TO_HUB_NAME.keys())
214 | parser.add_argument("-c", '--checkpoint', type=Path)
215 | parser.add_argument("-d", "--data-dir", type=Path)
216 | parser.add_argument("--batch-size", required=True, type=int)
217 | parser.add_argument("--grad-acc-steps", required=True, type=int)
218 | parser.add_argument("--dropout", required=True, type=float)
219 | parser.add_argument("--weight-decay", required=True, type=float)
220 | parser.add_argument("--num-seeds", required=True, type=int)
221 | args = parser.parse_args()
222 | main(args.task, args.model_name, args.checkpoint, args.data_dir, args.batch_size, args.grad_acc_steps,
223 | args.dropout, args.weight_decay, args.num_seeds)
224 |
--------------------------------------------------------------------------------
/src/models/lean_albert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch ALBERT modules that do not hog your GPU memory """
16 | from functools import lru_cache
17 |
18 | import torch
19 | import torch.nn as nn
20 | from transformers.file_utils import add_start_docstrings
21 | from transformers.modeling_outputs import BaseModelOutput
22 | from transformers.modeling_utils import PreTrainedModel
23 | from transformers.models.albert import AlbertConfig
24 | from transformers.models.albert.modeling_albert import (
25 | ACT2FN,
26 | ALBERT_START_DOCSTRING,
27 | AlbertForPreTraining,
28 | AlbertLayerGroup,
29 | AlbertMLMHead,
30 | AlbertModel,
31 | AlbertSOPHead,
32 | AlbertTransformer
33 | )
34 | from transformers.utils import logging
35 |
36 | from src.models.config import LeanAlbertConfig
37 | from src.modules.attn import LeanSelfAttention, RotaryAttentionCore
38 | from src.modules.ffn import LeanFFN
39 | from src.modules.rotary import RotaryEmbeddings
40 |
41 | logger = logging.get_logger(__name__)
42 |
43 | _CONFIG_FOR_DOC = "LeanAlbertConfig"
44 | _TOKENIZER_FOR_DOC = "AlbertTokenizer"
45 |
46 |
47 | def get_input_embedding(config: LeanAlbertConfig):
48 | if config.position_embedding_type == "absolute":
49 | return nn.Embedding(config.max_position_embeddings, config.embedding_size)
50 | elif config.position_embedding_type == "rotary":
51 | return None
52 | else:
53 | raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding}")
54 |
55 |
56 | @lru_cache()
57 | def get_attention_core(config: LeanAlbertConfig):
58 | if config.position_embedding_type == "absolute":
59 | return None
60 | elif config.position_embedding_type == "rotary":
61 | rotary_emb = RotaryEmbeddings(config.hidden_size // config.num_attention_heads, config.rotary_embedding_base)
62 | return RotaryAttentionCore(
63 | config.hidden_size,
64 | config.num_attention_heads,
65 | rotary_emb,
66 | attention_probs_dropout=config.attention_probs_dropout_prob,
67 | )
68 | else:
69 | raise NotImplementedError(f"Unsupported embedding type: {config.position_embedding_type}")
70 |
71 |
72 | class LeanAlbertEmbeddings(nn.Module):
73 | """
74 | Construct the embeddings from word, position and token_type embeddings.
75 | """
76 |
77 | def __init__(self, config):
78 | super().__init__()
79 | self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
80 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
81 | self.position_embeddings = get_input_embedding(config)
82 |
83 | self.layernorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
84 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
85 |
86 | if self.position_embeddings is not None:
87 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
88 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
89 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
90 |
91 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
92 | def forward(
93 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
94 | ):
95 | if input_ids is not None:
96 | input_shape = input_ids.size()
97 | else:
98 | input_shape = inputs_embeds.size()[:-1]
99 |
100 | seq_length = input_shape[1]
101 |
102 | if token_type_ids is None:
103 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
104 |
105 | if inputs_embeds is None:
106 | inputs_embeds = self.word_embeddings(input_ids)
107 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
108 |
109 | embeddings = inputs_embeds + token_type_embeddings
110 |
111 | if self.position_embeddings is not None:
112 | if position_ids is None:
113 | position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
114 | position_embeddings = self.position_embeddings(position_ids)
115 | embeddings += position_embeddings
116 |
117 | embeddings = self.layernorm(embeddings)
118 | embeddings = self.dropout(embeddings)
119 | return embeddings
120 |
121 |
122 | class LeanAlbertLayer(nn.Module):
123 | def __init__(self, config: LeanAlbertConfig):
124 | super().__init__()
125 |
126 | self.config = config
127 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
128 | self.seq_len_dim = 1
129 | self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
130 |
131 | self.attention = LeanSelfAttention(
132 | config.hidden_size,
133 | config.num_attention_heads,
134 | attention_core=get_attention_core(config),
135 | hidden_dropout_prob=config.hidden_dropout_prob,
136 | layer_norm_eps=config.layer_norm_eps,
137 | )
138 |
139 | self.ffn = LeanFFN(
140 | config.hidden_size,
141 | config.intermediate_size,
142 | activation=ACT2FN[config.hidden_act],
143 | gated=config.hidden_act_gated,
144 | layer_norm_eps=config.layer_norm_eps,
145 | dropout=config.hidden_dropout_prob,
146 | )
147 |
148 | def forward(self, hidden_states, attention_mask=None, output_attentions=False):
149 | attention_output, *extras = self.attention(hidden_states, attention_mask, output_attentions)
150 | ffn_output = self.ffn(attention_output)
151 | return (ffn_output, attention_output, *extras)
152 |
153 |
154 | class LeanAlbertLayerGroup(AlbertLayerGroup):
155 | def __init__(self, config):
156 | nn.Module.__init__(self)
157 | self.albert_layers = nn.ModuleList([LeanAlbertLayer(config) for _ in range(config.inner_group_num)])
158 |
159 | def forward(
160 | self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
161 | ):
162 | if any(head_mask):
163 | raise NotImplementedError(f"head mask was provided, but it is not supported")
164 |
165 | layer_hidden_states = ()
166 | layer_attentions = ()
167 |
168 | for layer_index, albert_layer in enumerate(self.albert_layers):
169 | layer_output = albert_layer(hidden_states, attention_mask, output_attentions)
170 | hidden_states = layer_output[0]
171 |
172 | if output_attentions:
173 | layer_attentions = layer_attentions + (layer_output[1],)
174 |
175 | if output_hidden_states:
176 | layer_hidden_states = layer_hidden_states + (hidden_states,)
177 |
178 | outputs = (hidden_states,)
179 | if output_hidden_states:
180 | outputs = outputs + (layer_hidden_states,)
181 | if output_attentions:
182 | outputs = outputs + (layer_attentions,)
183 | return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
184 |
185 |
186 | class LeanAlbertTransformer(AlbertTransformer):
187 | def __init__(self, config):
188 | nn.Module.__init__(self)
189 | self.config = config
190 | self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
191 | self.albert_layer_groups = nn.ModuleList(
192 | [LeanAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]
193 | )
194 | self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
195 |
196 | def forward(
197 | self,
198 | hidden_states,
199 | attention_mask=None,
200 | head_mask=None,
201 | output_attentions=False,
202 | output_hidden_states=False,
203 | return_dict=True,
204 | ):
205 | hidden_states = self.embedding_hidden_mapping_in(hidden_states)
206 |
207 | all_hidden_states = (hidden_states,) if output_hidden_states else None
208 | all_attentions = () if output_attentions else None
209 |
210 | for i in range(self.config.num_hidden_layers):
211 | # Number of layers in a hidden group
212 | layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
213 |
214 | # Index of the hidden group
215 | group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
216 |
217 | layer_group_output = self.albert_layer_groups[group_idx](
218 | hidden_states,
219 | attention_mask,
220 | head_mask[group_idx * layers_per_group: (group_idx + 1) * layers_per_group],
221 | output_attentions,
222 | output_hidden_states,
223 | )
224 | hidden_states = layer_group_output[0]
225 |
226 | if output_attentions:
227 | all_attentions = all_attentions + layer_group_output[-1]
228 |
229 | if output_hidden_states:
230 | all_hidden_states = all_hidden_states + (hidden_states,)
231 |
232 | if not return_dict:
233 | return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
234 |
235 | return BaseModelOutput(
236 | last_hidden_state=self.post_layer_norm(hidden_states),
237 | hidden_states=all_hidden_states,
238 | attentions=all_attentions,
239 | )
240 |
241 |
242 | @add_start_docstrings(
243 | "The bare LeanALBERT Model transformer outputting raw hidden-states without any specific head on top.",
244 | ALBERT_START_DOCSTRING,
245 | )
246 | class LeanAlbertModel(AlbertModel):
247 | config_class = LeanAlbertConfig
248 |
249 | def __init__(self, config: AlbertConfig, add_pooling_layer=True):
250 | PreTrainedModel.__init__(self, config)
251 |
252 | self.config = config
253 | self.embeddings = LeanAlbertEmbeddings(config)
254 | self.encoder = LeanAlbertTransformer(config)
255 |
256 | if add_pooling_layer:
257 | self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
258 | self.pooler_activation = nn.Tanh()
259 | else:
260 | self.pooler = None
261 | self.pooler_activation = None
262 |
263 | self.init_weights()
264 |
265 |
266 | class LeanAlbertForPreTraining(AlbertForPreTraining, PreTrainedModel):
267 | config_class = LeanAlbertConfig
268 | base_model_prefix = "albert"
269 |
270 | def __init__(self, config: AlbertConfig):
271 | PreTrainedModel.__init__(self, config)
272 |
273 | self.albert = LeanAlbertModel(config)
274 | self.predictions = AlbertMLMHead(config)
275 | self.sop_classifier = AlbertSOPHead(config)
276 |
277 | self.init_weights()
278 |
--------------------------------------------------------------------------------
/src/modules/linear.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements weight matrix sharing for linear layer: full sharing and sharing with adapters
3 | """
4 | import math
5 | from itertools import zip_longest
6 | from typing import Optional, Tuple, List
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.cuda.amp import custom_bwd, custom_fwd
12 |
13 | from src.modules.pixelfly import butterfly_matmul, butterfly_matmul_backward, get_butterfly_indices
14 | from src.modules.functional import maybe_script
15 |
16 |
17 | class GeneralizedMatrix(nn.Module):
18 | """A module that stores a shared pytorch tensor for use in GeneralizedLinear layers"""
19 |
20 | def __init__(self, in_features: int, out_features: int, block_size: int = 0, lowrank_dim: int = 0):
21 | super().__init__()
22 | self.out_features, self.in_features = out_features, in_features
23 |
24 | if block_size == 0:
25 | # fully-connected weight matrix
26 | self.weight = nn.Parameter(torch.empty(out_features, in_features))
27 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
28 | # note: this is usually overwritten by the model-wide initialization
29 | self.forward_indices = self.backward_indices = None
30 | else:
31 | # block-sparse weights with additive butterfly pattern
32 | forward_indices, backward_indices = get_butterfly_indices(
33 | out_features, in_features, block_size, stretch=False
34 | )
35 | self.register_buffer("forward_indices", forward_indices)
36 | self.register_buffer("backward_indices", backward_indices)
37 | active_blocks_per_input = self.forward_indices.numel() // (in_features // block_size)
38 | self.weight = nn.Parameter(torch.empty(in_features, active_blocks_per_input, block_size))
39 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
40 | # note: this is usually overwritten by the model-wide init
41 |
42 | if lowrank_dim:
43 | self.lowrank_first = nn.Parameter(torch.zeros(lowrank_dim, self.in_features))
44 | self.lowrank_second = nn.Parameter(torch.zeros(self.out_features, lowrank_dim))
45 | nn.init.normal_(self.lowrank_first, std=math.sqrt(2.0 / (5 * min(out_features, in_features))))
46 | nn.init.normal_(self.lowrank_second, std=math.sqrt(2.0 / (5 * min(out_features, in_features))))
47 | else:
48 | self.lowrank_first = self.lowrank_second = None
49 |
50 | @property
51 | def shape(self):
52 | return (self.out_features, self.in_features)
53 |
54 | def __repr__(self):
55 | return f"{self.__class__.__name__}{tuple(self.shape)}"
56 |
57 | def forward(self, input: torch.Tensor, *, ignore_lowrank: bool = False):
58 | """
59 | Multiply input tensor by this matrix with the same semantics as in torch.nn.Linear(..., bias=False)
60 |
61 | :param ignore_lowrank: if True, the low-rank components (lowrank_dim) will not be used in matrix multiplication
62 | """
63 | if self.forward_indices is not None:
64 | output = butterfly_matmul(input, self.weight, self.forward_indices)
65 | else:
66 | output = F.linear(input, self.weight)
67 | if self.lowrank_first is not None and not ignore_lowrank:
68 | output = F.linear(F.linear(input, self.lowrank_first), self.lowrank_second, output)
69 | return output
70 |
71 |
72 | class GeneralizedLinear(nn.Linear):
73 | """A linear layer with a shared full-rank matrix and an individual low-rank adapter"""
74 |
75 | def __init__(self, shared_matrix: GeneralizedMatrix, adapter_dim: int = 0, bias: bool = True):
76 | nn.Module.__init__(self)
77 | self.shared_matrix = shared_matrix
78 | self.out_features, self.in_features = self.shared_matrix.shape
79 | self.bias = nn.Parameter(torch.zeros(self.out_features)) if bias else None
80 |
81 | if adapter_dim != 0:
82 | self.adapter_first = nn.Parameter(torch.zeros(adapter_dim, self.in_features))
83 | self.adapter_second = nn.Parameter(torch.zeros(self.out_features, adapter_dim))
84 |
85 | # initialize in accordance with https://arxiv.org/pdf/2106.09685.pdf
86 | nn.init.xavier_normal_(self.adapter_first)
87 | nn.init.zeros_(self.adapter_second)
88 | else:
89 | self.adapter_first = self.adapter_second = None
90 |
91 | def get_combined_lowrank_components(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
92 | """Group together low-rank matrix components from this layer's adapter and GeneralizedMatrix for faster matmul"""
93 | if self.adapter_first is not None and self.shared_matrix.lowrank_first is None:
94 | return self.adapter_first, self.adapter_second
95 | elif self.adapter_first is None and self.shared_matrix.lowrank_first is not None:
96 | return self.shared_matrix.lowrank_first, self.shared_matrix.lowrank_second
97 | elif self.adapter_first is not None and self.shared_matrix.lowrank_first is not None:
98 | combined_first = torch.cat([self.shared_matrix.lowrank_first, self.adapter_first], dim=0)
99 | # ^-- cat0[(lowrank_dim x input_dim), (adapter_dim, input_dim)] -> (combined_dim, input_dim)
100 | combined_second = torch.cat([self.shared_matrix.lowrank_second, self.adapter_second], dim=1)
101 | # ^-- cat1[(output_dim x lowrank_dim), (output_dim, adapter_dim)] -> (combined_dim, input_dim)
102 | return combined_first, combined_second
103 | else:
104 | assert self.adapter_first is None and self.adapter_second is None
105 | assert self.shared_matrix.lowrank_first is None and self.shared_matrix.lowrank_second is None
106 | return None, None
107 |
108 | @property
109 | def weight(self):
110 | return self.shared_matrix.weight
111 |
112 | def forward(self, input: torch.Tensor) -> torch.Tensor:
113 | return _GeneralizedLinear.apply(
114 | input, self.weight, self.bias, *self.get_combined_lowrank_components(),
115 | self.shared_matrix.forward_indices, self.shared_matrix.backward_indices)
116 |
117 |
118 | class _GeneralizedLinear(torch.autograd.Function):
119 | @staticmethod
120 | @custom_fwd
121 | def forward(ctx, *args):
122 | output, *tensors_to_save = _GeneralizedLinear._forward_impl(*args)
123 | ctx.save_for_backward(*tensors_to_save)
124 | return output
125 |
126 | @staticmethod
127 | @maybe_script
128 | def _forward_impl(
129 | input: torch.Tensor,
130 | main_weight: torch.Tensor,
131 | bias: Optional[torch.Tensor],
132 | lowrank_first: Optional[torch.Tensor],
133 | lowrank_second: Optional[torch.Tensor],
134 | forward_indices: Optional[torch.Tensor],
135 | backward_indices: Optional[torch.Tensor],
136 | ):
137 | input_flat = input.view(-1, input.shape[-1])
138 | if forward_indices is not None:
139 | output = butterfly_matmul(input_flat, main_weight, forward_indices)
140 | if bias is not None:
141 | output.add_(bias.to(output.dtype))
142 | else:
143 | output = F.linear(input_flat, main_weight, bias)
144 |
145 | if lowrank_first is not None and lowrank_second is not None:
146 | lowrank_hid = F.linear(input_flat, lowrank_first)
147 | if "xla" in output.device.type: # xla does not support in-place ops
148 | output = torch.addmm(output, lowrank_hid, lowrank_second.t().to(output.dtype))
149 | else:
150 | output = torch.addmm(output, lowrank_hid, lowrank_second.t().to(output.dtype), out=output)
151 | else:
152 | lowrank_hid = None
153 | output = output.view(input.shape[:-1] + output.shape[-1:])
154 | return output, input, lowrank_hid, main_weight, lowrank_first, lowrank_second, backward_indices
155 |
156 | @staticmethod
157 | @custom_bwd
158 | def backward(ctx, grad_output: torch.Tensor):
159 | grads = _GeneralizedLinear._backward_impl(
160 | grad_output, *ctx.saved_tensors, needs_input_grad=ctx.needs_input_grad
161 | )
162 | return tuple(grad if needed else None for grad, needed in zip_longest(grads, ctx.needs_input_grad))
163 |
164 | @staticmethod
165 | @maybe_script
166 | def _backward_impl(grad_output: torch.Tensor,
167 | input: torch.Tensor,
168 | lowrank_hid: Optional[torch.Tensor],
169 | main_weight: torch.Tensor,
170 | lowrank_first: Optional[torch.Tensor],
171 | lowrank_second: Optional[torch.Tensor],
172 | backward_indices: Optional[torch.Tensor],
173 | needs_input_grad: List[bool]):
174 | grad_input = grad_input_flat = grad_main_weight = grad_lowrank_first = grad_lowrank_second = grad_bias \
175 | = grad_output_flat_transposed = grad_lowrank_hid_flat = lowrank_hid_flat = torch.empty(0)
176 | input_flat = input.flatten(0, -2) # [etc, in_features]
177 | grad_output_flat = grad_output.flatten(0, -2) # [etc, out_features]
178 |
179 | if lowrank_hid is not None:
180 | lowrank_hid_flat = lowrank_hid.flatten(0, -2) # [etc, lowrank_dim]
181 | if lowrank_first is not None and (needs_input_grad[0] or needs_input_grad[3]):
182 | assert lowrank_second is not None
183 | grad_lowrank_hid_flat = torch.matmul(grad_output_flat, lowrank_second) # [etc, lowrank_dim]
184 | if needs_input_grad[1] or needs_input_grad[4]:
185 | grad_output_flat_transposed = grad_output_flat.t() # [out_features, etc]
186 |
187 | if needs_input_grad[4]:
188 | assert lowrank_second is not None
189 | grad_lowrank_second = torch.matmul(grad_output_flat_transposed, lowrank_hid_flat)
190 | # ^-- [out_features, lowrank_dim]
191 | if needs_input_grad[3]:
192 | grad_lowrank_hid_flat_transposed = grad_lowrank_hid_flat.t() # [lowrank_dim, etc]
193 | grad_lowrank_first = torch.matmul(grad_lowrank_hid_flat_transposed, input_flat)
194 | # ^-- [lowrank_dim, in_features]
195 | if needs_input_grad[2]:
196 | grad_bias = grad_output_flat.sum(dim=0) # [out_features]
197 | if backward_indices is None:
198 | # dense shared matrix
199 | if needs_input_grad[1]:
200 | grad_main_weight = torch.matmul(grad_output_flat_transposed, input_flat)
201 | # ^-- [out_features, in_features]
202 | if needs_input_grad[0]:
203 | grad_input_flat = torch.matmul(grad_output_flat, main_weight)
204 | else:
205 | # block-sparse shared matrix
206 | grad_input_flat, grad_main_weight = butterfly_matmul_backward(
207 | grad_output_flat, input_flat, main_weight, backward_indices,
208 | input_requires_grad=needs_input_grad[0], weight_requires_grad=needs_input_grad[1])
209 |
210 | if needs_input_grad[0] and lowrank_first is not None:
211 | # grad w.r.t. input through low-rank components
212 | if 'xla' not in grad_output.device.type:
213 | grad_input_flat = grad_input_flat.addmm_(
214 | grad_lowrank_hid_flat.to(grad_output_flat.dtype),
215 | lowrank_first.to(grad_output_flat.dtype)
216 | )
217 | else:
218 | grad_input_flat = torch.addmm(grad_input_flat, grad_lowrank_hid_flat, lowrank_first)
219 | if needs_input_grad[0]:
220 | grad_input = grad_input_flat.view_as(input)
221 | return grad_input, grad_main_weight, grad_bias, grad_lowrank_first, grad_lowrank_second, None, None
222 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 Yandex Research
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/src/training/lamb_8bit.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is a joint work with Tim Dettmers, based on his library https://github.com/facebookresearch/bitsandbytes ;
3 | Unlike the rest of bnb optimizers, CPULAMB8Bit is tuned to further reduce memory footprint at the cost of performance.
4 | The intended use-case of CPULamb8Bit is to run in background on CPU while training with large batches.
5 | """
6 | import math
7 | from typing import Any, Dict, Optional
8 |
9 | import torch
10 | from bitsandbytes.functional import dequantize_blockwise, quantize_blockwise
11 | from bitsandbytes.optim.optimizer import Optimizer2State
12 | from torch_optimizer.types import Betas2, Params
13 |
14 | __all__ = ("CPULAMB8Bit",)
15 |
16 | from hivemind.utils.logging import get_logger, use_hivemind_log_handler
17 |
18 | use_hivemind_log_handler("in_root_logger")
19 | logger = get_logger(__name__)
20 |
21 | class CPULAMB8Bit(Optimizer2State):
22 | r"""
23 | Implements Lamb with quantized 8-bit statistics. The statistics are stored in host memory in the quantized form.
24 | The LAMB optimizer and block-wise quantization are described in the following papers:
25 | - LAMB: "Large Batch Optimization for Deep Learning: Training BERT in 76 minutes" https://arxiv.org/abs/1904.00962
26 | - Quantization: "8-bit Optimizers via Block-wise Quantization" https://arxiv.org/abs/2110.02861
27 | This specific implementation of LAMB is based on https://github.com/cybertronai/pytorch-lamb
28 | - bias correction defaults to False because paper v3 does not use debiasing
29 | - it has baked in clipping by global max_grad_norm
30 | Arguments:
31 | params: iterable of parameters to optimize or dicts defining
32 | parameter groups
33 | lr: learning rate (default: 1e-3)
34 | betas: coefficients used for computing
35 | running averages of gradient and its square (default: (0.9, 0.999))
36 | eps: term added to the denominator to improve
37 | numerical stability (default: 1e-8)
38 | weight_decay: weight decay (L2 penalty) (default: 0)
39 | clamp_value: clamp weight_norm in (0,clamp_value) (default: 10)
40 | set to a high value to avoid it (e.g 10e3)
41 | bias_correction: debias statistics by (1 - beta**step) (default: True)
42 | min_8bit_size: statistics for parameters with fewer than this many elements will not be quantized
43 | reuse_grad_buffers: if True, optimizer will modify gradients in-place to save memory.
44 | If enabled, one must ensure that .zero_grad() is called after each optimizer step.
45 | update_chunk_size: quantized statistics will be de-quantized in chunks of up to this many elements.
46 | """
47 |
48 | def __init__(
49 | self,
50 | params: Params,
51 | lr: float = 1e-3,
52 | betas: Betas2 = (0.9, 0.999),
53 | eps: float = 1e-6,
54 | weight_decay: float = 0,
55 | clamp_value: float = 10,
56 | bias_correction: bool = True,
57 | min_8bit_size: int = 65536,
58 | reuse_grad_buffers: bool = False,
59 | update_chunk_size: int = 2 ** 24,
60 | max_grad_norm: Optional[float] = None,
61 | ) -> None:
62 | if eps < 0.0:
63 | raise ValueError("Invalid epsilon value: {}".format(eps))
64 | if not 0.0 <= betas[0] < 1.0:
65 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
66 | if not 0.0 <= betas[1] < 1.0:
67 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
68 | if weight_decay < 0:
69 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
70 | if clamp_value < 0.0:
71 | raise ValueError("Invalid clamp value: {}".format(clamp_value))
72 |
73 | self.clamp_value = clamp_value
74 | self.bias_correction = bias_correction
75 | self.reuse_grad_buffers = reuse_grad_buffers
76 | self.update_chunk_size = update_chunk_size
77 | self.max_grad_norm = max_grad_norm
78 |
79 | super(CPULAMB8Bit, self).__init__(
80 | "cpu-lamb",
81 | params,
82 | lr,
83 | betas,
84 | eps,
85 | weight_decay,
86 | optim_bits=8,
87 | min_8bit_size=min_8bit_size,
88 | args=None,
89 | percentile_clipping=100,
90 | block_wise=4096,
91 | max_unorm=0,
92 | )
93 |
94 | @torch.no_grad()
95 | def step(self, closure=None):
96 | if self.max_grad_norm is not None:
97 | iter_params = (param for group in self.param_groups for param in group["params"])
98 | torch.nn.utils.clip_grad_norm_(iter_params, self.max_grad_norm)
99 | return super().step(closure=closure)
100 |
101 | @torch.no_grad()
102 | def init_state(self, group, p, gindex, pindex):
103 | config = self.get_config(gindex, pindex, group)
104 | assert config["percentile_clipping"] == 100, "percentile clipping is not implemented on CPU"
105 | assert config["max_unorm"] == 0
106 |
107 | if config["optim_bits"] == 32:
108 | dtype = torch.float32
109 | elif config["optim_bits"] == 8:
110 | dtype = torch.uint8
111 | else:
112 | raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
113 |
114 | if p.numel() < config["min_8bit_size"]:
115 | dtype = torch.float32
116 |
117 | state = self.state[p]
118 | state["step"] = 0
119 |
120 | if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
121 | state["state1"] = torch.zeros_like(
122 | p,
123 | memory_format=torch.preserve_format,
124 | dtype=torch.float32,
125 | device=p.device,
126 | )
127 | state["state2"] = torch.zeros_like(
128 | p,
129 | memory_format=torch.preserve_format,
130 | dtype=torch.float32,
131 | device=p.device,
132 | )
133 | elif dtype == torch.uint8:
134 | if state["step"] == 0:
135 | if "dynamic" not in self.name2qmap:
136 | self.fill_qmap()
137 | self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
138 | self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
139 |
140 | n = p.numel()
141 | blocks = (n - 1) // config["block_wise"] + 1
142 |
143 | state["state1"] = torch.zeros_like(
144 | p,
145 | memory_format=torch.preserve_format,
146 | dtype=torch.uint8,
147 | device=p.device,
148 | )
149 | state["qmap1"] = self.name2qmap["dynamic"]
150 |
151 | state["state2"] = torch.zeros_like(
152 | p,
153 | memory_format=torch.preserve_format,
154 | dtype=torch.uint8,
155 | device=p.device,
156 | )
157 | state["qmap2"] = self.name2qmap["udynamic"]
158 |
159 | state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
160 | state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
161 |
162 | @torch.no_grad()
163 | def update_step(self, group: Dict[str, Any], p: torch.Tensor, gindex: int, pindex: int):
164 | state = self.state[p]
165 | config = self.get_config(gindex, pindex, group)
166 |
167 | p_cpu, grad_cpu = p.cpu(), p.grad.cpu()
168 | # this is a no-op if parameters are already on CPU
169 |
170 | step = state["step"] = state["step"] + 1
171 | beta1, beta2 = group["betas"]
172 |
173 | param_delta = self._update_moments_and_compute_delta(
174 | state, config, p_cpu, grad_cpu, beta1, beta2, group["eps"], group["weight_decay"]
175 | )
176 | del grad_cpu # grad_cpu is no longer needed and may be modified if self.reuse_grad_buffers
177 |
178 | step_norm = torch.norm(param_delta)
179 | weight_norm = p_cpu.norm().clamp(0, self.clamp_value)
180 |
181 | trust_ratio = weight_norm / step_norm if weight_norm != 0 and step_norm != 0 else 1.0
182 | state["weight_norm"], state["step_norm"], state["trust_ratio"] = (weight_norm, step_norm, trust_ratio)
183 |
184 | # Apply bias to lr to avoid broadcast.
185 | bias_correction = math.sqrt(1 - beta2 ** step) / (1 - beta1 ** step) if self.bias_correction else 1
186 | step_size = group["lr"] * bias_correction
187 | p.data.add_(param_delta.to(p.device), alpha=-step_size * trust_ratio)
188 |
189 | def _update_moments_and_compute_delta(
190 | self,
191 | state: Dict,
192 | config: Dict,
193 | p_cpu: torch.Tensor,
194 | grad_cpu: torch.Tensor,
195 | beta1: float,
196 | beta2: float,
197 | eps: float,
198 | weight_decay: float,
199 | ) -> torch.Tensor:
200 | step, block_size, chunk_size = (state["step"], config["block_wise"], self.update_chunk_size)
201 |
202 | if state["state1"].dtype != torch.uint8:
203 | # not quantized: update normally
204 | exp_avg, exp_avg_sq = state["state1"], state["state2"]
205 | exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
206 | exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
207 |
208 | sqrt_out = grad_cpu if self.reuse_grad_buffers else None
209 | _denominator = torch.sqrt(exp_avg_sq, out=sqrt_out).add_(eps)
210 | param_delta = torch.div(exp_avg, _denominator, out=_denominator)
211 | if weight_decay != 0:
212 | param_delta.add_(p_cpu, alpha=weight_decay)
213 | return param_delta
214 | elif p_cpu.numel() <= chunk_size:
215 | # quantized tensor within chunk size
216 | exp_avg = dequantize_blockwise(state["state1"], (state["absmax1"], state["qmap1"]), blocksize=block_size)
217 | exp_avg_sq = dequantize_blockwise(state["state2"], (state["absmax2"], state["qmap2"]), blocksize=block_size)
218 |
219 | exp_avg.mul_(beta1).add_(grad_cpu, alpha=1 - beta1)
220 | exp_avg_sq.mul_(beta2).addcmul_(grad_cpu, grad_cpu, value=1 - beta2)
221 |
222 | quantize_blockwise(exp_avg, state["qmap1"], state["absmax1"], out=state["state1"])
223 | quantize_blockwise(exp_avg_sq, state["qmap2"], state["absmax2"], out=state["state2"])
224 | # note: quantize_blockwise also modifies qmap and absmax in-place
225 |
226 | param_delta = exp_avg.div_(exp_avg_sq.sqrt_().add_(eps))
227 | # note: this changes statistics in-place, but it's okay b/c we saved quantized version
228 |
229 | if weight_decay != 0:
230 | param_delta.add_(p_cpu, alpha=weight_decay)
231 | return param_delta
232 |
233 | else:
234 | # very large quantized tensor, compute updates in chunks to save RAM
235 | flat_p, flat_grad, flat_state1, flat_state2 = (
236 | tensor.view(-1) for tensor in (p_cpu, grad_cpu, state["state1"], state["state2"])
237 | )
238 | output_buffer = flat_grad if self.reuse_grad_buffers else torch.empty_like(flat_grad)
239 |
240 | for chunk_index, chunk_start in enumerate(range(0, len(flat_p), chunk_size)):
241 | chunk = slice(chunk_start, chunk_start + chunk_size)
242 | chunk_blocks = slice(chunk_start // block_size, (chunk_start + chunk_size) // block_size)
243 |
244 | chunk_p, chunk_grad = flat_p[chunk], flat_grad[chunk]
245 | chunk_state1, chunk_state2 = flat_state1[chunk], flat_state2[chunk]
246 | chunk_absmax1, chunk_absmax2 = (
247 | state["absmax1"][chunk_blocks],
248 | state["absmax2"][chunk_blocks],
249 | )
250 | if chunk_state1.storage_offset() != 0:
251 | # clone chunks to ensure that tensors do not have offsets (bnb hack, possibly no longer needed)
252 | chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2 = map(
253 | torch.clone, (chunk_state1, chunk_state2, chunk_absmax1, chunk_absmax2),
254 | )
255 |
256 | exp_avg_chunk = dequantize_blockwise(
257 | chunk_state1, (chunk_absmax1, state["qmap1"]), blocksize=block_size
258 | )
259 | exp_avg_sq_chunk = dequantize_blockwise(
260 | chunk_state2, (chunk_absmax2, state["qmap2"]), blocksize=block_size
261 | )
262 |
263 | exp_avg_chunk.mul_(beta1).add_(chunk_grad, alpha=1 - beta1)
264 | exp_avg_sq_chunk.mul_(beta2).addcmul_(chunk_grad, chunk_grad, value=1 - beta2)
265 |
266 | # note: output_buffer cannot be modified until this line because it shares memory with grad_cpu
267 | del chunk_grad
268 |
269 | flat_state1[chunk], (
270 | state["absmax1"][chunk_blocks],
271 | state["qmap1"],
272 | ) = quantize_blockwise(exp_avg_chunk, state["qmap1"], chunk_absmax1, out=chunk_state1)
273 | flat_state2[chunk], (
274 | state["absmax2"][chunk_blocks],
275 | state["qmap2"],
276 | ) = quantize_blockwise(exp_avg_sq_chunk, state["qmap2"], chunk_absmax2, out=chunk_state2)
277 | # note: we need to explicitly assign new quantized tensors because of cloning earlier
278 |
279 | torch.div(
280 | exp_avg_chunk,
281 | exp_avg_sq_chunk.sqrt_().add_(eps),
282 | out=output_buffer[chunk],
283 | )
284 | # note: this changes statistics in-place, but it's okay b/c we saved quantized version
285 |
286 | if weight_decay != 0:
287 | output_buffer[chunk].add_(flat_p[chunk], alpha=weight_decay)
288 |
289 | param_delta = output_buffer.view_as(grad_cpu)
290 |
291 | return param_delta
292 |
--------------------------------------------------------------------------------
/src/modules/ffn.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.cuda.amp import custom_bwd, custom_fwd
8 |
9 | from src.modules.functional import ACT2FN
10 | from src.modules.linear import GeneralizedLinear, _GeneralizedLinear
11 |
12 |
13 | class LeanFFN(nn.Module):
14 | """
15 | A transformer FFN module that doesn't hog your GPU memory. Uses a manually optimized differentiation algorithm.
16 |
17 | :param hidden_size: base hidden size of the transformer
18 | :param intermediate_size: a (typically larger) hidden dimension where activation is applied
19 | :param activation: a pytorch nonlinearity to use in the intermediate layer
20 | :param gated: use gated activations based on https://arxiv.org/abs/2002.05202 and https://arxiv.org/abs/2102.11972
21 | note: gated activations require 1.5x more parameters compared to their non-gated variants.
22 | :param layer_norm_eps: see torch.nn.functional.layer_norm
23 | :param sandwich_norm: if set, applies an additional layer norm to projected attention outputs before residuals,
24 | as proposed in the CogView paper ( arXiv:2105.13290 ). This is meant to make fp16 training
25 | more stable for deep transformers. This technique is also a part of NormFormer ( arXiv:2110.09456 )
26 | :param dropout: hidden dropout probability, applied to the output projection (before adding residual)
27 | :param residual: if True, adds the original layer input to the final layer output
28 |
29 | :param dense_i2h: custom *first* linear layer (hidden_size -> intermediate_size or 2x indermediate_size)
30 | :param dense_h2o: custom *second* linear layer (intermediate_size -> hidden_size)
31 | """
32 |
33 | def __init__(
34 | self,
35 | hidden_size: int,
36 | intermediate_size: int,
37 | activation=ACT2FN["gelu_fused"],
38 | gated: bool = False,
39 | layer_norm_eps: float = 1e-12,
40 | dropout: float = 0.0,
41 | sandwich_norm: bool = False,
42 | dense_i2h: Optional[nn.Linear] = None,
43 | dense_h2o: Optional[nn.Linear] = None,
44 | residual: bool = True,
45 | ):
46 | super().__init__()
47 | i2h_out_features = intermediate_size * 2 if gated else intermediate_size
48 | self.dense_i2h = nn.Linear(hidden_size, i2h_out_features) if dense_i2h is None else dense_i2h
49 | self.dense_h2o = nn.Linear(intermediate_size, hidden_size) if dense_h2o is None else dense_h2o
50 | assert type(self.dense_i2h) in (
51 | nn.Linear, GeneralizedLinear), "only Linear and GeneralizedLinear are supported"
52 | assert type(self.dense_h2o) in (
53 | nn.Linear, GeneralizedLinear), "only Linear and GeneralizedLinear are supported"
54 | assert self.dense_i2h.in_features == self.dense_h2o.out_features == hidden_size
55 | assert self.dense_i2h.out_features == i2h_out_features and self.dense_h2o.in_features == intermediate_size
56 | self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
57 | self.sandwich_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if sandwich_norm else None
58 | self.activation = activation
59 | self.gated = gated
60 | self.dropout = dropout
61 | self.residual = residual
62 |
63 | def forward(self, input):
64 | sandwich_ln_weight = sandwich_ln_bias = None
65 | if self.sandwich_norm is not None:
66 | sandwich_ln_weight, sandwich_ln_bias = self.sandwich_norm.weight, self.sandwich_norm.bias
67 | i2h_lowrank_first = i2h_lowrank_second = h2o_lowrank_first = h2o_lowrank_second = None
68 | i2h_forward_indices = i2h_backward_indices = h2o_forward_indices = h2o_backward_indices = None
69 | if isinstance(self.dense_i2h, GeneralizedLinear):
70 | i2h_lowrank_first, i2h_lowrank_second = self.dense_i2h.get_combined_lowrank_components()
71 | i2h_forward_indices = self.dense_i2h.shared_matrix.forward_indices
72 | i2h_backward_indices = self.dense_i2h.shared_matrix.backward_indices
73 | if isinstance(self.dense_h2o, GeneralizedLinear):
74 | h2o_lowrank_first, h2o_lowrank_second = self.dense_h2o.get_combined_lowrank_components()
75 | h2o_forward_indices = self.dense_h2o.shared_matrix.forward_indices
76 | h2o_backward_indices = self.dense_h2o.shared_matrix.backward_indices
77 |
78 | output = _LeanFFN.apply(
79 | input,
80 | self.layer_norm.weight,
81 | self.layer_norm.bias,
82 | self.dense_i2h.weight,
83 | self.dense_i2h.bias,
84 | i2h_lowrank_first,
85 | i2h_lowrank_second,
86 | i2h_forward_indices,
87 | i2h_backward_indices,
88 | self.dense_h2o.weight,
89 | self.dense_h2o.bias,
90 | h2o_lowrank_first,
91 | h2o_lowrank_second,
92 | h2o_forward_indices,
93 | h2o_backward_indices,
94 | sandwich_ln_weight,
95 | sandwich_ln_bias,
96 | self.activation,
97 | self.gated,
98 | self.dropout,
99 | self.training,
100 | self.layer_norm.eps,
101 | self.residual,
102 | )
103 | return output
104 |
105 |
106 | class _LeanFFN(torch.autograd.Function):
107 | """Autograd function for transformer FFN, manually optimized to reduce memory without affecting performance"""
108 |
109 | @staticmethod
110 | def _apply_activation(pre_activation: torch.Tensor, activation: callable, gated: bool):
111 | if not gated:
112 | return activation(pre_activation)
113 | else:
114 | pre_gate, lin = pre_activation.split(pre_activation.shape[-1] // 2, dim=-1)
115 | return activation(pre_gate).mul_(lin)
116 |
117 | @staticmethod
118 | @custom_fwd
119 | def forward(
120 | ctx,
121 | input: torch.Tensor,
122 | ln_weight: torch.Tensor,
123 | ln_bias: torch.Tensor,
124 | i2h_weight: torch.Tensor,
125 | i2h_bias: Optional[torch.Tensor],
126 | i2h_lowrank_first: Optional[torch.Tensor],
127 | i2h_lowrank_second: Optional[torch.Tensor],
128 | i2h_forward_indices: Optional[torch.IntTensor],
129 | i2h_backward_indices: Optional[torch.IntTensor],
130 | h2o_weight: torch.Tensor,
131 | h2o_bias: Optional[torch.Tensor],
132 | h2o_lowrank_first: Optional[torch.Tensor],
133 | h2o_lowrank_second: Optional[torch.Tensor],
134 | h2o_forward_indices: Optional[torch.IntTensor],
135 | h2o_backward_indices: Optional[torch.IntTensor],
136 | sandwich_ln_weight: Optional[torch.Tensor],
137 | sandwich_ln_bias: Optional[torch.Tensor],
138 | activation: callable,
139 | gated: bool,
140 | dropout: float,
141 | training: bool,
142 | ln_eps: float,
143 | residual: bool,
144 | ):
145 | ctx._dropout, ctx._training, ctx._ln_eps = dropout, training, ln_eps
146 | ctx._activation, ctx._gated, ctx._residual = activation, gated, residual
147 | ctx._use_sandwich = sandwich_ln_weight is not None
148 |
149 | dropout_mask, pre_sandwich = None, None # optional tensors to save
150 | input_2d = input.view(-1, input.shape[-1])
151 |
152 | input_ln = F.layer_norm(input_2d, input.shape[-1:], ln_weight, ln_bias, ln_eps)
153 |
154 | pre_activation, *i2h_tensors = _GeneralizedLinear._forward_impl(
155 | input_ln, i2h_weight, i2h_bias, i2h_lowrank_first, i2h_lowrank_second, i2h_forward_indices,
156 | i2h_backward_indices
157 | )
158 |
159 | hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, ctx._gated)
160 |
161 | out, *h2o_tensors = _GeneralizedLinear._forward_impl(
162 | hid_act, h2o_weight, h2o_bias, h2o_lowrank_first, h2o_lowrank_second, h2o_forward_indices,
163 | h2o_backward_indices
164 | )
165 |
166 | if ctx._use_sandwich:
167 | pre_sandwich = out
168 | out = F.layer_norm(pre_sandwich, pre_sandwich.shape[-1:], sandwich_ln_weight, sandwich_ln_bias, eps=ln_eps)
169 |
170 | out = F.dropout(out, dropout, training, inplace=True)
171 | if training and dropout:
172 | dropout_mask = (out == 0.0).to(torch.int8)
173 |
174 | if residual:
175 | out = torch.add(out, input_2d, out=out if 'xla' not in out.device.type else None)
176 |
177 | assert i2h_tensors[0] is input_ln and h2o_tensors[0] is hid_act # we can rematerialize these tensors
178 | tensors_to_save = [
179 | input, pre_activation, ln_weight, ln_bias, pre_sandwich, sandwich_ln_weight, sandwich_ln_bias, dropout_mask
180 | ]
181 | tensors_to_save.extend((*i2h_tensors[1:], *h2o_tensors[1:]))
182 | ctx.save_for_backward(*tensors_to_save)
183 | ctx._num_i2h_tensors = len(i2h_tensors)
184 | ctx._num_h2o_tensors = len(h2o_tensors)
185 | return out.view(*input.shape)
186 |
187 | @staticmethod
188 | def _h2o_backward(ctx, grad_output: torch.Tensor, hid_act: torch.Tensor):
189 | h2o_tensors = ctx.saved_tensors[-ctx._num_h2o_tensors + 1:]
190 | needs_input_grad = [hid_act.requires_grad, *ctx.needs_input_grad[9:15]]
191 | grads = _GeneralizedLinear._backward_impl(grad_output, hid_act, *h2o_tensors,
192 | needs_input_grad=needs_input_grad)
193 | return tuple(grad if needed else None for grad, needed in zip_longest(grads, needs_input_grad))
194 |
195 | @staticmethod
196 | def _i2h_backward(ctx, grad_output: torch.Tensor, input_ln: torch.Tensor):
197 | i2h_tensors = ctx.saved_tensors[-ctx._num_i2h_tensors - ctx._num_h2o_tensors + 2: -ctx._num_h2o_tensors + 1]
198 | needs_input_grad = [input_ln.requires_grad, *ctx.needs_input_grad[3:9]]
199 | grads = _GeneralizedLinear._backward_impl(grad_output, input_ln, *i2h_tensors,
200 | needs_input_grad=needs_input_grad)
201 | return tuple(grad if needed else None for grad, needed in zip_longest(grads, needs_input_grad))
202 |
203 | @staticmethod
204 | @custom_bwd
205 | def backward(ctx, grad_output):
206 | grad_input = grad_ln_weight = grad_ln_bias = grad_sandwich_ln_weight = grad_sandwich_ln_bias = None
207 | input, pre_activation, ln_weight, ln_bias, = ctx.saved_tensors[:4]
208 | pre_sandwich, sandwich_ln_weight, sandwich_ln_bias, dropout_mask = ctx.saved_tensors[4: 8]
209 | grad_output_2d = grad_output.view(-1, grad_output.shape[-1])
210 |
211 | # backward(... -> sandwich_norm -> dropout -> residual)
212 | grad_residual_2d = grad_output_2d if ctx._residual else None
213 | if dropout_mask is not None:
214 | grad_output_2d = grad_output_2d.mul(dropout_mask.to(grad_output_2d.dtype))
215 | if ctx._use_sandwich:
216 | assert pre_sandwich is not None
217 | with torch.enable_grad():
218 | required_grad = pre_sandwich.requires_grad
219 | pre_sandwich.requires_grad_(True)
220 | sandwich_out = F.layer_norm(
221 | pre_sandwich, pre_sandwich.shape[-1:], sandwich_ln_weight, sandwich_ln_bias, eps=ctx._ln_eps
222 | )
223 | grad_output, grad_sandwich_ln_weight, grad_sandwich_ln_bias = torch.autograd.grad(
224 | sandwich_out, [pre_sandwich, sandwich_ln_weight, sandwich_ln_bias], grad_outputs=grad_output_2d
225 | )
226 | pre_sandwich.requires_grad_(required_grad)
227 | del pre_sandwich, sandwich_out
228 |
229 | # backward(... -> nonlinearity -> intermediate_layernorm -> linear_h2o -> ...)
230 | input_2d = input.view(-1, input.shape[-1])
231 | grad_h2o_output_2d = grad_output.view(-1, grad_output.shape[-1])
232 |
233 | with torch.enable_grad():
234 | # rematerialize activation
235 | pre_activation.requires_grad_(True)
236 | hid_act = _LeanFFN._apply_activation(pre_activation, ctx._activation, ctx._gated)
237 |
238 | with torch.no_grad():
239 | (grad_hid_act, grad_h2o_weight, grad_h2o_bias, grad_h2o_lowrank_first, grad_h2o_lowrank_second,
240 | unused_grad_forward_indices, unused_grad_backward_indices) = \
241 | _LeanFFN._h2o_backward(ctx, grad_h2o_output_2d, hid_act)
242 |
243 | (grad_hid,) = torch.autograd.grad(hid_act, pre_activation, grad_outputs=grad_hid_act)
244 | pre_activation.requires_grad_(False)
245 | del hid_act
246 |
247 | # backward(... -> input_layernorm -> linear_i2h -> ...)
248 | with torch.enable_grad():
249 | # rematerialize input_ln
250 | input_2d.requires_grad_(True)
251 | input_ln_2d = F.layer_norm(input_2d, input.shape[-1:], ln_weight, ln_bias, ctx._ln_eps)
252 |
253 | with torch.no_grad():
254 | (grad_input_ln_2d, grad_i2h_weight, grad_i2h_bias, grad_i2h_lowrank_first, grad_i2h_lowrank_second,
255 | unused_grad_forward_indices, unused_grad_backward_indices) = \
256 | _LeanFFN._i2h_backward(ctx, grad_hid, input_ln_2d)
257 |
258 | if any(ctx.needs_input_grad[0:3]):
259 | partial_grad_input_2d, grad_ln_weight, grad_ln_bias = torch.autograd.grad(
260 | outputs=input_ln_2d, inputs=[input_2d, ln_weight, ln_bias], grad_outputs=grad_input_ln_2d
261 | )
262 | del input_2d, input_ln_2d, grad_input_ln_2d
263 |
264 | # add up residual grads
265 | if ctx.needs_input_grad[0]:
266 | grad_input = partial_grad_input_2d
267 | if ctx._residual:
268 | grad_input = grad_input.add_(grad_residual_2d)
269 | grad_input = grad_input.view(*input.shape)
270 |
271 | return (grad_input, grad_ln_weight, grad_ln_bias,
272 | grad_i2h_weight, grad_i2h_bias, grad_i2h_lowrank_first, grad_i2h_lowrank_second, None, None,
273 | grad_h2o_weight, grad_h2o_bias, grad_h2o_lowrank_first, grad_h2o_lowrank_second, None, None,
274 | grad_sandwich_ln_weight, grad_sandwich_ln_bias, None, None, None, None, None, None)
275 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RuLeanALBERT
2 |
3 | RuLeanALBERT is a pretrained masked language model for the Russian language using a memory-efficient architecture.
4 |
5 | ## Using the model
6 | You can download the pretrained weights, the tokenizer and the config file used for pretraining from the Hugging Face Hub: [huggingface.co/yandex/RuLeanALBERT](https://huggingface.co/yandex/RuLeanALBERT).
7 | Download them directly by running the following code:
8 | ```
9 | wget https://huggingface.co/yandex/RuLeanALBERT/resolve/main/state.pth state.pth
10 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/config.json config.json
11 | mkdir tokenizer
12 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/tokenizer/config.json tokenizer/config.json
13 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/tokenizer/special_tokens_map.json tokenizer/special_tokens_map.json
14 | wget https://huggingface.co/yandex/RuLeanALBERT/raw/main/tokenizer/tokenizer.json tokenizer/tokenizer.json
15 | ```
16 |
17 | As the model itself is using custom code (see [`src/models`](./src/models)), right now the simplest solution is to clone the repository and to use the relevant classes (`LeanAlbertForPreTraining` and its dependencies) in your code.
18 |
19 | Loading the model is as simple as
20 | ```
21 | tokenizer = AlbertTokenizerFast.from_pretrained('tokenizer')
22 | config = LeanAlbertConfig.from_pretrained('config.json')
23 | model = LeanAlbertForPreTraining(config)
24 | model.resize_token_embeddings(len(tokenizer))
25 | checkpoint = torch.load(checkpoint_path, map_location='cpu')['model']
26 | model.load_state_dict(checkpoint)
27 | ```
28 |
29 | ## Fine-tuning guide
30 |
31 | Once you have downloaded the model, you can use [`finetune_mlm.py`](./finetune_mlm.py) to evaluate it on [RussianSuperGLUE](https://russiansuperglue.com/) and [RuCoLA](https://rucola-benchmark.com/).
32 |
33 | To do this, you can run the command as follows:
34 | ```
35 | python finetune_mlm.py -t TASK_NAME \
36 | --checkpoint state.pth \
37 | --data-dir . \
38 | --batch-size 32 \
39 | --grad-acc-steps 1 \
40 | --dropout 0.1 \
41 | --weight-decay 0.01 \
42 | --num-seeds 1
43 | ```
44 |
45 | Most datasets will be loaded from the Hugging Face Hub. However, you need to download [RuCoLA](https://github.com/RussianNLP/RuCoLA) and [RuCoS](https://russiansuperglue.com/tasks/task_info/RuCoS) from their respective sources and place them in `RuCoLA` and `RuCoS` directories respectively.
46 |
47 | For reference, finetuning with a batch size of 32 on RuCoLA should take approximately 10 GB of GPU memory. If you exceed the GPU memory limits for a specific task, you can reduce the batch size by reducing `--batch-size` and increasing `--grad-acc-steps` accordingly.
48 |
49 | If you want to finetune an existing masked language model from the Hugging Face Hub, you can do it with the same code. The script directly supports fine-tuning RuRoBERTa-large: simply change `--checkpoint` to `--model-name ruroberta-large`.
50 |
51 | For LiDiRus, you should use the model trained on the TERRa dataset with `predict_on_lidirus.py`.
52 | For RWSD, all ML-based solutions perform worse than the most-frequent class baseline: in our experiments, we found that the majority of runs with RuLeanALBERT or RuRoBERTa converge to a similar solution.
53 |
54 | ## Pretraining a new model
55 |
56 | Here you can find the best practices that we learned from running the experiment. These may help you set up your own collaborative experiment.
57 |
58 | If your training run is not confidential, feel free to ask for help on the [Hivemind discord server](https://discord.gg/vRNN9ua2).
59 |
60 |
61 | 1. Choose and verify your training configuration
62 |
63 | Depending on you use case, you may want to change
64 | - Dataset and preprocessing ([`data.py`](tasks/mlm/data.py));
65 | - Tokenizer ([`arguments.py`](arguments.py));
66 | - Model config
67 |
68 | In particular, you need to specify the datasets in the `make_training_dataset` function from [`data.py`](tasks/mlm/data.py).
69 | One solution is to use the [`datasets`](https://github.com/huggingface/datasets) library and stream one of the existing large datasets for your target domain. **A working example** of `data.py` can be found in the [NCAI-research/CALM](https://github.com/NCAI-Research/CALM/blob/main/tasks/mlm/data.py) project.
70 |
71 |
72 | When transitioning to a new language or new dataset, it is important to check that the tokenizer/collator works as intended **before** you begin training.
73 | The best way to do that is to manually look at training minibatches:
74 | ```python
75 | from tasks.mlm.data import make_training_dataset
76 | from tasks.mlm.whole_word_mask import DataCollatorForWholeWordMask
77 |
78 | tokenizer = create_tokenizer_here(...)
79 | dataset = make_training_dataset(tokenizer, max_sequence_length=...) # see arguments.py
80 | collator = DataCollatorForWholeWordMask(tokenizer, pad_to_multiple_of=...) # see arguments.py
81 | data_loader = torch.utils.data.DataLoader(dataset, collate_fn=collator, batch_size=4)
82 |
83 | # generate a few batches
84 | rows = []
85 | with tqdm(enumerate(data_loader)) as progress:
86 | for i, row in progress:
87 | rows.append(row)
88 | if i > 10:
89 | break
90 |
91 | # look into the training data
92 | row_ix, sample_ix = 0, 1
93 | sources = [tokenizer.decode([i]) for i in rows[row_ix]['input_ids'][sample_ix].data.numpy()]
94 | print("MASK RATE:", (rows[row_ix]['input_ids'][sample_ix] == 4).data.numpy().sum() / (rows[row_ix]['input_ids'][sample_ix] != 0).data.numpy().sum())
95 |
96 | for i in range(len(sources)):
97 | if sources[i] == '[MASK]':
98 | pass#sources[i] = '[[' + tokenizer.decode(rows[row_ix]['labels'][sample_ix][i].item()) + ']]'
99 |
100 | print(' '.join(sources))
101 | ```
102 |
103 | If you make many changes, it also helps to train a very model using your own device to check if everything works as intended. A good initial configuration is 6 layers, 512 hidden, 2048 intermediate).
104 |
105 | If you're training with volunteers, the most convenient way is to set up a Hugging Face organization. For instructions on that, see "make your own" section of https://training-transformers-together.github.io . We use WANDB for tracking logs and training progress: we've set up a [WandB team](https://docs.wandb.ai/ref/app/features/teams) for this experiment. Alternatively, you can use hivemind standalone (and even without internet access) by setting --authorize False and WANDB_DISABLED=true -- or manually removing the corresponding options from the code.
106 |
107 |
108 |
109 |
110 | 2. Setting up auxiliary peers
111 |
112 | Auxiliary peers are low-end servers without GPU that will keep track of the latest model checkpoint and report metrics and assist in communication.
113 | You will need 1-3 workers that track metrics, upload statistics, etc. These peers do not use GPU.
114 | If you have many participants are behind firewall (in --client_mode), it helps to add more auxiliary servers, as they can serve as relays and help with all-reduce.
115 |
116 | __Minimum requirements:__ 15+ GB RAM, at least 100Mbit/s download/upload speed, at least one port opened to incoming connections;
117 |
118 | __Where to get:__ cloud providers that have cheap ingress/egress pricing. Good examples: [pebblehost](https://pebblehost.com/dedicated/) "Essential-01" and [hetzner](https://www.hetzner.com/cloud) CX41. Path of the true jedi: use your homelab or university server -- but that may require networking experience. AWS/GCP/Azure has similar offers, but they cost more due to [egress pricing](https://cloud.google.com/vpc/network-pricing).
119 |
120 |
121 | __Setup env:__
122 |
123 | ```
124 | sudo apt install -y git tmux
125 | curl https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh > Anaconda3-2021.11-Linux-x86_64.sh
126 | bash Anaconda3-2021.11-Linux-x86_64.sh -b -p ~/anaconda3
127 | source ~/anaconda3/bin/activate
128 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
129 |
130 | git clone https://github.com/yandex-research/RuLeanALBERT
131 | cd RuLeanALBERT && pip install -q -r requirements.txt &> log
132 |
133 | # re-install bitsandbytes for the actual CUDA version
134 | pip uninstall -y bitsandbytes-cuda111
135 | pip install bitsandbytes-cuda113==0.26.0
136 |
137 | curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
138 | sudo apt-get install git-lfs
139 | git lfs install
140 | ```
141 |
142 |
143 | __Run auxiliary worker:__
144 |
145 | 1. Open a tmux (or screen) session that will stay up after you logout. (`tmux new` , [about tmux](https://tmuxcheatsheet.com/))
146 |
147 | 2. Measure internet bandwidth and set `$BANDWIDTH` variable
148 | ```bash
149 |
150 | # You can measure bandwidth automatically:
151 | curl -s https://gist.githubusercontent.com/justheuristic/5467799d8f2ad59b36fa75f642cc9b87/raw/c5a4b9b66987c2115e6c54a07d97e0104dfbcd97/speedtest.py | python - --json > speedtest.json
152 | export BANDWIDTH=`python -c "import json; speedtest = json.load(open('speedtest.json')); print(int(max(1, min(speedtest['upload'], speedtest['download']) / 1e6)))"`
153 | echo "Internet Bandwidth (Mb/s) = $BANDWIDTH"
154 |
155 | # If that doesn't work, you can simply `export BANDWIDTH=TODOyour_bandwidth_mbits_here` using the minimum of download and upload speed.
156 | ```
157 |
158 |
159 | 3. Run the auxiliary peer
160 | ```bash
161 | export MY_IP=`curl --ipv4 -s http://whatismyip.akamai.com/`
162 | echo "MY IP (check not empty):" $MY_IP
163 | # If empty, please set this manually: export MY_IP=...
164 | # When training on internal infrastructure, feel free to use internal IP.
165 | # If using IPv6, please replace /ip4/ with /ip6/ in subsequent lines
166 |
167 | export PORT=12345 # please choose a port where you can accept incoming tcp connections (or open that port if you're on a cloud)
168 | export LISTEN_ON=/ip4/0.0.0.0/tcp/$PORT
169 | export ANNOUNCE_ON=/ip4/$MY_IP/tcp/$PORT
170 | export WANDB_START_METHOD=thread
171 | export CUDA_VISIBLE_DEVICES= # do not use GPUs even if they are avilable
172 |
173 | # organizations
174 | export WANDB_ENTITY=YOUR_USERNAME_HERE
175 | export HF_ORGANIZATION_NAME=YOUR_ORG_HERE
176 |
177 | # experiment name
178 | export EXP_NAME=my-exp
179 | export WANDB_PROJECT=$EXP_NAME
180 | export HF_MODEL_NAME=$EXP_NAME
181 |
182 | export WANDB_API_KEY=TODO_get_your_wandb_key_here_wandb.ai/authorize
183 | export HF_USER_ACCESS_TOKEN=TODO_create_user_access_token_here_with_WRITE_permissions_https://huggingface.co/settings/token
184 | # note: you can avoid setting the two tokens above: in that case, the script will ask you to login to wandb and huggingface
185 |
186 | # activate your anaconda environment
187 | source ~/anaconda3/bin/activate
188 |
189 |
190 | ulimit -n 16384 # this line is important, ignoring it may cause a "Too Many Open Files" error
191 |
192 | python run_aux_peer.py --run_id $EXP_NAME --host_maddrs $LISTEN_ON --announce_maddrs $ANNOUNCE_ON --wandb_project $WANDB_PROJECT --store_checkpoints --upload_interval 43200 --repo_url $HF_ORGANIZATION_NAME/$HF_MODEL_NAME --assist_in_averaging --bandwidth $BANDWIDTH
193 | # Optionally, add more peers to the training via `--initial_peers ONE_OR_MORE PEERS_HERE`
194 | ```
195 |
196 | If everything went right, it will print its address as such:
197 | 
198 |
199 | Please copy this address and use it as ``--initial_peers`` with GPU trainers and other auxiliary peers.
200 |
201 |
202 |
203 |
204 | 3. Setting up a trainer
205 | Trainers are peers with GPUs (or other compute accelerators) that compute gradients, average them via all-reduce and perform optimizer steps.
206 | There are two broad types of trainers: normal (full) peers and client mode peers. Client peers rely on others to average their gradients, but otherwise behave same as full peers. You can designate your trainer as a client-only using the `--client_mode` flag.
207 |
208 | __When do I need client mode?__ if a peer is unreliable (e.g. will likely be gone in 1 hour) OR sits behind a firewall that blocks incoming connections OR has very unstable internet connection, it should be a client. For instance, it is recommended to set colab / kaggle peers as clients. In turn, cloud GPUs (even spot instances!) are generally more reliable and should be full peers.
209 |
210 | Participating as a client is easy, you can find the code for that in **this colab notebook(TODO)**. Setting up a full peer is more difficult,
211 | ### Set up environment:
212 |
213 | This part is the same as in auxiliary peer, except we don't need LFS (that was needed to upload checkpoints).
214 | ```bash
215 | sudo apt install -y git tmux
216 | curl https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh > Anaconda3-2021.11-Linux-x86_64.sh
217 | bash Anaconda3-2021.11-Linux-x86_64.sh -b -p ~/anaconda3
218 | source ~/anaconda3/bin/activate
219 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
220 |
221 | git clone https://github.com/yandex-research/RuLeanALBERT
222 | cd RuLeanALBERT && pip install -q -r requirements.txt &> log
223 |
224 | # re-install bitsandbytes for the actual CUDA version
225 | pip uninstall -y bitsandbytes-cuda111
226 | pip install -y bitsandbytes-cuda113==0.26.0
227 |
228 | # note: we use bitsandbytes for 8-bit LAMB, and in turn, bitsandbytes needs cuda -- even if you run on a non-CUDA device.
229 | ```
230 |
231 | ```bash
232 | export MY_IP=`curl --ipv4 -s http://whatismyip.akamai.com/`
233 | echo "MY IP (check not empty):" $MY_IP
234 | # If empty, please set this manually: export MY_IP=...
235 | # When training on internal infrastructure, feel free to use internal IP.
236 | # If using IPv6, please replace /ip4/ with /ip6/ in subsequent lines
237 |
238 | export PORT=31337 # same requirements as for aux peer
239 | export LISTEN_ON=/ip4/0.0.0.0/tcp/$PORT
240 |
241 | export CUDA_VISIBLE_DEVICES=0 # supports multiple cuda devices!
242 |
243 | # organization & experiment name
244 | export WANDB_ENTITY=YOUR_USERNAME_HERE
245 | export HF_ORGANIZATION_NAME=YOUR_ORG_HERE
246 | export EXP_NAME=my-exp
247 | export WANDB_PROJECT=$EXP_NAME-hivemind-trainers
248 | export HF_MODEL_NAME=$EXP_NAME
249 |
250 | export WANDB_API_KEY=get_your_wandb_key_here_https://wandb.ai/authorize_OR_just_login_on_wandb
251 | export HF_USER_ACCESS_TOKEN=create_user_access_token_here_with_WRITE_permissions_https://huggingface.co/settings/token
252 | # note: you can avoid setting the two tokens above: in that case, the script will ask you to login to wandb and huggingface
253 |
254 | export INITIAL_PEERS="/ip4/IP_ADDR/tcp/12345/p2p/PEER_ID"
255 | # ^-- If you're runnnng an independent experiment, this must be your own initial peers. Can be either auxiliary peers or full gpu peers.
256 |
257 |
258 | curl -s https://raw.githubusercontent.com/sivel/speedtest-cli/master/speedtest.py | python - --json > speedtest.json
259 | export BANDWIDTH=`python -c "import json; speedtest = json.load(open('speedtest.json')); print(int(max(1, min(speedtest['upload'], speedtest['download']) / 1e6)))"`
260 | echo "Internet Bandwidth (Mb/s) = $BANDWIDTH"
261 |
262 | ulimit -n 16384 # this line is important, ignoring it may cause a "Too Many Open Files" error
263 |
264 | python run_trainer.py --run_id $EXP_NAME --host_maddrs $LISTEN_ON --announce_maddrs $ANNOUNCE_ON --initial_peers $INITIAL_PEERS --bandwidth $BANDWIDTH \
265 | --per_device_train_batch_size 1 --gradient_accumulation_steps 1
266 | # you can tune per_device_train_batch_size, gradient_accumulation steps, --fp16, --gradient_checkpoints based on the device. A good rule of thumb is that the device should compute (batch size x num accumulations) gradients over 1-10 seconds. Setting very large gradient_accumulation_steps can cause your peer to miss an averaging round.
267 |
268 | ```
269 |
270 |
271 |
272 |
273 |
274 | Best (and worst) practices
275 |
276 | - __Hardware requirements:__ The code is meant to run with the following specs: 2-core CPU, 12gb RAM (more if you train a bigger model). Peers used as `--initial_peers` must be accessible by others, so you may need to open a network port for incoming connections. The rest depends on what role you're playing:
277 |
278 | - __Auxiliary peers:__ If you use `--upload_interval X --repo_url Y` must have enough disk space to store all the checkpoints. For instance, assuming that training takes 1 month and the model+optimizer state takes 1GB, you will need 30GB with `--upload_interval 86400`, 60GB if `--upload_interval 28800`, etc. If `assist_in_averaging`, ensure you have at least 100Mbit/s bandwidth, more is better.
279 |
280 | - __Trainers__ need *some* means for compute: a GPU with at least 6GB memory or a TPU - as long as you can run pytorch on that. You will need to tune `--per_device_train_batch_size X` to fit into memory. Also, you can use `--fp16` even on old GPUs to save memory. Finally, `--gradient_checkpointing` can reduce memory usage at the cost of 30-40% slower training. Non-client-mode peers must have at least 100Mbit/s network bandwidth, mode is better.
281 |
282 |
283 | - __Swarm composition:__ you will need 2-3 peers with public IP as `--initial_peers` for redundancy. If some participants are behind firewalls, we recommend finding at least one non-firewalled participant per 5 peers behind firewall.
284 |
285 |
286 |
287 | ## Acknowledgements
288 |
289 | Many of the best practices from this guide were found in the [CALM](https://github.com/NCAI-Research/CALM) project for collaborative training by NCAI.
290 |
--------------------------------------------------------------------------------
/src/models/albert.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch ALBERT modules that do not hog your GPU memory """
16 |
17 | import torch
18 | from torch import nn as nn
19 | from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss
20 | from transformers import AlbertPreTrainedModel
21 | from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
22 | from transformers.modeling_utils import PreTrainedModel
23 | from transformers.models.albert.modeling_albert import AlbertForPreTrainingOutput
24 | from transformers.utils import logging
25 |
26 | from src.models.transformer import GradientCheckpointingMixin, LeanTransformer, LeanTransformerConfig
27 |
28 | logger = logging.get_logger(__name__)
29 |
30 | _CONFIG_FOR_DOC = "LeanAlbertConfig"
31 | _TOKENIZER_FOR_DOC = "AlbertTokenizer"
32 |
33 |
34 | class LeanAlbertConfig(LeanTransformerConfig):
35 | def __init__(
36 | self,
37 | *args,
38 | vocab_size: int = 30000,
39 | embedding_size: int = 128,
40 | classifier_dropout_prob: float = 0.1,
41 | type_vocab_size: int = 2,
42 | pad_token_id: int = 0,
43 | bos_token_id: int = 2,
44 | eos_token_id: int = 3,
45 | **kwargs
46 | ):
47 | super().__init__(
48 | *args,
49 | pad_token_id=pad_token_id,
50 | bos_token_id=bos_token_id,
51 | eos_token_id=eos_token_id,
52 | type_vocab_size=type_vocab_size,
53 | **kwargs
54 | )
55 | self.vocab_size = vocab_size
56 | self.embedding_size = embedding_size
57 | self.classifier_dropout_prob = classifier_dropout_prob
58 | self.type_vocab_size = type_vocab_size
59 |
60 |
61 | class LeanAlbertEmbeddings(nn.Module):
62 | """
63 | Construct the embeddings from word, position and token_type embeddings.
64 | """
65 |
66 | def __init__(self, config: LeanTransformerConfig):
67 | super().__init__()
68 | self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
69 |
70 | self.token_type_embeddings = config.get_token_type_embeddings()
71 | self.position_embeddings = config.get_input_position_embeddings()
72 |
73 | self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
74 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
75 | if config.embedding_size != config.hidden_size:
76 | self.embedding_hidden_mapping = nn.Linear(config.embedding_size, config.hidden_size)
77 |
78 | if self.position_embeddings is not None:
79 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
80 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
81 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
82 |
83 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
84 | def forward(
85 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
86 | ):
87 | if input_ids is not None:
88 | input_shape = input_ids.size()
89 | else:
90 | input_shape = inputs_embeds.size()[:-1]
91 |
92 | seq_length = input_shape[1]
93 |
94 | if token_type_ids is None:
95 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
96 |
97 | if inputs_embeds is None:
98 | inputs_embeds = self.word_embeddings(input_ids)
99 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
100 |
101 | embeddings = inputs_embeds + token_type_embeddings
102 |
103 | if self.position_embeddings is not None:
104 | if position_ids is None:
105 | position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
106 | position_embeddings = self.position_embeddings(position_ids)
107 | embeddings += position_embeddings
108 |
109 | embeddings = self.layer_norm(embeddings)
110 | embeddings = self.dropout(embeddings)
111 | if hasattr(self, "embedding_hidden_mapping"):
112 | embeddings = self.embedding_hidden_mapping(embeddings)
113 | return embeddings
114 |
115 |
116 | class LeanAlbertModel(GradientCheckpointingMixin, PreTrainedModel):
117 | config_class = LeanAlbertConfig
118 | base_model_prefix = "lean_albert"
119 | _keys_to_ignore_on_load_missing = [r"position_ids"]
120 |
121 | def __init__(self, config: config_class, add_pooling_layer=True):
122 | super().__init__(config)
123 |
124 | self.config = config
125 | self.embeddings = LeanAlbertEmbeddings(config)
126 | self.transformer = LeanTransformer(config)
127 |
128 | if add_pooling_layer:
129 | self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
130 | self.pooler_activation = nn.Tanh()
131 | else:
132 | self.pooler = None
133 | self.pooler_activation = None
134 |
135 | self.init_weights()
136 |
137 | def get_input_embeddings(self):
138 | return self.embeddings.word_embeddings
139 |
140 | def set_input_embeddings(self, value):
141 | self.embeddings.word_embeddings = value
142 |
143 | def _init_weights(self, module: nn.Module):
144 | return self.config.init_weights(module)
145 |
146 | def forward(
147 | self,
148 | input_ids=None,
149 | attention_mask=None,
150 | token_type_ids=None,
151 | position_ids=None,
152 | head_mask=None,
153 | inputs_embeds=None,
154 | output_attentions=None,
155 | output_hidden_states=None,
156 | return_dict=None,
157 | ):
158 | assert head_mask is None and output_attentions is None and output_hidden_states is None, "not implemented"
159 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
160 |
161 | if input_ids is not None and inputs_embeds is not None:
162 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
163 | elif input_ids is not None:
164 | input_shape = input_ids.size()
165 | elif inputs_embeds is not None:
166 | input_shape = inputs_embeds.size()[:-1]
167 | else:
168 | raise ValueError("You have to specify either input_ids or inputs_embeds")
169 |
170 | batch_size, seq_length = input_shape
171 | device = input_ids.device if input_ids is not None else inputs_embeds.device
172 |
173 | if attention_mask is None:
174 | attention_mask = torch.ones(input_shape, device=device, dtype=int)
175 | else:
176 | assert not torch.is_floating_point(attention_mask), "The model requires boolean or int mask with 0/1 entries"
177 |
178 | if token_type_ids is None:
179 | if hasattr(self.embeddings, "token_type_ids"):
180 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
181 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
182 | token_type_ids = buffered_token_type_ids_expanded
183 | else:
184 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
185 |
186 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
187 | extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
188 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
189 |
190 | embedding_output = self.embeddings(
191 | input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
192 | )
193 | transformer_outputs = self.transformer(embedding_output, extended_attention_mask)
194 |
195 | sequence_output = transformer_outputs[0]
196 |
197 | pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
198 |
199 | if not return_dict:
200 | return (sequence_output, pooled_output) + transformer_outputs[1:]
201 |
202 | return BaseModelOutputWithPooling(
203 | last_hidden_state=sequence_output,
204 | pooler_output=pooled_output,
205 | hidden_states=transformer_outputs.hidden_states,
206 | attentions=transformer_outputs.attentions,
207 | )
208 |
209 |
210 | class AlbertMLMHead(nn.Module):
211 | def __init__(self, config):
212 | super().__init__()
213 |
214 | self.layer_norm = nn.LayerNorm(config.embedding_size)
215 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
216 | self.dense = nn.Linear(config.hidden_size, config.embedding_size)
217 | self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
218 | self.activation = config.get_activation_callable()
219 | self.decoder.bias = self.bias
220 |
221 | def forward(self, hidden_states):
222 | hidden_states = self.dense(hidden_states)
223 | hidden_states = self.activation(hidden_states)
224 | hidden_states = self.layer_norm(hidden_states)
225 | hidden_states = self.decoder(hidden_states)
226 |
227 | prediction_scores = hidden_states
228 |
229 | return prediction_scores
230 |
231 | def _tie_weights(self):
232 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
233 | self.bias = self.decoder.bias
234 |
235 |
236 | class AlbertSOPHead(nn.Module):
237 | def __init__(self, config):
238 | super().__init__()
239 |
240 | self.dropout = nn.Dropout(config.classifier_dropout_prob)
241 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
242 |
243 | def forward(self, pooled_output):
244 | dropout_pooled_output = self.dropout(pooled_output)
245 | logits = self.classifier(dropout_pooled_output)
246 | return logits
247 |
248 |
249 | class LeanAlbertForPreTraining(GradientCheckpointingMixin, PreTrainedModel):
250 | config_class = LeanAlbertConfig
251 | base_model_prefix = "lean_albert"
252 |
253 | def __init__(self, config: config_class):
254 | super().__init__(config)
255 |
256 | self.albert = LeanAlbertModel(config)
257 | self.predictions = AlbertMLMHead(config)
258 | self.sop_classifier = AlbertSOPHead(config)
259 |
260 | def get_input_embeddings(self):
261 | return self.albert.embeddings.word_embeddings
262 |
263 | def set_input_embeddings(self, new_embeddings: nn.Module):
264 | self.albert.embeddings.word_embeddings = new_embeddings
265 |
266 | def get_output_embeddings(self):
267 | return self.predictions.decoder
268 |
269 | def set_output_embeddings(self, new_embeddings):
270 | self.predictions.decoder = new_embeddings
271 |
272 | def _init_weights(self, module: nn.Module):
273 | return self.config.init_weights(module)
274 |
275 | def forward(
276 | self,
277 | input_ids=None,
278 | attention_mask=None,
279 | token_type_ids=None,
280 | position_ids=None,
281 | head_mask=None,
282 | inputs_embeds=None,
283 | labels=None,
284 | sentence_order_label=None,
285 | output_attentions=None,
286 | output_hidden_states=None,
287 | return_dict=None,
288 | ):
289 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
290 |
291 | outputs = self.albert(
292 | input_ids,
293 | attention_mask=attention_mask,
294 | token_type_ids=token_type_ids,
295 | position_ids=position_ids,
296 | head_mask=head_mask,
297 | inputs_embeds=inputs_embeds,
298 | output_attentions=output_attentions,
299 | output_hidden_states=output_hidden_states,
300 | return_dict=return_dict,
301 | )
302 |
303 | sequence_output, pooled_output = outputs[:2]
304 |
305 | prediction_scores = self.predictions(sequence_output)
306 | sop_scores = self.sop_classifier(pooled_output)
307 |
308 | total_loss = None
309 | if labels is not None:
310 | loss_fct = nn.CrossEntropyLoss()
311 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
312 | if sentence_order_label is not None:
313 | sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
314 | total_loss = masked_lm_loss + sentence_order_loss
315 | else:
316 | total_loss = masked_lm_loss
317 |
318 | if not return_dict:
319 | output = (prediction_scores, sop_scores) + outputs[2:]
320 | return ((total_loss,) + output) if total_loss is not None else output
321 |
322 | return AlbertForPreTrainingOutput(
323 | loss=total_loss,
324 | prediction_logits=prediction_scores,
325 | sop_logits=sop_scores,
326 | hidden_states=outputs.hidden_states,
327 | attentions=outputs.attentions,
328 | )
329 |
330 |
331 | class LeanAlbertForSequenceClassification(AlbertPreTrainedModel):
332 | config_class = LeanAlbertConfig
333 | base_model_prefix = "albert"
334 |
335 | def __init__(self, config: LeanAlbertConfig):
336 | super().__init__(config)
337 | self.num_labels = config.num_labels
338 | self.config = config
339 |
340 | self.albert = LeanAlbertModel(config, add_pooling_layer=False)
341 |
342 | self.classifier = nn.Sequential(
343 | nn.Dropout(config.classifier_dropout_prob),
344 | nn.Linear(config.hidden_size, config.hidden_size),
345 | nn.Tanh(),
346 | nn.Dropout(config.classifier_dropout_prob),
347 | nn.Linear(config.hidden_size, self.config.num_labels)
348 | )
349 |
350 | # Initialize weights and apply final processing
351 | self.post_init()
352 |
353 | def forward(
354 | self,
355 | input_ids=None,
356 | attention_mask=None,
357 | token_type_ids=None,
358 | position_ids=None,
359 | head_mask=None,
360 | inputs_embeds=None,
361 | labels=None,
362 | output_attentions=None,
363 | output_hidden_states=None,
364 | return_dict=None,
365 | ):
366 | r"""
367 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
368 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
369 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
370 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
371 | """
372 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373 |
374 | outputs = self.albert(
375 | input_ids=input_ids,
376 | attention_mask=attention_mask,
377 | token_type_ids=token_type_ids,
378 | position_ids=position_ids,
379 | head_mask=head_mask,
380 | inputs_embeds=inputs_embeds,
381 | output_attentions=output_attentions,
382 | output_hidden_states=output_hidden_states,
383 | return_dict=return_dict,
384 | )
385 | sequence_output = outputs[0]
386 | logits = self.classifier(sequence_output[:, 0, :])
387 |
388 | loss = None
389 | if labels is not None:
390 | if self.config.problem_type is None:
391 | if self.num_labels == 1:
392 | self.config.problem_type = "regression"
393 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
394 | self.config.problem_type = "single_label_classification"
395 | else:
396 | self.config.problem_type = "multi_label_classification"
397 |
398 | if self.config.problem_type == "regression":
399 | loss_fct = MSELoss()
400 | if self.num_labels == 1:
401 | loss = loss_fct(logits.squeeze(), labels.squeeze())
402 | else:
403 | loss = loss_fct(logits, labels)
404 | elif self.config.problem_type == "single_label_classification":
405 | loss_fct = CrossEntropyLoss()
406 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
407 | elif self.config.problem_type == "multi_label_classification":
408 | loss_fct = BCEWithLogitsLoss()
409 | loss = loss_fct(logits, labels)
410 |
411 | if not return_dict:
412 | output = (logits,) + outputs[2:]
413 | return ((loss,) + output) if loss is not None else output
414 |
415 | return SequenceClassifierOutput(
416 | loss=loss,
417 | logits=logits,
418 | hidden_states=outputs.hidden_states,
419 | attentions=outputs.attentions,
420 | )
421 |
--------------------------------------------------------------------------------