├── requirements.txt ├── CITATION.cff ├── README.md ├── LICENSE.md └── main.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tiktoken 4 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "Citations would be appreciated if you end up using this tool! I currently go by Fern, no last name given." 3 | authors: 4 | given-names: "Fern" 5 | title: "hlb-gpt" 6 | version: 0.4.0 7 | date-released: 2023-03-05 8 | url: "https://github.com/tysam-code/hlb-gpt" 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Twitter URL](https://img.shields.io/twitter/url/https/twitter.com/hi_tysam.svg?style=social&label=Follow%20%40TySam_And)](https://twitter.com/hi_tysam) [![Support me on Patreon](https://img.shields.io/endpoint.svg?url=https%3A%2F%2Fshieldsio-patreon.vercel.app%2Fapi%3Fusername%3Dtysam%26type%3Dpatrons%26suffix%3Dsponsors&style=flat)](https://patreon.com/tysam) 2 | 3 | ## hlb-GPT 4 | 5 | Welcome to the hyperlightspeedbench-gpt (hlb-gpt) repository! This project is meant to be the best tool for researchers wanting to quickly explore new LLM ideas. It is also intended to be a good starting point for new projects and new ML developers. The code is simple, performant and well-documented with good default hyperparameters. It also optionally scales from 46 M parameters (the default) to up to 3 B parameters on a single A100 just by changing the model_scale parameter -- the rest of the hyperparameters are automatically inferred (though the scaling feature is still in alpha as the large model hyperparameters still need tuning). 6 | 7 | ### How to Run 8 | 9 | 10 | `git clone https://github.com/tysam-code/hlb-gpt && cd hlb-gpt && python -m pip install -r requirements.txt && python main.py` 11 | 12 | 13 | This code was developed exclusively in Colab, but also runs in the terminal as well. If you are running it in Colab, be sure to uncomment the code block at the top. 14 | 15 | ### Main 16 | 17 | This code achieves a ~3.80 validation loss on WikiText-103 within about 100 seconds or so on a single A100 with default settings. By default, it runs for 1000 steps before ending and running a demo inference on the trained network, though you can (and should!) change this value as you begin experimenting. The learning rate schedulers are set to run infinitely, the step count is just a cutoff. As one of the design decisions to keep things simple, this code does assume that you are using a 40 GB A100, though hopefully we will be able to port to more GPU memory sizes as the scaling rules solidify. 18 | 19 | The code is very short -- just over 300 lines or so. It implements a number of novel (or at least, novel-to-the-author) concepts, including a LatentAttention block that efficiently fuses attention and the MLP blocks into one, learnable linear position embeddings to let the attention layers learn a dynamic attention length, a dynamic microbatch scheduler based upon the expected gradient norm, a specific set of parameter group schedules, and several other things of various potential novelty. 20 | 21 | I originally referenced nanoGPT when originally writing this code, though this code has certainly become its own beast at this point! Much appreciation to Karpathy and contributors for that codebase. 22 | 23 | One of the intents of this codebase is to minimize the time-to-result for a given experiment. My experience leads me to believe that this is a good thing to optimize for (my appreciation to Keller Jordan for conversations on this topic a little while back). 24 | 25 | If you have any questions, please let me know. My Twitter DMs should be open, as well as my email. 26 | 27 | ### Contact 28 | 29 | Much of this work is supported by a combination of being both self-funded and being funded from the support of people like you. My Patreon is at [Patreon](https://www.patreon.com/user/posts?u=83632131) if you like what I'm doing here and would like to see more work like this in the future. If you want me to work up to a part-time amount of hours with you via consulting or contract work, please feel free to reach out to me at hire.tysam@gmail.com. I'd love to hear from you. 30 | 31 | ### Citation 32 | 33 | If you use this work in your research, please cite 34 | `@software{hlb-gpt_2024, 35 | author={Fern}, 36 | month={3}, 37 | title={{hlb-gpt}}, 38 | url={https://github.com/tysam-code/hlb-gpt}, 39 | version = {0.4.0}, 40 | year = {2024}}` 41 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Colab users, uncomment the following block to help clear out notebook state when re-running the cell. 2 | """ 3 | # don't forget these too: 4 | # !pip3 install tiktoken 5 | # If you don't have torch 2.0 on whatever environment you're using: 6 | # !pip3 install --upgrade torch 7 | try: 8 | _ = get_ipython().__class__.__name__ 9 | ## we set -f below to avoid prompting the user before clearing the notebook state 10 | %reset -f 11 | except NameError: 12 | pass ## we're still good 13 | """ 14 | import functools 15 | from functools import partial 16 | import subprocess 17 | 18 | import zipfile 19 | import math 20 | import os 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | from torch import nn 25 | 26 | # This seems like one of the best choices right now for a fast/lightweight/simple tokenizer. 27 | import tiktoken 28 | 29 | 30 | ################ 31 | # Introduction # 32 | ################ 33 | 34 | # This code was built from the ground up to support extremely rapid experimentation for solo researchers and small teams. It's meant to 35 | # be hackable nearly anywhere with minimal effort/side effects, which is why you might see more of a flat layout. It's also quite fast. 36 | # 37 | # The codebase is specifically designed for single A100s for now, but may expand with more GPU support in the future, depending. I originally 38 | # used Karpathy's nanoGPT as well as some of my other work as a reference when writing this, though this codebase is very much 39 | # its own thing at this point. 40 | # 41 | # If you found this codebase useful or informative, please consider supporting me directly at https://www.patreon.com/tysam . If you'd like 42 | # to speak about a contract or a consulting opportunity, feel free to reach out at hi [dot] re [dot] tysam [atsymbol] gmail [dot] com. 43 | # I'd love to hear from you! 44 | # 45 | # Now, on with the code! 46 | 47 | 48 | ############################## 49 | # Hyperparameters # 50 | ############################## 51 | 52 | # Note: The automatic rescaling of hyperparameters based on batchsize/etc is currently a work in progress. 53 | # This code assumes 40 GB-limit A100s for the scale-based hyperparameters, you may have to do some tinkering if you have a different setup. 54 | # So far, most of the tested configs have been between ~46 M and 1.5B or so, and have done moderately well. 55 | 56 | # This parameter determines the final size of the model. Roughly, num_model_params ~= model_scale * 49 M (# of params in the base model), but it scales nonlinearly. (#TODO is to make this more straight in the future) 57 | # Model scales other than 1.0 are in alpha currently -- they should run okay, but are almost certainly not tuned efficiently yet! This should hopefully be addressed in a future update. 58 | model_scale = 1.0 # OOM-tested from ~.5ish (28 M) to 148 (~3 B). Sets the model size. One of the most important hyperparameters. Supports noninteger values (2.3, etc) 59 | max_sequence_length = 1024 # Can go up or down. Mostly tested up to 1024, some models can avoid OOMs even with length 8192 (not really tested) 60 | gpu_token_capacity = 114688 # This is an amount that doesn't OOM on A100 at model_scale 1, length 1024. May need to change if you have a different GPU. Note: Hyperparameter tunings are currently based on the 40 GB limit of the A100. 61 | 62 | # Approximates the amount of tokens the GPU can hold based upon the scale of the model (scaled somewhat conservatively to avoid most OOMs. May OOM in some weird edgecases.) 63 | # Batchsize is determined automatically based upon the current sequence length and the rough token-capacity of the GPU for a given model. 64 | tokens_per_batch_capacity = math.floor(gpu_token_capacity / (1.52174 + .482 * model_scale**(.87))) 65 | 66 | # We support fractional model factors, this picks dimensions that the A100 can efficiently use. 67 | to_nearest_64 = lambda x: round(x/64) * 64 68 | 69 | 70 | # The default model here below is roughly ~46M parameters or so. 71 | hyp = { 72 | 'opt': { 73 | 'lr_mult': { 74 | 'base': 2.62, # The base_lr itself is derived from a scaling equation fit to GPT-3 parameters. This multiplier impacts all parameters, including those in the default group 75 | 'position_bias': 100., 76 | 'non_dot_products': 32., 77 | 'output_layer': 2., 78 | }, 79 | 'weight_decay': 2.**4, # This is the weight decay when the loss = 0., we approach it exponentially. Somewhat slows overfitting. 80 | 'total_train_steps': 1000, # We can run effectively infinitely, but is 1000 by default for the inference demo. For infinite runs, you can use the saved checkpoints from disk. 81 | 'microbatch': { # The microbatch scheduler assumes a power law decay schedule for the grad norm, and adjusts the microbatch size (minimum 1) to enforce it. 82 | 'sample_every': 5, # Sampling grad norm can be a bit expensive, so we do it every n steps instead. 83 | 'scale_lr': 1e-1, # Microbatch update rate 84 | }, 85 | 'eval_every': 50, # how many train iterations per eval round (we don't include eval time in our performance stats). Good to set to 10-20 for larger (~800M+ networks) 86 | 'save_every_n_evals': 2, # Good to set this low for larger networks 87 | 'num_eval_tokens': 153600, # Total # tokens total to eval over, divided into max_sequence_length-long sequences 88 | 'warmup_steps': 100, # For training stability in the main body of the network. (#TODO: Investigate the warmup imact a bit more) 89 | }, 90 | 'net': { 91 | 'residual_depth': to_nearest_64(384 * math.log2(1.+model_scale)), 92 | 'qk_dim_div': 8, 93 | 'expand_factor': 2, 94 | 'num_blocks': round(8 * math.log2(1.+model_scale)), 95 | }, 96 | 'misc': { 97 | 'num_tokens': 50304, # Rounded to the nearest value of 64 for efficiency 98 | 'sequence_length': { 99 | 'max': max_sequence_length, 100 | 'initial': 32, # Very short initial sequence length seems to help a lot 101 | 'growth_steps': 80, # We double the sequence length during training every n steps up to the maximum 102 | }, 103 | 'device': 'cuda', 104 | 'dtype': torch.bfloat16, 105 | 'data_location': 'data.pt', 106 | } 107 | } 108 | 109 | 110 | ############################################# 111 | # Dataloader # 112 | ############################################# 113 | 114 | if not os.path.exists(hyp['misc']['data_location']): 115 | print("downloading data and tokenizing (1-2 min)") 116 | 117 | raw_data_source = 'https://wikitext.smerity.com/wikitext-103-raw-v1.zip' 118 | raw_data_cache = './data_raw/' # where to cache the data after downloading 119 | 120 | if not os.path.isfile(raw_data_cache): 121 | os.makedirs(raw_data_cache, exist_ok=True) 122 | 123 | # Needed due to the website 403-blocking python agents for download, it seems? Many thanks to Smerity for re-hosting these after the main files went down. <3 :') 124 | subprocess.run(["wget", raw_data_source, "-O", raw_data_cache+"data.zip"], stdout=subprocess.PIPE) 125 | 126 | with zipfile.ZipFile('data_raw/data.zip', 'r') as zip_ref: 127 | zip_ref.extractall('data_raw/') 128 | 129 | with open('data_raw/wikitext-103-raw/wiki.train.raw') as data_file: 130 | raw_train_data = data_file.read() 131 | 132 | with open('data_raw/wikitext-103-raw/wiki.valid.raw') as data_file: 133 | raw_eval_data = data_file.read() 134 | 135 | 136 | tokenizer = tiktoken.get_encoding("gpt2") 137 | raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data) 138 | raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data) 139 | 140 | train_tokenized = torch.tensor(raw_tokenized_train, device=hyp['misc']['device'], dtype=torch.int) # int64 is likely overkill for the amount of tokens we have... 141 | eval_tokenized = torch.tensor(raw_tokenized_eval, device=hyp['misc']['device'], dtype=torch.int) 142 | 143 | data = { 144 | 'train': train_tokenized, 145 | 'eval': eval_tokenized 146 | } 147 | 148 | torch.save(data, hyp['misc']['data_location']) 149 | print("completed the tokenization process!") 150 | 151 | else: 152 | ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :) 153 | ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above 154 | ## hyp dictionary, then we should be good. :) 155 | data = torch.load(hyp['misc']['data_location']) 156 | 157 | 158 | ######################################## 159 | # Constants # 160 | ######################################## 161 | 162 | with torch.no_grad(): 163 | # Create the base arrays for the learnable linear positional bias. This helps save some memory consumption & processing time 164 | bias_range = torch.arange(-hyp['misc']['sequence_length']['max']+1, 1).to(hyp['misc']['device'], torch.bfloat16) 165 | position_bias_base = bias_range.unsqueeze(0) - bias_range.unsqueeze(1) 166 | negative_infinity_matrix_base = torch.empty_like(position_bias_base).fill_(-float("inf")) 167 | causal_mask = torch.tril(torch.ones((hyp['misc']['sequence_length']['max'], hyp['misc']['sequence_length']['max']), device=hyp['misc']['device'], dtype=torch.bool)) 168 | 169 | 170 | # Used in the dataloader to select indexes in a sequence. Preallocated for slight efficiency. 171 | batch_index_offsets = torch.arange(0, hyp['misc']['sequence_length']['max']+1, dtype=torch.long, device=hyp['misc']['device']) 172 | 173 | 174 | ############################################# 175 | # Network Components # 176 | ############################################# 177 | 178 | class LatentAttentionBlock(nn.Module): 179 | """ Efficient fused latent-space attention block. Linear keys and queries, nonlinear values.""" 180 | def __init__(self, num_dim): 181 | super().__init__() 182 | # Layer dim parameters. Play around with these, there's likely some undiscovered stuff still! 183 | self.dim = num_dim 184 | self.qk_dim = self.dim//hyp['net']['qk_dim_div'] 185 | self.v_dim = num_dim 186 | self.expand_dim = num_dim * hyp['net']['expand_factor'] 187 | 188 | # Main layer weights 189 | self.norm = nn.LayerNorm(self.dim, bias=False) 190 | self.expand = nn.Parameter(.5 * 1./hyp['net']['residual_depth']**.5 * 1./hyp['net']['expand_factor'] * torch.randn(2*self.qk_dim+2*self.expand_dim, self.dim)) 191 | self.project = nn.Parameter(1. * 1./hyp['net']['residual_depth']**.5 * 1./hyp['net']['expand_factor'] * 1./hyp['net']['num_blocks'] * torch.randn((self.dim, self.expand_dim))) 192 | 193 | # Learnable linear positional encodings. Similar to but different than https://arxiv.org/abs/2108.12409 194 | # Has a high lr mult applied to it so that each layer can learn its own attention scale. 195 | self.position_bias_mult = nn.Parameter(torch.tensor(1., device='cuda')) 196 | 197 | def forward(self, x): 198 | residual = x 199 | 200 | # Make additive attention mask, scaled by a learned mult for the position bias (lets us learn dynamic attention ranges per layer as needed) 201 | attn_mask = torch.where(causal_mask[:x.shape[1], :x.shape[1]], F.softplus(self.position_bias_mult) * position_bias_base[:x.shape[1], :x.shape[1]], negative_infinity_matrix_base[:x.shape[1], :x.shape[1]]) 202 | 203 | # Shared LayerNorm for linear layers and attention 204 | x = self.norm(x) 205 | 206 | # Fused into one kernel for memory+speed/etc 207 | query, key, linear, pre_gelu = F.linear(x, self.expand).split((self.qk_dim, self.qk_dim, self.expand_dim, self.expand_dim), dim=-1) 208 | 209 | # Compute GeGLU (one portion of the channels this will stay locally, another will become the nonlinear value for attention) 210 | geglu = linear * F.gelu(pre_gelu) 211 | 212 | # Partition between the input values and the v dim values 213 | geglu_local, geglu_attention_value = geglu.split((self.expand_dim-self.v_dim, self.v_dim), -1) 214 | 215 | # Compute attention. Something to note is that there are no attention heads here. This seemed to work a bit better, maybe due to not needing memory `.contiguous()` calls or similar 216 | attention = F.scaled_dot_product_attention(query, key, geglu_attention_value, attn_mask=attn_mask) 217 | 218 | # Output linear layer 219 | out = F.linear(torch.cat([geglu_local, attention], dim=-1), self.project) 220 | 221 | # Add to residual 222 | x = residual + out 223 | 224 | return x 225 | 226 | 227 | ############################################# 228 | # Network Definition # 229 | ############################################# 230 | 231 | # This may seem like an odd way to define a network, but it's a bit easier to hack into/make quick changes than other methods 232 | class SpeedyLangNet(nn.Module): 233 | def __init__(self, network_dict): 234 | super().__init__() 235 | self.net_dict = network_dict 236 | 237 | def forward(self, x): 238 | # Look up the input embeddings from the input tokens 239 | x = self.net_dict['embedding'](x) 240 | for block in range(hyp['net']['num_blocks']): 241 | x = self.net_dict['attn_layers'][block](x) # note: residuals are included in the block definitions for these layers 242 | x = self.net_dict['norm'](x) 243 | x = self.net_dict['outputs'](x) 244 | return x 245 | 246 | 247 | def make_net(): 248 | network_dict = nn.ModuleDict({ 249 | 'embedding': nn.Embedding(hyp['misc']['num_tokens'], hyp['net']['residual_depth'], scale_grad_by_freq=True), 250 | 'attn_layers': nn.ModuleList([LatentAttentionBlock(hyp['net']['residual_depth']) for _ in range(hyp['net']['num_blocks'])]), 251 | 'norm': nn.LayerNorm(hyp['net']['residual_depth'], bias=False), 252 | 'outputs': nn.Linear(hyp['net']['residual_depth'], hyp['misc']['num_tokens'], bias=False), 253 | }) 254 | net = SpeedyLangNet(network_dict) 255 | net = net.to(hyp['misc']['device'], torch.bfloat16) 256 | net.train() 257 | 258 | # Initialize the embedding and output matrixes, with weights scaled based upon the dimensionality of the network. 259 | torch.nn.init.normal_(net.net_dict['embedding'].weight.data, std=.25*1./hyp['net']['residual_depth']**.5) 260 | torch.nn.init.normal_(net.net_dict['outputs'] .weight.data, std=.5 *1./hyp['net']['residual_depth']**.5) 261 | 262 | return net 263 | 264 | 265 | ######################################## 266 | # Training Helpers # 267 | ######################################## 268 | 269 | # Get a single batch item. Currently used in the training loop 270 | @torch.no_grad 271 | def get_batch(data_dict, key, batchsize, length): 272 | start_indexes = torch.randint(len(data_dict[key])-length-1, (batchsize,), device=hyp['misc']['device']) # warning, completely random sampling, not a random derangement, that might help performance a bit! 273 | sequence_indexes = start_indexes.unsqueeze(-1) + batch_index_offsets[:length+1].unsqueeze(0) # slice, as batch_index_offsets are pre-allocated to max length for efficiency 274 | sampled_sequences = torch.take_along_dim(data_dict[key], sequence_indexes.flatten(), dim=0).view(batchsize, length+1).long() # have to flatten and reshape due to take_along_dim being 1d 275 | 276 | inputs, targets = sampled_sequences[:, :-1], sampled_sequences[:, 1:] # reslice to get our input tokens and our shifted-by-1 targets 277 | 278 | return inputs, targets 279 | 280 | # Make loss function 281 | loss_fn = nn.CrossEntropyLoss(reduction='mean', ignore_index=-1) 282 | 283 | 284 | ############################## 285 | # Scheduling # 286 | ############################## 287 | 288 | # Infinite power law dicay is a simple power law learning rate schedule. seems to perform really well in practice as is simpler than OneCycle to tune. 289 | # Does a linear warmup from a min_initial lr to the max_lr at the peak_step, then decays infinitely with a 1/x**(power_value)-type shape to it. 290 | # These schedulers are multiplicative, that is why they scales from some base value to 1, which is what PyTorch's LambdaLR expects 291 | infinite_power_law_decay = lambda step, min_initial_mult, peak_step, exponent: min_initial_mult + step/peak_step * (1 - min_initial_mult) if step < peak_step else (step + 1. - peak_step) ** exponent 292 | exp_decay_lr_scheduler_base = lambda step, decay: decay ** step 293 | 294 | infinite_powah = partial(infinite_power_law_decay, min_initial_mult=2e-2, peak_step=hyp['opt']['warmup_steps'], exponent=-.08) 295 | infinite_powah_outputs = partial(infinite_power_law_decay, min_initial_mult=1., peak_step=0., exponent=-.2) 296 | pos_bias_decay_lr = partial(exp_decay_lr_scheduler_base, decay=.995) 297 | 298 | def init_param_groups_dict(net, base_lr): 299 | # the 'scheduler' attribute that we create here is not used by the optimizer, here we just use it to conveniently store all of these attributes. 300 | param_groups = {} 301 | 302 | # Multiply by our delta over the base lr-scaling curve 303 | scaled_lr = base_lr * hyp['opt']['lr_mult']['base'] 304 | 305 | print("scaled lr: ", "{:0.8f}".format(scaled_lr)) 306 | 307 | # Decay is the default dictionary if there is no parameter name match 308 | param_groups['decay'] = {'params': [], 'lr': scaled_lr, 'eps': 1e-9, 'betas': (.9, .95), 'weight_decay': hyp['opt']['weight_decay'], 'scheduler': infinite_powah } 309 | param_groups['position_bias_mult'] = {'params': [], 'lr': hyp['opt']['lr_mult']['position_bias'] *scaled_lr, 'eps': 1e-9, 'betas': (.9, .95), 'weight_decay': 0, 'scheduler': pos_bias_decay_lr } 310 | param_groups['norm', 'bias', 'embedding'] = {'params': [], 'lr': hyp['opt']['lr_mult']['non_dot_products']*scaled_lr, 'eps': 1e-9, 'betas': (.9, .95), 'weight_decay': 0, 'scheduler': infinite_powah } 311 | param_groups['output'] = {'params': [], 'lr': hyp['opt']['lr_mult']['output_layer'] *scaled_lr, 'eps': 1e-9, 'betas': (.6, .95), 'weight_decay': 0, 'scheduler': infinite_powah_outputs} 312 | 313 | # Helper functions for matching parameters to dictionary keys 314 | in_list = lambda name, keyword_list: any(keyword in name for keyword in keyword_list) 315 | to_tuple = lambda x: x if type(x) == tuple else (x,) 316 | 317 | # In order, search through the dictionary keys, and add to that dictionary if a value in the dictionary key matches the name. 318 | # 'decay' is the name of the default group, and is the only group with weight decay. 319 | for name, p in net.named_parameters(): 320 | if p.requires_grad: 321 | target_param_dict = next(iter([k for k in param_groups.keys() if in_list(name, to_tuple(k))]), 'decay') 322 | param_groups[target_param_dict]['params'].append(p) 323 | 324 | return param_groups 325 | 326 | def get_grad_norm(net): 327 | # Gets the entire grad norm of the network. 328 | grad_norm = torch.tensor(0., device=hyp['misc']['device'], dtype=torch.float64) 329 | for p in net.parameters(): 330 | if p.grad is not None: 331 | param_norm = p.grad.detach().data.norm(2) 332 | grad_norm += param_norm.square() 333 | grad_norm = (grad_norm ** 0.5).item() 334 | return grad_norm 335 | 336 | 337 | def grow_sequence_length(old_length, old_batchsize): 338 | # Dynamically grows the sequence length and changes the batchsize to avoid OOMs 339 | new_length = min(2*old_length, hyp['misc']['sequence_length']['max']) 340 | new_batchsize = tokens_per_batch_capacity // new_length 341 | 342 | print(f"| increasing sequence length (old: {old_length}, new: {new_length}), adjusting batchsize as necessary to fit (old: {old_batchsize}, new: {new_batchsize})") 343 | 344 | return new_length, new_batchsize 345 | 346 | 347 | ############################## 348 | # Logging # 349 | ############################## 350 | 351 | variables_to_log = ['epoch', 'curr_step', 'train_loss', 'val_loss', 'val_perplexity', 'train_acc', 'val_acc', 'grad_norm', 'microbatch_steps', 'total_seconds'] 352 | # define the printing function and print the column heads 353 | def print_training_details(columns_list, separator_left=' ', separator_right=' |', column_labels_only=False, is_final_entry=False): 354 | output_line = "|" # start with the left bar 355 | 356 | # Build the print string for the output: 357 | for column_entry in columns_list: 358 | output_line += separator_left + column_entry + separator_right 359 | 360 | if column_labels_only: 361 | print('-'*(len(output_line))) # print an initial upper dividing bar 362 | 363 | print(output_line) 364 | 365 | if column_labels_only or is_final_entry: 366 | print('-'*(len(output_line))) # print a lower divider bar 367 | 368 | # The previous function was a shorter but slightly more heinous lambda, however, this may still cause you some pain. <3 :'( 369 | def format_for_table(var_list, locals): 370 | int_format = lambda x: f"{locals[x]}".rjust(len(x)) 371 | default_format = lambda x: "{:0.4f}".format(locals[x]).rjust(len(x)) 372 | blank_format = lambda x: " "*len(x) 373 | 374 | out_list = [blank_format(v) if v not in locals else (int_format(v) if type(locals[v]) == int else default_format(v)) for v in var_list] 375 | return out_list 376 | 377 | 378 | ######################################## 379 | # Train and Eval # 380 | ######################################## 381 | 382 | def eval(net): 383 | #################### 384 | # Evaluation Mode # 385 | #################### 386 | 387 | # Do a slightly noisy fast eval over the max sequence length (should work okay as a rough general measurement of how we're doing) 388 | # Note that this is an approximation, it doesn't even necessarily use all of the requested tokens (but gets close because of the floor operation.) 389 | eval_batchsize = max(math.floor(tokens_per_batch_capacity/(hyp['misc']['sequence_length']['max'])//16), 1) # Number of sequences per batch relative to the max-length batchsize capacity, downscale factor hardcoded to help prevent OOMs. Tunable 390 | num_eval_sequences = hyp['opt']['num_eval_tokens']//hyp['misc']['sequence_length']['max'] 391 | num_eval_steps = num_eval_sequences//eval_batchsize 392 | 393 | # float32 here to prevent truncation errors 394 | val_loss, val_acc = torch.tensor(0., device=hyp['misc']['device'], dtype=torch.float), torch.tensor(0., device=hyp['misc']['device'], dtype=torch.float) 395 | 396 | with torch.no_grad(): 397 | # Note: We eval at the maximum sequence length so that we can get an idea of how well the sequence length growing extrapolates out 398 | for _ in range(num_eval_steps): 399 | inputs, targets = get_batch(data, key='eval', batchsize=eval_batchsize, length=hyp['misc']['sequence_length']['max']) 400 | outputs = net(inputs) 401 | val_loss += 1./num_eval_steps * loss_fn(outputs.flatten(0, 1).float(), targets.flatten(0, 1)) 402 | val_acc += 1./num_eval_steps * (outputs.argmax(-1) == targets).float().mean() 403 | 404 | val_perplexity = 2.71828 ** val_loss 405 | 406 | return val_acc.item(), val_loss.item(), val_perplexity.item() 407 | 408 | def main(): 409 | 410 | ################# 411 | # Init # 412 | ################# 413 | # Full-run statistics variables 414 | total_seconds = 0. 415 | curr_microbatch_step = curr_step = 0 416 | tokens_seen = 0 417 | 418 | # Microbatch growing parameters 419 | # Leaving this hardcoded for now for simplicity, this helps keep the learning process stable. 420 | microbatch_steps = 0. # The noninteger estimate of microbatches required based upon the grad norm (sampled by dithering at each step.) 421 | discrete_sampled_microbatch_steps = max(1, int(microbatch_steps)) 422 | 423 | # Start at the initial length and maximum allowable batchsize. The batchsize is adjusted so that we see roughly the same number of tokens per batch. This means that shorter sequence lengths will have much larger batch sizes. 424 | curr_length = hyp['misc']['sequence_length']['initial'] 425 | curr_batchsize = tokens_per_batch_capacity // hyp['misc']['sequence_length']['initial'] 426 | final_batchsize = tokens_per_batch_capacity / hyp['misc']['sequence_length']['max'] 427 | assert final_batchsize > 1, f"Error: Specified configuration takes up too much memory (calculated final batchsize {final_batchsize} is less than 1!)" 428 | 429 | # Validation parameters 430 | val_loss, val_acc, val_perplexity = None, None, None 431 | 432 | # Get network 433 | net = make_net() 434 | 435 | # Get the total number of parameters in our model and use that to generate/calculate the base lr. 436 | total_trainable_params = sum([p.data.numel() if p.requires_grad else 0 for p in net.parameters()]) 437 | 438 | print('-'*(40)) 439 | print(f"total trainable params: {total_trainable_params:,}") 440 | print('-'*(40)) 441 | 442 | # Briefly log some details up front. (TODO: Condense nicely later.) 443 | print("curr_batchsize: ", curr_batchsize) 444 | print("final_batchsize: ", tokens_per_batch_capacity // hyp['misc']['sequence_length']['max']) 445 | print("max_sequence_length:", max_sequence_length) 446 | 447 | 448 | ##################### 449 | # Scaling Equations # 450 | ##################### 451 | 452 | # These equations are a result of rough general exponential/power law fits between parameters that worked for the 46M and 1.5B run 453 | # They seem to transfer not too badly when interpolating, however, they're far from perfect and assume 40 GB of memory (so if you use) 454 | # a smaller card, you might struggle a bit here. All in all -- this is still in alpha, but seems to be very useful within a limited arena 455 | # of making arbitrary models between 45M and 1.5B 456 | 457 | # A very, very pared down version of the gpt-3 training lr scaling rule roughly fit. It's used as a loose general base for the run LRs. 458 | base_lr = 9e7 / math.log(total_trainable_params)**8.8 459 | 460 | # The base value that we raise to the value of our loss in order to determine how much weight decay we need (exponentially strong as we approach 0.) 461 | weight_decay_pow_base = .007 * ((.01 * math.log(total_trainable_params))) ** (-4) 462 | 463 | # This defines how quickly we expect grad_norm drops for microbatch scheduling -- slightly faster for smaller models, slightly slower for larger models 464 | # Note: This will interact with really aggressive weight decay, some training runs may slow down a lot near the end as a result. 465 | microbatch_expected_grad_norm_pow = -.677 * math.log(total_trainable_params) ** -.2 466 | 467 | # Bit of a strange approximation, but this seemed 468 | microbatch_grad_norm_steps_scale = math.log(total_trainable_params) * total_trainable_params 469 | 470 | # Create multiple parameter groups based on parameter name, as certain kinds of parameters seem to work best 471 | # with specific combinations of learning rates and schedulers 472 | param_groups_dict = init_param_groups_dict(net, base_lr) 473 | opt = torch.optim.AdamW(param_groups_dict.values(), fused=True) 474 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, [k['scheduler'] for k in param_groups_dict.values()]) 475 | 476 | 477 | ################# 478 | # Training Mode # 479 | ################# 480 | 481 | ## print out the training column headers before each run. 482 | print_training_details(variables_to_log, column_labels_only=True) 483 | 484 | ## For accurately timing GPU code 485 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 486 | torch.cuda.synchronize() ## clean up any pre-net setup operations 487 | starter.record() 488 | 489 | net.train() 490 | 491 | # Main loop. Most of the complexity here is in the dynamic growing scheduler(s). 492 | while curr_step < hyp['opt']['total_train_steps']: 493 | inputs, targets = get_batch(data, key='train', batchsize=curr_batchsize, length=curr_length) 494 | 495 | outputs = net(inputs) 496 | 497 | loss = loss_fn(outputs.flatten(0, 1), targets.flatten(0, 1)) 498 | 499 | loss.div(discrete_sampled_microbatch_steps).backward() 500 | tokens_seen += curr_batchsize * curr_length 501 | 502 | # Quick non-eval summary every N training steps, at the end of every microbatch group, if we are not doing a _full eval_ here. 503 | if curr_step % 10 == 0 and curr_microbatch_step % discrete_sampled_microbatch_steps == 0 and not curr_step % hyp['opt']['eval_every'] == 0: 504 | train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item() 505 | train_loss = loss.detach().cpu().item() 506 | train_summary_vars = {'epoch': tokens_seen//len(data['train']), 'curr_step': curr_step, 'train_loss': train_loss, 'train_acc': train_acc, 'grad_norm': grad_norm} 507 | 508 | print_training_details(format_for_table(variables_to_log, locals=train_summary_vars)) 509 | 510 | 511 | # Once we've accumulated steps over all of our microbatches, take a single full-batchsize step. 512 | if curr_microbatch_step % discrete_sampled_microbatch_steps == 0: 513 | # Step the optimizer, then scheduler 514 | opt.step() 515 | 516 | # Dynamic weight decay scheduling. Based upon something similar to the reciprocal of the perplexity of the network over the data [inspired by section 5 of https://arxiv.org/pdf/2204.02311.pdf] 517 | # Smaller models have a higher base, and weight decay kicks in more sharply later. For larger models, it activates more early 518 | opt.param_groups[0]['weight_decay'] = 1./weight_decay_pow_base**(loss.detach()+1e-8).item() * hyp['opt']['weight_decay'] 519 | scheduler.step() 520 | 521 | # Check if we need to double our sequence length 522 | if curr_step % hyp['misc']['sequence_length']['growth_steps'] == 0 and curr_step != 0 and curr_length < hyp['misc']['sequence_length']['max']: 523 | curr_length, curr_batchsize = grow_sequence_length(curr_length, curr_batchsize) 524 | 525 | # The next several lines calculate a dynamic batchsize, simulated through manual dithering 526 | # There could be improvements or losses in changing the dithering strategy, since determinism and gradient descent can lead to some very not-so-nice (and subtle) loss oscillations. 527 | if curr_step % hyp['opt']['microbatch']['sample_every'] == 0: 528 | grad_norm = get_grad_norm(net) 529 | 530 | grad_norm_per_param = grad_norm/(total_trainable_params**.5) # This should keep the expected grad norm per parameter roughly the same (ignoring initializations) unless I did my napkin math wrong (feel free to correct it and test it out if so! <3 :') ) 531 | grad_norm_target = (((microbatch_grad_norm_steps_scale * (curr_step + 1e-2))) ** microbatch_expected_grad_norm_pow) 532 | ratio_diff = grad_norm_per_param/(grad_norm_target) 533 | 534 | # Update the fractional number of steps based on the % difference between the grad norm and expected grad norm. 535 | microbatch_steps *= 1. + (hyp['opt']['microbatch']['sample_every'] * hyp['opt']['microbatch']['scale_lr'] * (ratio_diff - 1)) 536 | microbatch_steps = max(microbatch_steps, 1e-1) # Clamp to keep this from going to zero, so that we can bounce back if needed 537 | 538 | # simple bernoulli dithering with probabilities based on how close we are to each integer 539 | base, dither_prob = divmod(microbatch_steps, 1) 540 | 541 | # Randomly sample next accumulate steps to use. This is the dithered operation, the 'microbatch_steps' is the noninteger accumulator between steps. 542 | discrete_sampled_microbatch_steps = max(1, int(base + torch.bernoulli(torch.tensor(dither_prob)).item())) # bernoulli via torch to save an unnecesary import :) 543 | 544 | opt.zero_grad() 545 | 546 | # reset microbatch steps and increment current step 547 | curr_microbatch_step = 0 548 | curr_step += 1 549 | 550 | # Since we're not running over epochs anymore, we have to manually calculate roughly what epoch it is. This is different than the standard random derangement of sampled sequences and has different pros/cons, is my understanding. :thumbsup: 551 | epoch = tokens_seen//len(data['train']) 552 | 553 | if curr_step % hyp['opt']['eval_every'] == 0: 554 | ender.record() 555 | torch.cuda.synchronize() 556 | 557 | total_seconds += 1e-3 * starter.elapsed_time(ender) 558 | train_loss = loss.detach().cpu().item() # Update the loss for the training details printout 559 | 560 | net.eval() 561 | val_acc, val_loss, val_perplexity = eval(net) 562 | 563 | if (curr_step//hyp['opt']['eval_every']) % hyp['opt']['save_every_n_evals'] == 0: 564 | torch.save(net, 'model.pt') 565 | 566 | # Print out our training details 567 | ## We also check to see if we're on our final eval loop (assum that max_curr_step lines up with the eval_every value) so we can print the 'bottom' of the table for each round. 568 | is_final_eval = (curr_step >= hyp['opt']['total_train_steps']) # If we're at the end of training, add a line after the end of the run 569 | print_training_details(format_for_table(variables_to_log, locals=locals()), is_final_entry=is_final_eval) 570 | 571 | torch.cuda.synchronize() 572 | starter.record() 573 | net.train() 574 | curr_microbatch_step += 1 575 | 576 | return net, val_loss # Return the final validation loss achieved (not using the 'best validation loss' selection strategy, which I think is okay here....) 577 | 578 | 579 | if __name__ == "__main__": 580 | final_val_loss_list = [] 581 | for _ in range(1): 582 | net, val_loss = main() 583 | final_val_loss_list.append(val_loss) 584 | print(f"Average final val loss: {sum(final_val_loss_list)/len(final_val_loss_list)}") # TODO add variance as well, later 585 | 586 | 587 | ######################## 588 | # Inference Test # 589 | ######################## 590 | net = torch.load('model.pt') 591 | 592 | net.eval() 593 | demo_sentence = "In 1856, Abraham Lincoln" 594 | 595 | tokenizer = tiktoken.get_encoding("gpt2") 596 | tokenized_demo_sentence = torch.tensor(tokenizer.encode_ordinary(demo_sentence), dtype=torch.int, device='cuda').unsqueeze(0) 597 | 598 | import sys; sys.setrecursionlimit(max_sequence_length*2) 599 | inference = lambda x, length=512, temp=1.: inference(torch.cat((x, torch.multinomial(net(x)[:, -1].div(temp).softmax(-1), 1)), dim=-1), length-1) if length > 0 else x 600 | 601 | import textwrap 602 | with torch.no_grad(): 603 | print("\nprompt: \n", textwrap.fill(tokenizer.decode(tokenized_demo_sentence.squeeze().cpu().numpy()), 80, replace_whitespace=False)) 604 | print("decoded result: \n", textwrap.fill(tokenizer.decode(inference(tokenized_demo_sentence).squeeze().cpu().numpy()), 80, replace_whitespace=False)) 605 | 606 | --------------------------------------------------------------------------------