├── .gitignore ├── LICENSE ├── README.md ├── assets └── loss_eval.png ├── requirements.txt └── src ├── __init__.py ├── dataloader.py ├── hellaswag_eval.py ├── inference.py ├── model.py ├── prepare_dataset.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | 4 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Saqib Azim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT-2 Implementation in PyTorch 2 | 3 | This project reproduces the GPT-2 model in pytorch and trains it from scratch on the FineWeb-Edu dataset - a high-quality subset of FineWeb dataset tailored for educational content. The goal is to offer a simplified, easy-to-understand PyTorch implementation. Note that this code is intended primarily for educational purposes and is not optimized for speed or production deployment. 4 | 5 | ### Key Features 6 | - **Simplified PyTorch Implementation:** Designed to be accessible and well-commented for ease of understanding. 7 | - **Customizable Training:** Hyperparameters are configurable via the command line and can be easily modified. 8 | - **Multi-GPU Training Support:** Training can be performed using multiple GPUs using PyTorch Distributed Data Parallel (DDP). 9 | 10 | 11 | ## Repository Structure 12 | - `src/train.py`: Script to train the GPT-2 model with customizable configurations. 13 | - `src/model.py`: Contains the GPT-2 model implementation, including embedding layers, transformer blocks, and output layers. 14 | - `src/dataloader.py`: Handles data loading and batching for the model during training. 15 | - `src/prepare_dataset.py`: Downloads and preprocesses the FineWebEdu dataset. Run this script before starting the training process. 16 | - `requirements.txt`: Python dependencies required to run the project. 17 | 18 | 19 | ## Getting Started 20 | 21 | ### Prerequisites 22 | Ensure you have the following dependencies installed: 23 | 24 | - numpy 25 | - pytorch 26 | - tiktoken 27 | - transformers (from huggingface) 28 | 29 | You can install all dependencies with: 30 | ```bash 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Dataset 35 | 36 | The GPT-2 model was originally trained on the WebText dataset (not publicly released). For this project, we use the FineWebEdu-10B dataset—a specialized educational subset of the FineWeb dataset. It contains approximately 10 billion tokens focused on high-quality educational content. 37 | 38 | To download and prepare the dataset: 39 | ```bash 40 | python prepare_dataset.py 41 | ``` 42 | 43 | ### Running the Training Script 44 | You can start training the GPT-2 model using the following commands: 45 | 46 | You can experiment with different training and model config hyperparameters by setting them through the command line. 47 | 48 | - Single-GPU Training: 49 | ```bash 50 | python train.py --num_epochs=5 51 | ``` 52 | 53 | - Multi-GPU Training (uses Pytorch DDP): 54 | ```bash 55 | torchrun --standalone --nproc_per_node=4 train.py # adjust number of GPUs as per availability 56 | ``` 57 | 58 | For more details on the training process and customizing hyperparameters, refer to the `src/train.py` script. 59 | 60 | Training was performed from scratch using multiple GPUs with PyTorch's DDP framework. 61 | 62 | 63 | After training the model, you can generate text based on custom prompts. Use the `src/inference.py` script to interact with the trained model and generate creative continuations. 64 | 65 | Run the inference script from the command line with the following syntax: 66 | ```bash 67 | python3 inference.py --prompt="I am a AI and robotics enthusiast, I want to" --max_tokens=50 --num_seq=5 68 | ``` 69 | 70 | This command will output 5 unique text sequences, each starting with the provided prompt and continuing for up to 50 tokens. 71 | 72 | 73 | ### Model Architecture 74 | The GPT-2 model consists of the following components: 75 | 76 | - **Token Embedding Layer:** Encodes input tokens to dense vectors. 77 | - **Positional Embedding Layer:** Adds positional information to the token embeddings. 78 | - **Transformer Blocks:** Each block includes layer normalization, multi-headed self-attention, and an MLP with residual connections. 79 | - **Output Head:** Predicts the next token in the sequence based on the preceding context. 80 | 81 | The model is trained to predict the next token in a sequence, enabling coherent text generation. For token generation, I have used huggingface `tiktoken` library that generates 50,257 tokens (same as GPT-2). 82 | 83 | 84 | ### Results 85 | 86 | The GPT-2 model was trained for roughly 95,365 steps (5 epochs) using two NVIDIA A100 GPUs. Training took approximately 46 hours. 87 | 88 | ![Training loss and Helloswag evaluation](./assets/loss_eval.png) 89 | 90 | To generate from the trained model, we provide an input prompt sequence, and ask the model to generate the next N tokens. Here are some samples of text generated by the trained model: 91 | 92 | - **prompt text:** "Hello, I am a language model" 93 | - **Model output:** 94 | ``` 95 | - Hello, I am a language modeler. I use the API, in whatever language I require it to write out. On first, I define a model for 96 | 97 | - Hello, I am a language model expert and need help with building these model. The project is designed in C++ and the Python library is used. The project 98 | 99 | - Hello, I am a language model developer at Google Cloud. It has great features on most platforms which makes it one of most popular. It also integrates with third 100 | ``` 101 | 102 | - **prompt text:** "I am a machine learning and robotics enthusiast, and I want to" 103 | - **Model output:** 104 | ``` 105 | - I am a machine learning and robotics enthusiast, and I want to share my excitement about this work as soon as possible. 106 | The purpose of this project was to help the engineers and programmers understand how the HURD and AVR circuits work and how 107 | 108 | - I am a machine learning and robotics enthusiast, and I want to try and train a new machine learning-based system such as a deep learning algorithm that is completely new to me. 109 | 110 | - I am a machine learning and robotics enthusiast, and I want to help you by helping you improve your Python programming skills.To understand the concept of machine learning, you must understand the concept of a machine learning model. Machine learning models 111 | 112 | - I am a machine learning and robotics enthusiast, and I want to be a part of the team.<|endoftext|>In your next project, you need to gather some interesting information from your team team. This data will help form a map that you can use to 113 | 114 | - I am a machine learning and robotics enthusiast, and I want to create a new, more sophisticated machine learning-based library for programming languages. To start, I am interested in the machine learning (ML) capabilities of new AI methods and techniques. 115 | ``` 116 | 117 | 118 | ## Potential Future Work 119 | 120 | 1. **Dataset Shuffling:** The current training code does not shuffle the dataset after each epoch. Implementing dataset shuffling between epochs could improve the model's ability to generalize and prevent overfitting to the order of the training data. 121 | 122 | 2. **Extended Training:** Experiment with training the model for more epochs to potentially improve performance. Monitor validation loss to determine the optimal number of epochs and implement early stopping if necessary. 123 | 124 | 125 | ## References: 126 | - [Language Models are Unsupervised Multitask Learners (GPT-2 Paper)](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 127 | - [GPT-3 Paper: Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165) 128 | - [FineWebEdu-10B Dataset](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu) 129 | - [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/abs/2205.14135) 130 | - [Attention is all you need](https://arxiv.org/abs/1706.03762) 131 | - [HellaSwag: Can a Machine Really Finish Your Sentence?](https://arxiv.org/abs/1905.07830) 132 | - Andrej Karpathy's Video Tutorial on GPT 133 | 134 | 135 | ## Acknowledgments 136 | This implementation is inspired by Andrej Karpathy’s tutorial and his approach to making complex AI concepts more accessible. -------------------------------------------------------------------------------- /assets/loss_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saqib1707/gpt2-from-scratch/af97fdb27c4979e8080a4bab8bbdfff8d4bb4474/assets/loss_eval.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | tiktoken 4 | transformers 5 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saqib1707/gpt2-from-scratch/af97fdb27c4979e8080a4bab8bbdfff8d4bb4474/src/__init__.py -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | script_dir = os.path.dirname(__file__) 6 | 7 | 8 | class DataLoaderLite: 9 | """ A simple dataloader for FineWebEdu-10B dataset """ 10 | 11 | def __init__(self, B, T, process_rank, num_processes, split='train'): 12 | super().__init__() 13 | self.B, self.T = B, T 14 | self.process_rank = process_rank 15 | self.num_processes = num_processes 16 | assert split in {'train', 'val'} 17 | 18 | # get the shard filenames 19 | data_root = os.path.join(script_dir, "../data/edu_fineweb10B") 20 | shard_filenames = os.listdir(data_root) 21 | shard_filenames = sorted([filename for filename in shard_filenames if split in filename]) 22 | self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames] 23 | assert len(self.shard_filepaths) > 0, f'no shards found for split {split}' 24 | master_process = process_rank == 0 25 | if master_process: 26 | print(f'found {len(self.shard_filepaths)} shards for split {split}') 27 | self.reset() 28 | 29 | def load_tokens(self, filepath): 30 | tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long) 31 | return tokens 32 | 33 | def reset(self): 34 | # state, init at shard 0 35 | self.curr_shard = 0 36 | self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard]) 37 | self.curr_pos = self.B * self.T * self.process_rank 38 | 39 | def next_batch(self): 40 | B, T = self.B, self.T 41 | batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1] 42 | x_batch = batch[:-1].view(B, T) 43 | y_batch = batch[1:].view(B, T) 44 | self.curr_pos += B * T * self.num_processes 45 | if self.curr_pos + (B * T + 1) > len(self.tokens): 46 | self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths) 47 | self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard]) 48 | self.curr_pos = self.B * self.T * self.process_rank 49 | return x_batch, y_batch -------------------------------------------------------------------------------- /src/hellaswag_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Downloads and evaluates HellaSwag in Python. 3 | https://github.com/rowanz/hellaswag 4 | 5 | Example HellaSwag json item: 6 | 7 | {"ind": 24, "activity_label": "Roof shingle removal", "ctx_a": "A man is sitting on a roof.", "ctx_b": "he", "ctx": "A man is sitting on a roof. he", "split": "val", "split_type": "indomain", "label": 3, "endings": ["is using wrap to wrap a pair of skis.", "is ripping level tiles off.", "is holding a rubik's cube.", "starts pulling up roofing on a roof."], "source_id": "activitynet~v_-JhWjGDPHMY"} 8 | 9 | ind: dataset ID 10 | activity_label: The ActivityNet or WikiHow label for this example 11 | context: There are two formats. The full context is in ctx. When the context ends in an (incomplete) noun phrase, like for ActivityNet, this incomplete noun phrase is in ctx_b, and the context up until then is in ctx_a. This can be useful for models such as BERT that need the last sentence to be complete. However, it's never required. If ctx_b is nonempty, then ctx is the same thing as ctx_a, followed by a space, then ctx_b. 12 | endings: a list of 4 endings. The correct index is given by label (0,1,2, or 3) 13 | split: train, val, or test. 14 | split_type: indomain if the activity label is seen during training, else zeroshot 15 | source_id: Which video or WikiHow article this example came from 16 | 17 | gpt2 (124M) 18 | - eleuther harness reports acc 28.92%, acc_norm 31.14% (multiple choice style) 19 | - this script: 10042 acc: 0.2859 acc_norm: 0.2955 (completion style) 20 | 21 | gpt2-xl (1558M) 22 | - eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style) 23 | - this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style) 24 | 25 | The validation set of HellaSwag has a total of 10,042 examples. 26 | """ 27 | 28 | import os 29 | import json 30 | import requests 31 | import tiktoken 32 | from tqdm import tqdm 33 | import torch 34 | import torch.nn as nn 35 | from torch.nn import functional as F 36 | from transformers import GPT2LMHeadModel 37 | 38 | # ----------------------------------------------------------------------------- 39 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag") 40 | 41 | def download_file(url: str, fname: str, chunk_size=1024): 42 | """Helper function to download a file from a given url""" 43 | resp = requests.get(url, stream=True) 44 | total = int(resp.headers.get("content-length", 0)) 45 | with open(fname, "wb") as file, tqdm( 46 | desc=fname, 47 | total=total, 48 | unit="iB", 49 | unit_scale=True, 50 | unit_divisor=1024, 51 | ) as bar: 52 | for data in resp.iter_content(chunk_size=chunk_size): 53 | size = file.write(data) 54 | bar.update(size) 55 | 56 | hellaswags = { 57 | "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl", 58 | "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl", 59 | "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl", 60 | } 61 | 62 | enc = tiktoken.get_encoding("gpt2") 63 | 64 | def download(split): 65 | """Downloads HellaSwag DATA_CACHE_DIR""" 66 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 67 | data_url = hellaswags[split] 68 | data_filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl") 69 | if not os.path.exists(data_filename): 70 | print(f"Downloading {data_url} to {data_filename}...") 71 | download_file(data_url, data_filename) 72 | 73 | def render_example(example): 74 | """ 75 | Given the example as a dictionary, render it as three torch tensors: 76 | - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates) 77 | - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods) 78 | - label (the index of the correct completion, which we hope has the highest likelihood) 79 | """ 80 | ctx = example["ctx"] 81 | label = example["label"] 82 | endings = example["endings"] 83 | # data needed to reproduce this eval on the C size 84 | data = { 85 | "label": label, 86 | "ctx_tokens": None, 87 | "ending_tokens": [], 88 | } 89 | # gather up all the tokens 90 | ctx_tokens = enc.encode(ctx) 91 | data["ctx_tokens"] = ctx_tokens 92 | tok_rows = [] 93 | mask_rows = [] 94 | for end in endings: 95 | end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer 96 | tok_rows.append(ctx_tokens + end_tokens) 97 | mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens)) 98 | data["ending_tokens"].append(end_tokens) 99 | 100 | # have to be careful during the collation because the number of tokens in each row can differ 101 | max_len = max(len(row) for row in tok_rows) 102 | tokens = torch.zeros((4, max_len), dtype=torch.long) 103 | mask = torch.zeros((4, max_len), dtype=torch.long) 104 | for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)): 105 | tokens[i, :len(tok_row)] = torch.tensor(tok_row) 106 | mask[i, :len(mask_row)] = torch.tensor(mask_row) 107 | return data, tokens, mask, label 108 | 109 | def iterate_examples(split): 110 | # there are 10,042 examples in total in val 111 | download(split) 112 | with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f: 113 | for line in f: 114 | example = json.loads(line) 115 | yield example 116 | 117 | @torch.no_grad() 118 | def evaluate(model_type, device): 119 | torch.set_float32_matmul_precision('high') # use tf32 120 | model = GPT2LMHeadModel.from_pretrained(model_type) 121 | model.to(device) 122 | # model = torch.compile(model) # optionally torch compile the model 123 | num_correct_norm = 0 124 | num_correct = 0 125 | num_total = 0 126 | for example in iterate_examples("val"): 127 | data, tokens, mask, label = render_example(example) 128 | tokens = tokens.to(device) 129 | mask = mask.to(device) 130 | 131 | # get the logits 132 | logits = model(tokens).logits 133 | # evaluate the autoregressive loss at all positions 134 | shift_logits = (logits[..., :-1, :]).contiguous() 135 | shift_tokens = (tokens[..., 1:]).contiguous() 136 | flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) 137 | flat_shift_tokens = shift_tokens.view(-1) 138 | shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none') 139 | shift_losses = shift_losses.view(tokens.size(0), -1) 140 | # now get the average loss just for the completion region (where mask == 1), in each row 141 | shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token 142 | masked_shift_losses = shift_losses * shift_mask 143 | # sum and divide by the number of 1s in the mask 144 | sum_loss = masked_shift_losses.sum(dim=1) 145 | avg_loss = sum_loss / shift_mask.sum(dim=1) 146 | # now we have a loss for each of the 4 completions 147 | # the one with the lowest loss should be the most likely 148 | pred = sum_loss.argmin().item() 149 | pred_norm = avg_loss.argmin().item() 150 | 151 | # accumulate stats 152 | num_total += 1 153 | num_correct += int(pred == label) 154 | num_correct_norm += int(pred_norm == label) 155 | print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}") 156 | 157 | # debug: pretty print a few examples, and the losses in each case 158 | if num_total < 10: 159 | print("---") 160 | print(f"Context:\n {example['ctx']}") 161 | print(f"Endings:") 162 | for i, end in enumerate(example["endings"]): 163 | print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}") 164 | print(f"predicted: {pred_norm}, actual: {label}") 165 | 166 | 167 | def get_most_likely_row(tokens, mask, logits): 168 | """ 169 | helper function for HellaSwag eval. Takes tokens, mask, and logits, 170 | returns the index of the completion with the lowest loss 171 | """ 172 | # evaluate the autoregressive loss at all positions 173 | shift_logits = (logits[..., :-1, :]).contiguous() 174 | shift_tokens = (tokens[..., 1:]).contiguous() 175 | flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) 176 | flat_shift_tokens = shift_tokens.view(-1) 177 | shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none') 178 | shift_losses = shift_losses.view(tokens.size(0), -1) 179 | # now get the average loss just for the completion region (where mask == 1), in each row 180 | shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token 181 | masked_shift_losses = shift_losses * shift_mask 182 | # sum and divide by the number of 1s in the mask 183 | sum_loss = masked_shift_losses.sum(dim=1) 184 | avg_loss = sum_loss / shift_mask.sum(dim=1) 185 | # now we have a loss for each of the 4 completions 186 | # the one with the lowest loss should be the most likely 187 | pred_norm = avg_loss.argmin().item() 188 | return pred_norm 189 | 190 | 191 | if __name__ == "__main__": 192 | import argparse 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument("-m", "--model_type", type=str, default="gpt2", help="the model type to use") 195 | parser.add_argument("-d", "--device", type=str, default="cuda", help="the device to use") 196 | args = parser.parse_args() 197 | evaluate(args.model_type, args.device) 198 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import tiktoken 5 | from dataclasses import dataclass 6 | 7 | from model import GPT 8 | 9 | 10 | class GPT2Inference: 11 | """ To generate text sequences using a trained GPT2 model """ 12 | 13 | def __init__(self, model, token_encoder, device): 14 | self.model = model 15 | self.token_encoder = token_encoder 16 | self.device = device 17 | self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' 18 | 19 | def generate_sequences(self, prompt, num_seq=5, max_tokens=50): 20 | self.model.eval() 21 | tokens = self.token_encoder.encode(prompt) 22 | tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length 23 | tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n) 24 | gen_tokens = tokens.to(self.device) 25 | # create a different rng generator so as not to impact the global rng state used for training 26 | sample_rng = torch.Generator(device=self.device).manual_seed(42) 27 | 28 | # generate new tokens one token at a time until the sequence length becomes 'max_tokens' 29 | while gen_tokens.shape[-1] <= max_tokens: 30 | with torch.no_grad(): 31 | with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): 32 | logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size) 33 | logits = logits[:, -1, :] # (num_seq, vocab_size) 34 | probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size) 35 | # take top-k 50 probs 36 | topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50) 37 | # sample a token from top-50 probabilities 38 | ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1) 39 | next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1) 40 | gen_tokens = torch.cat([gen_tokens, next_tok], dim=1) 41 | # decode generated tokens and print generated text 42 | for i in range(num_seq): 43 | tokens = gen_tokens[i, :max_tokens].tolist() 44 | gen_text = self.token_encoder.decode(tokens) 45 | print(f"> sample {i}: {gen_text}") 46 | 47 | 48 | def parse_args(): 49 | import argparse 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--prompt', type=str, default="Hello, I am a language model,") 52 | parser.add_argument('--num_seq', type=int, default=5) 53 | parser.add_argument('--max_tokens', type=int, default=50) 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | @dataclass 59 | class GPTConfig: 60 | context_length: int = 1024 # max context / sequence length 61 | vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 token 62 | num_layers: int = 12 63 | embd_size: int = 768 # embedding dim 64 | num_heads: int = 12 65 | 66 | 67 | def inference(args=None): 68 | if args is None: 69 | args = parse_args() 70 | 71 | device = 'cpu' 72 | if torch.cuda.is_available(): 73 | device = 'cuda' 74 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 75 | device = 'mps' # for apple macbook GPUs 76 | print(f'using device: {device}') 77 | 78 | model_path = './logs/model_95364.pt' 79 | checkpoint = torch.load(model_path, weights_only=False) 80 | print(f"loaded model from: {model_path}") 81 | # print(checkpoint['model'].keys()) 82 | 83 | model = GPT(config=checkpoint['config']) 84 | model.load_state_dict(checkpoint['model']) 85 | model = model.to(device) 86 | token_encoder = tiktoken.get_encoding('gpt2') 87 | generator = GPT2Inference(model, token_encoder, device) 88 | 89 | generator.generate_sequences(args.prompt, args.num_seq, args.max_tokens) 90 | 91 | 92 | if __name__ == '__main__': 93 | inference() 94 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dataclasses import dataclass 5 | import inspect 6 | 7 | 8 | @dataclass 9 | class GPTConfig: 10 | context_length: int = 1024 # max context / sequence length 11 | vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 token 12 | num_layers: int = 12 13 | embd_size: int = 768 # embedding dim 14 | num_heads: int = 12 15 | 16 | 17 | class CausalSelfAttention(nn.Module): 18 | def __init__(self, config): 19 | super().__init__() 20 | # 'embd_size' sized vector divided into 'num_heads' heads 21 | assert config.embd_size % config.num_heads == 0, f"embedding dim should be divisible by number of heads" 22 | self.num_heads = config.num_heads 23 | self.embd_size = config.embd_size 24 | # batched key, query, and value projections for all heads 25 | self.c_attn = nn.Linear(config.embd_size, 3 * config.embd_size) 26 | self.c_proj = nn.Linear(config.embd_size, config.embd_size) 27 | self.c_proj.SCALE_INIT = 1.0 28 | # not really a bias, more of a mask, but following OpenAI/HF naming convention 29 | # self.register_buffer("bias", torch.tril(torch.ones(config.context_length, config.context_length)).view(1, 1, config.context_length, config.context_length)) 30 | 31 | def forward(self, x): 32 | B, T, C = x.shape 33 | # calculate query, key, values for all heads in a batch and move head forward to be the batch dim 34 | # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs 35 | # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels 36 | qkv = self.c_attn(x) # (B, T, 3C) 37 | q, k, v = qkv.split(self.embd_size, dim=-1) # (B,T,C), (B,T,C), (B,T,C) 38 | q = q.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs) 39 | k = k.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs) 40 | v = v.view(B, T, self.num_heads, self.embd_size // self.num_heads).transpose(1, 2) # (B,nh,T,hs) 41 | # attn = q @ k.transpose(-2, -1) / np.sqrt(k.shape[-1]) # (B,nh,T,hs) @ (B,nh,hs,T) --> (B,nh,T,T) 42 | # attn = attn.masked_fill(self.bias[:,:,:T,:T] == 0, float("-inf")) 43 | # attn = F.softmax(attn, dim=-1) 44 | # out = attn @ v # (B,nh,T,T) @ (B,nh,T,hs) --> (B,nh,T,hs) 45 | # flash-attention paper (significantly faster, but logically the same as above 4 lines) 46 | out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # (B,nh,T,hs) 47 | out = out.transpose(1, 2).contiguous().view(B, T, C) # (B,nh,T,hs) --> (B,T,nh,hs) --> (B,T,C=nh*hs) 48 | out = self.c_proj(out) # (B,T,C) --> (B,T,C) 49 | return out 50 | 51 | 52 | class MLP(nn.Module): 53 | def __init__(self, config): 54 | super().__init__() 55 | self.c_fc = nn.Linear(config.embd_size, 4 * config.embd_size) 56 | self.gelu = nn.GELU(approximate='tanh') # approximate='tanh' used to try to reproduce gpt2 paper 57 | self.c_proj = nn.Linear(4 * config.embd_size, config.embd_size) 58 | self.c_proj.SCALE_INIT = 1.0 59 | 60 | def forward(self, x): 61 | x = self.c_fc(x) 62 | x = self.gelu(x) 63 | x = self.c_proj(x) 64 | return x 65 | 66 | 67 | class Block(nn.Module): 68 | """ Transformer Encoder block """ 69 | 70 | def __init__(self, config): 71 | super().__init__() 72 | self.ln_1 = nn.LayerNorm(config.embd_size) 73 | self.attn = CausalSelfAttention(config) 74 | self.ln_2 = nn.LayerNorm(config.embd_size) 75 | self.mlp = MLP(config) 76 | 77 | def forward(self, x): 78 | x = x + self.attn(self.ln_1(x)) 79 | x = x + self.mlp(self.ln_2(x)) 80 | return x 81 | 82 | 83 | class GPT(nn.Module): 84 | def __init__(self, config): 85 | super().__init__() 86 | self.config = config 87 | self.transformer = nn.ModuleDict(dict( 88 | wte = nn.Embedding(self.config.vocab_size, self.config.embd_size), 89 | wpe = nn.Embedding(self.config.context_length, self.config.embd_size), 90 | h = nn.ModuleList([Block(self.config) for _ in range(self.config.num_layers)]), 91 | ln_f = nn.LayerNorm(self.config.embd_size) 92 | )) 93 | # language modeling head 94 | self.lm_head = nn.Linear(self.config.embd_size, self.config.vocab_size, bias=False) 95 | # weight sharing scheme (reduces 768*50267=~40M params, fewer params, more efficient) 96 | self.transformer.wte.weight = self.lm_head.weight 97 | # init params (iterates over all submodules and applies _init_weights) 98 | self.apply(self._init_weights) 99 | 100 | def _init_weights(self, module): 101 | if isinstance(module, nn.Linear): 102 | std = 0.02 103 | if hasattr(module, 'SCALE_INIT'): 104 | std /= (2 * self.config.num_layers)**0.5 105 | torch.nn.init.normal_(module.weight, mean=0, std=std) # as per openai gpt-2 source code 106 | if module.bias is not None: 107 | torch.nn.init.zeros_(module.bias) 108 | elif isinstance(module, nn.Embedding): 109 | torch.nn.init.normal_(module.weight, mean=0, std=0.02) 110 | 111 | def forward(self, idx, targets=None): 112 | B, T = idx.shape 113 | assert T <= self.config.context_length, f'sequence length {T} should be <= {self.config.context_length}' 114 | pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T,) 115 | pos_embd = self.transformer.wpe(pos) # (T, embd_size) 116 | tok_embd = self.transformer.wte(idx) # (B, T, embd_size) 117 | x = pos_embd + tok_embd # (B, T, embd_size) 118 | for block in self.transformer.h: 119 | x = block(x) 120 | x = self.transformer.ln_f(x) # (B, T, embd_size) 121 | logits = self.lm_head(x) # (B, T, vocab_size) 122 | loss = None 123 | if targets is not None: 124 | loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1)) 125 | return logits, loss 126 | 127 | @classmethod 128 | def from_pretrained(cls, model_type): 129 | """ Loads pretrained GPT2 model weights from huggingface """ 130 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 131 | from transformers import GPT2LMHeadModel 132 | print(f"loading weights from pretrained gpt: {model_type}") 133 | 134 | config_args = { 135 | 'gpt2': dict(num_layers=12, num_heads=12, embd_size=768), # 124M params 136 | 'gpt2-medium': dict(num_layers=24, num_heads=16, embd_size=1024), # 350M params 137 | 'gpt2-large': dict(num_layers=36, num_heads=20, embd_size=1280), # 774M params 138 | 'gpt2-xl': dict(num_layers=48, num_heads=25, embd_size=1600), # 1558M params 139 | }[model_type] 140 | config_args['vocab_size'] = 50257 141 | config_args['context_length'] = 1024 142 | 143 | # create a from-scratch minGPT model 144 | config = GPTConfig(**config_args) 145 | model = GPT(config) 146 | sd = model.state_dict() 147 | sd_keys = sd.keys() 148 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] 149 | 150 | # init a huggingface transformers model 151 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 152 | sd_hf = model_hf.state_dict() 153 | sd_keys_hf = sd_hf.keys() 154 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] 155 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] 156 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 157 | 158 | assert len(sd_keys) == len(sd_keys_hf), f"mismatched keys {len(sd_keys)} != {len(sd_keys_hf)}" 159 | 160 | # copy while ensuring all parameters are aligned in names and shape 161 | for k in sd_keys_hf: 162 | if any(k.endswith(w) for w in transposed): 163 | # need to transpose Conv1D weights 164 | assert sd_hf[k].shape[::-1] == sd[k].shape 165 | with torch.no_grad(): 166 | sd[k].copy_(sd_hf[k].T) 167 | else: 168 | assert sd_hf[k].shape == sd[k].shape 169 | with torch.no_grad(): 170 | sd[k].copy_(sd_hf[k]) 171 | return model 172 | 173 | def configure_optimizers(self, weight_decay, lr, device_type, master_process): 174 | """ 175 | Essentially implements weight decay (regularization tool, by decaying the weights, we 176 | forcing the optimizer to use more of the weights, and not allowing any single weight to dominate) 177 | """ 178 | # start with all of the candidate params (that require gradient) 179 | param_dict = {pn: p for pn, p in self.named_parameters() if p.requires_grad} 180 | 181 | # create optim groups: any parameters that are 2D will be weight decayed, otherwise no. 182 | # i.e., all weight tensors in matmuls + embeddings will decay, whereas biases and layernorms won't be decayed 183 | decay_params = [p for pn, p in param_dict.items() if p.dim() >= 2] 184 | nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2] 185 | optim_groups = [ 186 | {'params': decay_params, 'weight_decay': weight_decay}, 187 | {'params': nodecay_params, 'weight_decay': 0.0} 188 | ] 189 | num_decay_params = sum(p.numel() for p in decay_params) 190 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 191 | if master_process: 192 | print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters') 193 | print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters') 194 | 195 | # use fused version of AdamW optimizer (faster than non-fused version) 196 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 197 | use_fused = fused_available and device_type == 'cuda' 198 | if master_process: 199 | print(f'using fused AdamW optimizer: {use_fused}') 200 | optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused) 201 | return optimizer -------------------------------------------------------------------------------- /src/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | FineWeb-Edu dataset (for srs pretraining) 3 | https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu 4 | Downloads and tokenizes the data and saves data shards to disk. 5 | Run simply as: 6 | $ python prepare_dataset.py 7 | Will save shards to the local directory "edu_fineweb10B". 8 | """ 9 | 10 | import os 11 | import numpy as np 12 | import multiprocessing as mp 13 | import tiktoken 14 | from datasets import load_dataset 15 | from tqdm import tqdm 16 | 17 | 18 | script_dir = os.path.dirname(__file__) 19 | local_dir = os.path.join(script_dir, "../data/edu_fineweb10B") 20 | remote_name = "sample-10BT" 21 | shard_size = int(1e8) # 100M tokens per shard, total 100 shards = 10B gpt2 tokens 22 | 23 | # create cache and local dir if it doesn't exist yet 24 | DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir) 25 | os.makedirs(DATA_CACHE_DIR, exist_ok=True) 26 | 27 | # download the dataset 28 | fw = load_dataset('HuggingFaceFW/fineweb-edu', name=remote_name, split='train') 29 | 30 | # init the tokenizer 31 | enc = tiktoken.get_encoding('gpt2') 32 | eot = enc._special_tokens['<|endoftext|>'] # end of text token 33 | 34 | def tokenize(doc): 35 | """ tokenizes a single document and returns a np array of uint16 tokens """ 36 | tokens = [eot] # special <|endoftext|> token delimits all documents 37 | tokens.extend(enc.encode_ordinary(doc['text'])) 38 | tokens_np = np.array(tokens) 39 | assert (tokens_np >= 0).all() and (tokens_np < 2**16).all(), 'token dict too large for uint16' 40 | tokens_np_uint16 = tokens_np.astype(np.uint16) 41 | return tokens_np_uint16 42 | 43 | # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder tokens) 44 | nprocs = max(1, os.cpu_count() // 2) 45 | with mp.Pool(nprocs) as pool: 46 | shard_idx = 0 47 | # preallocate buffer to hold current shard 48 | all_tokens_np = np.empty((shard_size,), dtype=np.uint16) 49 | token_count = 0 50 | progress_bar = None 51 | 52 | for tokens in pool.imap(tokenize, fw, chunksize=16): 53 | # check if there is enough space in current shard for new tokens 54 | if token_count + len(tokens) < shard_size: 55 | # simply append tokens to current shard 56 | all_tokens_np[token_count : token_count + len(tokens)] = tokens 57 | token_count += len(tokens) 58 | if progress_bar is None: 59 | progress_bar = tqdm(total=shard_size, unit='tokens', desc=f'shard {shard_idx}') 60 | progress_bar.update(len(tokens)) 61 | else: 62 | # write current shard and start a new one 63 | split = 'val' if shard_idx == 0 else 'train' 64 | filepath = os.path.join(DATA_CACHE_DIR, f'edufineweb_{split}_{shard_idx:06d}') 65 | # split the document into whatever fits in this shard, remainder goes to next one 66 | remainder = shard_size - token_count 67 | progress_bar.update(remainder) 68 | all_tokens_np[token_count : token_count + remainder] = tokens[:remainder] 69 | np.save(filepath, all_tokens_np) 70 | shard_idx += 1 71 | progress_bar = None 72 | all_tokens_np[0:len(tokens) - remainder] = tokens[remainder:] 73 | token_count = len(tokens) - remainder 74 | 75 | if token_count != 0: 76 | split = 'val' if shard_idx == 0 else 'train' 77 | filepath = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_idx:06d}") 78 | np.save(filepath, all_tokens_np[:token_count]) -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import time 5 | from dataclasses import dataclass 6 | import tiktoken 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | # import code; code.interact(local=locals()) 13 | 14 | from model import GPT 15 | from dataloader import DataLoaderLite 16 | from src.hellaswag_eval import render_example, iterate_examples, get_most_likely_row 17 | 18 | torch.set_float32_matmul_precision('high') # enable TF32 precision 19 | 20 | # set torch compile to True (if it doesn't throws any error) to speed up training 21 | use_torch_compile = False 22 | 23 | 24 | class Trainer: 25 | def __init__( 26 | self, 27 | model, 28 | optimizer, 29 | train_loader, 30 | val_loader, 31 | token_encoder, 32 | eval_freq, 33 | grad_accum_steps, 34 | ddp, 35 | ddp_rank, 36 | ddp_world_size, 37 | device, 38 | logpath 39 | ): 40 | self.ddp = ddp 41 | self.ddp_rank = ddp_rank 42 | self.master_process = ddp_rank == 0 43 | self.ddp_world_size = ddp_world_size 44 | 45 | self.model = model 46 | self.optimizer = optimizer 47 | self.train_loader = train_loader 48 | self.val_loader = val_loader 49 | self.token_encoder = token_encoder 50 | 51 | self.eval_freq = eval_freq 52 | self.grad_accum_steps = grad_accum_steps 53 | self.device = device 54 | self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' 55 | self.logpath = logpath 56 | 57 | 58 | def train( 59 | self, 60 | max_steps, 61 | warmup_steps, 62 | max_lr, 63 | min_lr 64 | ): 65 | for step in range(max_steps): 66 | t0 = time.time() 67 | self.is_last_step = (step == max_steps - 1) 68 | 69 | # evaluate validation loss 70 | if step % self.eval_freq == 0 or self.is_last_step: 71 | self.evaluate_validation(step) 72 | 73 | # evaluate model performance on HellaSwag every once in a while 74 | if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile): 75 | self.evaluate_helloswag(step) 76 | 77 | # generate sequences from the model every once in a while 78 | if ((step > 0 and step % self.eval_freq == 0) or self.is_last_step) and (not use_torch_compile): 79 | self.generate_sequences(num_seq=5, max_tokens=32) 80 | 81 | # training loop starts here 82 | self.model.train() # sets model to train mode 83 | self.optimizer.zero_grad() # resets all gradients 84 | batch_loss = 0.0 85 | 86 | for mini_step in range(self.grad_accum_steps): 87 | inp, tar = self.train_loader.next_batch() 88 | inp, tar = inp.to(self.device), tar.to(self.device) 89 | 90 | # FORWARD PASS !!! 91 | # autocast to bfloat16 for faster compute and memory efficiency 92 | with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): 93 | logits, loss = self.model(inp, tar) 94 | 95 | # loss is scaled to account for gradient accumulation, because the gradients just add 96 | # on each successive backward() call. Addition of gradients corresponds to SUM in the objective, 97 | # but we want MEAN instead of a SUM 98 | loss /= self.grad_accum_steps 99 | batch_loss += loss.detach() 100 | 101 | if self.ddp: 102 | # in the final mini_step, sync and avg all gradients across all processes. used by both forward and backward processes 103 | # can use 'no_sync()' context manager alternatively. 104 | self.model.require_backward_grad_sync = (mini_step == self.grad_accum_steps - 1) 105 | 106 | # each process accumulates gradients separately when 'require_backward_grad_sync'=False 107 | # in the final 'mini_step', 'require_backward_grad_sync' becomes True, therefore 108 | # gradients are averaged across all processes and shared among them by loss.backward() 109 | loss.backward() 110 | 111 | if self.ddp: 112 | # 'batch_loss' is outside of DDP container, so need to perform 'all_reduce' to 113 | # average out 'batch_loss' across all processes of all ranks. 'batch_loss' tensor exists on all GPUs. 114 | # 'all_reduce' averages and deposits the result on all the processes 115 | dist.all_reduce(batch_loss, op=dist.ReduceOp.AVG) 116 | 117 | # once gradients are computed, clip the global l2-norm of the gradient at 1.0 118 | norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) # monitor/print 'norm' 119 | 120 | # determine learning rate with decay 121 | lr = self.estimate_lr(step, warmup_steps, max_steps, max_lr, min_lr) 122 | # set learning rate for this iteration 123 | for param_group in self.optimizer.param_groups: 124 | param_group['lr'] = lr 125 | 126 | self.optimizer.step() 127 | if self.device_type == 'cuda': 128 | torch.cuda.synchronize() # wait for the GPU to finish work 129 | 130 | dt = (time.time() - t0) * 1000.0 # in ms 131 | tokens_processed = self.train_loader.B * self.train_loader.T * self.grad_accum_steps * self.ddp_world_size 132 | tokens_per_sec = tokens_processed / dt 133 | 134 | if self.master_process: 135 | print(f'step {step:4d} | loss: {batch_loss.item():.6f} | lr: {lr:.2e} | norm: {norm:.4f} | dt: {dt:.4f}ms | tok/sec: {tokens_per_sec:.4f}') 136 | with open(self.logpath, 'a') as f: 137 | f.write(f'{step} train {batch_loss.item():.6f}\n') 138 | 139 | 140 | def evaluate_validation(self, step): 141 | self.model.eval() # sets model to eval mode 142 | self.val_loader.reset() 143 | # evaluate the model on validation set 144 | with torch.no_grad(): 145 | val_loss_accum = 0.0 146 | val_steps = 20 147 | for _ in range(val_steps): 148 | inp, tar = self.val_loader.next_batch() 149 | inp, tar = inp.to(self.device), tar.to(self.device) 150 | with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): 151 | logits, loss = self.model(inp, tar) 152 | loss /= val_steps 153 | val_loss_accum += loss.detach() 154 | 155 | if self.ddp: 156 | dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) 157 | if self.master_process: 158 | print(f'Val loss: {val_loss_accum.item():.4f}') 159 | with open(self.logpath, 'a') as f: 160 | f.write(f'{step} val {val_loss_accum.item():.4f}\n') 161 | 162 | if step > 0 and (step % 10000 == 0 or self.is_last_step): 163 | raw_model = self.model.module if self.ddp else self.model 164 | logdir = os.path.dirname(self.logpath) 165 | ckpt_path = os.path.join(logdir, f'model_{step:05d}.pt') 166 | checkpoint = { 167 | 'model': raw_model.state_dict(), 168 | 'config': raw_model.config, 169 | 'step': step, 170 | 'val_loss': val_loss_accum.item() 171 | } # add optimizer.state_dict(), rng_seeds, etc. if resuming training 172 | torch.save(checkpoint, ckpt_path) 173 | 174 | 175 | def evaluate_helloswag(self, step): 176 | """ 177 | Construct a batch of 4 sequences and perform token completion using 178 | our model. 179 | """ 180 | n_total = 0 181 | n_correct_norm = 0 182 | for i, example in enumerate(iterate_examples('val')): 183 | # only process examples where i % ddp_world_size == ddp_rank 184 | if i % self.ddp_world_size != self.ddp_rank: 185 | continue 186 | # render the example into tokens and labels 187 | _, tokens, mask, label = render_example(example) # (4,N), (4,N), (4,N) 188 | tokens, mask = tokens.to(self.device), mask.to(self.device) 189 | with torch.no_grad(): 190 | with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): 191 | logits, loss = self.model(tokens) 192 | pred_norm = get_most_likely_row(tokens, mask, logits) 193 | n_total += 1 194 | n_correct_norm += int(pred_norm == label) 195 | # reduce the stats across all processes 196 | if self.ddp: 197 | n_total = torch.tensor(n_total, device=self.device, dtype=torch.long) 198 | n_correct_norm = torch.tensor(n_correct_norm, device=self.device, dtype=torch.long) 199 | dist.all_reduce(n_total, op=dist.ReduceOp.SUM) 200 | dist.all_reduce(n_correct_norm, op=dist.ReduceOp.SUM) 201 | n_total = n_total.item() 202 | n_correct_norm = n_correct_norm.item() 203 | acc_norm = n_correct_norm / n_total 204 | if self.master_process: 205 | print(f'HelloSwag accuracy: {n_correct_norm}/{n_total}={acc_norm:.4f}') 206 | with open(self.logpath, 'a') as f: 207 | f.write(f'{step} hellaswag {acc_norm:.4f}\n') 208 | 209 | 210 | def generate_sequences(self, num_seq=4, max_tokens=32): 211 | self.model.eval() 212 | tokens = self.token_encoder.encode("Hello, I am a language model") 213 | tokens = torch.tensor(tokens, dtype=torch.long) # (n,) n : current sequence length 214 | tokens = tokens.unsqueeze(0).repeat(num_seq, 1) # (1,n) --> (num_seq, n) 215 | gen_tokens = tokens.to(self.device) 216 | # create a different rng generator so as not to impact the global rng state used for training 217 | sample_rng = torch.Generator(device=self.device) 218 | # adding 'ddp_rank' in seeding to generate different tokens for different rank processes 219 | sample_rng.manual_seed(42 + self.ddp_rank) 220 | # generate new tokens one token at a time until the sequence length becomes 'max_tokens' 221 | while gen_tokens.shape[-1] <= max_tokens: 222 | with torch.no_grad(): 223 | with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): 224 | logits, loss = self.model(gen_tokens) # (num_seq, n, vocab_size) 225 | logits = logits[:, -1, :] # (num_seq, vocab_size) 226 | probs = F.softmax(logits, dim=-1) # (num_seq, vocab_size) 227 | # take top-k 50 probs 228 | topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) # (num_seq, 50), (num_seq, 50) 229 | # sample a token from top-50 probabilities 230 | ix = torch.multinomial(topk_probs, num_samples=1, generator=sample_rng) # (num_seq, 1) 231 | next_tok = torch.gather(topk_indices, -1, ix) # (num_seq, 1) 232 | gen_tokens = torch.cat([gen_tokens, next_tok], dim=1) 233 | # decode generated tokens and print generated text 234 | for i in range(num_seq): 235 | tokens = gen_tokens[i, :max_tokens].tolist() 236 | gen_text = self.token_encoder.decode(tokens) 237 | print(f"> rank {self.ddp_rank} sample {i}: {gen_text}") 238 | 239 | 240 | def estimate_lr(self, step, warmup_steps, max_steps, max_lr, min_lr): 241 | """ 242 | Learning rate scheduler: Cosine-decay learning schedule with warmup 243 | """ 244 | # 1) linear warmup for 'warmup_iters' steps 245 | if step < warmup_steps: 246 | return max_lr * (step+1) / warmup_steps 247 | # 2) if step > lr_decay_iters, return min lr 248 | if step > max_steps: 249 | return min_lr 250 | # 3) in between, use cosine decay down to min lr 251 | decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) 252 | assert 0 <= decay_ratio <= 1 253 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0 254 | return min_lr + coeff * (max_lr - min_lr) 255 | 256 | 257 | @dataclass 258 | class GPTConfig: 259 | context_length: int = 1024 # max context / sequence length 260 | vocab_size: int = 50257 # number of tokens: 50000 BPE merges + 256 bytes tokens + 1 token 261 | num_layers: int = 12 262 | embd_size: int = 768 # embedding dim 263 | num_heads: int = 12 264 | 265 | 266 | def get_args(): 267 | import argparse 268 | parser = argparse.ArgumentParser(description="Hyperparameter Configuration") 269 | parser.add_argument("--total_batch_size", type=int, default=524288, help="number of tokens processed for each weight update") # =2^19 tokens/step update, (~0.5M tokens used in openai gpt3 paper) 270 | parser.add_argument("--mini_batch_size", type=int, default=32, help="setting of mini_batch_size is just a performance optimization. bigger gpu, bigger mini_batch_size") 271 | parser.add_argument("--context_length", type=int, default=1024) # max sequence length (can also try 2048) 272 | parser.add_argument("--num_layers", type=int, default=12) 273 | parser.add_argument("--embd_size", type=int, default=768) 274 | parser.add_argument("--num_heads", type=int, default=12) 275 | parser.add_argument("--max_lr", type=float, default=1e-3) 276 | parser.add_argument("--min_lr", type=float, default=1e-3 * 0.1) 277 | parser.add_argument("--warmup_steps", type=int, default=715) 278 | parser.add_argument("--weight_decay", type=float, default=0.1) 279 | parser.add_argument("--num_epochs", type=int, default=5) 280 | parser.add_argument("--steps_per_epoch", type=int, default=19073) # 10^10 / 2^19 ~ 19073 for 1 epoch on FineWebEdu-sample10BT 281 | parser.add_argument("--eval_freq", type=int, default=250) 282 | # parser.add_argument("--use_torch_compile", action='store_true') # default False 283 | parser.add_argument("--seed", type=int, default=1337, help="Random seed for reproducibility") 284 | parser.add_argument("--logdir", type=str, default="./logs/") 285 | return parser.parse_args() 286 | 287 | 288 | def main(): 289 | args = get_args() 290 | 291 | # Print the hyperparameters 292 | print("Hyperparameter Configuration:") 293 | for key, value in vars(args).items(): 294 | print(f"{key}: {value}") 295 | 296 | # create the logs directory if it doesn't exist 297 | os.makedirs(args.logdir, exist_ok=True) 298 | logpath = os.path.join(args.logdir, 'log.txt') 299 | with open(logpath, 'w') as f: 300 | pass 301 | 302 | # set up DDP (distributed data parallel) 303 | # 'torchrun' command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE 304 | # RANK and LOCAL_RANK same for (single node, multi-GPU) settings, may differ for (multinode, 305 | # multi GPU) settings. 306 | ddp = int(os.environ.get('RANK', -1)) != -1 # if this is a ddp run or not 307 | if ddp: 308 | # use of ddp requires CUDA 309 | assert torch.cuda.is_available(), f'use of DDP requires CUDA' 310 | dist.init_process_group(backend='nccl') 311 | ddp_rank = int(os.environ['RANK']) 312 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 313 | ddp_world_size = int(os.environ['WORLD_SIZE']) 314 | device = f'cuda:{ddp_local_rank}' 315 | torch.cuda.set_device(device) 316 | # master process (arbitrarily set to 0) will do printing, logging, checkpointing, etc. 317 | master_process = ddp_rank == 0 318 | else: 319 | # not using ddp 320 | ddp_rank = 0 321 | ddp_local_rank = 0 322 | ddp_world_size = 1 323 | master_process = True # ddp_rank == 0 324 | device = 'cpu' 325 | if torch.cuda.is_available(): 326 | device = 'cuda' 327 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 328 | device = 'mps' # for apple macbook GPUs 329 | print(f'using device: {device}') 330 | 331 | device_type = 'cuda' if device.startswith('cuda') else 'cpu' 332 | 333 | # setting seed for reproducibility 334 | np.random.seed(args.seed) 335 | torch.manual_seed(args.seed) # sets seed for random number generation on CPU 336 | if torch.cuda.is_available(): 337 | torch.cuda.manual_seed(args.seed) # sets seed for random number generation on GPU 338 | torch.cuda.manual_seed_all(args.seed) # sets seed for all GPUs 339 | 340 | assert args.total_batch_size % (args.mini_batch_size * args.context_length * ddp_world_size) == 0, f'ensure total_batch_size divisible by B*T*ddp_world_size' 341 | grad_accum_steps = args.total_batch_size // (args.mini_batch_size * args.context_length * ddp_world_size) 342 | if master_process: 343 | print(f'desired batch size (number of tokens): {args.total_batch_size}') 344 | print(f'gradient accumulation steps: {grad_accum_steps}') 345 | print(f'GPU: {ddp_rank}, {ddp_local_rank}') 346 | 347 | train_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='train') 348 | val_loader = DataLoaderLite(B=args.mini_batch_size, T=args.context_length, process_rank=ddp_rank, num_processes=ddp_world_size, split='val') 349 | 350 | # create GPT model. each ddp process will create its own instance of the model but since the seed is fixed, 351 | # they will create same identical model 352 | gpt_config = GPTConfig(vocab_size=50304, # 50304 (nice number, lots of power of 2s) used instead of 50257 (bad, odd number) 353 | context_length=args.context_length, 354 | num_layers=args.num_layers, 355 | num_heads=args.num_heads, 356 | embd_size=args.embd_size 357 | ) 358 | model = GPT(config=gpt_config) 359 | # model = GPT.from_pretrained('gpt2') # init from OpenAI GPT-2 360 | model.to(device) # move model to device 361 | if use_torch_compile: 362 | # use torch compile almost always unless debugging (requires compilation time, but makes training faster) 363 | # speedup comes from reducing python overhead and GPU read/write 364 | model = torch.compile(model) 365 | 366 | if ddp: 367 | # wraps the model in DDP container (forward pass is unchanged, but after backward pass, 368 | # gradients computed across each processes averaged by DDP using 'AllReduce' and shared across 369 | # all processes so that each process has same gradients) 370 | model = DDP(model, device_ids=[ddp_local_rank]) 371 | 372 | raw_model = model.module if ddp else model 373 | optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, lr=args.max_lr, device_type=device_type, master_process=master_process) 374 | token_encoder = tiktoken.get_encoding('gpt2') 375 | 376 | start_time = time.time() 377 | # init the trainer object 378 | trainer = Trainer(model, optimizer, train_loader, val_loader, token_encoder, args.eval_freq, grad_accum_steps, 379 | ddp, ddp_rank, ddp_world_size, device, logpath) 380 | 381 | max_steps = args.steps_per_epoch * args.num_epochs 382 | trainer.train(max_steps, args.warmup_steps, args.max_lr, args.min_lr) 383 | 384 | dt = (time.time() - start_time) / (60*60) 385 | print(f"Total training time: {dt:.4f}hr") 386 | 387 | if ddp: 388 | dist.destroy_process_group() 389 | 390 | 391 | if __name__ == "__main__": 392 | main() 393 | --------------------------------------------------------------------------------