├── .vscode └── launch.json ├── Datasets ├── Alpaca │ ├── download.py │ └── setup.sh └── Gutenberg │ ├── download.py │ ├── prepare_dataset.py │ └── setup.sh ├── LICENSE.txt ├── Models ├── GPT2 │ ├── GPT2.py │ ├── config.py │ └── load_weights.py └── Llama │ ├── Llama2.py │ ├── Llama3.py │ ├── common_components.py │ ├── config.py │ ├── load_weights_llama2.py │ └── load_weights_llama3.py ├── README.md ├── args.py ├── build_components.py ├── config_hf.json ├── datautils ├── dataloader.py ├── dataloader_instruction_finetune.py ├── dataset.py ├── dataset_instruction_finetune.py └── mixed_precision.py ├── generate.py ├── logger.py ├── lora.py ├── main.py ├── requirements.txt ├── train.py └── utils.py /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | 5 | { 6 | "name": "Python Debugger: Current File", 7 | "type": "debugpy", 8 | "request": "launch", 9 | "program": "${file}", 10 | "console": "integratedTerminal", 11 | "subProcess": true, 12 | "justMyCode": false, 13 | "env": { 14 | "CUDA_VISIBLE_DEVICES": "0,1" // Use GPU 0,1 only 15 | }, 16 | "args": [] 17 | 18 | } 19 | ] 20 | } -------------------------------------------------------------------------------- /Datasets/Alpaca/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import urllib 4 | from urllib import request 5 | 6 | def download_and_load_file(file_path, url): 7 | 8 | if not os.path.exists(file_path): 9 | with request.urlopen(url) as response: 10 | text_data = response.read().decode("utf-8") 11 | with open(file_path, "w", encoding="utf-8") as file: 12 | file.write(text_data) 13 | 14 | with open(file_path, "r") as file: 15 | data = json.load(file) 16 | 17 | return data 18 | 19 | if __name__ == "__main__": 20 | 21 | file_path = "instruction-data-alpaca.json" 22 | url = "https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json" 23 | data = download_and_load_file(file_path, url) 24 | -------------------------------------------------------------------------------- /Datasets/Alpaca/setup.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | 3 | python download.py 4 | 5 | mv instruction-data-alpaca.json ./data/ -------------------------------------------------------------------------------- /Datasets/Gutenberg/download.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | 3 | # Replace with the Google Drive file ID or full link 4 | file_id = '1i8eeP79dN2TwIK7H4qr_Y-ji1cB19SMU' 5 | url = f'https://drive.google.com/uc?id={file_id}&export=download' 6 | 7 | # Specify the output path where you want to save the file 8 | output_path = './downloaded_file.zip' 9 | 10 | # Download the file 11 | gdown.download(url, output_path, quiet=False) -------------------------------------------------------------------------------- /Datasets/Gutenberg/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | # Add the path of the cloned repo to sys.path 6 | repo_path = os.path.abspath(r'./gutenberg') 7 | sys.path.append(repo_path) 8 | 9 | from tqdm import tqdm 10 | from gutenberg.src.cleanup import strip_headers 11 | import re 12 | 13 | def is_english(text, threshold=0.9): 14 | ascii_chars = sum(1 for c in text if ord(c) < 128) 15 | return ascii_chars / len(text) > threshold 16 | 17 | #Combine files into separate text blocks 18 | def combine_files(file_paths, target_dir, max_size_mb=500, separator="<|endoftext|>", fallback_encoding="latin1"): 19 | 20 | 21 | if not os.path.exists(target_dir): 22 | os.makedirs(target_dir) 23 | 24 | current_content = [] 25 | current_size = 0 26 | file_counter = 1 27 | print("Generating combined files.") 28 | for file_path in tqdm(file_paths): 29 | try: 30 | with open(file_path, "r", encoding="utf-8") as file: 31 | content = file.read() 32 | except UnicodeDecodeError: 33 | # Attempt to read the file with a fallback encoding 34 | tqdm.write(f"Warning: UnicodeDecodeError encountered. Trying fallback encoding for {file_path}") 35 | with open(file_path, "r", encoding=fallback_encoding) as file: 36 | content = file.read() 37 | 38 | if not is_english(content): 39 | tqdm.write(f"Skipping {file_path} as it does not contain primarily English text.") 40 | continue 41 | content = strip_headers(content) 42 | 43 | # Regular expression to replace multiple blank lines with a single blank line 44 | content = re.sub(r'\n\s*\n', '\n\n', content) 45 | estimated_size = len(content.encode("utf-8")) 46 | 47 | if current_size + estimated_size > max_size_mb * 1024 * 1024: 48 | target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt") 49 | with open(target_file_path, "w", encoding="utf-8") as target_file: 50 | target_file.write(separator.join(current_content)) 51 | file_counter += 1 52 | current_content = [content] 53 | current_size = estimated_size 54 | else: 55 | current_content.append(content) 56 | current_size += estimated_size 57 | 58 | if current_content: 59 | target_file_path = os.path.join(target_dir, f"combined_{file_counter}.txt") 60 | with open(target_file_path, "w", encoding="utf-8") as target_file: 61 | target_file.write(separator.join(current_content)) 62 | return file_counter 63 | 64 | if __name__ == "__main__": 65 | 66 | parser = argparse.ArgumentParser(description="Preprocess and combine text files for pretraining") 67 | 68 | parser.add_argument("--data_dir", type=str, default="Gutenberg/txt", 69 | help="Directory containing the downloaded raw training data") 70 | parser.add_argument("--max_size_mb", type=int, default=500, 71 | help="The maximum file size for each concatenated file in megabytes") 72 | parser.add_argument("--output_dir", type=str, default="data_dir", 73 | help="Directory where the preprocessed data will be saved") 74 | 75 | args = parser.parse_args() 76 | 77 | all_files=[] 78 | for path, subdirs, files in os.walk(args.data_dir): 79 | for name in files: 80 | if name.endswith((".txt")): 81 | all_files.append(os.path.join(path, name)) 82 | 83 | print(f"{len(all_files)} file(s) to process.") 84 | file_counter = combine_files(all_files, args.output_dir, max_size_mb=args.max_size_mb) 85 | print(f"{file_counter} file(s) saved in {os.path.abspath(args.output_dir)}") -------------------------------------------------------------------------------- /Datasets/Gutenberg/setup.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | 3 | python download.py 4 | 5 | unzip downloaded_file.zip 6 | 7 | git clone https://github.com/pgcorpus/gutenberg.git 8 | 9 | mkdir data_dir 10 | 11 | python prepare_dataset.py -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | explicitly excluding any books specific to this software and any related images, 29 | and includes but is not limited to software source code, 30 | documentation source (excluding books and images related to this software), 31 | and configuration files. 32 | 33 | "Object" form shall mean any form resulting from mechanical 34 | transformation or translation of a Source form, including but 35 | not limited to compiled object code, generated documentation, 36 | and conversions to other media types. 37 | 38 | "Work" shall mean the work of authorship, whether in Source or 39 | Object form, made available under the License, as indicated by a 40 | copyright notice that is included in or attached to the work 41 | (an example is provided in the Appendix below). 42 | 43 | "Derivative Works" shall mean any work, whether in Source or Object 44 | form, that is based on (or derived from) the Work and for which the 45 | editorial revisions, annotations, elaborations, or other modifications 46 | represent, as a whole, an original work of authorship. For the purposes 47 | of this License, Derivative Works shall not include works that remain 48 | separable from, or merely link (or bind by name) to the interfaces of, 49 | the Work and Derivative Works thereof. 50 | 51 | "Contribution" shall mean any work of authorship, including 52 | the original version of the Work and any modifications or additions 53 | to that Work or Derivative Works thereof, that is intentionally 54 | submitted to Licensor for inclusion in the Work by the copyright owner 55 | or by an individual or Legal Entity authorized to submit on behalf of 56 | the copyright owner. For the purposes of this definition, "submitted" 57 | means any form of electronic, verbal, or written communication sent 58 | to the Licensor or its representatives, including but not limited to 59 | communication on electronic mailing lists, source code control systems, 60 | and issue tracking systems that are managed by, or on behalf of, the 61 | Licensor for the purpose of discussing and improving the Work, but 62 | excluding communication that is conspicuously marked or otherwise 63 | designated in writing by the copyright owner as "Not a Contribution." 64 | 65 | "Contributor" shall mean Licensor and any individual or Legal Entity 66 | on behalf of whom a Contribution has been received by Licensor and 67 | subsequently incorporated within the Work. 68 | 69 | 2. Grant of Copyright License. Subject to the terms and conditions of 70 | this License, each Contributor hereby grants to You a perpetual, 71 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 72 | copyright license to reproduce, prepare Derivative Works of, 73 | publicly display, publicly perform, sublicense, and distribute the 74 | Work and such Derivative Works in Source or Object form. 75 | 76 | 3. Grant of Patent License. Subject to the terms and conditions of 77 | this License, each Contributor hereby grants to You a perpetual, 78 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 79 | (except as stated in this section) patent license to make, have made, 80 | use, offer to sell, sell, import, and otherwise transfer the Work, 81 | where such license applies only to those patent claims licensable 82 | by such Contributor that are necessarily infringed by their 83 | Contribution(s) alone or by combination of their Contribution(s) 84 | with the Work to which such Contribution(s) was submitted. If You 85 | institute patent litigation against any entity (including a 86 | cross-claim or counterclaim in a lawsuit) alleging that the Work 87 | or a Contribution incorporated within the Work constitutes direct 88 | or contributory patent infringement, then any patent licenses 89 | granted to You under this License for that Work shall terminate 90 | as of the date such litigation is filed. 91 | 92 | 4. Redistribution. You may reproduce and distribute copies of the 93 | Work or Derivative Works thereof in any medium, with or without 94 | modifications, and in Source or Object form, provided that You 95 | meet the following conditions: 96 | 97 | (a) You must give any other recipients of the Work or 98 | Derivative Works a copy of this License; and 99 | 100 | (b) You must cause any modified files to carry prominent notices 101 | stating that You changed the files; and 102 | 103 | (c) You must retain, in the Source form of any Derivative Works 104 | that You distribute, all copyright, patent, trademark, and 105 | attribution notices from the Source form of the Work, 106 | excluding those notices that do not pertain to any part of 107 | the Derivative Works; and 108 | 109 | (d) If the Work includes a "NOTICE" text file as part of its 110 | distribution, then any Derivative Works that You distribute must 111 | include a readable copy of the attribution notices contained 112 | within such NOTICE file, excluding those notices that do not 113 | pertain to any part of the Derivative Works, in at least one 114 | of the following places: within a NOTICE text file distributed 115 | as part of the Derivative Works; within the Source form or 116 | documentation, if provided along with the Derivative Works; or, 117 | within a display generated by the Derivative Works, if and 118 | wherever such third-party notices normally appear. The contents 119 | of the NOTICE file are for informational purposes only and 120 | do not modify the License. You may add Your own attribution 121 | notices within Derivative Works that You distribute, alongside 122 | or as an addendum to the NOTICE text from the Work, provided 123 | that such additional attribution notices cannot be construed 124 | as modifying the License. 125 | 126 | You may add Your own copyright statement to Your modifications and 127 | may provide additional or different license terms and conditions 128 | for use, reproduction, or distribution of Your modifications, or 129 | for any such Derivative Works as a whole, provided Your use, 130 | reproduction, and distribution of the Work otherwise complies with 131 | the conditions stated in this License. 132 | 133 | 5. Submission of Contributions. Unless You explicitly state otherwise, 134 | any Contribution intentionally submitted for inclusion in the Work 135 | by You to the Licensor shall be under the terms and conditions of 136 | this License, without any additional terms or conditions. 137 | Notwithstanding the above, nothing herein shall supersede or modify 138 | the terms of any separate license agreement you may have executed 139 | with Licensor regarding such Contributions. 140 | 141 | 6. Trademarks. This License does not grant permission to use the trade 142 | names, trademarks, service marks, or product names of the Licensor, 143 | except as required for reasonable and customary use in describing the 144 | origin of the Work and reproducing the content of the NOTICE file. 145 | 146 | 7. Disclaimer of Warranty. Unless required by applicable law or 147 | agreed to in writing, Licensor provides the Work (and each 148 | Contributor provides its Contributions) on an "AS IS" BASIS, 149 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 150 | implied, including, without limitation, any warranties or conditions 151 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 152 | PARTICULAR PURPOSE. You are solely responsible for determining the 153 | appropriateness of using or redistributing the Work and assume any 154 | risks associated with Your exercise of permissions under this License. 155 | 156 | 8. Limitation of Liability. In no event and under no legal theory, 157 | whether in tort (including negligence), contract, or otherwise, 158 | unless required by applicable law (such as deliberate and grossly 159 | negligent acts) or agreed to in writing, shall any Contributor be 160 | liable to You for damages, including any direct, indirect, special, 161 | incidental, or consequential damages of any character arising as a 162 | result of this License or out of the use or inability to use the 163 | Work (including but not limited to damages for loss of goodwill, 164 | work stoppage, computer failure or malfunction, or any and all 165 | other commercial damages or losses), even if such Contributor 166 | has been advised of the possibility of such damages. 167 | 168 | 9. Accepting Warranty or Additional Liability. While redistributing 169 | the Work or Derivative Works thereof, You may choose to offer, 170 | and charge a fee for, acceptance of support, warranty, indemnity, 171 | or other liability obligations and/or rights consistent with this 172 | License. However, in accepting such obligations, You may act only 173 | on Your own behalf and on Your sole responsibility, not on behalf 174 | of any other Contributor, and only if You agree to indemnify, 175 | defend, and hold each Contributor harmless for any liability 176 | incurred by, or claims asserted against, such Contributor by reason 177 | of your accepting any such warranty or additional liability. 178 | 179 | END OF TERMS AND CONDITIONS 180 | 181 | APPENDIX: How to apply the Apache License to your work. 182 | 183 | To apply the Apache License to your work, attach the following 184 | boilerplate notice, with the fields enclosed by brackets "[]" 185 | replaced with your own identifying information. (Don't include 186 | the brackets!) The text should be enclosed in the appropriate 187 | comment syntax for the file format. We also recommend that a 188 | file or class name and description of purpose be included on the 189 | same "printed page" as the copyright notice for easier 190 | identification within third-party archives. 191 | 192 | Copyright 2023-2024 Sebastian Raschka 193 | 194 | Licensed under the Apache License, Version 2.0 (the "License"); 195 | you may not use this file except in compliance with the License. 196 | You may obtain a copy of the License at 197 | 198 | http://www.apache.org/licenses/LICENSE-2.0 199 | 200 | Unless required by applicable law or agreed to in writing, software 201 | distributed under the License is distributed on an "AS IS" BASIS, 202 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 203 | See the License for the specific language governing permissions and 204 | limitations under the License. -------------------------------------------------------------------------------- /Models/GPT2/GPT2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.utils.checkpoint import checkpoint_sequential 4 | 5 | class MultiHeadAttention(nn.Module): 6 | def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): 7 | super().__init__() 8 | assert d_out % num_heads == 0, "d_out must be divisible by n_heads" 9 | 10 | self.d_out = d_out 11 | self.num_heads = num_heads 12 | self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim 13 | 14 | self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) 15 | self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) 16 | self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) 17 | self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs 18 | self.dropout = nn.Dropout(dropout) 19 | self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) 20 | 21 | def forward(self, x): 22 | b, num_tokens, d_in = x.shape 23 | 24 | keys = self.W_key(x) # Shape: (b, num_tokens, d_out) 25 | queries = self.W_query(x) 26 | values = self.W_value(x) 27 | 28 | # We implicitly split the matrix by adding a `num_heads` dimension 29 | # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) 30 | keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 31 | values = values.view(b, num_tokens, self.num_heads, self.head_dim) 32 | queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) 33 | 34 | # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) 35 | keys = keys.transpose(1, 2) 36 | queries = queries.transpose(1, 2) 37 | values = values.transpose(1, 2) 38 | 39 | # Compute scaled dot-product attention (aka self-attention) with a causal mask 40 | attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head 41 | 42 | # Original mask truncated to the number of tokens and converted to boolean 43 | mask_bool = self.mask.bool()[:num_tokens, :num_tokens] 44 | 45 | # Use the mask to fill attention scores 46 | attn_scores.masked_fill_(mask_bool, -torch.inf) 47 | 48 | attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) 49 | attn_weights = self.dropout(attn_weights) 50 | 51 | # Shape: (b, num_tokens, num_heads, head_dim) 52 | context_vec = (attn_weights @ values).transpose(1, 2) 53 | 54 | # Combine heads, where self.d_out = self.num_heads * self.head_dim 55 | context_vec = context_vec.reshape(b, num_tokens, self.d_out) 56 | context_vec = self.out_proj(context_vec) # optional projection 57 | 58 | return context_vec 59 | 60 | class LayerNorm(nn.Module): 61 | def __init__(self, emb_dim): 62 | super().__init__() 63 | self.eps = 1e-5 64 | self.scale = nn.Parameter(torch.ones(emb_dim)) 65 | self.shift = nn.Parameter(torch.zeros(emb_dim)) 66 | 67 | def forward(self, x): 68 | mean = x.mean(dim=-1, keepdim=True) 69 | var = x.var(dim=-1, keepdim=True, unbiased=False) 70 | norm_x = (x - mean) / torch.sqrt(var + self.eps) 71 | return self.scale * norm_x + self.shift 72 | 73 | 74 | class GELU(nn.Module): 75 | def __init__(self): 76 | super().__init__() 77 | 78 | def forward(self, x): 79 | return 0.5 * x * (1 + torch.tanh( 80 | torch.sqrt(torch.tensor(2.0 / torch.pi)) * 81 | (x + 0.044715 * torch.pow(x, 3)) 82 | )) 83 | 84 | 85 | class FeedForward(nn.Module): 86 | def __init__(self, cfg): 87 | super().__init__() 88 | self.layers = nn.Sequential( 89 | nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), 90 | GELU(), 91 | nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), 92 | ) 93 | 94 | def forward(self, x): 95 | return self.layers(x) 96 | 97 | 98 | class TransformerBlock(nn.Module): 99 | def __init__(self, cfg): 100 | super().__init__() 101 | self.att = MultiHeadAttention( 102 | d_in=cfg["emb_dim"], 103 | d_out=cfg["emb_dim"], 104 | context_length=cfg["context_length"], 105 | num_heads=cfg["n_heads"], 106 | dropout=cfg["drop_rate"], 107 | qkv_bias=cfg["qkv_bias"]) 108 | self.ff = FeedForward(cfg) 109 | self.norm1 = LayerNorm(cfg["emb_dim"]) 110 | self.norm2 = LayerNorm(cfg["emb_dim"]) 111 | self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) 112 | 113 | def forward(self, x): 114 | # Shortcut connection for attention block 115 | shortcut = x 116 | x = self.norm1(x) 117 | x = self.att(x) # Shape [batch_size, num_tokens, emb_size] 118 | x = self.drop_shortcut(x) 119 | x = x + shortcut # Add the original input back 120 | 121 | # Shortcut connection for feed-forward block 122 | shortcut = x 123 | x = self.norm2(x) 124 | x = self.ff(x) 125 | x = self.drop_shortcut(x) 126 | x = x + shortcut # Add the original input back 127 | 128 | return x 129 | 130 | 131 | class GPTModel(nn.Module): 132 | def __init__(self, cfg, use_actv_ckpt=False): 133 | super().__init__() 134 | self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) 135 | self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) 136 | self.drop_emb = nn.Dropout(cfg["drop_rate"]) 137 | 138 | self.use_actv_ckpt = use_actv_ckpt 139 | self.cfg=cfg 140 | 141 | self.trf_blocks = nn.Sequential( 142 | *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) 143 | 144 | self.final_norm = LayerNorm(cfg["emb_dim"]) 145 | self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) 146 | 147 | def forward(self, in_idx): 148 | batch_size, seq_len = in_idx.shape 149 | tok_embeds = self.tok_emb(in_idx) 150 | pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) 151 | x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] 152 | x = self.drop_emb(x) 153 | if(self.use_actv_ckpt): 154 | x = checkpoint_sequential(self.trf_blocks, segments=self.cfg["n_layers"], input=x, use_reentrant=False) 155 | else: 156 | x = self.trf_blocks(x) 157 | x = self.final_norm(x) 158 | logits = self.out_head(x) 159 | return logits -------------------------------------------------------------------------------- /Models/GPT2/config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | GPT_CONFIG_124M = { 4 | "vocab_size": 50257, # Vocabulary size 5 | "context_length": 1024, # Context length 6 | "emb_dim": 768, # Embedding dimension 7 | "n_heads": 12, # Number of attention heads 8 | "n_layers": 12, # Number of layers 9 | "drop_rate": 0.1, # Dropout rate 10 | "qkv_bias": False, # Query-key-value bias 11 | "eos_id":50256, 12 | "eos_text":"<|endoftext|>" 13 | } 14 | 15 | GPT_CONFIG_355M = { 16 | "vocab_size": 50257, # Vocabulary size 17 | "context_length": 1024, # Context length 18 | "emb_dim": 1024, # Embedding dimension 19 | "n_heads": 16, # Number of attention heads 20 | "n_layers": 24, # Number of layers 21 | "drop_rate": 0.1, # Dropout rate 22 | "qkv_bias": False, # Query-key-value bias 23 | "eos_id":50256, 24 | "eos_text":"<|endoftext|>" 25 | } 26 | 27 | GPT_CONFIG_774M = { 28 | "vocab_size": 50257, # Vocabulary size 29 | "context_length": 1024, # Context length 30 | "emb_dim": 1280, # Embedding dimension 31 | "n_heads": 20, # Number of attention heads 32 | "n_layers": 36, # Number of layers 33 | "drop_rate": 0.1, # Dropout rate 34 | "qkv_bias": False, # Query-key-value bias 35 | "eos_id":50256, 36 | "eos_text":"<|endoftext|>" 37 | } 38 | 39 | GPT_CONFIG_1_5B = { 40 | "vocab_size": 50257, # Vocabulary size 41 | "context_length": 1024, # Context length 42 | "emb_dim": 1600, # Embedding dimension 43 | "n_heads": 25, # Number of attention heads 44 | "n_layers": 48, # Number of layers 45 | "drop_rate": 0.1, # Dropout rate 46 | "qkv_bias": False, # Query-key-value bias 47 | "eos_id":50256, 48 | "eos_text":"<|endoftext|>" 49 | } 50 | 51 | available_configs = { 52 | "124M":GPT_CONFIG_124M, 53 | "355M":GPT_CONFIG_355M, 54 | "774M":GPT_CONFIG_774M, 55 | "1.5B":GPT_CONFIG_1_5B 56 | } 57 | 58 | 59 | 60 | def get_config_gpt2(num_params): 61 | 62 | num_params = str(num_params) 63 | 64 | assert num_params in available_configs, "A GPT2 model for given number of parameters does not exists." 65 | 66 | return available_configs[num_params] 67 | 68 | -------------------------------------------------------------------------------- /Models/GPT2/load_weights.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Model 2 | import numpy as np 3 | import torch 4 | 5 | 6 | hf_mapping = { 7 | "124M": "openai-community/gpt2", 8 | "355M": "openai-community/gpt2-medium", 9 | "774M": "openai-community/gpt2-large", 10 | "1.5B": "openai-community/gpt2-xl" 11 | } 12 | 13 | def assign_check(left, right): 14 | if left.shape != right.shape: 15 | raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}") 16 | return torch.nn.Parameter(right.clone().detach().to(left.device)) 17 | 18 | def load_weights(gpt, gpt_hf, config): 19 | 20 | d = gpt_hf.state_dict() 21 | 22 | gpt.pos_emb.weight = assign_check(gpt.pos_emb.weight, d["wpe.weight"]) 23 | gpt.tok_emb.weight = assign_check(gpt.tok_emb.weight, d["wte.weight"]) 24 | 25 | for b in range(config["n_layers"]): 26 | q_w, k_w, v_w = np.split(d[f"h.{b}.attn.c_attn.weight"], 3, axis=-1) 27 | gpt.trf_blocks[b].att.W_query.weight = assign_check(gpt.trf_blocks[b].att.W_query.weight, q_w.T) 28 | gpt.trf_blocks[b].att.W_key.weight = assign_check(gpt.trf_blocks[b].att.W_key.weight, k_w.T) 29 | gpt.trf_blocks[b].att.W_value.weight = assign_check(gpt.trf_blocks[b].att.W_value.weight, v_w.T) 30 | 31 | q_b, k_b, v_b = np.split(d[f"h.{b}.attn.c_attn.bias"], 3, axis=-1) 32 | gpt.trf_blocks[b].att.W_query.bias = assign_check(gpt.trf_blocks[b].att.W_query.bias, q_b) 33 | gpt.trf_blocks[b].att.W_key.bias = assign_check(gpt.trf_blocks[b].att.W_key.bias, k_b) 34 | gpt.trf_blocks[b].att.W_value.bias = assign_check(gpt.trf_blocks[b].att.W_value.bias, v_b) 35 | 36 | 37 | gpt.trf_blocks[b].att.out_proj.weight = assign_check(gpt.trf_blocks[b].att.out_proj.weight, d[f"h.{b}.attn.c_proj.weight"].T) 38 | gpt.trf_blocks[b].att.out_proj.bias = assign_check(gpt.trf_blocks[b].att.out_proj.bias, d[f"h.{b}.attn.c_proj.bias"]) 39 | 40 | gpt.trf_blocks[b].ff.layers[0].weight = assign_check(gpt.trf_blocks[b].ff.layers[0].weight, d[f"h.{b}.mlp.c_fc.weight"].T) 41 | gpt.trf_blocks[b].ff.layers[0].bias = assign_check(gpt.trf_blocks[b].ff.layers[0].bias, d[f"h.{b}.mlp.c_fc.bias"]) 42 | gpt.trf_blocks[b].ff.layers[2].weight = assign_check(gpt.trf_blocks[b].ff.layers[2].weight, d[f"h.{b}.mlp.c_proj.weight"].T) 43 | gpt.trf_blocks[b].ff.layers[2].bias = assign_check(gpt.trf_blocks[b].ff.layers[2].bias, d[f"h.{b}.mlp.c_proj.bias"]) 44 | 45 | gpt.trf_blocks[b].norm1.scale = assign_check(gpt.trf_blocks[b].norm1.scale, d[f"h.{b}.ln_1.weight"]) 46 | gpt.trf_blocks[b].norm1.shift = assign_check(gpt.trf_blocks[b].norm1.shift, d[f"h.{b}.ln_1.bias"]) 47 | gpt.trf_blocks[b].norm2.scale = assign_check(gpt.trf_blocks[b].norm2.scale, d[f"h.{b}.ln_2.weight"]) 48 | gpt.trf_blocks[b].norm2.shift = assign_check(gpt.trf_blocks[b].norm2.shift, d[f"h.{b}.ln_2.bias"]) 49 | 50 | gpt.final_norm.scale = assign_check(gpt.final_norm.scale, d[f"ln_f.weight"]) 51 | gpt.final_norm.shift = assign_check(gpt.final_norm.shift, d[f"ln_f.bias"]) 52 | gpt.out_head.weight = assign_check(gpt.out_head.weight, d["wte.weight"]) 53 | 54 | 55 | def load_hf_weights(model,num_params,config): 56 | 57 | assert num_params in hf_mapping, "A GPT2 model for given number of parameters does not exists." 58 | 59 | gpt_hf = GPT2Model.from_pretrained(hf_mapping[num_params], cache_dir="hf_checkpoints") 60 | gpt_hf.eval() 61 | 62 | load_weights(model, gpt_hf,config) 63 | -------------------------------------------------------------------------------- /Models/Llama/Llama2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.utils.checkpoint import checkpoint_sequential 4 | 5 | import sentencepiece as spm 6 | 7 | from Models.Llama.common_components import compute_rope, FeedForward, RMSNorm 8 | 9 | 10 | class Llama2Tokenizer: 11 | def __init__(self, tokenizer_file): 12 | sp = spm.SentencePieceProcessor() 13 | sp.load(tokenizer_file) 14 | self.tokenizer = sp 15 | 16 | def encode(self, text): 17 | return self.tokenizer.encode_as_ids(text) 18 | 19 | def decode(self, ids): 20 | return self.tokenizer.decode_pieces(ids) 21 | 22 | class MultiHeadAttention(nn.Module): 23 | def __init__(self, d_in, d_out, context_length, num_heads, dtype=None): # ,dropout, num_heads, qkv_bias=False): 24 | super().__init__() 25 | assert d_out % num_heads == 0, "d_out must be divisible by n_heads" 26 | 27 | self.d_out = d_out 28 | self.num_heads = num_heads 29 | self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim 30 | 31 | self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) 32 | self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype) 33 | self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype) 34 | self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) # Linear layer to combine head outputs 35 | 36 | self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)) 37 | 38 | cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length) 39 | self.register_buffer("cos", cos) 40 | self.register_buffer("sin", sin) 41 | 42 | 43 | def forward(self, x): 44 | 45 | b, num_tokens, d_in = x.shape 46 | 47 | keys = self.W_key(x) # Shape: (b, num_tokens, d_out) 48 | queries = self.W_query(x) 49 | values = self.W_value(x) 50 | 51 | # We implicitly split the matrix by adding a `num_heads` dimension 52 | # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) 53 | keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 54 | values = values.view(b, num_tokens, self.num_heads, self.head_dim) 55 | queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) 56 | 57 | # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) 58 | keys = keys.transpose(1, 2) 59 | queries = queries.transpose(1, 2) 60 | values = values.transpose(1, 2) 61 | 62 | keys = compute_rope(keys, self.cos, self.sin) 63 | queries = compute_rope(queries, self.cos, self.sin) 64 | 65 | # Compute scaled dot-product attention (aka self-attention) with a causal mask 66 | attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head 67 | 68 | # Original mask truncated to the number of tokens and converted to boolean 69 | mask_bool = self.mask.bool()[:num_tokens, :num_tokens] 70 | 71 | # Use the mask to fill attention scores 72 | attn_scores.masked_fill_(mask_bool, -torch.inf) 73 | 74 | attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) 75 | 76 | # Shape: (b, num_tokens, num_heads, head_dim) 77 | context_vec = (attn_weights @ values).transpose(1, 2) 78 | 79 | # Combine heads, where self.d_out = self.num_heads * self.head_dim 80 | context_vec = context_vec.reshape(b, num_tokens, self.d_out) 81 | context_vec = self.out_proj(context_vec) # optional projection 82 | 83 | return context_vec 84 | 85 | 86 | 87 | 88 | def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096): 89 | assert head_dim % 2 == 0, "Embedding dimension must be even" 90 | 91 | # Compute the inverse frequencies 92 | inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)) 93 | 94 | # Generate position indices 95 | positions = torch.arange(context_length) 96 | 97 | # Compute the angles 98 | angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) 99 | 100 | # Expand angles to match the head_dim 101 | angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim) 102 | 103 | # Precompute sine and cosine 104 | cos = torch.cos(angles) 105 | sin = torch.sin(angles) 106 | 107 | return cos, sin 108 | 109 | 110 | class TransformerBlock(nn.Module): 111 | def __init__(self, cfg): 112 | super().__init__() 113 | self.att = MultiHeadAttention( 114 | d_in=cfg["emb_dim"], 115 | d_out=cfg["emb_dim"], 116 | context_length=cfg["context_length"], 117 | num_heads=cfg["n_heads"], 118 | dtype=cfg["dtype"] 119 | ) 120 | self.ff = FeedForward(cfg) 121 | 122 | self.norm1 = RMSNorm(cfg["emb_dim"]) 123 | self.norm2 = RMSNorm(cfg["emb_dim"]) 124 | 125 | 126 | def forward(self, x): 127 | # Shortcut connection for attention block 128 | shortcut = x 129 | x = self.norm1(x) 130 | x = self.att(x) # Shape [batch_size, num_tokens, emb_size] 131 | x = x + shortcut # Add the original input back 132 | 133 | # Shortcut connection for feed-forward block 134 | shortcut = x 135 | x = self.norm2(x) 136 | x = self.ff(x) 137 | x = x + shortcut # Add the original input back 138 | 139 | return x 140 | 141 | 142 | class Llama2Model(nn.Module): 143 | def __init__(self, cfg): 144 | super().__init__() 145 | self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) 146 | 147 | self.trf_blocks = nn.Sequential( 148 | *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) 149 | 150 | self.final_norm = RMSNorm(cfg["emb_dim"]) 151 | 152 | self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) 153 | 154 | def forward(self, in_idx): 155 | batch_size, seq_len = in_idx.shape 156 | tok_embeds = self.tok_emb(in_idx) 157 | x = tok_embeds # Shape [batch_size, num_tokens, emb_size] 158 | x = self.trf_blocks(x) 159 | x = self.final_norm(x) 160 | logits = self.out_head(x) 161 | return logits -------------------------------------------------------------------------------- /Models/Llama/Llama3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from pathlib import Path 5 | import tiktoken 6 | from tiktoken.load import load_tiktoken_bpe 7 | from torch.utils.checkpoint import checkpoint_sequential 8 | 9 | from Models.Llama.common_components import compute_rope, FeedForward, RMSNorm 10 | 11 | 12 | 13 | 14 | class Llama3Tokenizer: 15 | def __init__(self, model_path): 16 | assert os.path.isfile(model_path), f"Model file {model_path} not found" 17 | mergeable_ranks = load_tiktoken_bpe(model_path) 18 | 19 | self.special_tokens = { 20 | "<|begin_of_text|>": 128000, 21 | "<|end_of_text|>": 128001, 22 | "<|start_header_id|>": 128006, 23 | "<|end_header_id|>": 128007, 24 | "<|eot_id|>": 128009, 25 | } 26 | self.special_tokens.update({ 27 | f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values() 28 | }) 29 | 30 | self.model = tiktoken.Encoding( 31 | name=Path(model_path).name, 32 | pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", 33 | mergeable_ranks=mergeable_ranks, 34 | special_tokens=self.special_tokens 35 | ) 36 | 37 | 38 | def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()): 39 | if bos: 40 | tokens = [self.special_tokens["<|begin_of_text|>"]] 41 | else: 42 | tokens = [] 43 | 44 | tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special) 45 | 46 | if eos: 47 | tokens.append(self.special_tokens["<|end_of_text|>"]) 48 | return tokens 49 | 50 | def decode(self, tokens): 51 | return self.model.decode(tokens) 52 | 53 | class SharedBuffers: 54 | _buffers = {} 55 | 56 | @staticmethod 57 | def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32): 58 | key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype) 59 | 60 | if key not in SharedBuffers._buffers: 61 | # Create or fetch the buffers 62 | mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) 63 | cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config) 64 | if dtype is not None: 65 | cos = cos.to(dtype) 66 | sin = sin.to(dtype) 67 | SharedBuffers._buffers[key] = (mask, cos, sin) 68 | 69 | return SharedBuffers._buffers[key] 70 | 71 | 72 | def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None): 73 | assert head_dim % 2 == 0, "Embedding dimension must be even" 74 | 75 | # Compute the inverse frequencies 76 | inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim)) 77 | 78 | # Frequency adjustments 79 | if freq_config is not None: 80 | low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"] 81 | high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"] 82 | 83 | wavelen = 2 * torch.pi / inv_freq 84 | 85 | inv_freq_llama = torch.where( 86 | wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq 87 | ) 88 | 89 | smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / ( 90 | freq_config["high_freq_factor"] - freq_config["low_freq_factor"] 91 | ) 92 | 93 | smoothed_inv_freq = ( 94 | (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq 95 | ) 96 | 97 | is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen) 98 | inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) 99 | inv_freq = inv_freq_llama 100 | 101 | 102 | # Generate position indices 103 | positions = torch.arange(context_length) 104 | 105 | # Compute the angles 106 | angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2) 107 | 108 | # Expand angles to match the head_dim 109 | angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim) 110 | 111 | # Precompute sine and cosine 112 | cos = torch.cos(angles) 113 | sin = torch.sin(angles) 114 | 115 | return cos, sin 116 | 117 | 118 | class GroupedQueryAttention(nn.Module): 119 | def __init__( 120 | self, d_in, d_out, context_length, num_heads, 121 | num_kv_groups, 122 | rope_base=10_000, 123 | rope_config=None, 124 | dtype=None 125 | ): 126 | super().__init__() 127 | assert d_out % num_heads == 0, "d_out must be divisible by num_heads" 128 | assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" 129 | 130 | self.d_out = d_out 131 | self.num_heads = num_heads 132 | self.head_dim = d_out // num_heads 133 | 134 | 135 | self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) 136 | self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) 137 | self.num_kv_groups = num_kv_groups 138 | self.group_size = num_heads // num_kv_groups 139 | 140 | 141 | self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) 142 | self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) 143 | 144 | # Fetch buffers using SharedBuffers 145 | mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype) 146 | 147 | 148 | self.register_buffer("mask", mask) 149 | self.register_buffer("cos", cos) 150 | self.register_buffer("sin", sin) 151 | 152 | def forward(self, x): 153 | b, num_tokens, d_in = x.shape 154 | 155 | queries = self.W_query(x) # Shape: (b, num_tokens, d_out) 156 | keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) 157 | values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) 158 | 159 | # Reshape queries, keys, and values 160 | queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) 161 | 162 | keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) 163 | values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) 164 | ################################################ 165 | 166 | # Transpose keys, values, and queries 167 | keys = keys.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim) 168 | values = values.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim) 169 | queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim) 170 | 171 | # Apply RoPE 172 | keys = compute_rope(keys, self.cos, self.sin) 173 | queries = compute_rope(queries, self.cos, self.sin) 174 | 175 | ##################### NEW ##################### 176 | # Expand keys and values to match the number of heads 177 | # Shape: (b, num_heads, num_tokens, head_dim) 178 | 179 | keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) 180 | values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) 181 | # For example, before repeat_interleave along dim=1 (query groups): 182 | # [K1, K2] 183 | # After repeat_interleave (each query group is repeated group_size times): 184 | # [K1, K1, K2, K2] 185 | # If we used regular repeat instead of repeat_interleave, we'd get: 186 | # [K1, K2, K1, K2] 187 | ################################################ 188 | 189 | # Compute scaled dot-product attention (aka self-attention) with a causal mask 190 | # Shape: (b, num_heads, num_tokens, num_tokens) 191 | attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head 192 | 193 | # Original mask truncated to the number of tokens and converted to boolean 194 | mask_bool = self.mask.bool()[:num_tokens, :num_tokens] 195 | 196 | # Use the mask to fill attention scores 197 | attn_scores.masked_fill_(mask_bool, -torch.inf) 198 | 199 | attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) 200 | assert keys.shape[-1] == self.head_dim 201 | 202 | # Shape: (b, num_tokens, num_heads, head_dim) 203 | context_vec = (attn_weights @ values).transpose(1, 2) 204 | 205 | # Combine heads, where self.d_out = self.num_heads * self.head_dim 206 | context_vec = context_vec.reshape(b, num_tokens, self.d_out) 207 | context_vec = self.out_proj(context_vec) # optional projection 208 | 209 | return context_vec 210 | 211 | class TransformerBlock(nn.Module): 212 | def __init__(self, cfg): 213 | super().__init__() 214 | self.att = GroupedQueryAttention( # MultiHeadAttention( 215 | d_in=cfg["emb_dim"], 216 | d_out=cfg["emb_dim"], 217 | context_length=cfg["context_length"], 218 | num_heads=cfg["n_heads"], 219 | num_kv_groups=cfg["n_kv_groups"], 220 | rope_base=cfg["rope_base"], 221 | rope_config=cfg["rope_freq"], 222 | dtype=cfg["dtype"] 223 | ) 224 | self.ff = FeedForward(cfg) 225 | self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-5) 226 | self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-5) 227 | 228 | def forward(self, x): 229 | # Shortcut connection for attention block 230 | shortcut = x 231 | x = self.norm1(x) 232 | #After using normalization, make sure the output is of correct dtype. 233 | x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size] 234 | x = x + shortcut # Add the original input back 235 | 236 | # Shortcut connection for feed-forward block 237 | shortcut = x 238 | x = self.norm2(x) 239 | #After using normalization, make sure the output is of correct dtype. 240 | x = self.ff(x.to(torch.bfloat16)) 241 | x = x + shortcut # Add the original input back 242 | 243 | return x 244 | 245 | class Llama3Model(nn.Module): 246 | def __init__(self, cfg, use_actv_ckpt=False): 247 | super().__init__() 248 | self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) 249 | 250 | self.use_actv_ckpt = use_actv_ckpt 251 | self.cfg=cfg 252 | 253 | self.trf_blocks = nn.Sequential( 254 | *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) 255 | 256 | self.final_norm = RMSNorm(cfg["emb_dim"], eps=1e-5) 257 | self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) 258 | 259 | def forward(self, in_idx): 260 | tok_embeds = self.tok_emb(in_idx) 261 | x = tok_embeds 262 | if(self.use_actv_ckpt): 263 | x = checkpoint_sequential(self.trf_blocks, segments=self.cfg["n_layers"], input=x, use_reentrant=False) 264 | else: 265 | x = self.trf_blocks(x) 266 | x = self.final_norm(x) 267 | logits = self.out_head(x.to(torch.bfloat16))#After using normalization, make sure the output is of correct dtype. 268 | return logits -------------------------------------------------------------------------------- /Models/Llama/common_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def compute_rope(x, cos, sin): 5 | # x: (batch_size, num_heads, seq_len, head_dim) 6 | batch_size, num_heads, seq_len, head_dim = x.shape 7 | assert head_dim % 2 == 0, "Head dimension must be even" 8 | 9 | # Split x into first half and second half 10 | x1 = x[..., : head_dim // 2] # First half 11 | x2 = x[..., head_dim // 2 :] # Second half 12 | 13 | # Adjust sin and cos shapes 14 | cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim) 15 | sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) 16 | 17 | # Apply the rotary transformation 18 | rotated = torch.cat((-x2, x1), dim=-1) 19 | x_rotated = (x * cos) + (rotated * sin) 20 | 21 | return x_rotated.to(dtype=x.dtype) 22 | 23 | def rescale_theta(theta_old, context_length_old, context_length_new): 24 | scaling_factor = context_length_new / context_length_old 25 | theta_new = theta_old * scaling_factor 26 | return theta_new 27 | 28 | 29 | # Can use torch.nn.RMSNorm() instead. 30 | class RMSNorm(nn.Module): 31 | def __init__(self, emb_dim, eps=1e-5): 32 | super().__init__() 33 | self.eps = eps 34 | self.emb_dim = emb_dim 35 | self.weight = nn.Parameter(torch.ones(emb_dim)).float() 36 | 37 | def forward(self, x): 38 | means = x.pow(2).mean(dim=-1, keepdim=True) 39 | x_normed = x * torch.rsqrt(means + self.eps) 40 | return (x_normed * self.weight).to(dtype=x.dtype) 41 | 42 | 43 | # Same as torch.nn.functional.silu 44 | class SiLU(nn.Module): 45 | def __init__(self): 46 | super(SiLU, self).__init__() 47 | 48 | def forward(self, x): 49 | return x * torch.sigmoid(x) 50 | 51 | 52 | class FeedForward(nn.Module): 53 | def __init__(self, cfg): 54 | super().__init__() 55 | self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) 56 | self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) 57 | self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False) 58 | self.silu = SiLU() 59 | 60 | def forward(self, x): 61 | x_fc1 = self.fc1(x) 62 | x_fc2 = self.fc2(x) 63 | x = self.silu(x_fc1) * x_fc2 64 | return self.fc3(x) -------------------------------------------------------------------------------- /Models/Llama/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Models.Llama.common_components import rescale_theta 3 | 4 | LLAMA2_CONFIG_7B = { 5 | "vocab_size": 32000, # Vocabulary size 6 | "context_length": 4096, # Context length 7 | "emb_dim": 4096, # Embedding dimension 8 | "n_heads": 32, # Number of attention heads 9 | "n_layers": 32, # Number of layers 10 | "hidden_dim": 11008, # Size of the intermediate dimension in FeedForward 11 | "dtype": torch.bfloat16, # Lower-precision dtype to save memory 12 | } 13 | 14 | LLAMA3_CONFIG_8B = { 15 | "vocab_size": 128_256, # Larger vocabulary size 16 | "context_length": 8192, # Larger context length 17 | "emb_dim": 4096, # Embedding dimension 18 | "n_heads": 32, # Number of attention heads 19 | "n_layers": 32, # Number of layers 20 | "hidden_dim": 14_336, # Larger size of the intermediate dimension in FeedForward 21 | "n_kv_groups": 8, # Key-Value groups for grouped-query attention 22 | "rope_base": 500_000.0, # The base in RoPE's "theta" was increased to 500_000 23 | "rope_freq": None, # Additional configuration for adjusting the RoPE frequencies 24 | "dtype": torch.bfloat16, # Lower-precision dtype to save memory 25 | "eos_id":128001, 26 | "eos_text":"<|end_of_text|>" 27 | } 28 | 29 | LLAMA31_CONFIG_8B = { 30 | "vocab_size": 128_256, # Vocabulary size 31 | "context_length": 131_072, # Larger supported context length 32 | "emb_dim": 4096, # Embedding dimension 33 | "n_heads": 32, # Number of attention heads 34 | "n_layers": 32, # Number of layers 35 | "hidden_dim": 14_336, # Size of the intermediate dimension in FeedForward 36 | "n_kv_groups": 8, # Key-Value groups for grouped-query attention 37 | "rope_base": 500_000.0, # The base in RoPE's "theta" 38 | "dtype": torch.bfloat16, # Lower-precision dtype to save memory 39 | "rope_freq": { # RoPE frequency scaling 40 | "factor": 8.0, 41 | "low_freq_factor": 1.0, 42 | "high_freq_factor": 4.0, 43 | "original_context_length": 8192, 44 | }, 45 | "eos_id":128001, 46 | "eos_text":"<|end_of_text|>" 47 | } 48 | 49 | LLAMA32_CONFIG_1B = { 50 | "vocab_size": 128_256, # Vocabulary size 51 | "context_length": 131_072, # Context length 52 | "emb_dim": 2048, # Half the embedding dimension 53 | "n_heads": 32, # Number of attention heads 54 | "n_layers": 16, # Half the number of layers 55 | "hidden_dim": 8192, # Almost half the size of the intermediate dimension in FeedForward 56 | "n_kv_groups": 8, # Key-Value groups for grouped-query attention 57 | "rope_base": 500_000.0, # The base in RoPE's "theta" 58 | "dtype": torch.bfloat16, # Lower-precision dtype to save memory 59 | "rope_freq": { # RoPE frequency scaling 60 | "factor": 32.0, # Adjustment of the rescaling factor 61 | "low_freq_factor": 1.0, 62 | "high_freq_factor": 4.0, 63 | "original_context_length": 8192, 64 | }, 65 | "eos_id":128001, 66 | "eos_text":"<|end_of_text|>" 67 | } 68 | 69 | available_configs_llama2 = { 70 | "7B":LLAMA2_CONFIG_7B 71 | } 72 | 73 | available_configs_llama3 = { 74 | "8B":LLAMA3_CONFIG_8B 75 | } 76 | 77 | available_configs_llama3_1 = { 78 | "8B":LLAMA31_CONFIG_8B 79 | } 80 | 81 | available_configs_llama3_2 = { 82 | "1B":LLAMA32_CONFIG_1B 83 | } 84 | 85 | def get_config_llama(num_params,model_name): 86 | 87 | num_params = str(num_params) 88 | 89 | available_configs = globals().get(f"available_configs_{model_name}", None) 90 | 91 | assert num_params in available_configs, f"A {model_name} model for given number of parameters does not exists." 92 | 93 | config = available_configs[num_params] 94 | 95 | old_context_length = config["context_length"] 96 | if(old_context_length!=1024): 97 | 98 | config["context_length"] = 1024 #8192 99 | 100 | config["rope_base"] = rescale_theta( 101 | config["rope_base"], 102 | old_context_length, 103 | config["context_length"] 104 | ) 105 | 106 | print("New RoPE theta:", config["rope_base"]) 107 | 108 | return config 109 | 110 | 111 | def get_config_llama2(num_params): 112 | 113 | num_params = str(num_params) 114 | 115 | assert num_params in available_configs_llama2, "A llama2 model for given number of parameters does not exists." 116 | 117 | return available_configs_llama2[num_params] 118 | 119 | 120 | def get_config_llama3(): 121 | 122 | return LLAMA3_CONFIG_8B 123 | 124 | def get_config_llam31(): 125 | 126 | old_context_length = LLAMA31_CONFIG_8B["context_length"] 127 | LLAMA31_CONFIG_8B["context_length"] = 1024 #8192 128 | 129 | LLAMA31_CONFIG_8B["rope_base"] = rescale_theta( 130 | LLAMA31_CONFIG_8B["rope_base"], 131 | old_context_length, 132 | LLAMA31_CONFIG_8B["context_length"] 133 | ) 134 | 135 | print("New RoPE theta:", LLAMA31_CONFIG_8B["rope_base"]) 136 | 137 | return LLAMA31_CONFIG_8B 138 | 139 | def get_config_llam32(): 140 | 141 | old_context_length = LLAMA32_CONFIG_1B["context_length"] 142 | LLAMA32_CONFIG_1B["context_length"] = 1024 #8192 143 | 144 | LLAMA32_CONFIG_1B["rope_base"] = rescale_theta( 145 | LLAMA32_CONFIG_1B["rope_base"], 146 | old_context_length, 147 | LLAMA32_CONFIG_1B["context_length"] 148 | ) 149 | 150 | print("New RoPE theta:", LLAMA32_CONFIG_1B["rope_base"]) 151 | 152 | return LLAMA32_CONFIG_1B -------------------------------------------------------------------------------- /Models/Llama/load_weights_llama2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from huggingface_hub import hf_hub_download 3 | 4 | def assign(left, right): 5 | if left.shape != right.shape: 6 | raise ValueError(f"Shape mismatch. Left: {left.shape}, Right: {right.shape}") 7 | 8 | if isinstance(right, torch.Tensor): 9 | return torch.nn.Parameter(right.clone().detach()) 10 | else: 11 | return torch.nn.Parameter(torch.tensor(right)) 12 | 13 | 14 | def load_weights_into_llama(model, param_config, params): 15 | model.tok_emb.weight = assign(model.tok_emb.weight, params["tok_embeddings.weight"]) 16 | 17 | for l in range(param_config["n_layers"]): 18 | 19 | # Load attention weights 20 | model.trf_blocks[l].att.W_query.weight = assign( 21 | model.trf_blocks[l].att.W_query.weight, 22 | params[f"layers.{l}.attention.wq.weight"] 23 | ) 24 | model.trf_blocks[l].att.W_key.weight = assign( 25 | model.trf_blocks[l].att.W_key.weight, 26 | params[f"layers.{l}.attention.wk.weight"] 27 | ) 28 | model.trf_blocks[l].att.W_value.weight = assign( 29 | model.trf_blocks[l].att.W_value.weight, 30 | params[f"layers.{l}.attention.wv.weight"] 31 | ) 32 | model.trf_blocks[l].att.out_proj.weight = assign( 33 | model.trf_blocks[l].att.out_proj.weight, 34 | params[f"layers.{l}.attention.wo.weight"] 35 | ) 36 | model.trf_blocks[l].norm1.weight = assign( 37 | model.trf_blocks[l].norm1.weight, 38 | params[f"layers.{l}.attention_norm.weight"] 39 | ) 40 | 41 | # Load FeedForward weights 42 | model.trf_blocks[l].ff.fc1.weight = assign( 43 | model.trf_blocks[l].ff.fc1.weight, 44 | params[f"layers.{l}.feed_forward.w1.weight"] 45 | ) 46 | # For some reason w2 and w3 are provided in the wrong order in the weights file 47 | model.trf_blocks[l].ff.fc2.weight = assign( 48 | model.trf_blocks[l].ff.fc2.weight, 49 | params[f"layers.{l}.feed_forward.w3.weight"] 50 | ) 51 | model.trf_blocks[l].ff.fc3.weight = assign( 52 | model.trf_blocks[l].ff.fc3.weight, 53 | params[f"layers.{l}.feed_forward.w2.weight"] 54 | ) 55 | model.trf_blocks[l].norm2.weight = assign( 56 | model.trf_blocks[l].norm2.weight, 57 | params[f"layers.{l}.ffn_norm.weight"] 58 | ) 59 | 60 | # Load output layer weights 61 | model.final_norm.weight = assign(model.final_norm.weight, params["norm.weight"]) 62 | model.out_head.weight = assign(model.out_head.weight, params["output.weight"]) 63 | 64 | 65 | def load_hf_weights(model,config): 66 | 67 | weights_file = hf_hub_download( 68 | repo_id="meta-llama/Llama-2-7b", 69 | filename="consolidated.00.pth", 70 | local_dir="Llama-2-7b" 71 | ) 72 | 73 | 74 | weights = torch.load(weights_file, weights_only=True) 75 | 76 | load_weights_into_llama(model, config, weights) -------------------------------------------------------------------------------- /Models/Llama/load_weights_llama3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from huggingface_hub import hf_hub_download 3 | from safetensors.torch import load_file 4 | 5 | def assign(left, right, tensor_name="unknown"): 6 | if left.shape != right.shape: 7 | raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}") 8 | 9 | if isinstance(right, torch.Tensor): 10 | return torch.nn.Parameter(right.clone().detach()) 11 | else: 12 | return torch.nn.Parameter(torch.tensor(right)) 13 | 14 | 15 | def load_weights_into_llama(model, param_config, params): 16 | model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") 17 | 18 | for l in range(param_config["n_layers"]): 19 | 20 | # Load attention weights 21 | model.trf_blocks[l].att.W_query.weight = assign( 22 | model.trf_blocks[l].att.W_query.weight, 23 | params[f"model.layers.{l}.self_attn.q_proj.weight"], 24 | f"model.layers.{l}.self_attn.q_proj.weight" 25 | ) 26 | model.trf_blocks[l].att.W_key.weight = assign( 27 | model.trf_blocks[l].att.W_key.weight, 28 | params[f"model.layers.{l}.self_attn.k_proj.weight"], 29 | f"model.layers.{l}.self_attn.k_proj.weight" 30 | ) 31 | model.trf_blocks[l].att.W_value.weight = assign( 32 | model.trf_blocks[l].att.W_value.weight, 33 | params[f"model.layers.{l}.self_attn.v_proj.weight"], 34 | f"model.layers.{l}.self_attn.v_proj.weight" 35 | ) 36 | model.trf_blocks[l].att.out_proj.weight = assign( 37 | model.trf_blocks[l].att.out_proj.weight, 38 | params[f"model.layers.{l}.self_attn.o_proj.weight"], 39 | f"model.layers.{l}.self_attn.o_proj.weight" 40 | ) 41 | model.trf_blocks[l].norm1.weight = assign( 42 | model.trf_blocks[l].norm1.weight, 43 | params[f"model.layers.{l}.input_layernorm.weight"], 44 | f"model.layers.{l}.input_layernorm.weight" 45 | ) 46 | 47 | # Load FeedForward weights 48 | model.trf_blocks[l].ff.fc1.weight = assign( 49 | model.trf_blocks[l].ff.fc1.weight, 50 | params[f"model.layers.{l}.mlp.gate_proj.weight"], 51 | f"model.layers.{l}.mlp.gate_proj.weight" 52 | ) 53 | model.trf_blocks[l].ff.fc2.weight = assign( 54 | model.trf_blocks[l].ff.fc2.weight, 55 | params[f"model.layers.{l}.mlp.up_proj.weight"], 56 | f"model.layers.{l}.mlp.up_proj.weight" 57 | ) 58 | model.trf_blocks[l].ff.fc3.weight = assign( 59 | model.trf_blocks[l].ff.fc3.weight, 60 | params[f"model.layers.{l}.mlp.down_proj.weight"], 61 | f"model.layers.{l}.mlp.down_proj.weight" 62 | ) 63 | model.trf_blocks[l].norm2.weight = assign( 64 | model.trf_blocks[l].norm2.weight, 65 | params[f"model.layers.{l}.post_attention_layernorm.weight"], 66 | f"model.layers.{l}.post_attention_layernorm.weight" 67 | ) 68 | 69 | # Load output layer weights 70 | model.final_norm.weight = assign(model.final_norm.weight, params["model.norm.weight"], "model.norm.weight") 71 | 72 | if "lm_head.weight" in params.keys(): 73 | model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight") 74 | else: 75 | model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight") 76 | print("Model uses weight tying.") 77 | 78 | 79 | def load_hf_weights(model,model_name,config): 80 | 81 | combined_weights = {} 82 | 83 | if(model_name=="llama3"): 84 | 85 | for i in range(1, 5): 86 | weights_file = hf_hub_download( 87 | repo_id="meta-llama/Meta-Llama-3-8B", 88 | filename=f"model-0000{i}-of-00004.safetensors", 89 | local_dir="Llama-3-8B" 90 | ) 91 | current_weights = load_file(weights_file) 92 | combined_weights.update(current_weights) 93 | 94 | elif(model_name=="llama3_1"): 95 | 96 | for i in range(1, 5): 97 | weights_file = hf_hub_download( 98 | repo_id="meta-llama/Llama-3.1-8B", 99 | filename=f"model-0000{i}-of-00004.safetensors", 100 | local_dir="Llama-3.1-8B" 101 | ) 102 | current_weights = load_file(weights_file) 103 | combined_weights.update(current_weights) 104 | 105 | elif(model_name=="llama3_2"): 106 | 107 | weights_file = hf_hub_download( 108 | repo_id="meta-llama/Llama-3.2-1B", 109 | filename=f"model.safetensors", 110 | local_dir="Llama-3.2-1B" 111 | ) 112 | combined_weights = load_file(weights_file) 113 | 114 | load_weights_into_llama(model, config, combined_weights) 115 | 116 | del combined_weights -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # LLM Training 4 | 5 | This repository provides a full implementation of GPT2 and Llama2/Llama3 models from scratch in Python. It allows for flexible training configurations, including single and multi-GPU distributed training, activation checkpointing, FSDP (Fully Sharded Data Parallel), mixed precision training, and more. The repo also supports pretraining on raw text and finetuning on instruction datasets. Additionally, LoRA (Low-Rank Adaptation) integration is included for memory-efficient model finetuning. 6 | 7 | ## Features 8 | 9 | - **GPT2 & Llama2/Llama3 Models**: Implemented from scratch. 10 | - **Pretrained Weights**: Load pretrained weights from Hugging Face. 11 | - **Training Configurations**: 12 | - Single GPU & Multi-GPU (Distributed Training). 13 | - Activation Checkpointing. 14 | - FSDP (Fully Sharded Data Parallel). 15 | - Mixed Precision Training (using `torch.cuda.amp`). 16 | - LoRA (Low-Rank Adaptation) support for efficient finetuning. 17 | - **Training Modes**: 18 | - Pretraining on raw text. 19 | - Finetuning on instruction datasets. 20 | 21 | 22 | ## Requirements 23 | 24 | - Python 3.8+ 25 | - PyTorch 1.10+ (with GPU support) 26 | - Hugging Face Transformers 27 | - CUDA (for GPU support) 28 | - Other dependencies listed in `requirements.txt` 29 | 30 | To install dependencies, run: 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | To download weights from hugging face, you need to first login into hugging face and get the access token which will go towards HF_ACCESS_TOKEN under `config_hf.json` file. 36 | 37 | The repo code has been tested on aws instances - g4dn.xlarge for single-GPU run and g4dn.12xlarge for multi-GPU run. 38 | 39 | ## Quick Start 40 | 41 | To run the training, the entry point is the `main.py` file. Here's how to get started: 42 | 43 | ### Download Gutenberg dataset for pretraining on raw text 44 | 45 | Run setup.sh under Datasets/Gutenberg folder 46 | 47 | ```bash 48 | bash setup.sh 49 | ``` 50 | 51 | ### Download Alpaca dataset for finetuning on instruction dataset 52 | 53 | Run setup.sh under Datasets/Alpaca folder 54 | 55 | ```bash 56 | bash setup.sh 57 | ``` 58 | 59 | ### Example 1: Pretrain GPT2 on Raw Text 60 | 61 | ```bash 62 | python main.py --data_dir ./Datasets/Gutenberg/data_dir --load_weights 63 | ``` 64 | 65 | ### Example 2: Distributed Data Parallel training with Low rank adapatation (LoRA) 66 | 67 | ```bash 68 | python main.py --data_dir ./Datasets/Gutenberg/data_dir --load_weights --model GPT2 --run_type multi_gpu --use_lora 69 | ``` 70 | 71 | ### Example 3: Enable Activation Checkpointing 72 | 73 | ```bash 74 | python main.py --data_dir ./Datasets/Gutenberg/data_dir --load_weights --model GPT2 --num_params 774M --run_type multi_gpu --use_lora --use_actv_ckpt 75 | ``` 76 | 77 | ### Example 4: Fully-Sharded Data Parallelism 78 | 79 | ```bash 80 | python main.py --data_dir ./Datasets/Gutenberg/data_dir --load_weights --model GPT2 --num_params 774M --run_type multi_gpu --use_actv_ckpt --use_lora --use_fsdp 81 | ``` 82 | 83 | ### Example 5: Finetune Llama3.2 on Instruction Dataset 84 | 85 | ```bash 86 | python main.py --dataset alpaca --data_dir ./Datasets/Alpaca/data --load_weights --model llama3_2 --num_params 1B --finetune --run_type multi_gpu --use_actv_ckpt --use_lora --lr 1e-5 --data_type bf16 87 | ``` 88 | 89 | --- 90 | ## Model configurations supported 91 | 92 | - GPT2 - 124M, 355M, 774M, 1.5B 93 | - Llama2 - 7B 94 | - Llama3 - 8B 95 | - Llama3.1 - 8B 96 | - Llama3.2 - 1B 97 | 98 | --- 99 | ## Arguments 100 | 101 | Here are some common arguments you can use to configure your training: 102 | 103 | - `--model` : Type of model to train (`GPT2`, `llama2`, `llama3`). 104 | - `--num_params` : Choose the model size that you want to train. 105 | - `--data_type` : Datatype for the model. 106 | - `--load_weights` : Use this argument if you want to load weights from hugging face. 107 | - `--n_epochs` : Number of epochs for training. 108 | - `--batch_size` : Batch size for training. 109 | - `--run_type` : 'single_gpu' is default. 'multi_gpu' to enable distributed data parallel training. 110 | - `--lr` : Learning rate to use after finishing warmup steps. By default we are using cosine annealing as lr schedule. 111 | - `--warmup_steps` : Number of warmup steps for training. 112 | - `--dataset` : Dataset to be used for training. 113 | - `--finetune` : To activate instruction finetuning. Default is to pretrain on raw text. 114 | - `--use_zero_opt` : Activate Zero optimzer. 115 | - `--use_lora` : Use LoRA for model finetuning. 116 | - `--lora_rank` : Rank value for LoRA. 117 | - `--lora_alpha` : Alpha value for LoRA. 118 | - `--use_fsdp` : Enable FSDP (Fully Sharded Data Parallel) for distributed training. 119 | - `--mixed_precision` : Use mixed precision training for better performance.(Only supported with mixed precision at this time). 120 | - `--use_actv_ckpt` : Enable activation checkpointing to reduce memory usage. 121 | 122 | --- 123 | 124 | ## Advanced Usage 125 | 126 | ### Mixed Precision Training 127 | 128 | This repository supports mixed precision training, which speeds up training and reduces memory consumption. You can enable it using the `--mixed_precision` flag. Only supported with FSDP at this time. 129 | 130 | ### FSDP (Fully Sharded Data Parallel) 131 | 132 | For multi-GPU training, FSDP is supported for better memory efficiency and performance. Simply use the `--use_fsdp` argument to enable it. 133 | 134 | ### LoRA (Low-Rank Adaptation) 135 | 136 | You can enable LoRA for efficient finetuning by using the `--use_lora` argument. LoRA reduces the computational cost while maintaining good performance on downstream tasks. 137 | 138 | ### Activation Checkpointing 139 | 140 | Activation checkpointing is implemented to allow training on large models by saving memory. You can enable it with the `--use_actv_ckpt` flag. 141 | 142 | ## Notes 143 | 144 | - **GPU Setup**: Ensure that you have CUDA installed and properly configured. 145 | - **Pretrained Weights**: You can easily load pretrained weights from Hugging Face for your models. Simply pass the `--load_weights` flag to `main.py` script. 146 | - **Scalability**: This code is designed to scale across multiple GPUs and large datasets. 147 | 148 | --- 149 | 150 | ## Contributing 151 | 152 | Feel free to fork this repository and open issues for any questions. 153 | 154 | --- 155 | 156 | ## License 157 | 158 | This code is licensed under the Apache License. See the `LICENSE` file for more information. 159 | 160 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | import os 4 | import torch 5 | from utils import model_params_mapping 6 | 7 | 8 | def perform_checks(args): 9 | 10 | """Performs validation of input arguments.""" 11 | 12 | if(not args.warnings): 13 | warnings.filterwarnings("ignore") 14 | 15 | if(not os.path.exists(args.data_dir)): 16 | raise Exception("The data dir path specified does not exists.") 17 | 18 | if(args.num_params not in model_params_mapping[args.model]): 19 | raise Exception(f"You are asking to load {args.num_params} configuration for {args.model} model. This configuration is currently not supported.") 20 | 21 | if(args.run_type=="single_gpu" and args.use_fsdp): 22 | raise Exception("FSDP not supported on single GPU non-distributed training.") 23 | 24 | if(args.use_zero_opt and args.use_fsdp): 25 | raise Exception("Zero Optimizer is not supported with FSDP.") 26 | 27 | if(args.use_fsdp and not torch.cuda.is_available()): 28 | raise Exception("FSDP can only be activated when CUDA device is available.") 29 | 30 | if(not args.use_fsdp and args.mixed_precision): 31 | raise Exception("Mixed precision is only supported with FSDP at this time.") 32 | 33 | 34 | def get_args(): 35 | 36 | """Get command line arguments.""" 37 | 38 | parser = argparse.ArgumentParser(description='Model Training Configuration') 39 | 40 | parser.add_argument('--data_dir', type=str, default='/home/ec2-user/train-llm-from-scratch/Datasets/Gutenberg/data_dir_small', 41 | help='Directory containing the training data') 42 | parser.add_argument('--output_dir', type=str, default='model_checkpoints', 43 | help='Directory where the model checkpoints will be saved') 44 | parser.add_argument('--n_epochs', type=int, default=2, 45 | help='Number of epochs to train the model') 46 | parser.add_argument('--print_sample_iter', type=int, default=10, 47 | help='Iterations between printing sample outputs') 48 | parser.add_argument('--eval_freq', type=int, default=10, 49 | help='Frequency of evaluations during training') 50 | parser.add_argument('--save_ckpt_freq', type=int, default=100, 51 | help='Frequency of saving model checkpoints during training') 52 | parser.add_argument('--lr', type=float, default=5e-4, 53 | help='Learning rate for the optimizer') 54 | parser.add_argument('--batch_size', type=int, default=4, 55 | help='Batch size for training') 56 | parser.add_argument('--warmup_steps', type=int, default=10, 57 | help='Warmup steps for training.') 58 | parser.add_argument('--initial_lr', type=float, default=1e-05, 59 | help='Initial learning rate.') 60 | parser.add_argument('--min_lr', type=float, default=1e-6, 61 | help='Minimum learning rate.') 62 | parser.add_argument('--debug', action="store_true", 63 | help='Uses a very small model for debugging purposes') 64 | parser.add_argument('--model', type=str, default="GPT2", 65 | choices=["GPT2","llama2","llama3","llama3_1","llama3_2"], 66 | help='The model to use.') 67 | parser.add_argument('--num_params', type=str, default="124M", 68 | help='Model size.') 69 | parser.add_argument('--load_weights', action="store_true", 70 | help='Do we need to load pretrained weights?') 71 | parser.add_argument('--data_type',type=str,default="fp32", 72 | choices=["fp32","fp16","bf16"], 73 | help="Datatype to use.bf16 is better choice for training compared to fp16 due to stability reasons.") 74 | parser.add_argument('--run_type',type=str,default="single_gpu", 75 | choices=['single_gpu', 'multi_gpu'], 76 | help="How to optmize the run? Should be multi_gpu for FSDP.") 77 | parser.add_argument('--use_zero_opt',action="store_true", 78 | help="Don't use with FSDP. Use Zero Redeundancy Optimizer") 79 | parser.add_argument('--use_actv_ckpt',action="store_true", 80 | help="Activation checkpointing") 81 | parser.add_argument('--use_fsdp',action="store_true", 82 | help="Fully Sharded Data Parallelism. Requires multi-gpu run") 83 | parser.add_argument('--mixed_precision',type=str, 84 | choices=['fp16', 'bf16'], 85 | help="Mixed precision to be used for FSDP.") 86 | parser.add_argument('--finetune',action="store_true", 87 | help="Enable finetuning.") 88 | parser.add_argument('--dataset',type=str,default="gutenberg", 89 | choices=['gutenberg', 'alpaca'], 90 | help="Dataset to be used.") 91 | parser.add_argument('--use_lora',action="store_true", 92 | help="Enable LoRA training.") 93 | parser.add_argument('--lora_rank',type=int,default=64, 94 | help="Rank value for LoRA.") 95 | parser.add_argument('--lora_alpha',type=int,default=32, 96 | help="Alpha value for LoRA.") 97 | parser.add_argument('--warnings',action="store_true", 98 | help="Turn the warnings on.") 99 | 100 | args = parser.parse_args() 101 | 102 | perform_checks(args) 103 | 104 | return args -------------------------------------------------------------------------------- /build_components.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | import torch 3 | import torch.nn as nn 4 | 5 | import torch.multiprocessing as mp 6 | from torch.distributed import init_process_group, destroy_process_group 7 | from torch.nn.parallel import DistributedDataParallel as DDP 8 | from torch.distributed.optim import ZeroRedundancyOptimizer 9 | 10 | 11 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 12 | from torch.distributed.fsdp import FullStateDictConfig, StateDictType 13 | from torch.distributed.fsdp import ShardingStrategy, BackwardPrefetch 14 | from torch.distributed.fsdp.fully_sharded_data_parallel import ( 15 | CPUOffload, 16 | BackwardPrefetch, 17 | ) 18 | from torch.distributed.fsdp.wrap import ( 19 | transformer_auto_wrap_policy, 20 | size_based_auto_wrap_policy, 21 | ModuleWrapPolicy, 22 | enable_wrap, 23 | wrap, 24 | ) 25 | 26 | 27 | import utils 28 | from Models.GPT2.config import get_config_gpt2 29 | from Models.Llama.config import get_config_llama 30 | 31 | from Models.GPT2.GPT2 import GPTModel 32 | from Models.Llama.Llama2 import Llama2Model 33 | from Models.Llama.Llama3 import Llama3Model 34 | 35 | from Models.Llama.Llama2 import Llama2Tokenizer 36 | from Models.Llama.Llama3 import Llama3Tokenizer 37 | 38 | from Models.Llama.common_components import rescale_theta 39 | 40 | from huggingface_hub import hf_hub_download 41 | from huggingface_hub import login 42 | 43 | from logger import setup_logger 44 | 45 | # Create a logger specific to this module 46 | logger = setup_logger('build_components') 47 | 48 | 49 | def build_config(args): 50 | 51 | """Build and returns config dictionary.""" 52 | 53 | 54 | if(args.model=="GPT2"): 55 | 56 | config = get_config_gpt2(args.num_params) 57 | 58 | elif(args.model.startswith("llama")): 59 | 60 | config = get_config_llama(args.num_params,args.model) 61 | 62 | config.update({"dtype":utils.datatype_mapping[args.data_type]}) 63 | 64 | if(args.load_weights and args.model=="GPT2"): 65 | 66 | config.update({"qkv_bias":True}) 67 | 68 | if args.debug: 69 | 70 | config.update({ 71 | "context_length": 10, # Context length 72 | "emb_dim": 32, # Embedding dimension 73 | "hidden_dim": 10, # Hidden dimension of feedforward layer 74 | "n_heads": 16, # Number of attention heads 75 | "n_layers": 2, # Number of layers 76 | "qkv_bias": False # Query-key-value bias 77 | }) 78 | 79 | 80 | return config 81 | 82 | def load_model_weights(args,config,model): 83 | 84 | utils.login_hf() 85 | if(args.model=="GPT2"): 86 | 87 | from Models.GPT2.load_weights import load_hf_weights 88 | 89 | load_hf_weights(model,args.num_params,config) 90 | 91 | if(args.model=="llama2"): 92 | 93 | from Models.Llama.load_weights_llama2 import load_hf_weights 94 | 95 | load_hf_weights(model, config) 96 | 97 | 98 | if(args.model.startswith("llama3")): 99 | 100 | from Models.Llama.load_weights_llama3 import load_hf_weights 101 | 102 | load_hf_weights(model, args.model, config) 103 | 104 | def activate_lora(args,cfg,model): 105 | 106 | from lora import replace_linear_with_lora 107 | 108 | logger.info("Using LoRA...") 109 | 110 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 111 | logger.info(f"Total trainable parameters before: {total_params:,}") 112 | 113 | logger.info("Turning the weights off ...") 114 | 115 | for param in model.parameters(): 116 | param.requires_grad = False 117 | 118 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 119 | logger.info(f"Total trainable parameters after: {total_params:,}") 120 | 121 | replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, dtype=cfg["dtype"]) 122 | 123 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 124 | logger.info(f"Total trainable LoRA parameters: {total_params:,}") 125 | 126 | 127 | def multigpu_setup(args,rank,model): 128 | 129 | if(torch.cuda.is_available()): #FSDP is only possible with GPU. 130 | 131 | fsdp_kwargs = { 132 | "sharding_strategy" : ShardingStrategy.FULL_SHARD, 133 | "cpu_offload" : None, 134 | "backward_prefetch" : BackwardPrefetch.BACKWARD_PRE, 135 | "mixed_precision" : None, 136 | "sync_module_states" : False, 137 | "device_id":torch.cuda.current_device(), 138 | "use_orig_params":False 139 | } 140 | 141 | if(args.use_fsdp): 142 | 143 | from datautils.mixed_precision import fpSixteen,bfSixteen 144 | mixed_precision_policy=None 145 | if(args.mixed_precision): 146 | if(args.mixed_precision=="fp16"): 147 | mixed_precision_policy = fpSixteen 148 | elif(args.mixed_precision=="bf16"): 149 | mixed_precision_policy = bfSixteen 150 | 151 | fsdp_kwargs.update({"mixed_precision":mixed_precision_policy}) 152 | 153 | if(args.use_lora): 154 | fsdp_kwargs.update({"use_orig_params":True}) 155 | # #ignored_modules = [module for name, module in model.named_modules() if ".lora" not in name] 156 | # ignored_modules=[] 157 | # for name, module in model.named_modules(): 158 | # if(".lora" not in name): 159 | # ignored_modules.append(module) 160 | # fsdp_kwargs.update({"ignored_modules":ignored_modules}) 161 | 162 | from Models.GPT2.GPT2 import TransformerBlock 163 | # my_auto_wrap_policy = functools.partial( 164 | # size_based_auto_wrap_policy, min_num_params=100 165 | # ) 166 | # trf_auto_wrap_policy = functools.partial( 167 | # transformer_auto_wrap_policy, 168 | # transformer_layer_cls={ 169 | # TransformerBlock, 170 | # }, 171 | # ) 172 | #my_auto_wrap_policy=None 173 | my_auto_wrap_policy = ModuleWrapPolicy(module_classes=[nn.Embedding,TransformerBlock]) 174 | 175 | model = FSDP(model, 176 | auto_wrap_policy=my_auto_wrap_policy,**fsdp_kwargs) 177 | else: 178 | model = DDP(model, device_ids=[rank]) 179 | 180 | if(rank==0): 181 | logger.info("Model wrapped with DDP/FSDP ....") 182 | utils.print_memory_usage() 183 | 184 | return model 185 | 186 | def build_model(config,rank,device,args): 187 | 188 | """ 189 | 190 | Args: 191 | config: config dictionary object. 192 | rank: rank of process. 193 | device: CUDA device being used. 194 | args: command line input arguments. 195 | 196 | Returns: 197 | model: Instance of model class. 198 | 199 | """ 200 | 201 | utils.start_memory_tracking() 202 | 203 | if(args.model=="GPT2"): 204 | model = GPTModel(config,args.use_actv_ckpt) 205 | elif(args.model=="llama2"): 206 | model = Llama2Model(config,args.use_actv_ckpt) 207 | elif(args.model.startswith("llama3")): 208 | model = Llama3Model(config,args.use_actv_ckpt) 209 | else: 210 | raise Exception("Invalid model Exception : This code does not support this model configuration.") 211 | 212 | total_params = utils.get_num_params(model) 213 | 214 | if(rank==0): 215 | logger.info(f"Total number of parameters in the model : {total_params:,}") 216 | 217 | utils.model_memory_size(model,config["dtype"]) 218 | 219 | if(args.load_weights): 220 | 221 | if(rank!=0 and args.run_type=="multi_gpu"): 222 | torch.distributed.barrier() 223 | 224 | load_model_weights(args,config,model) 225 | 226 | if(rank==0 and args.run_type=="multi_gpu"): 227 | torch.distributed.barrier() 228 | 229 | if(rank==0): 230 | logger.info("Weights loaded ....") 231 | utils.print_memory_usage() 232 | 233 | if(args.use_lora): 234 | 235 | activate_lora(args,config,model) 236 | 237 | # FSDP should keep model on cpu and then using FSDP wrapper, 238 | # put the weights on multiple GPUs. This will help save memory 239 | # especially when model is too large to be loaded in single GPU memory. 240 | if(not args.use_fsdp): 241 | model.to(device) 242 | 243 | if(rank==0): 244 | logger.info("Model loaded into cuda device ....") 245 | utils.print_memory_usage() 246 | 247 | if args.use_zero_opt: 248 | 249 | for param in model.parameters(): 250 | if not param.is_contiguous(): 251 | param.data = param.data.contiguous() 252 | 253 | if(args.run_type=="multi_gpu"): 254 | model = multigpu_setup(args,rank,model) 255 | 256 | if(rank==0): 257 | logger.info(f"Following is the model for rank {rank}: ") 258 | logger.info(model) 259 | 260 | return model 261 | 262 | 263 | def build_optimizer(args,model): 264 | 265 | if args.use_zero_opt: 266 | 267 | optimizer = ZeroRedundancyOptimizer( 268 | model.parameters(), 269 | optimizer_class=torch.optim.AdamW, 270 | lr=args.lr, 271 | weight_decay=0.1 272 | ) 273 | else: 274 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1) 275 | 276 | return optimizer 277 | 278 | 279 | def build_tokenizer(rank,args): 280 | 281 | if(rank!=0 and args.run_type=="multi_gpu"): 282 | torch.distributed.barrier() 283 | 284 | if(args.model=="GPT2"): 285 | 286 | tokenizer = tiktoken.get_encoding("gpt2") 287 | 288 | elif(args.model=="llama2"): 289 | 290 | utils.login_hf() 291 | 292 | tokenizer_file = hf_hub_download( 293 | repo_id="meta-llama/Llama-2-7b", 294 | filename="tokenizer.model", 295 | local_dir="Llama-2-7b" 296 | ) 297 | 298 | tokenizer = Llama2Tokenizer(tokenizer_file) 299 | 300 | elif(args.model.startswith("llama3")): 301 | 302 | utils.login_hf() 303 | 304 | if(args.model=="llama3"): 305 | 306 | tokenizer_file = hf_hub_download( 307 | repo_id="meta-llama/Meta-Llama-3-8B", 308 | filename="original/tokenizer.model", 309 | local_dir="Llama-3-8B" 310 | ) 311 | 312 | elif(args.model=="llama3_1"): 313 | 314 | tokenizer_file = hf_hub_download( 315 | repo_id="meta-llama/Llama-3.1-8B", 316 | filename="original/tokenizer.model", 317 | local_dir="Llama-3.1-8B" 318 | ) 319 | 320 | elif(args.model=="llama3_2"): 321 | 322 | tokenizer_file = hf_hub_download( 323 | repo_id="meta-llama/Llama-3.2-1B", 324 | filename="original/tokenizer.model", 325 | local_dir="Llama-3.2-1B" 326 | ) 327 | 328 | tokenizer = Llama3Tokenizer(tokenizer_file) 329 | 330 | else: 331 | 332 | raise Exception("Tokenizer Not Found Exception: No tokenizer found for the given model.") 333 | 334 | if(rank==0 and args.run_type=="multi_gpu"): 335 | torch.distributed.barrier() 336 | 337 | return tokenizer 338 | 339 | def build_components(rank: int, device: torch.device, args): 340 | 341 | """ 342 | 343 | Build and returns training objects such as config, model, optimizer and tokenizer. 344 | 345 | Args: 346 | 347 | rank: rank of process. 348 | device: CUDA device. 349 | args: command line arguments. 350 | 351 | Returns: 352 | config: config file. 353 | model: LLM model object. 354 | optimizer: optimizer object. 355 | tokenizer: tokenizer object 356 | 357 | """ 358 | 359 | try: 360 | 361 | config = build_config(args) 362 | 363 | model = build_model(config,rank,device,args) 364 | 365 | optimizer = build_optimizer(args,model) 366 | 367 | tokenizer = build_tokenizer(rank,args) 368 | 369 | return config, model, optimizer, tokenizer 370 | 371 | except Exception as e: 372 | 373 | logger.exception(f"\nException while building training components :\n{e}") 374 | 375 | 376 | -------------------------------------------------------------------------------- /config_hf.json: -------------------------------------------------------------------------------- 1 | { 2 | "HF_ACCESS_TOKEN":"" 3 | } -------------------------------------------------------------------------------- /datautils/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import tiktoken 3 | import torch 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from datautils.dataset import DatasetPT 7 | import utils 8 | 9 | class DataloaderPT: 10 | 11 | def __init__(self,tokenizer, batch_size, max_length,stride,eos_text="<|endoftext|>",dataset_name="gutenberg",run_type="single_gpu",train_ratio=0.90,collate_func=None): 12 | 13 | self.tokenizer = tokenizer 14 | self.batch_size = batch_size 15 | self.max_length = max_length 16 | self.stride = stride 17 | self.train_ratio = train_ratio 18 | self.run_type=run_type 19 | self.collate_func = collate_func 20 | self.dataset_name = dataset_name 21 | self.eos_text = eos_text 22 | 23 | def create_dataloader(self,txt,shuffle=True, drop_last=True, num_workers=0): 24 | 25 | if(self.dataset_name=="gutenberg"): 26 | dataset = DatasetPT(txt, self.tokenizer, self.max_length, self.stride) 27 | 28 | if(self.run_type=="multi_gpu"): 29 | dataloader = DataLoader( 30 | dataset, batch_size=self.batch_size,pin_memory=True, 31 | shuffle=False, drop_last=drop_last, 32 | sampler=DistributedSampler(dataset), #rank and num_replicas is inferred automatically. 33 | collate_fn = self.collate_func) 34 | else: 35 | dataloader = DataLoader( 36 | dataset, batch_size=self.batch_size,pin_memory=True, 37 | shuffle=shuffle, drop_last=drop_last, 38 | num_workers=num_workers, 39 | collate_fn = self.collate_func) 40 | 41 | return dataloader 42 | 43 | def create_dataloaders(self, text_data, num_workers=0): 44 | 45 | split_idx = int(self.train_ratio * len(text_data)) 46 | 47 | train_loader = self.create_dataloader( 48 | text_data[:split_idx], 49 | drop_last=True, 50 | shuffle=True, 51 | num_workers=num_workers 52 | ) 53 | val_loader = self.create_dataloader( 54 | text_data[split_idx:], 55 | drop_last=False, 56 | shuffle=False, 57 | num_workers=num_workers 58 | ) 59 | return train_loader, val_loader 60 | 61 | def get_total_steps_epoch(self,data_files): 62 | 63 | num_steps=0 64 | for index, file_path in enumerate(data_files, 1): 65 | trailing_string = " " + self.eos_text +" " 66 | text_data = utils.read_text_file(file_path) + trailing_string 67 | train_loader, val_loader = self.create_dataloaders( 68 | text_data, 69 | num_workers=2 70 | ) 71 | num_steps = num_steps + len(train_loader) 72 | 73 | return num_steps 74 | 75 | -------------------------------------------------------------------------------- /datautils/dataloader_instruction_finetune.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import tiktoken 3 | import torch 4 | from torch.utils.data.distributed import DistributedSampler 5 | 6 | from datautils.dataset import DatasetPT 7 | from datautils.dataset_instruction_finetune import InstructionDataset 8 | import utils 9 | 10 | 11 | def custom_collate_fn( 12 | batch, 13 | pad_token_id=50256, 14 | ignore_index=-100, 15 | allowed_max_length=None 16 | ): 17 | # Find the longest sequence in the batch 18 | batch_max_length = max(len(item)+1 for instruction_length,item in batch) 19 | 20 | # Pad and prepare inputs and targets 21 | inputs_lst, targets_lst = [], [] 22 | 23 | for instruction_length, item in batch: 24 | new_item = item.copy() 25 | # Add an <|endoftext|> token 26 | new_item += [pad_token_id] 27 | # Pad sequences to max_length 28 | padded = new_item + [pad_token_id] * (batch_max_length - len(new_item)) 29 | 30 | inputs = torch.tensor(padded[:-1]) # Truncate the last token for inputs 31 | targets = torch.tensor(padded[1:]) # Shift +1 to the right for targets 32 | 33 | # Replace all but the first padding tokens in targets by ignore_index 34 | mask = targets == pad_token_id 35 | indices = torch.nonzero(mask).squeeze() 36 | if indices.numel() > 1: 37 | targets[indices[1:]] = ignore_index 38 | 39 | # instruction_length-1 since we have targets=padded[1:] i.e. it already lacks the first token. 40 | targets[:instruction_length-1] = ignore_index 41 | 42 | # Optionally truncate to maximum sequence length 43 | if allowed_max_length is not None: 44 | inputs = inputs[:allowed_max_length] 45 | targets = targets[:allowed_max_length] 46 | 47 | inputs_lst.append(inputs) 48 | targets_lst.append(targets) 49 | 50 | # Convert list of inputs and targets to tensors and transfer to target device 51 | inputs_tensor = torch.stack(inputs_lst) 52 | targets_tensor = torch.stack(targets_lst) 53 | 54 | return inputs_tensor, targets_tensor 55 | 56 | class DataloaderIF: 57 | 58 | def __init__(self,tokenizer, batch_size, max_length, dataset_name="alpaca", run_type="single_gpu", 59 | train_ratio=0.90,collate_func=None): 60 | 61 | self.tokenizer = tokenizer 62 | self.batch_size = batch_size 63 | self.max_length = max_length 64 | self.train_ratio = train_ratio 65 | self.run_type=run_type 66 | self.collate_func = collate_func 67 | self.dataset_name = dataset_name 68 | 69 | def create_dataloader(self,txt,shuffle=True, drop_last=True, num_workers=0): 70 | 71 | if(self.dataset_name=="alpaca"): 72 | dataset = InstructionDataset(txt, self.tokenizer) 73 | 74 | if(self.run_type=="multi_gpu"): 75 | dataloader = DataLoader( 76 | dataset, batch_size=self.batch_size,pin_memory=True, 77 | shuffle=False, drop_last=drop_last, 78 | sampler=DistributedSampler(dataset), #rank and num_replicas is inferred automatically. 79 | collate_fn = self.collate_func) 80 | else: 81 | dataloader = DataLoader( 82 | dataset, batch_size=self.batch_size,pin_memory=True, 83 | shuffle=shuffle, drop_last=drop_last, 84 | num_workers=num_workers, 85 | collate_fn = self.collate_func) 86 | 87 | return dataloader 88 | 89 | def create_dataloaders(self, text_data, num_workers=0): 90 | 91 | split_idx = int(self.train_ratio * len(text_data)) 92 | 93 | train_loader = self.create_dataloader( 94 | text_data[:split_idx], 95 | drop_last=True, 96 | shuffle=True, 97 | num_workers=num_workers 98 | ) 99 | val_loader = self.create_dataloader( 100 | text_data[split_idx:], 101 | drop_last=False, 102 | shuffle=False, 103 | num_workers=num_workers 104 | ) 105 | return train_loader, val_loader 106 | 107 | def get_total_steps_epoch(self,data_files): 108 | 109 | num_steps=0 110 | for index, file_path in enumerate(data_files, 1): 111 | text_data = utils.read_json_file(file_path) 112 | train_loader, val_loader = self.create_dataloaders( 113 | text_data, 114 | num_workers=2 115 | ) 116 | num_steps = num_steps + len(train_loader) 117 | 118 | return num_steps 119 | 120 | -------------------------------------------------------------------------------- /datautils/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import tiktoken 3 | import torch 4 | 5 | class DatasetPT(Dataset): 6 | def __init__(self, txt, tokenizer, max_length, stride): 7 | self.input_ids = [] 8 | self.target_ids = [] 9 | 10 | token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'}) 11 | 12 | for i in range(0, len(token_ids) - max_length, stride): 13 | input_chunk = token_ids[i:i + max_length] 14 | target_chunk = token_ids[i + 1: i + max_length + 1] 15 | self.input_ids.append(torch.tensor(input_chunk)) 16 | self.target_ids.append(torch.tensor(target_chunk)) 17 | 18 | def __len__(self): 19 | return len(self.input_ids) 20 | 21 | def __getitem__(self, idx): 22 | return self.input_ids[idx], self.target_ids[idx] -------------------------------------------------------------------------------- /datautils/dataset_instruction_finetune.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import tiktoken 3 | import torch 4 | 5 | 6 | def format_input(entry): 7 | instruction_text = ( 8 | f"Below is an instruction that describes a task. " 9 | f"Write a response that appropriately completes the request." 10 | f"\n\n### Instruction:\n{entry['instruction']}" 11 | ) 12 | 13 | input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else "" 14 | 15 | return instruction_text + input_text 16 | 17 | def format_input_phi(entry): 18 | instruction_text = ( 19 | f"<|user|>\n{entry['instruction']}" 20 | ) 21 | 22 | input_text = f"\n{entry['input']}" if entry["input"] else "" 23 | 24 | return instruction_text + input_text 25 | 26 | class InstructionDataset(Dataset): 27 | 28 | def __init__(self, data, tokenizer): 29 | self.data = data 30 | 31 | # Pre-tokenize texts 32 | self.encoded_texts = [] 33 | self.instruction_lengths = [] 34 | 35 | for entry in data: 36 | 37 | instruction_plus_input = format_input(entry) 38 | response_text = f"\n\n### Response:\n{entry['output']}" 39 | full_text = instruction_plus_input + response_text 40 | 41 | self.encoded_texts.append( 42 | tokenizer.encode(full_text) 43 | ) 44 | 45 | instruction_length = len(tokenizer.encode(instruction_plus_input)) 46 | self.instruction_lengths.append(instruction_length) 47 | 48 | def __getitem__(self, index): 49 | return self.instruction_lengths[index], self.encoded_texts[index] 50 | 51 | def __len__(self): 52 | return len(self.data) 53 | 54 | class InstructionDatasetPhi(Dataset): 55 | def __init__(self, data, tokenizer): 56 | self.data = data 57 | 58 | # Pre-tokenize texts 59 | self.encoded_texts = [] 60 | for entry in data: 61 | 62 | ################################################################### 63 | # NEW: Use `format_input_phi` and adjust the response text template 64 | instruction_plus_input = format_input_phi(entry) 65 | response_text = f"\n<|assistant|>:\n{entry['output']}" 66 | ################################################################### 67 | full_text = instruction_plus_input + response_text 68 | self.encoded_texts.append( 69 | tokenizer.encode(full_text) 70 | ) 71 | 72 | def __getitem__(self, index): 73 | return self.encoded_texts[index] 74 | 75 | def __len__(self): 76 | return len(self.data) -------------------------------------------------------------------------------- /datautils/mixed_precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.distributed.fsdp import ( 4 | # FullyShardedDataParallel as FSDP, 5 | # CPUOffload, 6 | MixedPrecision, 7 | # BackwardPrefetch, 8 | # ShardingStrategy, 9 | ) 10 | 11 | # requires grad scaler in main loop 12 | fpSixteen = MixedPrecision( 13 | param_dtype=torch.float16, 14 | # Gradient communication precision. 15 | reduce_dtype=torch.float16, 16 | # Buffer precision. 17 | buffer_dtype=torch.float16, 18 | ) 19 | 20 | bfSixteen = MixedPrecision( 21 | param_dtype=torch.bfloat16, 22 | # Gradient communication precision. 23 | reduce_dtype=torch.bfloat16, 24 | # Buffer precision. 25 | buffer_dtype=torch.bfloat16, 26 | ) 27 | 28 | bfSixteen_working = MixedPrecision( 29 | param_dtype=torch.float32, 30 | reduce_dtype=torch.bfloat16, 31 | buffer_dtype=torch.bfloat16, 32 | ) 33 | 34 | fp32_policy = MixedPrecision( 35 | param_dtype=torch.float32, 36 | reduce_dtype=torch.float32, 37 | buffer_dtype=torch.float32, 38 | ) -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): 5 | 6 | # idx is (B, T) array of indices in the current context 7 | # For-loop is the same as before: Get logits, and only focus on last time step 8 | for _ in range(max_new_tokens): 9 | idx_cond = idx[:, -context_size:] 10 | with torch.no_grad(): 11 | logits = model(idx_cond) 12 | logits = logits[:, -1, :] 13 | 14 | # Filter logits with top_k sampling 15 | if top_k is not None: 16 | # Keep only top_k values 17 | top_logits, _ = torch.topk(logits, top_k) 18 | min_val = top_logits[:, -1] 19 | logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits) 20 | 21 | # Apply temperature scaling 22 | if temperature > 0.0: 23 | logits = logits / temperature 24 | 25 | # Apply softmax to get probabilities 26 | probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) 27 | 28 | # Sample from the distribution 29 | idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) 30 | 31 | # Otherwise same as before: get idx of the vocab entry with the highest logits value 32 | else: 33 | idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) 34 | 35 | if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified 36 | break 37 | 38 | # Same as before: append sampled index to the running sequence 39 | idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) 40 | 41 | return idx -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | 4 | def setup_logger(name: str): 5 | 6 | # Create logger 7 | logger = logging.getLogger(name) 8 | 9 | # Set the logging level (Optional) 10 | logger.setLevel(logging.DEBUG) # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL 11 | 12 | # Create a console handler to output logs to the console 13 | console_handler = logging.StreamHandler() 14 | 15 | # Define the log format 16 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | console_handler.setFormatter(formatter) 18 | 19 | # Add handlers to the logger 20 | logger.addHandler(console_handler) 21 | 22 | return logger 23 | -------------------------------------------------------------------------------- /lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class LinearWithLoRA(nn.Module): 6 | def __init__(self, linear, rank, alpha, dtype=torch.float32): 7 | super().__init__() 8 | self.linear = linear 9 | self.lora = LoRALayer( 10 | linear.in_features, linear.out_features, rank, alpha, dtype 11 | ) 12 | 13 | def forward(self, x): 14 | return self.linear(x) + self.lora(x) 15 | 16 | 17 | class LoRALayer(nn.Module): 18 | def __init__(self, in_dim, out_dim, rank, alpha, dtype=torch.float32): 19 | super().__init__() 20 | self.A = torch.nn.Parameter(torch.empty(in_dim, rank, dtype=dtype)) 21 | nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization 22 | self.B = torch.nn.Parameter(torch.zeros(rank, out_dim, dtype=dtype)) 23 | self.alpha = alpha 24 | 25 | def forward(self, x): 26 | x = self.alpha * (x @ self.A @ self.B) 27 | return x 28 | 29 | 30 | def replace_linear_with_lora(model, rank, alpha,dtype=torch.float32): 31 | 32 | for name, module in model.named_children(): 33 | 34 | if isinstance(module, torch.nn.Linear): 35 | # Replace the Linear layer with LinearWithLoRA 36 | setattr(model, name, LinearWithLoRA(module, rank, alpha, dtype)) 37 | else: 38 | # Recursively apply the same function to child modules 39 | replace_linear_with_lora(module, rank, alpha, dtype) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from pathlib import Path 4 | import json 5 | import os 6 | from functools import partial 7 | 8 | import torch.multiprocessing as mp 9 | from torch.distributed import init_process_group, destroy_process_group 10 | 11 | from train import Trainer 12 | import utils 13 | from datautils.dataloader import DataloaderPT 14 | from datautils.dataloader_instruction_finetune import DataloaderIF 15 | from build_components import build_components 16 | from args import get_args 17 | 18 | from logger import setup_logger 19 | 20 | # Create a logger specific to this module 21 | logger = setup_logger('main') 22 | 23 | 24 | def ddp_setup(rank, world_size): 25 | """ 26 | Args: 27 | rank: Unique identifier of each process 28 | world_size: Total number of processes 29 | """ 30 | os.environ["MASTER_ADDR"] = "localhost" 31 | os.environ["MASTER_PORT"] = "12355" 32 | torch.cuda.set_device(rank) 33 | init_process_group(backend="nccl", rank=rank, world_size=world_size) 34 | 35 | 36 | def cleanup(): 37 | destroy_process_group() 38 | 39 | 40 | def main(rank: int, args): 41 | 42 | """ 43 | Main function. 44 | 45 | Args: 46 | rank: rank of the process. 47 | args: Command line input arguments. 48 | """ 49 | 50 | if(args.run_type=="multi_gpu"): 51 | ddp_setup(rank, args.world_size) 52 | device = torch.device(f"cuda:{rank}") 53 | else: 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 55 | 56 | utils.set_seed() 57 | 58 | config,model,optimizer,tokenizer = build_components(rank,device,args) 59 | 60 | data_dir = args.data_dir 61 | all_files = [os.path.join(path, name) for path, subdirs, files 62 | in os.walk(data_dir) for name in files if name.endswith((".txt",".json"))] 63 | total_files = len(all_files) 64 | 65 | if total_files == 0: 66 | raise Exception("No training text files found. Make sure you " 67 | "selected the correct input directory") 68 | 69 | 70 | if(rank==0): 71 | logger.info(f"Total data files: {total_files}") 72 | 73 | dataloader_kwargs = {"tokenizer":tokenizer, 74 | "batch_size":args.batch_size, 75 | "max_length":config["context_length"], 76 | "dataset_name":args.dataset, 77 | "run_type":args.run_type, 78 | "train_ratio":0.9} 79 | 80 | collate_func = None 81 | if(args.finetune): 82 | from datautils.dataloader_instruction_finetune import custom_collate_fn 83 | customized_collate_fn = partial( 84 | custom_collate_fn, 85 | allowed_max_length=config["context_length"] 86 | ) 87 | loaderObj = DataloaderIF( 88 | collate_func=customized_collate_fn, 89 | **dataloader_kwargs) 90 | else: 91 | loaderObj = DataloaderPT( 92 | stride=config["context_length"], 93 | eos_text=config["eos_text"], 94 | collate_func=collate_func, 95 | **dataloader_kwargs) 96 | 97 | output_dir = Path(args.output_dir) 98 | output_dir.mkdir(parents=True, exist_ok=True) 99 | 100 | trainer = Trainer( 101 | config=config, 102 | data_files=all_files, 103 | loaderObj=loaderObj, 104 | device=device, 105 | model=model, 106 | optimizer=optimizer, 107 | save_dir=output_dir, 108 | rank=rank, 109 | warmup_steps=args.warmup_steps, 110 | initial_lr=args.initial_lr, 111 | min_lr=args.min_lr, 112 | eval_freq=args.eval_freq, 113 | save_ckpt_freq=args.save_ckpt_freq, 114 | print_sample_iter=args.print_sample_iter, 115 | eval_iter=5 116 | ) 117 | 118 | 119 | #Test a single sentence 120 | trainer.generate_and_print_sample("Every effort moves you",temperature=1.0,top_k=5,memory_check=True) 121 | 122 | if(args.finetune): 123 | train_losses, val_losses, tokens_seen, track_lrs = trainer.finetune_model( 124 | n_epochs=args.n_epochs 125 | ) 126 | else: 127 | train_losses, val_losses, tokens_seen, track_lrs = trainer.train_model( 128 | n_epochs=args.n_epochs 129 | ) 130 | 131 | if(rank==0): 132 | epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses)) 133 | utils.plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses, output_dir) 134 | 135 | final_model_file = "model_pg_final.pth" 136 | 137 | trainer.save_checkpoint(final_model_file) 138 | 139 | logger.info(f"Maximum GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB") 140 | 141 | if(args.run_type=="multi_gpu"): 142 | torch.distributed.barrier() 143 | cleanup() 144 | 145 | if __name__ == "__main__": 146 | 147 | input_args = get_args() 148 | 149 | if(input_args.run_type=="multi_gpu"): 150 | world_size = torch.cuda.device_count() 151 | input_args.world_size=world_size 152 | mp.spawn(main, args=[input_args], nprocs=world_size,join=True) 153 | else: 154 | main(0,input_args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | gdown 3 | transformers 4 | huggingface_hub 5 | sentencepiece 6 | blobfile 7 | safetensors>=0.4.4 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import tqdm 4 | from tqdm import tqdm 5 | import math 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 8 | from torch.distributed.fsdp import FullStateDictConfig, StateDictType 9 | 10 | import utils 11 | from generate import generate 12 | 13 | from logger import setup_logger 14 | 15 | # Create a logger specific to this module 16 | logger = setup_logger('train') 17 | 18 | class Trainer: 19 | 20 | def __init__(self,model,optimizer,config,data_files,loaderObj,save_dir, 21 | warmup_steps=10, initial_lr=1e-05, min_lr=1e-6,device="cpu", 22 | rank=0,eval_freq=1,save_ckpt_freq=1,print_sample_iter=1,eval_iter=1): 23 | 24 | self.config = config 25 | self.data_files = data_files 26 | self.loaderObj = loaderObj 27 | self.device = device 28 | self.optimizer = optimizer 29 | self.save_dir = save_dir 30 | self.rank = rank 31 | self.warmup_steps = warmup_steps 32 | self.initial_lr = initial_lr 33 | self.min_lr=min_lr 34 | self.model = model 35 | self.eval_freq=eval_freq 36 | self.save_ckpt_freq=save_ckpt_freq 37 | self.print_sample_iter=print_sample_iter 38 | self.eval_iter = eval_iter 39 | 40 | self.global_step=-1 41 | self.track_lrs=[] 42 | self.train_losses=[] 43 | self.val_losses=[] 44 | self.track_tokens_seen=[] 45 | 46 | self.tokens_seen=0 47 | 48 | 49 | def calc_loss_batch(self,input_batch, target_batch): 50 | 51 | input_batch, target_batch = input_batch.to(self.device), target_batch.to(self.device) 52 | logits = self.model(input_batch) 53 | loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) 54 | return loss 55 | 56 | def train_batch(self,input_batch, target_batch): 57 | 58 | try: 59 | self.optimizer.zero_grad() 60 | self.global_step += 1 61 | 62 | if self.global_step < self.warmup_steps: 63 | lr = self.initial_lr + self.global_step * self.lr_increment 64 | else: 65 | progress = ((self.global_step - self.warmup_steps) / (self.total_training_steps - self.warmup_steps)) 66 | lr = self.min_lr + (self.peak_lr - self.min_lr) * 0.5 * ( 67 | 1 + math.cos(math.pi * progress)) 68 | 69 | for param_group in self.optimizer.param_groups: 70 | param_group["lr"] = lr 71 | 72 | self.track_lrs.append(lr) 73 | # if(self.rank==0): 74 | # utils.print_memory_usage() 75 | loss = self.calc_loss_batch(input_batch, target_batch) 76 | # if(self.rank==0): 77 | # utils.print_memory_usage() 78 | loss.backward() 79 | # if(self.rank==0): 80 | # utils.print_memory_usage() 81 | 82 | if isinstance(self.model, FSDP): 83 | self.model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) 84 | elif(isinstance(self.model,DDP)): 85 | torch.nn.utils.clip_grad_norm_(self.model.module.parameters(), max_norm=1.0) 86 | else: 87 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) 88 | 89 | self.optimizer.step() 90 | 91 | self.tokens_seen += input_batch.numel() 92 | 93 | except Exception as e: 94 | 95 | logger.error(f"An error occurred in train_batch(): {e}") 96 | 97 | def train_epoch(self,epoch_no,train_loader,val_loader,start_context="Every effort moves you"): 98 | 99 | try: 100 | self.model.train() 101 | for input_batch, target_batch in train_loader: 102 | 103 | self.train_batch(input_batch, target_batch) 104 | 105 | # Optional evaluation step 106 | if self.global_step % self.eval_freq == 0: 107 | 108 | train_loss, val_loss = self.evaluate_model(train_loader, val_loader, 109 | self.eval_iter) 110 | 111 | self.train_losses.append(train_loss) 112 | self.val_losses.append(val_loss) 113 | self.track_tokens_seen.append(self.tokens_seen) 114 | 115 | if(self.rank==0): 116 | logger.info(f"\n Ep {epoch_no+1} (Step {self.global_step}): " 117 | f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f} \n") 118 | 119 | # Generate text passage 120 | if self.global_step % self.print_sample_iter == 0: 121 | 122 | self.generate_and_print_sample(start_context,temperature=1.0,top_k=5,memory_check=True) 123 | 124 | 125 | if self.global_step % self.save_ckpt_freq == 0: 126 | 127 | self.save_checkpoint(f"model_pg_{self.global_step}.pth") 128 | #logger.info(f"Successfully saved checkpoint for step {self.global_step}") 129 | 130 | except Exception as e: 131 | 132 | logger.error(f"An error occurred in train_epoch(): {e}") 133 | 134 | 135 | def train_model(self, n_epochs): 136 | 137 | self.peak_lr = self.optimizer.param_groups[0]["lr"] 138 | self.total_training_steps = self.loaderObj.get_total_steps_epoch(self.data_files) * n_epochs 139 | self.lr_increment = (self.peak_lr - self.initial_lr) / self.warmup_steps 140 | 141 | try: 142 | if(self.rank==0): 143 | pbar = tqdm(total=n_epochs*len(self.data_files)) 144 | for epoch in range(n_epochs): 145 | 146 | # Iterate over the books in the training corpus 147 | for index, file_path in enumerate(self.data_files, 1): 148 | 149 | trailing_string = " " + self.config["eos_text"] +" " 150 | text_data = utils.read_text_file(file_path) + trailing_string #" <|endoftext|> " 151 | 152 | # Initialize new data loaders for each book 153 | train_loader, val_loader = self.loaderObj.create_dataloaders( 154 | text_data, 155 | num_workers=2 156 | ) 157 | 158 | if hasattr(train_loader.sampler, 'set_epoch'): 159 | train_loader.sampler.set_epoch(epoch) 160 | 161 | self.train_epoch(epoch,train_loader, val_loader) 162 | 163 | if(self.rank==0): 164 | pbar.update(1) 165 | 166 | 167 | except KeyboardInterrupt: 168 | self.save_checkpoint(f"model_pg_{self.global_step}_interrupted.pth") 169 | 170 | return self.train_losses, self.val_losses, self.track_tokens_seen, self.track_lrs 171 | 172 | def finetune_model(self,n_epochs): 173 | 174 | self.peak_lr = self.optimizer.param_groups[0]["lr"] 175 | self.total_training_steps = self.loaderObj.get_total_steps_epoch(self.data_files) * n_epochs 176 | self.lr_increment = (self.peak_lr - self.initial_lr) / self.warmup_steps 177 | 178 | try: 179 | if(self.rank==0): 180 | pbar = tqdm(total=n_epochs*len(self.data_files)) 181 | for epoch in range(n_epochs): 182 | 183 | # Iterate over the books in the training corpus 184 | for index, file_path in enumerate(self.data_files, 1): 185 | 186 | text_data = utils.read_json_file(file_path) 187 | 188 | # Initialize new data loaders for each book 189 | train_loader, val_loader = self.loaderObj.create_dataloaders( 190 | text_data, 191 | num_workers=2 192 | ) 193 | 194 | if hasattr(train_loader.sampler, 'set_epoch'): 195 | train_loader.sampler.set_epoch(epoch) 196 | 197 | self.train_epoch(epoch,train_loader, val_loader,start_context="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat is an antonym of 'complicated'?") 198 | 199 | if(self.rank==0): 200 | pbar.update(1) 201 | 202 | except KeyboardInterrupt: 203 | 204 | self.save_checkpoint(f"model_pg_{self.global_step}_interrupted.pth") 205 | 206 | return self.train_losses, self.val_losses, self.track_tokens_seen, self.track_lrs 207 | 208 | 209 | 210 | def generate_and_print_sample(self,start_context,temperature=0.0,top_k=None,memory_check=False,max_new_tokens=200): 211 | 212 | self.model.eval() 213 | context_size = self.config["context_length"] 214 | 215 | encoded = utils.text_to_token_ids(start_context, self.loaderObj.tokenizer,self.config).to(self.device) 216 | 217 | token_ids = generate( 218 | model=self.model, idx=encoded, 219 | max_new_tokens=max_new_tokens, context_size=context_size,temperature=temperature,top_k=top_k,eos_id=self.config["eos_id"]) 220 | 221 | decoded_text = utils.token_ids_to_text(token_ids, self.loaderObj.tokenizer) 222 | 223 | self.model.train() 224 | 225 | if(self.rank==0 and memory_check): 226 | logger.info(decoded_text.replace("\n", " ")) 227 | 228 | def save_checkpoint(self,file_name): 229 | 230 | if isinstance(self.model, DDP): 231 | torch.distributed.barrier() 232 | 233 | try: 234 | file_name = self.save_dir / file_name 235 | 236 | if self.rank==0: 237 | 238 | if isinstance(self.model, DDP) : 239 | torch.save(self.model.module.state_dict(), file_name) 240 | elif(not isinstance(self.model, FSDP)): 241 | torch.save(self.model.state_dict(), file_name) 242 | 243 | if isinstance(self.model, FSDP): 244 | 245 | cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 246 | with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, cfg): 247 | cpu_state = self.model.state_dict() 248 | 249 | if self.rank==0: 250 | torch.save(cpu_state, file_name) 251 | 252 | logger.info(f"Saved checkpoint {file_name}") 253 | 254 | except Exception as e: 255 | logger.error(f"An error occurred while saving checkpoint : {e}") 256 | 257 | if isinstance(self.model, DDP) : 258 | torch.distributed.barrier() 259 | 260 | 261 | def calc_loss_loader(self,data_loader, num_batches=None): 262 | 263 | total_loss = 0. 264 | if len(data_loader) == 0: 265 | return float("nan") 266 | elif num_batches is None: 267 | num_batches = len(data_loader) 268 | else: 269 | num_batches = min(num_batches, len(data_loader)) 270 | for i, (input_batch, target_batch) in enumerate(data_loader): 271 | if i < num_batches: 272 | loss = self.calc_loss_batch(input_batch, target_batch) 273 | total_loss += loss.item() 274 | else: 275 | break 276 | return total_loss / num_batches 277 | 278 | 279 | def evaluate_model(self,train_loader, val_loader, eval_iter=5): 280 | 281 | self.model.eval() 282 | with torch.no_grad(): 283 | train_loss = self.calc_loss_loader(train_loader, num_batches=eval_iter) 284 | val_loss = self.calc_loss_loader(val_loader, num_batches=eval_iter) 285 | self.model.train() 286 | return train_loss, val_loss 287 | 288 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Util functions useful for various modules. 2 | import random 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from matplotlib.ticker import MaxNLocator 7 | import json 8 | from enum import Enum 9 | from huggingface_hub import login 10 | 11 | from logger import setup_logger 12 | 13 | # Create a logger specific to this module 14 | logger = setup_logger('utils') 15 | 16 | datasize_mapping = { 17 | "fp32":4, 18 | "fp16":2, 19 | "bf16":2 20 | } 21 | 22 | datatype_mapping = { 23 | "fp32":torch.float32, 24 | "fp16":torch.float16, 25 | "bf16":torch.bfloat16 26 | } 27 | 28 | model_params_mapping = { 29 | "GPT2":["124M","355M","774M","1.5B"], 30 | "llama2":["7B"], 31 | "llama3":["8B"], 32 | "llama3_1":["8B"], 33 | "llama3_2":["1B"] 34 | } 35 | 36 | def set_seed(): 37 | 38 | """Set random seeds.""" 39 | 40 | RANDOM_SEED=123 41 | 42 | random.seed(RANDOM_SEED) 43 | np.random.seed(RANDOM_SEED) 44 | torch.manual_seed(seed=0) 45 | torch.random.manual_seed(seed=RANDOM_SEED) 46 | 47 | torch.backends.cudnn.benchmark = False 48 | torch.backends.cudnn.deterministic = True 49 | 50 | def text_to_token_ids(text, tokenizer, cfg): 51 | encoded = tokenizer.encode(text, allowed_special={cfg["eos_text"]}) 52 | encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension 53 | return encoded_tensor 54 | 55 | def token_ids_to_text(token_ids, tokenizer): 56 | flat = token_ids.squeeze(0) # remove batch dimension 57 | return tokenizer.decode(flat.tolist()) #convert to list from tensor and then decode 58 | 59 | def read_text_file(file_path): 60 | with open(file_path, "r", encoding="utf-8") as file: 61 | text_data = file.read() 62 | return text_data 63 | 64 | def read_json_file(file_path): 65 | with open(file_path, "r", encoding="utf-8") as file: 66 | data = json.load(file) 67 | return data 68 | 69 | def get_num_params(model): 70 | 71 | total_params = sum(p.numel() for p in model.parameters()) 72 | 73 | return total_params 74 | 75 | def get_total_size(num_params,data_type): 76 | 77 | assert data_type in datasize_mapping, "This datatype is currently not supprted." 78 | 79 | datatype_size = int(datasize_mapping[data_type]) 80 | logger.info(f"Since the datatype is {data_type}, each parameter is going to consume {datatype_size} bytes") 81 | 82 | logger.info( 83 | "Assuming that we are using Adam optimizer, this model will require 1N (N is number of parameters)" 84 | "for parameters, 1N for gradients and 2N for first and second moment estimates of Adam." 85 | "So total 4N of GPU memory" 86 | ) 87 | 88 | total_size_bytes = 4 * num_params * datatype_size 89 | 90 | # Convert to gigabytes 91 | total_size_mb = total_size_bytes / (1024 * 1024 * 1024) 92 | 93 | logger.info( 94 | f"Estimated size of the model: {total_size_mb:.2f} GB.\n\n" 95 | "During the forward pass, activations are stored for backpropagation." 96 | "These can significantly increase memory usage. This memory is not included in above calculations." 97 | "Please use activation checkpointing to decrease activation memory." 98 | ) 99 | 100 | 101 | def model_memory_size(model, input_dtype=torch.float32): 102 | total_params = 0 103 | total_grads = 0 104 | for param in model.parameters(): 105 | # Calculate total number of elements per parameter 106 | param_size = param.numel() 107 | total_params += param_size 108 | # Check if gradients are stored for this parameter 109 | if param.requires_grad: 110 | total_grads += param_size 111 | 112 | # Calculate buffer size (non-parameters that require memory) 113 | total_buffers = sum(buf.numel() for buf in model.buffers()) 114 | 115 | # Size in bytes = (Number of elements) * (Size of each element in bytes) 116 | # We assume parameters and gradients are stored in the same type as input dtype 117 | element_size = torch.tensor(0, dtype=input_dtype).element_size() 118 | total_memory_bytes = (total_params + total_grads + total_buffers) * element_size 119 | 120 | # Convert bytes to gigabytes 121 | total_memory_gb = total_memory_bytes / (1024**3) 122 | 123 | logger.info(f"Estimated size of the model: {total_memory_gb:.2f} GB.\n\n") 124 | 125 | 126 | 127 | def start_memory_tracking(): 128 | """Initialize GPU memory tracking.""" 129 | if torch.cuda.is_available(): 130 | torch.cuda.reset_peak_memory_stats() 131 | else: 132 | logger.info("This notebook is intended for CUDA GPUs but CUDA is not available.") 133 | 134 | def print_memory_usage(): 135 | max_gpu_memory = float(torch.cuda.max_memory_allocated()) / (1024 ** 3) # Convert bytes to GB 136 | logger.info(f"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB") 137 | 138 | def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir): 139 | fig, ax1 = plt.subplots() 140 | 141 | # Plot training and validation loss against epochs 142 | ax1.plot(epochs_seen, train_losses, label="Training loss") 143 | ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") 144 | ax1.set_xlabel("Epochs") 145 | ax1.set_ylabel("Loss") 146 | ax1.legend(loc="upper right") 147 | ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) 148 | 149 | # Create a second x-axis for tokens seen 150 | ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis 151 | ax2.plot(tokens_seen, train_losses, alpha=0) # Invisible plot for aligning ticks 152 | ax2.set_xlabel("Tokens seen") 153 | 154 | fig.tight_layout() # Adjust layout to make room 155 | plt.savefig(output_dir / "losses.pdf") 156 | 157 | def login_hf(): 158 | 159 | with open("config_hf.json", "r") as config_file: 160 | config = json.load(config_file) 161 | access_token = config["HF_ACCESS_TOKEN"] 162 | 163 | login(token=access_token) --------------------------------------------------------------------------------