├── .gitignore ├── LICENSE ├── README.md ├── agorabanner.png ├── example_language.py ├── example_multimodal.py ├── gpt4 ├── __init__.py ├── attend.py ├── gpt4.py ├── model.py ├── train.py └── utils │ ├── __init__.py │ └── stable_adam.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eternal Reclaimer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) 2 | 3 | 4 | 5 | # GPT4 6 | The open source implementation of the base model behind GPT-4 from OPENAI [Language + Multi-Modal], click here for the [Research Paper](https://arxiv.org/pdf/2303.08774.pdf) 7 | 8 | 9 | # Installation 10 | `pip install gpt4-torch` 11 | 12 | 13 | # Usage 14 | 15 | Here's an illustrative code snippet that showcases GPT-3 in action: 16 | 17 | 18 | ```python 19 | import torch 20 | from gpt4 import GPT4 21 | 22 | # Generate a random input sequence 23 | x = torch.randint(0, 256, (1, 1024)).cuda() 24 | 25 | # Initialize GPT-3 model 26 | model = GPT4() 27 | 28 | # Pass the input sequence through the model 29 | output = model(x) 30 | ``` 31 | 32 | ## MultiModal Iteration 33 | * Pass in text and and image tensors into GPT4 34 | ```python 35 | import torch 36 | from gpt4.gpt4 import GPT4MultiModal 37 | 38 | #usage 39 | img = torch.randn(1, 3, 256, 256) 40 | text = torch.randint(0, 20000, (1, 1024)) 41 | 42 | 43 | model = GPT4MultiModal() 44 | output = model(text, img) 45 | 46 | ``` 47 | 48 | 49 | # 📚 Training 50 | 51 | ```python 52 | from gpt4 import train 53 | 54 | train() 55 | 56 | ``` 57 | 58 | For further instructions, refer to the [Training SOP](DOCs/TRAINING.md). 59 | 60 | 61 | 1. Set the environment variables: 62 | - `ENTITY_NAME`: Your wandb project name 63 | - `OUTPUT_DIR`: Directory to save the weights (e.g., `./weights`) 64 | - `MASTER_ADDR`: For distributed training 65 | - `MASTER_PORT` For master port distributed training 66 | - `RANK`- Number of nodes services 67 | - `WORLD_SIZE` Number of gpus 68 | 69 | 2. Configure the training: 70 | - Accelerate Config 71 | - Enable Deepspeed 3 72 | - Accelerate launch train_distributed_accelerate.py 73 | 74 | For more information, refer to the [Training SOP](DOCs/TRAINING.md). 75 | -------------------------------------------------------------------------------- /agorabanner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyegomez/GPT4/f79f992189318c99419fd3cd29d9d955f5f67a55/agorabanner.png -------------------------------------------------------------------------------- /example_language.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gpt4.gpt4 import GPT4 3 | 4 | x = torch.randint(0, 256, (1, 1024)).cuda() 5 | 6 | model = GPT4() 7 | 8 | model(x) 9 | 10 | -------------------------------------------------------------------------------- /example_multimodal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gpt4.gpt4 import GPT4MultiModal 3 | 4 | #usage 5 | img = torch.randn(1, 3, 256, 256) 6 | caption = torch.randint(0, 20000, (1, 1024)) 7 | 8 | model = GPT4MultiModal() 9 | output = model(img, caption) 10 | print(output.shape) # (1, 1024, 20000) 11 | 12 | -------------------------------------------------------------------------------- /gpt4/__init__.py: -------------------------------------------------------------------------------- 1 | from gpt4.gpt4 import GPT4 2 | from gpt4.train import train -------------------------------------------------------------------------------- /gpt4/attend.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from dataclasses import dataclass 3 | from functools import partial, wraps 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from packaging import version 10 | from torch import Tensor, einsum, nn 11 | 12 | # constants 13 | 14 | EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 15 | 16 | @dataclass 17 | class Intermediates: 18 | qk_similarities: Optional[Tensor] = None 19 | pre_softmax_attn: Optional[Tensor] = None 20 | post_softmax_attn: Optional[Tensor] = None 21 | 22 | def to_tuple(self): 23 | return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn) 24 | 25 | # helpers 26 | 27 | def exists(val): 28 | return val is not None 29 | 30 | def default(val, d): 31 | return val if exists(val) else d 32 | 33 | def compact(arr): 34 | return [*filter(exists, arr)] 35 | 36 | def once(fn): 37 | called = False 38 | @wraps(fn) 39 | def inner(x): 40 | nonlocal called 41 | if called: 42 | return 43 | called = True 44 | return fn(x) 45 | return inner 46 | 47 | print_once = once(print) 48 | 49 | # functions for creating causal mask 50 | # need a special one for onnx cpu (no support for .triu) 51 | 52 | def create_causal_mask(i, j, device): 53 | return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) 54 | 55 | def onnx_create_causal_mask(i, j, device): 56 | r = torch.arange(i, device = device) 57 | causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j') 58 | causal_mask = F.pad(causal_mask, (j - i, 0), value = False) 59 | return causal_mask 60 | 61 | # main class 62 | 63 | class Attend(nn.Module): 64 | def __init__( 65 | self, 66 | *, 67 | dropout = 0., 68 | causal = False, 69 | heads = None, 70 | talking_heads = False, 71 | sparse_topk = None, 72 | scale = None, 73 | qk_norm = False, 74 | flash = False, 75 | add_zero_kv = False, 76 | onnxable = False 77 | ): 78 | super().__init__() 79 | self.scale = scale 80 | self.qk_norm = qk_norm 81 | 82 | self.causal = causal 83 | self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask 84 | 85 | self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax 86 | 87 | self.dropout = dropout 88 | self.attn_dropout = nn.Dropout(dropout) 89 | 90 | # talking heads 91 | 92 | assert not (flash and talking_heads), 'talking heads not compatible with flash attention' 93 | 94 | self.talking_heads = talking_heads 95 | if talking_heads: 96 | self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) 97 | self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) 98 | 99 | # sparse topk 100 | 101 | assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention' 102 | self.sparse_topk = sparse_topk 103 | 104 | # add a key / value token composed of zeros 105 | # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html 106 | 107 | self.add_zero_kv = add_zero_kv 108 | 109 | # flash attention 110 | 111 | self.flash = flash 112 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 113 | 114 | # determine efficient attention configs for cuda and cpu 115 | 116 | self.cpu_config = EfficientAttentionConfig(True, True, True) 117 | self.cuda_config = None 118 | 119 | if not torch.cuda.is_available() or not flash: 120 | return 121 | 122 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 123 | 124 | if device_properties.major == 8 and device_properties.minor == 0: 125 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 126 | self.cuda_config = EfficientAttentionConfig(True, False, False) 127 | else: 128 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 129 | self.cuda_config = EfficientAttentionConfig(False, True, True) 130 | 131 | def flash_attn( 132 | self, 133 | q, k, v, 134 | mask = None, 135 | attn_bias = None 136 | ): 137 | batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 138 | 139 | # Recommended for multi-query single-key-value attention by Tri Dao 140 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) 141 | 142 | if k.ndim == 3: 143 | k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) 144 | 145 | if v.ndim == 3: 146 | v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) 147 | 148 | # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention 149 | 150 | if self.qk_norm: 151 | default_scale = q.shape[-1] ** -0.5 152 | q = q * (default_scale / self.scale) 153 | 154 | # Check if mask exists and expand to compatible shape 155 | # The mask is B L, so it would have to be expanded to B H N L 156 | 157 | causal = self.causal 158 | 159 | if exists(mask): 160 | assert mask.ndim == 4 161 | mask = mask.expand(batch, heads, q_len, k_len) 162 | 163 | # manually handle causal mask, if another mask was given 164 | 165 | if causal: 166 | causal_mask = self.create_causal_mask(q_len, k_len, device = device) 167 | mask = mask & ~causal_mask 168 | causal = False 169 | 170 | # handle alibi positional bias 171 | # convert from bool to float 172 | 173 | if exists(attn_bias): 174 | attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1) 175 | 176 | # if mask given, the mask would already contain the causal mask from above logic 177 | # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number 178 | 179 | mask_value = -torch.finfo(q.dtype).max 180 | 181 | if exists(mask): 182 | attn_bias = attn_bias.masked_fill(~mask, mask_value // 2) 183 | elif causal: 184 | causal_mask = self.create_causal_mask(q_len, k_len, device = device) 185 | attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2) 186 | causal = False 187 | 188 | # scaled_dot_product_attention handles attn_mask either as bool or additive bias 189 | # make it an additive bias here 190 | 191 | mask = attn_bias 192 | 193 | # Check if there is a compatible device for flash attention 194 | 195 | config = self.cuda_config if is_cuda else self.cpu_config 196 | 197 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 198 | 199 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 200 | out = F.scaled_dot_product_attention( 201 | q, k, v, 202 | attn_mask = mask, 203 | dropout_p = self.dropout if self.training else 0., 204 | is_causal = causal 205 | ) 206 | 207 | return out, Intermediates() 208 | 209 | def forward( 210 | self, 211 | q, k, v, 212 | mask = None, 213 | attn_bias = None, 214 | prev_attn = None 215 | ): 216 | """ 217 | einstein notation 218 | b - batch 219 | h - heads 220 | n, i, j - sequence length (base sequence length, source, target) 221 | d - feature dimension 222 | """ 223 | 224 | n, device = q.shape[-2], q.device 225 | 226 | scale = default(self.scale, q.shape[-1] ** -0.5) 227 | 228 | if self.add_zero_kv: 229 | k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v)) 230 | 231 | if exists(mask): 232 | mask = F.pad(mask, (1, 0), value = True) 233 | 234 | if exists(attn_bias): 235 | attn_bias = F.pad(attn_bias, (1, 0), value = 0.) 236 | 237 | if self.flash: 238 | assert not exists(prev_attn), 'residual attention not compatible with flash attention' 239 | return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) 240 | 241 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' 242 | 243 | dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale 244 | 245 | if exists(prev_attn): 246 | dots = dots + prev_attn 247 | 248 | qk_similarities = dots.clone() 249 | 250 | if self.talking_heads: 251 | dots = self.pre_softmax_talking_heads(dots) 252 | 253 | if exists(attn_bias): 254 | dots = dots + attn_bias 255 | 256 | i, j, dtype = *dots.shape[-2:], dots.dtype 257 | 258 | mask_value = -torch.finfo(dots.dtype).max 259 | 260 | if exists(self.sparse_topk) and self.sparse_topk < j: 261 | top_values, _ = dots.topk(self.sparse_topk, dim = -1) 262 | sparse_topk_mask = dots < top_values[..., -1:] 263 | mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask 264 | 265 | if exists(mask): 266 | dots = dots.masked_fill(~mask, mask_value) 267 | 268 | if self.causal: 269 | causal_mask = self.create_causal_mask(i, j, device = device) 270 | dots = dots.masked_fill(causal_mask, mask_value) 271 | 272 | pre_softmax_attn = dots.clone() 273 | 274 | attn = self.attn_fn(dots, dim = -1) 275 | attn = attn.type(dtype) 276 | 277 | post_softmax_attn = attn.clone() 278 | 279 | attn = self.attn_dropout(attn) 280 | 281 | if self.talking_heads: 282 | attn = self.post_softmax_talking_heads(attn) 283 | 284 | out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) 285 | 286 | intermediates = Intermediates( 287 | qk_similarities = qk_similarities, 288 | pre_softmax_attn = pre_softmax_attn, 289 | post_softmax_attn = post_softmax_attn 290 | ) 291 | 292 | return out, intermediates 293 | 294 | # cascading heads logic 295 | 296 | def to_single_heads(t, dim = 1): 297 | heads = t.unbind(dim = dim) 298 | return tuple(head.unsqueeze(dim) for head in heads) 299 | 300 | class CascadingHeads(nn.Module): 301 | def __init__(self, attend: Attend): 302 | super().__init__() 303 | self.attend = attend 304 | 305 | def forward( 306 | self, 307 | q, k, v, 308 | mask = None, 309 | attn_bias = None, 310 | prev_attn = None 311 | ): 312 | assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same' 313 | 314 | # split inputs into per-head inputs 315 | 316 | heads = q.shape[1] 317 | 318 | queries = to_single_heads(q) 319 | keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads) 320 | values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads) 321 | 322 | mask = (mask,) * heads 323 | 324 | attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads) 325 | prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads) 326 | 327 | # now loop through each head, without output of previous head summed with the next head 328 | # thus cascading 329 | 330 | all_outs = [] 331 | all_intermediates = [] 332 | 333 | prev_head_out = None 334 | 335 | for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn): 336 | 337 | if exists(prev_head_out): 338 | h_q = h_q + prev_head_out 339 | 340 | out, intermediates = self.attend( 341 | h_q, h_k, h_v, 342 | mask = h_mask, 343 | attn_bias = h_attn_bias, 344 | prev_attn = h_prev_attn 345 | ) 346 | 347 | prev_head_out = out 348 | 349 | all_outs.append(out) 350 | all_intermediates.append(intermediates) 351 | 352 | # cat all output heads 353 | 354 | all_outs = torch.cat(all_outs, dim = 1) 355 | 356 | # cat all intermediates, if they exist 357 | 358 | qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates)) 359 | 360 | qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn)) 361 | 362 | aggregated_intermediates = Intermediates( 363 | qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None, 364 | pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None, 365 | post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None 366 | ) 367 | 368 | return all_outs, aggregated_intermediates -------------------------------------------------------------------------------- /gpt4/gpt4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from gpt4.model import ( 5 | AutoregressiveWrapper, 6 | Decoder, 7 | Encoder, 8 | Transformer, 9 | ViTransformerWrapper, 10 | ) 11 | 12 | 13 | class GPT4(nn.Module): 14 | """ 15 | GPT4 is a transformer-based model architecture. It initializes with 16 | a Transformer and AutoregressiveWrapper with default or user-specified parameters. 17 | Initialize the model with specified or default parameters. 18 | Args: 19 | - num_tokens: Number of tokens in the vocabulary 20 | - max_seq_len: Maximum sequence length 21 | - dim: Dimension of the model 22 | - depth: Depth of the model 23 | - dim_head: Dimension of the model head 24 | - heads: Number of heads 25 | - use_abs_pos_emb: Whether to use absolute position embedding 26 | - alibi_pos_bias: Alibi position bias 27 | - alibi_num_heads: Number of alibi heads 28 | - rotary_xpos: Rotary position 29 | - attn_flash: Attention flash 30 | - deepnorm: Deep normalization 31 | - shift_tokens: Number of tokens to shift 32 | - attn_one_kv_head: Attention one key/value head 33 | - qk_norm: Query-key normalization 34 | - attn_qk_norm: Attention query-key normalization 35 | - attn_qk_norm_dim_scale: Attention query-key normalization dimension scale 36 | - embedding_provider: Embedding provider module 37 | """ 38 | def __init__(self, 39 | num_tokens=50432, 40 | max_seq_len=8192, 41 | dim=2560, 42 | depth=32, 43 | dim_head=128, 44 | heads=24, 45 | use_abs_pos_emb=False, 46 | alibi_pos_bias=True, 47 | alibi_num_heads=12, 48 | rotary_xpos=True, 49 | attn_flash=True, 50 | # shift_tokens=1, 51 | attn_one_kv_head=True, # multiquery attention 52 | qk_norm=True, 53 | attn_qk_norm=True, 54 | attn_qk_norm_dim_scale=True, 55 | ): 56 | super().__init__() 57 | 58 | try: 59 | self.decoder = Transformer( 60 | num_tokens=num_tokens, 61 | max_seq_len=max_seq_len, 62 | use_abs_pos_emb=use_abs_pos_emb, 63 | attn_layers=Decoder( 64 | dim=dim, 65 | depth=depth, 66 | dim_head=dim_head, 67 | heads=heads, 68 | alibi_pos_bias=alibi_pos_bias, 69 | alibi_num_heads=alibi_num_heads, 70 | rotary_xpos=rotary_xpos, 71 | attn_flash=attn_flash, 72 | # deepnorm=deepnorm, 73 | # shift_tokens=shift_tokens, 74 | attn_one_kv_head=attn_one_kv_head, 75 | qk_norm=qk_norm, 76 | attn_qk_norm=attn_qk_norm, 77 | attn_qk_norm_dim_scale=attn_qk_norm_dim_scale 78 | ) 79 | ) 80 | 81 | self.decoder = AutoregressiveWrapper(self.decoder) 82 | 83 | except Exception as e: 84 | print("Failed to initialize Andromeda: ", e) 85 | raise 86 | 87 | def forward(self, text_tokens, **kwargs): 88 | try: 89 | model_input = self.decoder.forward(text_tokens)[0] 90 | return self.decoder(model_input, padded_x=model_input[0]) 91 | except Exception as e: 92 | print("Failed in forward method: ", e) 93 | raise 94 | 95 | 96 | 97 | class GPT4MultiModal(torch.nn.Module): 98 | def __init__(self, 99 | image_size=256, 100 | patch_size=32, 101 | encoder_dim=512, 102 | encoder_depth=6, 103 | encoder_heads=8, 104 | num_tokens=20000, 105 | max_seq_len=1024, 106 | decoder_dim=512, 107 | decoder_depth=6, 108 | decoder_heads=8, 109 | alibi_num_heads=4, 110 | use_abs_pos_emb=False, 111 | cross_attend=True, 112 | alibi_pos_bias=True, 113 | rotary_xpos=True, 114 | attn_flash=True, 115 | qk_norm=True): 116 | 117 | super(GPT4MultiModal, self).__init__() 118 | 119 | self.encoder = ViTransformerWrapper( 120 | image_size=image_size, 121 | patch_size=patch_size, 122 | attn_layers=Encoder( 123 | dim=encoder_dim, 124 | depth=encoder_depth, 125 | heads=encoder_heads 126 | ) 127 | ) 128 | 129 | self.decoder = Transformer( 130 | num_tokens=num_tokens, 131 | max_seq_len=max_seq_len, 132 | use_abs_pos_emb=use_abs_pos_emb, 133 | attn_layers=Decoder( 134 | dim=decoder_dim, 135 | depth=decoder_depth, 136 | heads=decoder_heads, 137 | cross_attend=cross_attend, 138 | alibi_pos_bias=alibi_pos_bias, 139 | alibi_num_heads=alibi_num_heads, 140 | rotary_xpos=rotary_xpos, 141 | attn_flash=attn_flash, 142 | qk_norm=qk_norm, 143 | ) 144 | ) 145 | 146 | def forward(self, img, text): 147 | try: 148 | encoded = self.encoder(img, return_embeddings=True) 149 | return self.decoder(text, context=encoded) 150 | except Exception as error: 151 | print(f"Failed in forward method: {error}") 152 | raise 153 | 154 | 155 | -------------------------------------------------------------------------------- /gpt4/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from functools import partial, wraps 4 | from inspect import isfunction 5 | 6 | # constants 7 | from math import ceil 8 | from random import random 9 | from typing import Callable, List, Optional 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from einops import pack, rearrange, reduce, repeat, unpack 14 | from torch import Tensor, einsum, nn 15 | 16 | from gpt4.attend import Attend, Intermediates 17 | 18 | 19 | def exists(val): 20 | return val is not None 21 | 22 | def eval_decorator(fn): 23 | def inner(self, *args, **kwargs): 24 | was_training = self.training 25 | self.eval() 26 | out = fn(self, *args, **kwargs) 27 | self.train(was_training) 28 | return out 29 | return inner 30 | 31 | # nucleus 32 | 33 | def top_p(logits, thres = 0.9): 34 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 35 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 36 | 37 | sorted_indices_to_remove = cum_probs > (1 - thres) 38 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 39 | sorted_indices_to_remove[:, 0] = 0 40 | 41 | sorted_logits[sorted_indices_to_remove] = float('-inf') 42 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 43 | 44 | # topk 45 | 46 | def top_k(logits, thres = 0.9): 47 | k = ceil((1 - thres) * logits.shape[-1]) 48 | val, ind = torch.topk(logits, k) 49 | probs = torch.full_like(logits, float('-inf')) 50 | probs.scatter_(1, ind, val) 51 | return probs 52 | 53 | # top_a 54 | 55 | def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): 56 | probs = F.softmax(logits, dim=-1) 57 | limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio 58 | logits[probs < limit] = float('-inf') 59 | logits[probs >= limit] = 1 60 | return logits 61 | 62 | # autoregressive wrapper class 63 | 64 | class AutoregressiveWrapper(nn.Module): 65 | def __init__( 66 | self, 67 | net, 68 | ignore_index = -100, 69 | pad_value = 0, 70 | mask_prob = 0. 71 | ): 72 | super().__init__() 73 | self.pad_value = pad_value 74 | self.ignore_index = ignore_index 75 | 76 | self.net = net 77 | self.max_seq_len = net.max_seq_len 78 | 79 | # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432 80 | assert mask_prob < 1. 81 | self.mask_prob = mask_prob 82 | 83 | @torch.no_grad() 84 | @eval_decorator 85 | def generate( 86 | self, 87 | start_tokens, 88 | seq_len, 89 | eos_token = None, 90 | temperature = 1., 91 | filter_logits_fn = top_k, 92 | filter_thres = 0.9, 93 | min_p_pow = 2.0, 94 | min_p_ratio = 0.02, 95 | **kwargs 96 | ): 97 | 98 | start_tokens, ps = pack([start_tokens], '* n') 99 | 100 | b, t = start_tokens.shape 101 | 102 | out = start_tokens 103 | 104 | for _ in range(seq_len): 105 | x = out[:, -self.max_seq_len:] 106 | 107 | logits = self.net(x, **kwargs)[:, -1] 108 | 109 | if filter_logits_fn in {top_k, top_p}: 110 | filtered_logits = filter_logits_fn(logits, thres = filter_thres) 111 | probs = F.softmax(filtered_logits / temperature, dim=-1) 112 | 113 | elif filter_logits_fn is top_a: 114 | filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio) 115 | probs = F.softmax(filtered_logits / temperature, dim=-1) 116 | 117 | sample = torch.multinomial(probs, 1) 118 | 119 | out = torch.cat((out, sample), dim=-1) 120 | 121 | if exists(eos_token): 122 | is_eos_tokens = (out == eos_token) 123 | 124 | if is_eos_tokens.any(dim = -1).all(): 125 | # mask out everything after the eos tokens 126 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 127 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 128 | out = out.masked_fill(mask, self.pad_value) 129 | break 130 | 131 | out = out[:, t:] 132 | 133 | out, = unpack(out, ps, '* n') 134 | 135 | return out 136 | 137 | def forward(self, x, return_loss=True, **kwargs): 138 | seq, ignore_index = x.shape[1], self.ignore_index 139 | 140 | inp, target = x[:, :-1], x[:, 1:] 141 | 142 | if self.mask_prob > 0.: 143 | rand = torch.randn(inp.shape, device = x.device) 144 | rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out 145 | num_mask = min(int(seq * self.mask_prob), seq - 1) 146 | indices = rand.topk(num_mask, dim = -1).indices 147 | mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() 148 | kwargs.update(self_attn_context_mask = mask) 149 | 150 | logits = self.net(inp, **kwargs) 151 | 152 | loss = F.cross_entropy( 153 | rearrange(logits, 'b n c -> b c n'), 154 | target, 155 | ignore_index = ignore_index 156 | ) 157 | 158 | if return_loss: 159 | return logits, loss 160 | 161 | return logits 162 | 163 | 164 | 165 | DEFAULT_DIM_HEAD = 64 166 | 167 | @dataclass 168 | class LayerIntermediates: 169 | hiddens: Optional[List[Tensor]] = None 170 | attn_intermediates: Optional[List[Intermediates]] = None 171 | layer_hiddens: Optional[List[Tensor]] = None 172 | attn_z_loss: Optional[Tensor] = None 173 | 174 | # helpers 175 | 176 | def exists(val): 177 | return val is not None 178 | 179 | def default(val, d): 180 | if exists(val): 181 | return val 182 | return d() if isfunction(d) else d 183 | 184 | def cast_tuple(val, depth): 185 | return val if isinstance(val, tuple) else (val,) * depth 186 | 187 | def maybe(fn): 188 | @wraps(fn) 189 | def inner(x, *args, **kwargs): 190 | if not exists(x): 191 | return x 192 | return fn(x, *args, **kwargs) 193 | return inner 194 | 195 | class always(): 196 | def __init__(self, val): 197 | self.val = val 198 | def __call__(self, *args, **kwargs): 199 | return self.val 200 | 201 | class not_equals(): 202 | def __init__(self, val): 203 | self.val = val 204 | def __call__(self, x, *args, **kwargs): 205 | return x != self.val 206 | 207 | class equals(): 208 | def __init__(self, val): 209 | self.val = val 210 | def __call__(self, x, *args, **kwargs): 211 | return x == self.val 212 | 213 | def Sequential(*modules): 214 | return nn.Sequential(*filter(exists, modules)) 215 | 216 | # tensor helpers 217 | 218 | def max_neg_value(tensor): 219 | return -torch.finfo(tensor.dtype).max 220 | 221 | def l2norm(t, groups = 1): 222 | t = rearrange(t, '... (g d) -> ... g d', g = groups) 223 | t = F.normalize(t, p = 2, dim = -1) 224 | return rearrange(t, '... g d -> ... (g d)') 225 | 226 | def pad_at_dim(t, pad, dim = -1, value = 0.): 227 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) 228 | zeros = ((0, 0) * dims_from_right) 229 | return F.pad(t, (*zeros, *pad), value = value) 230 | 231 | def or_reduce(masks): 232 | head, *body = masks 233 | for rest in body: 234 | head = head | rest 235 | return head 236 | 237 | # auxiliary loss helpers 238 | 239 | def calc_z_loss( 240 | pre_softmax_attns: List[Tensor], 241 | mask = None, 242 | weight = 1. 243 | ): 244 | # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906 245 | # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects 246 | # also used in PaLM as one of the measures 247 | 248 | lse = 0. 249 | 250 | for attn in pre_softmax_attns: 251 | lse = lse + attn.logsumexp(dim = -1) 252 | 253 | loss = torch.square(lse) 254 | loss = reduce(loss, 'b h n -> b n', 'sum') 255 | 256 | if not exists(mask): 257 | return loss.mean() * weight 258 | 259 | loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5) 260 | return loss * weight 261 | 262 | # init helpers 263 | 264 | def init_zero_(layer): 265 | nn.init.constant_(layer.weight, 0.) 266 | if exists(layer.bias): 267 | nn.init.constant_(layer.bias, 0.) 268 | 269 | # keyword argument helpers 270 | 271 | def pick_and_pop(keys, d): 272 | values = list(map(lambda key: d.pop(key), keys)) 273 | return dict(zip(keys, values)) 274 | 275 | def group_dict_by_key(cond, d): 276 | return_val = [dict(),dict()] 277 | for key in d.keys(): 278 | match = bool(cond(key)) 279 | ind = int(not match) 280 | return_val[ind][key] = d[key] 281 | return (*return_val,) 282 | 283 | def string_begins_with(prefix, str): 284 | return str.startswith(prefix) 285 | 286 | def group_by_key_prefix(prefix, d): 287 | return group_dict_by_key(partial(string_begins_with, prefix), d) 288 | 289 | def groupby_prefix_and_trim(prefix, d): 290 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 291 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 292 | return kwargs_without_prefix, kwargs 293 | 294 | # initializations 295 | 296 | def deepnorm_init( 297 | transformer, 298 | beta, 299 | module_name_match_list = ['.ff.', '.to_v', '.to_out'] 300 | ): 301 | for name, module in transformer.named_modules(): 302 | if type(module) != nn.Linear: 303 | continue 304 | 305 | needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list)) 306 | gain = beta if needs_beta_gain else 1 307 | nn.init.xavier_normal_(module.weight.data, gain = gain) 308 | 309 | if exists(module.bias): 310 | nn.init.constant_(module.bias.data, 0) 311 | 312 | # structured dropout, more effective than traditional attention dropouts 313 | 314 | def dropout_seq(seq, mask, dropout): 315 | b, n, *_, device = *seq.shape, seq.device 316 | logits = torch.randn(b, n, device = device) 317 | 318 | if exists(mask): 319 | mask_value = max_neg_value(logits) 320 | logits = logits.masked_fill(~mask, mask_value) 321 | 322 | keep_prob = 1. - dropout 323 | num_keep = max(1, int(keep_prob * n)) 324 | keep_indices = logits.topk(num_keep, dim = 1).indices 325 | 326 | batch_indices = torch.arange(b, device = device) 327 | batch_indices = rearrange(batch_indices, 'b -> b 1') 328 | 329 | seq = seq[batch_indices, keep_indices] 330 | 331 | if exists(mask): 332 | seq_counts = mask.sum(dim = -1) 333 | seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() 334 | keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1') 335 | 336 | mask = mask[batch_indices, keep_indices] & keep_mask 337 | 338 | return seq, mask 339 | 340 | # activations 341 | 342 | class ReluSquared(nn.Module): 343 | def forward(self, x): 344 | return F.relu(x) ** 2 345 | 346 | # embedding 347 | 348 | class TokenEmbedding(nn.Module): 349 | def __init__(self, dim, num_tokens, l2norm_embed = False): 350 | super().__init__() 351 | self.l2norm_embed = l2norm_embed 352 | self.emb = nn.Embedding(num_tokens, dim) 353 | 354 | def forward(self, x): 355 | token_emb = self.emb(x) 356 | return l2norm(token_emb) if self.l2norm_embed else token_emb 357 | 358 | # positional embeddings 359 | 360 | class AbsolutePositionalEmbedding(nn.Module): 361 | def __init__(self, dim, max_seq_len, l2norm_embed = False): 362 | super().__init__() 363 | self.scale = dim ** -0.5 if not l2norm_embed else 1. 364 | self.max_seq_len = max_seq_len 365 | self.l2norm_embed = l2norm_embed 366 | self.emb = nn.Embedding(max_seq_len, dim) 367 | 368 | def forward(self, x, pos = None): 369 | seq_len, device = x.shape[1], x.device 370 | assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' 371 | 372 | if not exists(pos): 373 | pos = torch.arange(seq_len, device = device) 374 | 375 | pos_emb = self.emb(pos) 376 | pos_emb = pos_emb * self.scale 377 | return l2norm(pos_emb) if self.l2norm_embed else pos_emb 378 | 379 | class ScaledSinusoidalEmbedding(nn.Module): 380 | def __init__(self, dim, theta = 10000): 381 | super().__init__() 382 | assert (dim % 2) == 0 383 | self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) 384 | 385 | half_dim = dim // 2 386 | freq_seq = torch.arange(half_dim).float() / half_dim 387 | inv_freq = theta ** -freq_seq 388 | self.register_buffer('inv_freq', inv_freq, persistent = False) 389 | 390 | def forward(self, x, pos = None): 391 | seq_len, device = x.shape[1], x.device 392 | 393 | if not exists(pos): 394 | pos = torch.arange(seq_len, device = device) 395 | 396 | emb = einsum('i, j -> i j', pos, self.inv_freq) 397 | emb = torch.cat((emb.sin(), emb.cos()), dim = -1) 398 | return emb * self.scale 399 | 400 | class RelativePositionBias(nn.Module): 401 | def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8): 402 | super().__init__() 403 | self.scale = scale 404 | self.causal = causal 405 | self.num_buckets = num_buckets 406 | self.max_distance = max_distance 407 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 408 | 409 | @staticmethod 410 | def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): 411 | ret = 0 412 | n = -relative_position 413 | if not causal: 414 | num_buckets //= 2 415 | ret += (n < 0).long() * num_buckets 416 | n = torch.abs(n) 417 | else: 418 | n = torch.max(n, torch.zeros_like(n)) 419 | 420 | max_exact = num_buckets // 2 421 | is_small = n < max_exact 422 | 423 | val_if_large = max_exact + ( 424 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 425 | ).long() 426 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 427 | 428 | ret += torch.where(is_small, n, val_if_large) 429 | return ret 430 | 431 | @property 432 | def device(self): 433 | return next(self.parameters()).device 434 | 435 | def forward(self, i, j): 436 | device = self.device 437 | q_pos = torch.arange(j - i, j, dtype = torch.long, device = device) 438 | k_pos = torch.arange(j, dtype = torch.long, device = device) 439 | rel_pos = k_pos[None, :] - q_pos[:, None] 440 | rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) 441 | values = self.relative_attention_bias(rp_bucket) 442 | bias = rearrange(values, 'i j h -> h i j') 443 | return bias * self.scale 444 | 445 | class DynamicPositionBias(nn.Module): 446 | def __init__(self, dim, *, heads, depth, log_distance = False, norm = False): 447 | super().__init__() 448 | assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1' 449 | self.log_distance = log_distance 450 | 451 | self.mlp = nn.ModuleList([]) 452 | 453 | self.mlp.append(Sequential( 454 | nn.Linear(1, dim), 455 | nn.LayerNorm(dim) if norm else None, 456 | nn.SiLU() 457 | )) 458 | 459 | for _ in range(depth - 1): 460 | self.mlp.append(Sequential( 461 | nn.Linear(dim, dim), 462 | nn.LayerNorm(dim) if norm else None, 463 | nn.SiLU() 464 | )) 465 | 466 | self.mlp.append(nn.Linear(dim, heads)) 467 | 468 | @property 469 | def device(self): 470 | return next(self.parameters()).device 471 | 472 | def forward(self, i, j): 473 | assert i == j 474 | n, device = j, self.device 475 | 476 | # get the (n x n) matrix of distances 477 | seq_arange = torch.arange(n, device = device) 478 | context_arange = torch.arange(n, device = device) 479 | indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j') 480 | indices += (n - 1) 481 | 482 | # input to continuous positions MLP 483 | pos = torch.arange(-n + 1, n, device = device).float() 484 | pos = rearrange(pos, '... -> ... 1') 485 | 486 | if self.log_distance: 487 | pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) 488 | 489 | for layer in self.mlp: 490 | pos = layer(pos) 491 | 492 | # get position biases 493 | bias = pos[indices] 494 | bias = rearrange(bias, 'i j h -> h i j') 495 | return bias 496 | 497 | class AlibiPositionalBias(nn.Module): 498 | def __init__(self, heads, total_heads, **kwargs): 499 | super().__init__() 500 | self.heads = heads 501 | self.total_heads = total_heads 502 | 503 | slopes = Tensor(self._get_slopes(heads)) 504 | slopes = rearrange(slopes, 'h -> h 1 1') 505 | self.register_buffer('slopes', slopes, persistent = False) 506 | self.register_buffer('bias', None, persistent = False) 507 | 508 | def get_bias(self, i, j, device): 509 | i_arange = torch.arange(j - i, j, device = device) 510 | j_arange = torch.arange(j, device = device) 511 | bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) 512 | return bias 513 | 514 | @staticmethod 515 | def _get_slopes(heads): 516 | def get_slopes_power_of_2(n): 517 | start = (2**(-2**-(math.log2(n)-3))) 518 | ratio = start 519 | return [start*ratio**i for i in range(n)] 520 | 521 | if math.log2(heads).is_integer(): 522 | return get_slopes_power_of_2(heads) 523 | 524 | closest_power_of_2 = 2 ** math.floor(math.log2(heads)) 525 | return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] 526 | 527 | @property 528 | def device(self): 529 | return next(self.buffers()).device 530 | 531 | def forward(self, i, j): 532 | h, device = self.total_heads, self.device 533 | 534 | if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: 535 | return self.bias[..., :i, :j] 536 | 537 | bias = self.get_bias(i, j, device) 538 | bias = bias * self.slopes 539 | 540 | num_heads_unalibied = h - bias.shape[0] 541 | bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0) 542 | self.register_buffer('bias', bias, persistent = False) 543 | 544 | return self.bias 545 | 546 | class RotaryEmbedding(nn.Module): 547 | def __init__( 548 | self, 549 | dim, 550 | use_xpos = False, 551 | scale_base = 512, 552 | interpolation_factor = 1., 553 | base = 10000, 554 | base_rescale_factor = 1. 555 | ): 556 | super().__init__() 557 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 558 | # has some connection to NTK literature 559 | # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ 560 | base *= base_rescale_factor ** (dim / (dim - 2)) 561 | 562 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 563 | self.register_buffer('inv_freq', inv_freq) 564 | 565 | assert interpolation_factor >= 1. 566 | self.interpolation_factor = interpolation_factor 567 | 568 | if not use_xpos: 569 | self.register_buffer('scale', None) 570 | return 571 | 572 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) 573 | 574 | self.scale_base = scale_base 575 | self.register_buffer('scale', scale) 576 | 577 | def forward(self, seq_len, device): 578 | t = torch.arange(seq_len, device = device).type_as(self.inv_freq) 579 | t = t / self.interpolation_factor 580 | 581 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 582 | freqs = torch.cat((freqs, freqs), dim = -1) 583 | 584 | if not exists(self.scale): 585 | return freqs, 1. 586 | 587 | power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base 588 | scale = self.scale ** rearrange(power, 'n -> n 1') 589 | scale = torch.cat((scale, scale), dim = -1) 590 | 591 | return freqs, scale 592 | 593 | 594 | def rotate_half(x): 595 | x = rearrange(x, '... (j d) -> ... j d', j = 2) 596 | x1, x2 = x.unbind(dim = -2) 597 | return torch.cat((-x2, x1), dim = -1) 598 | 599 | def apply_rotary_pos_emb(t, freqs, scale = 1): 600 | seq_len = t.shape[-2] 601 | freqs = freqs[-seq_len:, :] 602 | return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) 603 | 604 | # norms 605 | 606 | class Scale(nn.Module): 607 | def __init__(self, value, fn): 608 | super().__init__() 609 | self.value = value 610 | self.fn = fn 611 | 612 | def forward(self, x, **kwargs): 613 | out = self.fn(x, **kwargs) 614 | scale_fn = lambda t: t * self.value 615 | 616 | if not isinstance(out, tuple): 617 | return scale_fn(out) 618 | 619 | return (scale_fn(out[0]), *out[1:]) 620 | 621 | class ScaleNorm(nn.Module): 622 | def __init__(self, dim, eps = 1e-5): 623 | super().__init__() 624 | self.eps = eps 625 | self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5)) 626 | 627 | def forward(self, x): 628 | norm = torch.norm(x, dim = -1, keepdim = True) 629 | return x / norm.clamp(min = self.eps) * self.g 630 | 631 | class RMSNorm(nn.Module): 632 | def __init__(self, dim): 633 | super().__init__() 634 | self.scale = dim ** 0.5 635 | self.g = nn.Parameter(torch.ones(dim)) 636 | 637 | def forward(self, x): 638 | return F.normalize(x, dim = -1) * self.scale * self.g 639 | 640 | class SimpleRMSNorm(nn.Module): 641 | def __init__(self, dim): 642 | super().__init__() 643 | self.scale = dim ** 0.5 644 | 645 | def forward(self, x): 646 | return F.normalize(x, dim = -1) * self.scale 647 | 648 | # residual and residual gates 649 | 650 | class Residual(nn.Module): 651 | def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.): 652 | super().__init__() 653 | self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None 654 | self.scale_residual_constant = scale_residual_constant 655 | 656 | def forward(self, x, residual): 657 | if exists(self.residual_scale): 658 | residual = residual * self.residual_scale 659 | 660 | if self.scale_residual_constant != 1: 661 | residual = residual * self.scale_residual_constant 662 | 663 | return x + residual 664 | 665 | class GRUGating(nn.Module): 666 | def __init__(self, dim, scale_residual = False, **kwargs): 667 | super().__init__() 668 | self.gru = nn.GRUCell(dim, dim) 669 | self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None 670 | 671 | def forward(self, x, residual): 672 | if exists(self.residual_scale): 673 | residual = residual * self.residual_scale 674 | 675 | gated_output = self.gru( 676 | rearrange(x, 'b n d -> (b n) d'), 677 | rearrange(residual, 'b n d -> (b n) d') 678 | ) 679 | 680 | return gated_output.reshape_as(x) 681 | 682 | # token shifting 683 | 684 | def shift(t, amount, mask = None): 685 | if amount == 0: 686 | return t 687 | else: 688 | amount = min(amount, t.shape[1]) 689 | 690 | if exists(mask): 691 | t = t.masked_fill(~mask[..., None], 0.) 692 | 693 | return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.) 694 | 695 | class ShiftTokens(nn.Module): 696 | def __init__(self, shifts, fn): 697 | super().__init__() 698 | self.fn = fn 699 | self.shifts = tuple(shifts) 700 | 701 | def forward(self, x, **kwargs): 702 | mask = kwargs.get('mask', None) 703 | shifts = self.shifts 704 | segments = len(shifts) 705 | feats_per_shift = x.shape[-1] // segments 706 | splitted = x.split(feats_per_shift, dim = -1) 707 | segments_to_shift, rest = splitted[:segments], splitted[segments:] 708 | segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) 709 | x = torch.cat((*segments_to_shift, *rest), dim = -1) 710 | return self.fn(x, **kwargs) 711 | 712 | # feedforward 713 | 714 | class GLU(nn.Module): 715 | def __init__( 716 | self, 717 | dim_in, 718 | dim_out, 719 | activation: Callable, 720 | mult_bias = False 721 | ): 722 | super().__init__() 723 | self.act = activation 724 | self.proj = nn.Linear(dim_in, dim_out * 2) 725 | self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1. 726 | 727 | def forward(self, x): 728 | x, gate = self.proj(x).chunk(2, dim = -1) 729 | return x * self.act(gate) * self.mult_bias 730 | 731 | class FeedForward(nn.Module): 732 | def __init__( 733 | self, 734 | dim, 735 | dim_out = None, 736 | mult = 4, 737 | glu = False, 738 | glu_mult_bias = False, 739 | swish = False, 740 | relu_squared = False, 741 | post_act_ln = False, 742 | dropout = 0., 743 | no_bias = False, 744 | zero_init_output = False 745 | ): 746 | super().__init__() 747 | inner_dim = int(dim * mult) 748 | dim_out = default(dim_out, dim) 749 | 750 | if relu_squared: 751 | activation = ReluSquared() 752 | elif swish: 753 | activation = nn.SiLU() 754 | else: 755 | activation = nn.GELU() 756 | 757 | if glu: 758 | project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias) 759 | else: 760 | project_in = nn.Sequential( 761 | nn.Linear(dim, inner_dim, bias = not no_bias), 762 | activation 763 | ) 764 | 765 | self.ff = Sequential( 766 | project_in, 767 | nn.LayerNorm(inner_dim) if post_act_ln else None, 768 | nn.Dropout(dropout), 769 | nn.Linear(inner_dim, dim_out, bias = not no_bias) 770 | ) 771 | 772 | # init last linear layer to 0 773 | if zero_init_output: 774 | init_zero_(self.ff[-1]) 775 | 776 | def forward(self, x): 777 | return self.ff(x) 778 | 779 | # attention. it is all we need 780 | 781 | class Attention(nn.Module): 782 | def __init__( 783 | self, 784 | dim, 785 | dim_head = DEFAULT_DIM_HEAD, 786 | heads = 8, 787 | causal = False, 788 | flash = False, 789 | talking_heads = False, 790 | head_scale = False, 791 | sparse_topk = None, 792 | num_mem_kv = 0, 793 | dropout = 0., 794 | on_attn = False, 795 | gate_values = False, 796 | zero_init_output = False, 797 | max_attend_past = None, 798 | qk_norm = False, 799 | qk_norm_groups = 1, 800 | qk_norm_scale = 10, 801 | qk_norm_dim_scale = False, 802 | one_kv_head = False, 803 | shared_kv = False, 804 | value_dim_head = None, 805 | tensor_product = False, # https://arxiv.org/abs/2208.06061 806 | cascading_heads = False, 807 | add_zero_kv = False, # same as add_zero_attn in pytorch 808 | onnxable = False 809 | ): 810 | super().__init__() 811 | self.scale = dim_head ** -0.5 812 | 813 | self.heads = heads 814 | self.causal = causal 815 | self.max_attend_past = max_attend_past 816 | 817 | value_dim_head = default(value_dim_head, dim_head) 818 | q_dim = k_dim = dim_head * heads 819 | v_dim = out_dim = value_dim_head * heads 820 | 821 | self.one_kv_head = one_kv_head 822 | if one_kv_head: 823 | k_dim = dim_head 824 | v_dim = value_dim_head 825 | out_dim = v_dim * heads 826 | 827 | self.to_q = nn.Linear(dim, q_dim, bias = False) 828 | self.to_k = nn.Linear(dim, k_dim, bias = False) 829 | 830 | # shared key / values, for further memory savings during inference 831 | assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values' 832 | self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None 833 | 834 | # relations projection from tp-attention 835 | self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None 836 | 837 | # add GLU gating for aggregated values, from alphafold2 838 | self.to_v_gate = None 839 | if gate_values: 840 | self.to_v_gate = nn.Linear(dim, out_dim) 841 | nn.init.constant_(self.to_v_gate.weight, 0) 842 | nn.init.constant_(self.to_v_gate.bias, 1) 843 | 844 | # cosine sim attention 845 | self.qk_norm = qk_norm 846 | self.qk_norm_groups = qk_norm_groups 847 | self.qk_norm_scale = qk_norm_scale 848 | 849 | # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442 850 | self.qk_norm_dim_scale = qk_norm_dim_scale 851 | 852 | self.qk_norm_q_scale = self.qk_norm_k_scale = 1 853 | if qk_norm and qk_norm_dim_scale: 854 | self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head)) 855 | self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head)) 856 | 857 | assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups' 858 | assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)' 859 | 860 | # attend class - includes core attention algorithm + talking heads 861 | 862 | self.attend = Attend( 863 | heads = heads, 864 | causal = causal, 865 | talking_heads = talking_heads, 866 | dropout = dropout, 867 | sparse_topk = sparse_topk, 868 | qk_norm = qk_norm, 869 | scale = qk_norm_scale if qk_norm else self.scale, 870 | add_zero_kv = add_zero_kv, 871 | flash = flash, 872 | onnxable = onnxable 873 | ) 874 | 875 | # head scaling 876 | self.head_scale = head_scale 877 | if head_scale: 878 | self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) 879 | 880 | # explicit topk sparse attention 881 | self.sparse_topk = sparse_topk 882 | 883 | # add memory key / values 884 | self.num_mem_kv = num_mem_kv 885 | if num_mem_kv > 0: 886 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 887 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 888 | 889 | # attention on attention 890 | self.attn_on_attn = on_attn 891 | self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False) 892 | 893 | # init output projection 0 894 | if zero_init_output: 895 | init_zero_(self.to_out) 896 | 897 | def forward( 898 | self, 899 | x, 900 | context = None, 901 | mask = None, 902 | context_mask = None, 903 | attn_mask = None, 904 | rel_pos = None, 905 | rotary_pos_emb = None, 906 | prev_attn = None, 907 | mem = None 908 | ): 909 | b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context) 910 | kv_input = default(context, x) 911 | 912 | q_input = x 913 | k_input = kv_input 914 | v_input = kv_input 915 | r_input = x 916 | 917 | if exists(mem): 918 | k_input = torch.cat((mem, k_input), dim = -2) 919 | v_input = torch.cat((mem, v_input), dim = -2) 920 | 921 | q = self.to_q(q_input) 922 | k = self.to_k(k_input) 923 | v = self.to_v(v_input) if exists(self.to_v) else k 924 | r = self.to_r(r_input) if exists(self.to_r) else None 925 | 926 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 927 | 928 | if not self.one_kv_head: 929 | k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r)) 930 | 931 | if self.qk_norm: 932 | qk_l2norm = partial(l2norm, groups = self.qk_norm_groups) 933 | q, k = map(qk_l2norm, (q, k)) 934 | scale = self.qk_norm_scale 935 | 936 | q = q * self.qk_norm_q_scale 937 | k = k * self.qk_norm_k_scale 938 | 939 | if exists(rotary_pos_emb) and not has_context: 940 | freqs, xpos_scale = rotary_pos_emb 941 | l = freqs.shape[-1] 942 | 943 | q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) 944 | (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) 945 | 946 | ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale))) 947 | q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr))) 948 | 949 | input_mask = context_mask if has_context else mask 950 | 951 | if self.num_mem_kv > 0: 952 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v)) 953 | 954 | if self.qk_norm: 955 | mem_k = l2norm(mem_k) 956 | mem_k = mem_k * self.qk_norm_k_scale 957 | 958 | k = torch.cat((mem_k, k), dim = -2) 959 | v = torch.cat((mem_v, v), dim = -2) 960 | 961 | if exists(input_mask): 962 | input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True) 963 | 964 | i, j = map(lambda t: t.shape[-2], (q, k)) 965 | 966 | # determine masking 967 | 968 | mask_value = max_neg_value(q) 969 | masks = [] 970 | final_attn_mask = None 971 | 972 | if exists(input_mask): 973 | input_mask = rearrange(input_mask, 'b j -> b 1 1 j') 974 | masks.append(~input_mask) 975 | 976 | if exists(attn_mask): 977 | assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' 978 | if attn_mask.ndim == 2: 979 | attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j') 980 | elif attn_mask.ndim == 3: 981 | attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j') 982 | masks.append(~attn_mask) 983 | 984 | if exists(self.max_attend_past): 985 | range_q = torch.arange(j - i, j, device = device) 986 | range_k = torch.arange(j, device = device) 987 | dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j') 988 | max_attend_past_mask = dist > self.max_attend_past 989 | masks.append(max_attend_past_mask) 990 | 991 | if len(masks) > 0: 992 | final_attn_mask = ~or_reduce(masks) 993 | 994 | # prepare relative positional bias, if needed 995 | 996 | attn_bias = None 997 | if exists(rel_pos): 998 | attn_bias = rel_pos(i, j) 999 | 1000 | # attention is all we need 1001 | 1002 | out, intermediates = self.attend( 1003 | q, k, v, 1004 | mask = final_attn_mask, 1005 | attn_bias = attn_bias, 1006 | prev_attn = prev_attn 1007 | ) 1008 | 1009 | # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients 1010 | 1011 | if exists(r): 1012 | out = out * r + out 1013 | 1014 | # normformer scaling of heads 1015 | 1016 | if head_scale: 1017 | out = out * self.head_scale_params 1018 | 1019 | # merge heads 1020 | 1021 | out = rearrange(out, 'b h n d -> b n (h d)') 1022 | 1023 | # alphafold2 styled gating of the values 1024 | 1025 | if exists(self.to_v_gate): 1026 | gates = self.to_v_gate(x) 1027 | out = out * gates.sigmoid() 1028 | 1029 | # combine the heads 1030 | 1031 | out = self.to_out(out) 1032 | 1033 | if exists(mask): 1034 | mask = rearrange(mask, 'b n -> b n 1') 1035 | out = out.masked_fill(~mask, 0.) 1036 | 1037 | return out, intermediates 1038 | 1039 | class AttentionLayers(nn.Module): 1040 | def __init__( 1041 | self, 1042 | dim, 1043 | depth, 1044 | heads = 8, 1045 | causal = False, 1046 | cross_attend = False, 1047 | only_cross = False, 1048 | use_scalenorm = False, 1049 | use_rmsnorm = False, 1050 | use_simple_rmsnorm = False, 1051 | alibi_pos_bias = False, 1052 | alibi_num_heads = None, 1053 | rel_pos_bias = False, 1054 | rel_pos_num_buckets = 32, 1055 | rel_pos_max_distance = 128, 1056 | dynamic_pos_bias = False, 1057 | dynamic_pos_bias_log_distance = False, 1058 | dynamic_pos_bias_mlp_depth = 2, 1059 | dynamic_pos_bias_norm = False, 1060 | rotary_pos_emb = False, 1061 | rotary_emb_dim = None, 1062 | rotary_xpos = False, 1063 | rotary_interpolation_factor = 1., 1064 | rotary_xpos_scale_base = 512, 1065 | rotary_base_rescale_factor = 1., 1066 | custom_layers = None, 1067 | sandwich_coef = None, 1068 | par_ratio = None, 1069 | residual_attn = False, 1070 | cross_residual_attn = False, 1071 | macaron = False, 1072 | pre_norm = True, 1073 | pre_norm_has_final_norm = True, 1074 | gate_residual = False, 1075 | scale_residual = False, 1076 | scale_residual_constant = 1., 1077 | deepnorm = False, 1078 | shift_tokens = 0, 1079 | sandwich_norm = False, 1080 | resi_dual = False, 1081 | resi_dual_scale = 1., 1082 | zero_init_branch_output = False, 1083 | layer_dropout = 0., 1084 | cross_attn_tokens_dropout = 0., 1085 | **kwargs 1086 | ): 1087 | super().__init__() 1088 | rotary_pos_emb = rotary_pos_emb or rotary_xpos 1089 | 1090 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 1091 | attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs) 1092 | 1093 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 1094 | 1095 | self.dim = dim 1096 | self.depth = depth 1097 | self.layers = nn.ModuleList([]) 1098 | 1099 | self.has_pos_emb = rel_pos_bias or rotary_pos_emb 1100 | 1101 | rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) 1102 | 1103 | assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention' 1104 | self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None 1105 | 1106 | assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' 1107 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 1108 | 1109 | # relative positional bias 1110 | 1111 | flash_attn = attn_kwargs.get('flash', False) 1112 | assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias' 1113 | 1114 | self.rel_pos = None 1115 | if rel_pos_bias: 1116 | assert not flash_attn, 'flash attention not compatible with t5 relative positional bias' 1117 | self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) 1118 | elif dynamic_pos_bias: 1119 | assert not flash_attn, 'flash attention not compatible with dynamic positional bias' 1120 | self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm) 1121 | elif alibi_pos_bias: 1122 | alibi_num_heads = default(alibi_num_heads, heads) 1123 | assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' 1124 | self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads) 1125 | 1126 | # determine deepnorm and residual scale 1127 | 1128 | if deepnorm: 1129 | assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings' 1130 | pre_norm = sandwich_norm = resi_dual = False 1131 | scale_residual = True 1132 | scale_residual_constant = (2 * depth) ** 0.25 1133 | 1134 | assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both' 1135 | assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' 1136 | 1137 | if resi_dual: 1138 | pre_norm = False 1139 | 1140 | self.pre_norm = pre_norm 1141 | self.sandwich_norm = sandwich_norm 1142 | 1143 | self.resi_dual = resi_dual 1144 | assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.' 1145 | self.resi_dual_scale = resi_dual_scale 1146 | 1147 | self.residual_attn = residual_attn 1148 | self.cross_residual_attn = cross_residual_attn 1149 | assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention' 1150 | 1151 | self.cross_attend = cross_attend 1152 | 1153 | assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm' 1154 | 1155 | if use_scalenorm: 1156 | norm_class = ScaleNorm 1157 | elif use_rmsnorm: 1158 | norm_class = RMSNorm 1159 | elif use_simple_rmsnorm: 1160 | norm_class = SimpleRMSNorm 1161 | else: 1162 | norm_class = nn.LayerNorm 1163 | 1164 | norm_fn = partial(norm_class, dim) 1165 | 1166 | if cross_attend and not only_cross: 1167 | default_block = ('a', 'c', 'f') 1168 | elif cross_attend and only_cross: 1169 | default_block = ('c', 'f') 1170 | else: 1171 | default_block = ('a', 'f') 1172 | 1173 | if macaron: 1174 | default_block = ('f',) + default_block 1175 | 1176 | # zero init 1177 | 1178 | if zero_init_branch_output: 1179 | attn_kwargs = {**attn_kwargs, 'zero_init_output': True} 1180 | ff_kwargs = {**ff_kwargs, 'zero_init_output': True} 1181 | 1182 | # calculate layer block order 1183 | 1184 | if exists(custom_layers): 1185 | layer_types = custom_layers 1186 | elif exists(par_ratio): 1187 | par_depth = depth * len(default_block) 1188 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 1189 | default_block = tuple(filter(not_equals('f'), default_block)) 1190 | par_attn = par_depth // par_ratio 1191 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 1192 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 1193 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 1194 | par_block = default_block + ('f',) * (par_width - len(default_block)) 1195 | par_head = par_block * par_attn 1196 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 1197 | elif exists(sandwich_coef): 1198 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 1199 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 1200 | else: 1201 | layer_types = default_block * depth 1202 | 1203 | self.layer_types = layer_types 1204 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 1205 | 1206 | # stochastic depth 1207 | 1208 | self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types)) 1209 | 1210 | # structured dropout for cross attending 1211 | 1212 | self.cross_attn_tokens_dropout = cross_attn_tokens_dropout 1213 | 1214 | # calculate token shifting 1215 | 1216 | shift_tokens = cast_tuple(shift_tokens, len(layer_types)) 1217 | 1218 | # whether it has post norm 1219 | 1220 | self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity() 1221 | 1222 | # iterate and construct layers 1223 | 1224 | for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): 1225 | is_last_layer = ind == (len(self.layer_types) - 1) 1226 | 1227 | if layer_type == 'a': 1228 | layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) 1229 | elif layer_type == 'c': 1230 | layer = Attention(dim, heads = heads, **attn_kwargs) 1231 | elif layer_type == 'f': 1232 | layer = FeedForward(dim, **ff_kwargs) 1233 | layer = layer if not macaron else Scale(0.5, layer) 1234 | else: 1235 | raise Exception(f'invalid layer type {layer_type}') 1236 | 1237 | if layer_shift_tokens > 0: 1238 | shift_range_upper = layer_shift_tokens + 1 1239 | shift_range_lower = -layer_shift_tokens if not causal else 0 1240 | layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) 1241 | 1242 | residual_fn = GRUGating if gate_residual else Residual 1243 | residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant) 1244 | 1245 | pre_branch_norm = norm_fn() if pre_norm else None 1246 | post_branch_norm = norm_fn() if sandwich_norm else None 1247 | post_main_norm = norm_fn() if not pre_norm else None 1248 | 1249 | norms = nn.ModuleList([ 1250 | pre_branch_norm, 1251 | post_branch_norm, 1252 | post_main_norm 1253 | ]) 1254 | 1255 | self.layers.append(nn.ModuleList([ 1256 | norms, 1257 | layer, 1258 | residual 1259 | ])) 1260 | 1261 | if deepnorm: 1262 | init_gain = (8 * depth) ** -0.25 1263 | deepnorm_init(self, init_gain) 1264 | 1265 | def forward( 1266 | self, 1267 | x, 1268 | context = None, 1269 | mask = None, 1270 | context_mask = None, 1271 | attn_mask = None, 1272 | self_attn_context_mask = None, 1273 | mems = None, 1274 | return_hiddens = False 1275 | ): 1276 | assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' 1277 | 1278 | hiddens = [] 1279 | layer_hiddens = [] 1280 | intermediates = [] 1281 | 1282 | prev_attn = None 1283 | prev_cross_attn = None 1284 | 1285 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 1286 | 1287 | rotary_pos_emb = None 1288 | if exists(self.rotary_pos_emb): 1289 | max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) 1290 | rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) 1291 | 1292 | outer_residual = x * self.resi_dual_scale 1293 | 1294 | for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)): 1295 | is_last = ind == (len(self.layers) - 1) 1296 | 1297 | if self.training and layer_dropout > 0. and random() < layer_dropout: 1298 | continue 1299 | 1300 | if layer_type == 'a': 1301 | if return_hiddens: 1302 | hiddens.append(x) 1303 | layer_mem = mems.pop(0) if mems else None 1304 | 1305 | if layer_type == 'c': 1306 | if self.training and self.cross_attn_tokens_dropout > 0.: 1307 | context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout) 1308 | 1309 | inner_residual = x 1310 | 1311 | if return_hiddens: 1312 | layer_hiddens.append(x) 1313 | 1314 | pre_norm, post_branch_norm, post_main_norm = norm 1315 | 1316 | if exists(pre_norm): 1317 | x = pre_norm(x) 1318 | 1319 | if layer_type == 'a': 1320 | out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) 1321 | elif layer_type == 'c': 1322 | out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) 1323 | elif layer_type == 'f': 1324 | out = block(x) 1325 | 1326 | if self.resi_dual: 1327 | outer_residual = outer_residual + out * self.resi_dual_scale 1328 | 1329 | if exists(post_branch_norm): 1330 | out = post_branch_norm(out) 1331 | 1332 | x = residual_fn(out, inner_residual) 1333 | 1334 | if layer_type in ('a', 'c') and return_hiddens: 1335 | intermediates.append(inter) 1336 | 1337 | if layer_type == 'a' and self.residual_attn: 1338 | prev_attn = inter.pre_softmax_attn 1339 | elif layer_type == 'c' and self.cross_residual_attn: 1340 | prev_cross_attn = inter.pre_softmax_attn 1341 | 1342 | if exists(post_main_norm): 1343 | x = post_main_norm(x) 1344 | 1345 | if return_hiddens: 1346 | layer_hiddens.append(x) 1347 | 1348 | if self.resi_dual: 1349 | x = x + self.final_norm(outer_residual) 1350 | else: 1351 | x = self.final_norm(x) 1352 | 1353 | if return_hiddens: 1354 | intermediates = LayerIntermediates( 1355 | hiddens = hiddens, 1356 | attn_intermediates = intermediates, 1357 | layer_hiddens = layer_hiddens 1358 | ) 1359 | 1360 | return x, intermediates 1361 | 1362 | return x 1363 | 1364 | class Encoder(AttentionLayers): 1365 | def __init__(self, **kwargs): 1366 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 1367 | super().__init__(causal = False, **kwargs) 1368 | 1369 | class Decoder(AttentionLayers): 1370 | def __init__(self, **kwargs): 1371 | assert 'causal' not in kwargs, 'cannot set causality on decoder' 1372 | super().__init__(causal = True, **kwargs) 1373 | 1374 | class CrossAttender(AttentionLayers): 1375 | def __init__(self, **kwargs): 1376 | super().__init__(cross_attend = True, only_cross = True, **kwargs) 1377 | 1378 | class ViTransformerWrapper(nn.Module): 1379 | def __init__( 1380 | self, 1381 | *, 1382 | image_size, 1383 | patch_size, 1384 | attn_layers, 1385 | channels = 3, 1386 | num_classes = None, 1387 | post_emb_norm = False, 1388 | emb_dropout = 0. 1389 | ): 1390 | super().__init__() 1391 | assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' 1392 | assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' 1393 | dim = attn_layers.dim 1394 | num_patches = (image_size // patch_size) ** 2 1395 | patch_dim = channels * patch_size ** 2 1396 | 1397 | self.patch_size = patch_size 1398 | 1399 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) 1400 | 1401 | self.patch_to_embedding = nn.Sequential( 1402 | nn.LayerNorm(patch_dim), 1403 | nn.Linear(patch_dim, dim), 1404 | nn.LayerNorm(dim) 1405 | ) 1406 | 1407 | self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity() 1408 | self.dropout = nn.Dropout(emb_dropout) 1409 | 1410 | self.attn_layers = attn_layers 1411 | 1412 | self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity() 1413 | 1414 | def forward( 1415 | self, 1416 | img, 1417 | return_embeddings = False 1418 | ): 1419 | p = self.patch_size 1420 | 1421 | x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) 1422 | x = self.patch_to_embedding(x) 1423 | n = x.shape[1] 1424 | 1425 | x = x + self.pos_embedding[:, :n] 1426 | 1427 | x = self.post_emb_norm(x) 1428 | x = self.dropout(x) 1429 | 1430 | x = self.attn_layers(x) 1431 | 1432 | if not exists(self.mlp_head) or return_embeddings: 1433 | return x 1434 | 1435 | x = x.mean(dim = -2) 1436 | return self.mlp_head(x) 1437 | 1438 | class Transformer(nn.Module): 1439 | def __init__( 1440 | self, 1441 | *, 1442 | num_tokens, 1443 | max_seq_len, 1444 | attn_layers, 1445 | emb_dim = None, 1446 | max_mem_len = 0, 1447 | shift_mem_down = 0, 1448 | emb_dropout = 0., 1449 | post_emb_norm = False, 1450 | num_memory_tokens = None, 1451 | tie_embedding = False, 1452 | logits_dim = None, 1453 | use_abs_pos_emb = True, 1454 | scaled_sinu_pos_emb = False, 1455 | l2norm_embed = False, 1456 | emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1 1457 | attn_z_loss_weight = 1e-4 1458 | ): 1459 | super().__init__() 1460 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 1461 | 1462 | dim = attn_layers.dim 1463 | emb_dim = default(emb_dim, dim) 1464 | self.emb_dim = emb_dim 1465 | self.num_tokens = num_tokens 1466 | 1467 | self.max_seq_len = max_seq_len 1468 | self.max_mem_len = max_mem_len 1469 | self.shift_mem_down = shift_mem_down 1470 | 1471 | self.l2norm_embed = l2norm_embed 1472 | self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed) 1473 | 1474 | if not (use_abs_pos_emb and not attn_layers.has_pos_emb): 1475 | self.pos_emb = always(0) 1476 | elif scaled_sinu_pos_emb: 1477 | self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) 1478 | else: 1479 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed) 1480 | 1481 | self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 1482 | 1483 | self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity() 1484 | self.emb_dropout = nn.Dropout(emb_dropout) 1485 | 1486 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 1487 | self.attn_layers = attn_layers 1488 | 1489 | self.init_() 1490 | 1491 | logits_dim = default(logits_dim, num_tokens) 1492 | self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t() 1493 | 1494 | # memory tokens (like [cls]) from Memory Transformers paper 1495 | num_memory_tokens = default(num_memory_tokens, 0) 1496 | self.num_memory_tokens = num_memory_tokens 1497 | if num_memory_tokens > 0: 1498 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 1499 | 1500 | def init_(self): 1501 | if self.l2norm_embed: 1502 | nn.init.normal_(self.token_emb.emb.weight, std = 1e-5) 1503 | if not isinstance(self.pos_emb, always): 1504 | nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5) 1505 | return 1506 | 1507 | nn.init.kaiming_normal_(self.token_emb.emb.weight) 1508 | 1509 | def forward( 1510 | self, 1511 | x, 1512 | return_embeddings = False, 1513 | return_logits_and_embeddings = False, 1514 | return_intermediates = False, 1515 | mask = None, 1516 | return_mems = False, 1517 | return_attn = False, 1518 | mems = None, 1519 | pos = None, 1520 | prepend_embeds = None, 1521 | sum_embeds = None, 1522 | return_attn_z_loss = False, 1523 | attn_z_loss_weight = 1e-4, 1524 | **kwargs 1525 | ): 1526 | b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient 1527 | return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss 1528 | 1529 | # absolute positional embedding 1530 | 1531 | external_pos_emb = exists(pos) and pos.dtype != torch.long 1532 | pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos 1533 | x = self.token_emb(x) + pos_emb 1534 | 1535 | # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training 1536 | 1537 | if exists(sum_embeds): 1538 | x = x + sum_embeds 1539 | 1540 | # post embedding norm, purportedly leads to greater stabilization 1541 | 1542 | x = self.post_emb_norm(x) 1543 | 1544 | # whether to append embeds, as in PaLI, for image embeddings 1545 | 1546 | if exists(prepend_embeds): 1547 | prepend_seq, prepend_dim = prepend_embeds.shape[1:] 1548 | assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions' 1549 | 1550 | x = torch.cat((prepend_embeds, x), dim = -2) 1551 | 1552 | # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model 1553 | 1554 | if emb_frac_gradient < 1: 1555 | assert emb_frac_gradient > 0 1556 | x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient) 1557 | 1558 | # embedding dropout 1559 | 1560 | x = self.emb_dropout(x) 1561 | 1562 | x = self.project_emb(x) 1563 | 1564 | if num_mem > 0: 1565 | mem = repeat(self.memory_tokens, 'n d -> b n d', b = b) 1566 | x = torch.cat((mem, x), dim = 1) 1567 | 1568 | # auto-handle masking after appending memory tokens 1569 | if exists(mask): 1570 | mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True) 1571 | 1572 | if self.shift_mem_down and exists(mems): 1573 | mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] 1574 | mems = [*mems_r, *mems_l] 1575 | 1576 | if return_hiddens: 1577 | x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) 1578 | else: 1579 | x = self.attn_layers(x, mask = mask, mems = mems, **kwargs) 1580 | 1581 | mem, x = x[:, :num_mem], x[:, num_mem:] 1582 | 1583 | if return_logits_and_embeddings: 1584 | out = (self.to_logits(x), x) 1585 | elif return_embeddings: 1586 | out = x 1587 | else: 1588 | out = self.to_logits(x) 1589 | 1590 | if return_attn_z_loss: 1591 | pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates)) 1592 | intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight) 1593 | return_intermediates = True 1594 | 1595 | if return_intermediates: 1596 | return out, intermediates 1597 | 1598 | if return_mems: 1599 | hiddens = intermediates.hiddens 1600 | new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens 1601 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 1602 | return out, new_mems 1603 | 1604 | if return_attn: 1605 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 1606 | return out, attn_maps 1607 | 1608 | return out -------------------------------------------------------------------------------- /gpt4/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import os 4 | from datetime import timedelta 5 | from functools import partial 6 | from itertools import chain 7 | 8 | import torch 9 | 10 | ########### SETUP CONFIG 11 | import torch.distributed as dist 12 | from accelerate import Accelerator 13 | from accelerate.logging import get_logger 14 | from accelerate.state import AcceleratorState 15 | from accelerate.utils import InitProcessGroupKwargs 16 | from datasets import load_dataset 17 | from lion_pytorch import Lion 18 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 19 | CheckpointImpl, 20 | apply_activation_checkpointing, 21 | checkpoint_wrapper, 22 | ) 23 | from torch.distributed.fsdp import ( 24 | BackwardPrefetch, 25 | FullyShardedDataParallel, 26 | MixedPrecision, 27 | ShardingStrategy, 28 | ) 29 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 30 | from torch.nn import LayerNorm 31 | from torch.optim import AdamW 32 | from torch.utils.data import DataLoader 33 | from tqdm import tqdm 34 | from transformers import ( 35 | AutoTokenizer, 36 | default_data_collator, 37 | get_cosine_schedule_with_warmup, 38 | get_linear_schedule_with_warmup, 39 | set_seed, 40 | ) 41 | 42 | from gpt4.gpt4 import GPT4 43 | from gpt4.model import Transformer 44 | from gpt4.utils.stable_adam import StableAdamWUnfused 45 | 46 | # state = AcceleratorState() 47 | 48 | 49 | logger = get_logger(__name__, log_level="INFO") 50 | 51 | class CFG: 52 | BATCH_SIZE = 1 53 | GRADIENT_ACCUMULATE_EVERY: int = 1 54 | SEED: int = 42 55 | LEARNING_RATE: float = 1e-4 #3e-4 # 1e-4 for lion 56 | WEIGHT_DECAY: float = 0.1 57 | SEQ_LEN: int = 8192 58 | NUM_CPU: int = multiprocessing.cpu_count() 59 | USE_DEEPSPEED: bool = True 60 | USE_FSDP: bool = True 61 | USE_PRETOKENIZED: bool = True 62 | USE_ACTIVATION_CHECKPOINTING: bool = True 63 | RESUME_FROM_CHECKPOINT: str = False 64 | CHECKPOINTING_STEPS: int = 1000 65 | OUTPUT_DIR: str = 'checkpoints/' # Folder 66 | ENTITY_NAME: str = "Andromeda" 67 | LOGGING_STEPS: int = 100 68 | 69 | 70 | # helpers 71 | 72 | 73 | def print_num_params(model, accelerator: Accelerator): 74 | # n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 75 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 76 | accelerator.print(f"Number of parameters in model: {n_params}") 77 | 78 | 79 | # activation checkpointing 80 | 81 | 82 | def activation_checkpointing( 83 | model: torch.nn.Module, 84 | offload_to_cpu: bool = False, 85 | accelerator: Accelerator = None, 86 | ): 87 | """ 88 | Apply activation checkpointing to a model. 89 | 90 | Args: 91 | model (Module): The model to which to apply activation checkpointing. 92 | offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False. 93 | accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None. 94 | """ 95 | if accelerator is not None: 96 | accelerator.print("Using activation checkpointing") 97 | def check_fn(submodule): 98 | return isinstance(submodule, Transformer) 99 | non_reentrant_wrapper = partial( 100 | checkpoint_wrapper, 101 | offload_to_cpu=offload_to_cpu, 102 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 103 | ) 104 | apply_activation_checkpointing( 105 | model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 106 | ) 107 | 108 | 109 | # FSDP 110 | 111 | 112 | def fsdp( 113 | model: torch.nn.Module, 114 | auto_wrap: bool = False, 115 | mp: str = "fp32", 116 | shard_strat: str = "NO_SHARD", 117 | ): 118 | """ 119 | This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding. 120 | 121 | Args: 122 | model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP. 123 | auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False. 124 | mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'. 125 | shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'. 126 | 127 | Raises: 128 | ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'. 129 | ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'. 130 | 131 | Returns: 132 | torch.nn.Module: The input model wrapped with FSDP. 133 | """ 134 | if auto_wrap: 135 | Andromeda_auto_wrap_policy = partial( 136 | transformer_auto_wrap_policy, 137 | transformer_layer_cls={ 138 | Transformer, 139 | }, 140 | ) 141 | else: 142 | Andromeda_auto_wrap_policy = None 143 | 144 | if mp == "bf16": 145 | mp_fsdp = MixedPrecision( 146 | param_dtype=torch.bfloat16, 147 | # Gradient communication precision. 148 | reduce_dtype=torch.bfloat16, 149 | # Buffer precision. 150 | buffer_dtype=torch.bfloat16, 151 | ) 152 | elif mp == "fp16": 153 | mp_fsdp = MixedPrecision( 154 | param_dtype=torch.float16, 155 | # Gradient communication precision. 156 | reduce_dtype=torch.float16, 157 | # Buffer precision. 158 | buffer_dtype=torch.float16, 159 | ) 160 | elif mp == "fp32": 161 | mp_fsdp = MixedPrecision( 162 | param_dtype=torch.float32, 163 | # Gradient communication precision. 164 | reduce_dtype=torch.float32, 165 | # Buffer precision. 166 | buffer_dtype=torch.float32, 167 | ) 168 | else: 169 | raise ValueError( 170 | "Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format( 171 | mp 172 | ) 173 | ) 174 | 175 | if shard_strat == "SHARD_GRAD": 176 | sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP 177 | elif shard_strat == "FULL_SHARD": 178 | sharding_strat_fsdp = ShardingStrategy.FULL_SHARD 179 | elif shard_strat == "NO_SHARD": 180 | sharding_strat_fsdp = ShardingStrategy.NO_SHARD 181 | else: 182 | raise ValueError( 183 | "Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format( 184 | shard_strat 185 | ) 186 | ) 187 | 188 | model = FullyShardedDataParallel( 189 | model, 190 | auto_wrap_policy=Andromeda_auto_wrap_policy, 191 | mixed_precision=mp_fsdp, 192 | backward_prefetch=BackwardPrefetch.BACKWARD_PRE, 193 | sharding_strategy=sharding_strat_fsdp, 194 | forward_prefetch=True, 195 | use_orig_params=True, 196 | ) 197 | 198 | return model 199 | 200 | 201 | # learning rate scheduler 202 | 203 | 204 | def get_lr_scheduler_with_warmup( 205 | optimizer: torch.optim.Optimizer, 206 | scheduler_type: str, 207 | num_warmup_steps: int, 208 | max_train_steps: int, 209 | grad_accumulate_every: int = 1, 210 | accelerator: Accelerator = None, 211 | ): 212 | """ 213 | Get a learning rate scheduler with warmup. 214 | 215 | Args: 216 | optimizer (Optimizer): The optimizer for which to create the learning rate scheduler. 217 | scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine". 218 | num_warmup_steps (int): The number of warmup steps for the learning rate scheduler. 219 | max_train_steps (int): The maximum number of training steps. 220 | grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1. 221 | accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None. 222 | 223 | Returns: 224 | The learning rate scheduler with warmup. 225 | 226 | Raises: 227 | ValueError: If scheduler_type is not "linear" or "cosine". 228 | """ 229 | NUM_WARMUP_STEPS = num_warmup_steps 230 | GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every 231 | if accelerator is not None: 232 | accelerator.print(f"Using {scheduler_type} lr scheduler") 233 | if scheduler_type == "linear": 234 | return get_linear_schedule_with_warmup( 235 | optimizer=optimizer, 236 | num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY, 237 | num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY, 238 | ) 239 | elif scheduler_type == "cosine": 240 | return get_cosine_schedule_with_warmup( 241 | optimizer=optimizer, 242 | num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY, 243 | num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY, 244 | ) 245 | else: 246 | raise ValueError( 247 | "Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format( 248 | scheduler_type 249 | ) 250 | ) 251 | 252 | 253 | # optimizers 254 | 255 | 256 | def decoupled_optimizer( 257 | model: torch.nn.Module, 258 | learning_rate: float, 259 | weight_decay: float, 260 | beta_1: float, 261 | beta_2: float, 262 | optimizer_type: str, 263 | use_fsdp: bool = True, 264 | accelerator: Accelerator = None, 265 | ): 266 | """ 267 | Decouples the optimizer from the training process. 268 | 269 | This function sets up the optimizer for the model by creating two groups of parameters: 270 | one for weight decay and one without weight decay. Then, it initializes the optimizer 271 | with these two groups of parameters. 272 | 273 | Args: 274 | model (Module): The model whose parameters are optimized. 275 | learning_rate (float): The learning rate for the optimizer. 276 | weight_decay (float): The weight decay for the optimizer. 277 | beta_1 (float): The exponential decay rate for the 1st moment estimates. 278 | beta_2 (float): The exponential decay rate for the 2nd moment estimates. 279 | optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'. 280 | use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True. 281 | accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None. 282 | 283 | Returns: 284 | Optimizer: The initialized optimizer. 285 | 286 | Raises: 287 | ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'. 288 | """ 289 | accelerator.print(f"Using {optimizer_type} optimizer") 290 | # Create an empty dictionary called param_dict to store the model's named parameters. 291 | param_dict = {} 292 | # Iterate over the model's named parameters and populate the param_dict with key-value pairs. 293 | for param_name, param in model.named_parameters(): 294 | param_dict[param_name] = param 295 | 296 | # Separate the model's named modules into two groups: decay and no_decay. 297 | 298 | # Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay. 299 | no_decay = [] 300 | 301 | if use_fsdp: 302 | exclude_module = "_fsdp_wrapped_module.token_emb" 303 | else: 304 | exclude_module = "token_emb" 305 | 306 | # Iterate through the named modules of the model. 307 | for module_name, module in model.named_modules(): 308 | # Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding). 309 | for ndim in [LayerNorm, torch.nn.Embedding]: 310 | if isinstance(module, ndim): 311 | # If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list. 312 | if module_name == exclude_module: 313 | no_decay.append(f"{module_name}.weight") 314 | else: 315 | # If the module is an instance of LayerNorm 316 | no_decay.append(f"{module_name}.gamma") 317 | # Exit the inner loop since the desired module has been found. 318 | break 319 | 320 | # Create an empty list to store the names of the Linear layer weights with weight decay. 321 | decay = [] 322 | 323 | # Iterate through the named modules of the model. 324 | for module_name, module in model.named_modules(): 325 | # Check if the current module is an instance of the desired type (torch.nn.Linear). 326 | for ndim in [torch.nn.Linear]: 327 | if isinstance(module, ndim): 328 | # If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list. 329 | decay.append(f"{module_name}.weight") 330 | # Exit the inner loop since the desired module has been found. 331 | break 332 | 333 | # Create two separate lists of model parameters: decay_param and no_decay_param. 334 | # The decay_param list contains the parameters that should have weight decay applied. 335 | # The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter. 336 | 337 | # Create an empty list called decay_param to store the parameters with weight decay. 338 | decay_param = [] 339 | 340 | if use_fsdp: 341 | exclude_param = "_fsdp_wrapped_module.to_logits.weight" 342 | else: 343 | exclude_param = "to_logits.weight" 344 | 345 | # Iterate over the decay list, which contains the names of the parameters with weight decay. 346 | for param in decay: 347 | # Check if the current parameter is not 'to_logits.weight'. 348 | # Append the corresponding parameter from param_dict to the decay_param list. 349 | 350 | if param != exclude_param: 351 | decay_param.append(param_dict[param]) 352 | 353 | # Create an empty list called no_decay_param to store the parameters without weight decay. 354 | no_decay_param = [] 355 | 356 | # Iterate over the no_decay list, which contains the names of the parameters without weight decay. 357 | for param in no_decay: 358 | try: 359 | 360 | # Append the corresponding parameter from param_dict to the no_decay_param list. 361 | no_decay_param.append(param_dict[param]) 362 | except KeyError: 363 | # print(f"Parameter {param_name} does not exist in the model") 364 | pass 365 | 366 | # Create a list called grouped_params that contains two dictionaries. 367 | # The first dictionary has the decay_param list and the corresponding weight_decay value. 368 | # The second dictionary has the no_decay_param list and a weight_decay value of 0.0. 369 | grouped_params = [ 370 | {"params": decay_param, "weight_decay": weight_decay}, 371 | {"params": no_decay_param, "weight_decay": 0.0}, 372 | ] 373 | 374 | # Create a variable called optimizer that stores an instance of the optimizer. 375 | if optimizer_type == "lion": 376 | optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) 377 | elif optimizer_type == "adamw": 378 | optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) 379 | elif optimizer_type == "stable_adamw": 380 | optimizer = StableAdamWUnfused( 381 | grouped_params, lr=learning_rate, betas=(beta_1, beta_2), 382 | ) 383 | else: 384 | raise ValueError( 385 | "Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format( 386 | optimizer_type 387 | ) 388 | ) 389 | 390 | # Return the optimizer. 391 | return optimizer 392 | 393 | 394 | # dataloaders 395 | 396 | 397 | def build_dataloaders(): 398 | """ 399 | Build data loaders for training. 400 | 401 | This function performs the following steps: 402 | 1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model. 403 | 2. Load the "openwebtext" dataset. 404 | 3. Tokenize the dataset, adding the end-of-sentence token to each text. 405 | 4. Process the tokenized dataset into chunks of a specified block size. 406 | 407 | Returns: 408 | Dataset: The processed dataset ready for training. 409 | """ 410 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 411 | dataset = load_dataset("openwebtext", split="train") 412 | 413 | tokenized_dataset = dataset.map( 414 | lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]), 415 | batched=True, 416 | num_proc=CFG.NUM_CPU, 417 | remove_columns=["text"], 418 | ) 419 | 420 | block_size = CFG.SEQ_LEN 421 | 422 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 423 | def group_texts(examples): 424 | # Concatenate all texts. 425 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 426 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 427 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 428 | # customize this part to your needs. 429 | if total_length >= block_size: 430 | total_length = (total_length // block_size) * block_size 431 | # Split by chunks of max_len. 432 | result = { 433 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 434 | for k, t in concatenated_examples.items() 435 | } 436 | return result 437 | 438 | train_dataset = tokenized_dataset.map( 439 | group_texts, batched=True, num_proc=CFG.NUM_CPU, 440 | ) 441 | 442 | return train_dataset 443 | 444 | #switch to falconwebdataset 445 | def build_pre_tokenized(): 446 | d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]") 447 | # d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train") 448 | # d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train") 449 | # d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train") 450 | # d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train") 451 | # train_dataset = concatenate_datasets([d0, d1, d2, d3, d4]) 452 | return d0 453 | 454 | 455 | 456 | def Train(): 457 | # accelerator 458 | 459 | timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) 460 | 461 | accelerator = Accelerator( 462 | gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY, 463 | mixed_precision="fp16", 464 | log_with="wandb", 465 | kwargs_handlers=[timeout], 466 | ) 467 | 468 | state = AcceleratorState() 469 | 470 | state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE #?????? 471 | 472 | accelerator.init_trackers( 473 | project_name="Andromeda", 474 | config={ 475 | "batch_size": CFG.BATCH_SIZE, 476 | "gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY, 477 | "learning_rate": CFG.LEARNING_RATE, 478 | "seq_len": CFG.SEQ_LEN, 479 | }, 480 | # init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}}, 481 | ) 482 | 483 | accelerator.print(f"Total GPUS: {accelerator.num_processes}") 484 | 485 | # set seed 486 | 487 | set_seed(CFG.SEED) 488 | 489 | # model = Andromeda( 490 | # num_tokens=50432, 491 | # max_seq_len=8192, 492 | # dim=3072, 493 | # depth=24, 494 | # dim_head=128, 495 | # heads=12, 496 | # use_abs_pos_emb=False, 497 | # alibi_pos_bias=True, 498 | # alibi_num_heads=6, 499 | # rotary_xpos=True, 500 | # attn_flash=True, 501 | # shift_tokens=1, 502 | # attn_one_kv_head=True, 503 | # qk_norm=True, 504 | # attn_qk_norm=True, 505 | # attn_qk_norm_dim_scale=True, 506 | # embedding_provider=AndromedaEmbedding() 507 | # ) 508 | model = GPT4() 509 | 510 | print_num_params(model, accelerator) 511 | 512 | if CFG.USE_FSDP: 513 | model = fsdp( 514 | model, 515 | mp="fp16", 516 | shard_strat="SHARD_GRAD" 517 | ) 518 | 519 | if CFG.USE_ACTIVATION_CHECKPOINTING: 520 | activation_checkpointing(model, accelerator) 521 | 522 | model = accelerator.prepare(model) 523 | 524 | # dataloaders 525 | 526 | if CFG.USE_PRETOKENIZED: 527 | train_dataset = build_pre_tokenized() 528 | else: 529 | train_dataset = build_dataloaders() 530 | 531 | train_loader = DataLoader( 532 | train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator, 533 | ) 534 | 535 | 536 | # optimizer 537 | optim = decoupled_optimizer( 538 | model=model, 539 | learning_rate=CFG.LEARNING_RATE, 540 | weight_decay=CFG.WEIGHT_DECAY, 541 | beta_1=0.90, 542 | beta_2=0.95, 543 | optimizer_type='lion', 544 | use_fsdp=True, 545 | accelerator=accelerator 546 | ) 547 | 548 | # Determine number of training steps 549 | 550 | max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) 551 | accelerator.print(f"Max train steps: {max_train_steps}") 552 | 553 | # lr scheduler 554 | 555 | NUM_WARMUP_STEPS = int(max_train_steps * 0.01) 556 | accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}") 557 | 558 | # if False: # if CFG.USE_DEEPSPEED: 559 | # lr_scheduler = DummyScheduler( 560 | # optim, 561 | # total_num_steps=max_train_steps * accelerator.num_processes, 562 | # warmup_num_steps=NUM_WARMUP_STEPS 563 | # ) 564 | # else: 565 | lr_scheduler = get_lr_scheduler_with_warmup( 566 | optimizer=optim, 567 | scheduler_type="cosine", 568 | num_warmup_steps=NUM_WARMUP_STEPS, 569 | max_train_steps=max_train_steps, 570 | grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY, 571 | ) 572 | 573 | # prepare 574 | 575 | optim, train_loader, lr_scheduler = accelerator.prepare( 576 | optim, train_loader, lr_scheduler 577 | ) 578 | 579 | # checkpoint scheduler 580 | 581 | accelerator.register_for_checkpointing(lr_scheduler) 582 | 583 | # I do not know why Huggingface recommends recalculation of max_train_steps 584 | 585 | max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) 586 | accelerator.print(f"Max train steps recalculated: {max_train_steps}") 587 | 588 | # Total batch size for logging 589 | 590 | total_batch_size = ( 591 | CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY 592 | ) 593 | accelerator.print(f"Total batch size: {total_batch_size}") 594 | 595 | # resume training 596 | 597 | progress_bar = tqdm( 598 | range(max_train_steps), disable=not accelerator.is_local_main_process 599 | ) 600 | completed_steps = 0 601 | 602 | if CFG.RESUME_FROM_CHECKPOINT: 603 | if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "": 604 | accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}") 605 | accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT) 606 | path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT) 607 | training_difference = os.path.splitext(path)[0] 608 | 609 | # need to multiply `gradient_accumulation_steps` to reflect real steps 610 | resume_step = ( 611 | int(training_difference.replace("step_", "")) 612 | * CFG.GRADIENT_ACCUMULATE_EVERY 613 | ) 614 | 615 | if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None: 616 | train_loader = accelerator.skip_first_batches(train_loader, resume_step) 617 | completed_steps += resume_step 618 | progress_bar.update(resume_step) 619 | 620 | # training 621 | 622 | model.train() 623 | for step, batch in enumerate(train_loader): 624 | with accelerator.accumulate(model): 625 | inputs = batch["input_ids"].to(accelerator.device) 626 | loss = model(inputs, return_loss=True) 627 | accelerator.backward(loss) 628 | 629 | accelerator.log({"loss": loss.item()}, step=step) 630 | 631 | if accelerator.sync_gradients: 632 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 633 | 634 | optim.step() 635 | lr_scheduler.step() 636 | optim.zero_grad() 637 | 638 | if accelerator.sync_gradients: 639 | progress_bar.update(1) 640 | completed_steps += 1 641 | 642 | if isinstance(CFG.CHECKPOINTING_STEPS, int): 643 | if completed_steps % CFG.CHECKPOINTING_STEPS == 0: 644 | output_dir = f"step_{completed_steps }" 645 | if CFG.OUTPUT_DIR is not None: 646 | output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir) 647 | accelerator.save_state(output_dir) 648 | 649 | if completed_steps >= max_train_steps: 650 | break 651 | 652 | #logging every CFG.LOGGING STEPS 653 | if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0: 654 | logger.info( 655 | f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}" 656 | ) 657 | 658 | # end training 659 | 660 | # accelerator.print(f"Training Finished") 661 | accelerator.end_training() 662 | 663 | # save final model 664 | 665 | # accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}") 666 | if CFG.OUTPUT_DIR is not None: 667 | accelerator.wait_for_everyone() 668 | unwrapped_model = accelerator.unwrap_model(model) 669 | with accelerator.main_process_first(): 670 | accelerator.save( 671 | unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt" 672 | ) 673 | 674 | 675 | def train(): 676 | os.environ['MASTER_ADDR'] #'localhost' 677 | os.environ['MASTER_PORT'] #= '9994' 678 | 679 | # # [CRITICAL] Pay attention to this when scaling to multiple GPUs and clusters 680 | 681 | # # Pay attention to this, use "accelerate config" 682 | 683 | os.environ['RANK'] #= str(0) # Number of nodes (servers) 684 | os.environ['WORLD_SIZE'] # = str(torch.cuda.device_count()) 685 | 686 | dist.init_process_group(backend='nccl') #init_method="env://") 687 | 688 | Train() 689 | 690 | if __name__ == '__main__': 691 | train() -------------------------------------------------------------------------------- /gpt4/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyegomez/GPT4/f79f992189318c99419fd3cd29d9d955f5f67a55/gpt4/utils/__init__.py -------------------------------------------------------------------------------- /gpt4/utils/stable_adam.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | # This is the unfused version of StableAdamW. It is slower than the fused version (coming). 5 | 6 | 7 | class StableAdamWUnfused(torch.optim.Optimizer): 8 | def __init__( 9 | self, 10 | params, 11 | lr=0.002, 12 | weight_decay=0.2, 13 | betas=(0.9, 0.99), 14 | eps=1e-8, 15 | clip_thresh=1.0, 16 | precision="amp_bfloat16", 17 | custom_scalar=65536, 18 | ): 19 | beta1, beta2 = betas[0], betas[1] 20 | defaults = dict(lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2) 21 | super(StableAdamWUnfused, self).__init__(params, defaults) 22 | 23 | self.eps = eps 24 | self.d = clip_thresh 25 | 26 | # Set precision to "custom_fp16" if you want to use a fixed loss scalar, custom_scalar, which is divided out in the update step. 27 | # If you do this, call (custom_scalar * loss).backward() instead of loss.backward(). 28 | self.precision = precision 29 | self.custom_scaler = custom_scalar 30 | 31 | for group in self.param_groups: 32 | group["step"] = 1.0 33 | 34 | print("Using StableAdamWUnfused-v1") 35 | 36 | def __setstate__(self, state): 37 | super(StableAdamWUnfused, self).__setstate__(state) 38 | 39 | def step(self, closure=None): 40 | if closure is not None: 41 | closure() 42 | 43 | for group in self.param_groups: 44 | lr = group["lr"] 45 | weight_decay = group["weight_decay"] 46 | beta1 = group["beta1"] 47 | beta2 = group["beta2"] 48 | step = group["step"] 49 | 50 | for p in group["params"]: 51 | if p.grad is None: 52 | continue 53 | theta = p.data 54 | param_state = self.state[p] 55 | 56 | if self.precision == "custom_fp16": 57 | g = p.grad.data / self.custom_scaler 58 | if torch.any(torch.isnan(g) | torch.isinf(g)): 59 | continue 60 | else: 61 | g = p.grad.data 62 | 63 | if "exp_avg" not in param_state: 64 | v = param_state["exp_avg"] = torch.zeros_like(theta) 65 | u = param_state["exp_avg_sq"] = torch.zeros_like(theta) 66 | else: 67 | v = param_state["exp_avg"] 68 | u = param_state["exp_avg_sq"] 69 | 70 | beta1hat = beta1 * (1 - beta1 ** (step - 1)) / (1 - beta1**step) 71 | beta2hat = beta2 * (1 - beta2 ** (step - 1)) / (1 - beta2**step) 72 | 73 | v = v.mul_(beta1hat).add_(g, alpha=1.0 - beta1hat) 74 | u = u.mul_(beta2hat).addcmul_(g, g, value=1.0 - beta2hat) 75 | 76 | denominator = u.sqrt().add_(self.eps) 77 | 78 | # StableAdamW = AdamW + update clipping (https://arxiv.org/abs/1804.04235) applied tensor-wise. 79 | rms = ( 80 | torch.div( 81 | g.pow(2), torch.maximum(u, (self.eps**2) * torch.ones_like(u)) 82 | ) 83 | .mean() 84 | .sqrt() 85 | .item() 86 | ) 87 | 88 | theta = theta.mul_(1.0 - lr * weight_decay).addcdiv_( 89 | v, denominator, value=-lr * (1.0 / max(1.0, rms / self.d)) 90 | ) 91 | 92 | # save current params 93 | param_state["exp_avg"] = v 94 | param_state["exp_avg_sq"] = u 95 | 96 | group["step"] = step + 1 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.0.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "gpt4-torch" 7 | version = "0.0.3" 8 | description = "GPT4 - Pytorch" 9 | authors = ["Kye Gomez "] 10 | license = "MIT" 11 | readme = "README.md" 12 | homepage = "https://github.com/kyegomez/gpt3" 13 | keywords = ["artificial intelligence", "attention mechanism", "transformers"] 14 | 15 | [tool.poetry.dependencies] 16 | python = "^3.6" 17 | torch = "*" 18 | lion-pytorch = "*" 19 | numpy = "*" 20 | einops = "*" 21 | accelerate = "*" 22 | transformers = "*" 23 | SentencePiece = "*" 24 | datasets = "*" 25 | matplotlib = "*" 26 | deepspeed = "*" 27 | 28 | [tool.poetry.dev-dependencies] 29 | 30 | [[tool.poetry.packages]] 31 | include = "gpt4/*" --------------------------------------------------------------------------------