├── LICENSE ├── README.md ├── config └── config.py ├── data_loader └── data_loader.py ├── requirements.txt ├── scripts ├── data_download.py ├── data_preprocess.py ├── generate_text.py └── train_transformer.py ├── sft_rlhf_guide.ipynb └── src ├── __init__.py └── models ├── __init__.py ├── attention.py ├── mlp.py ├── transformer.py └── transformer_block.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Fareed Khan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![main image](https://cdn-images-1.medium.com/max/5200/1*r99Hq3YBd5FTTWLNYKKvPw.png) 2 | 3 |
4 | 5 | 6 | # Train LLM From Scratch 7 | 8 | ![Python](https://img.shields.io/badge/Python-3.8%2B-blue) ![License](https://img.shields.io/badge/License-MIT-green) ![Contributions](https://img.shields.io/badge/Contributions-Welcome-blue) [![Docs](https://img.shields.io/badge/Docs-Available-success)](#step-by-step-code-explanation) 9 | 10 | **I am Looking for a PhD position in AI**. Take a look at my [Resume](https://drive.google.com/file/d/1Q_iklJ1RVGSb-Pdey8BHy3k8IF3UJv0z/view?usp=sharing) or [GitHub](https://github.com/FareedKhan-dev) 11 | 12 |
13 | 14 | I implemented a transformer model from scratch using PyTorch, based on the paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). You can use my scripts to train your own **billion** or **million** parameter LLM using a single GPU. 15 | 16 | Below is the output of the trained 13 million parameter LLM: 17 | 18 | ``` 19 | In ***1978, The park was returned to the factory-plate that 20 | the public share to the lower of the electronic fence that 21 | follow from the Station's cities. The Canal of ancient Western 22 | nations were confined to the city spot. The villages were directly 23 | linked to cities in China that revolt that the US budget and in 24 | Odambinais is uncertain and fortune established in rural areas. 25 | ``` 26 | 27 | ## Table of Contents 28 | - [Training Data Info](#training-data-info) 29 | - [Prerequisites and Training Time](#prerequisites-and-training-time) 30 | - [Code Structure](#code-structure) 31 | - [Usage](#usage) 32 | - [Step by Step Code Explanation](#step-by-step-code-explanation) 33 | - [Importing Libraries](#importing-libraries) 34 | - [Preparing the Training Data](#preparing-the-training-data) 35 | - [Transformer Overview](#transformer-overview) 36 | - [Multi Layer Perceptron (MLP)](#multi-layer-perceptron-mlp) 37 | - [Single Head Attention](#single-head-attention) 38 | - [Multi Head Attention](#multi-head-attention) 39 | - [Transformer Block](#transformer-block) 40 | - [The Final Model](#the-final-model) 41 | - [Batch Processing](#batch-processing) 42 | - [Training Parameters](#training-parameters) 43 | - [Training the Model](#training-the-model) 44 | - [Saving the Trained Model](#saving-the-trained-model) 45 | - [Training Loss](#training-loss) 46 | - [Generating Text](#generating-text) 47 | - [What’s Next](#whats-next) 48 | 49 | ## Training Data Info 50 | 51 | Training data is from the Pile dataset, which is a diverse, open-source, and large-scale dataset for training language models. The Pile dataset is a collection of 22 diverse datasets, including text from books, articles, websites, and more. The total size of the Pile dataset is 825GB, Below is the sample of the training data: 52 | 53 | ```python 54 | Line: 0 55 | { 56 | "text": "Effect of sleep quality ... epilepsy.", 57 | "meta": { 58 | "pile_set_name": "PubMed Abstracts" 59 | } 60 | } 61 | 62 | Line: 1 63 | { 64 | "text": "LLMops a new GitHub Repository ...", 65 | "meta": { 66 | "pile_set_name": "Github" 67 | } 68 | } 69 | ``` 70 | 71 | ## Prerequisites and Training Time 72 | 73 | Make sure you have a basic understanding of object-oriented programming (OOP), neural networks (NN) and PyTorch to understand the code. Below are some resources to help you get started: 74 | 75 | | Topic | Video Link | 76 | |---------------------|-----------------------------------------------------------| 77 | | OOP | [OOP Video](https://www.youtube.com/watch?v=Ej_02ICOIgs&pp=ygUKb29wIHB5dGhvbg%3D%3D) | 78 | | Neural Network | [Neural Network Video](https://www.youtube.com/watch?v=Jy4wM2X21u0&pp=ygUbbmV1cmFsIG5ldHdvcmsgcHl0aG9uIHRvcmNo) | 79 | | Pytorch | [Pytorch Video](https://www.youtube.com/watch?v=V_xro1bcAuA&pp=ygUbbmV1cmFsIG5ldHdvcmsgcHl0aG9uIHRvcmNo) | 80 | 81 | You will need a GPU to train your model. Colab or Kaggle T4 will work for training a 13+ million-parameter model, but they will fail for billion-parameter training. Take a look at the comparison: 82 | 83 | | GPU Name | Memory | Data Size | 2B LLM Training | 13M LLM Training | Max Practical LLM Size (Training) | 84 | |--------------------------|--------|-----------|-----------------|------------------|-----------------------------------| 85 | | NVIDIA A100 | 40 GB | Large | ✔ | ✔ | ~6B–8B | 86 | | NVIDIA V100 | 16 GB | Medium | ✘ | ✔ | ~2B | 87 | | AMD Radeon VII | 16 GB | Medium | ✘ | ✔ | ~1.5B–2B | 88 | | NVIDIA RTX 3090 | 24 GB | Large | ✔ | ✔ | ~3.5B–4B | 89 | | Tesla P100 | 16 GB | Medium | ✘ | ✔ | ~1.5B–2B | 90 | | NVIDIA RTX 3080 | 10 GB | Medium | ✘ | ✔ | ~1.2B | 91 | | AMD RX 6900 XT | 16 GB | Large | ✘ | ✔ | ~2B | 92 | | NVIDIA GTX 1080 Ti | 11 GB | Medium | ✘ | ✔ | ~1.2B | 93 | | Tesla T4 | 16 GB | Small | ✘ | ✔ | ~1.5B–2B | 94 | | NVIDIA Quadro RTX 8000 | 48 GB | Large | ✔ | ✔ | ~8B–10B | 95 | | NVIDIA RTX 4070 | 12 GB | Medium | ✘ | ✔ | ~1.5B | 96 | | NVIDIA RTX 4070 Ti | 12 GB | Medium | ✘ | ✔ | ~1.5B | 97 | | NVIDIA RTX 4080 | 16 GB | Medium | ✘ | ✔ | ~2B | 98 | | NVIDIA RTX 4090 | 24 GB | Large | ✔ | ✔ | ~4B | 99 | | NVIDIA RTX 4060 Ti | 8 GB | Small | ✘ | ✔ | ~1B | 100 | | NVIDIA RTX 4060 | 8 GB | Small | ✘ | ✔ | ~1B | 101 | | NVIDIA RTX 4050 | 6 GB | Small | ✘ | ✔ | ~0.75B | 102 | | NVIDIA RTX 3070 | 8 GB | Small | ✘ | ✔ | ~1B | 103 | | NVIDIA RTX 3060 Ti | 8 GB | Small | ✘ | ✔ | ~1B | 104 | | NVIDIA RTX 3060 | 12 GB | Medium | ✘ | ✔ | ~1.5B | 105 | | NVIDIA RTX 3050 | 8 GB | Small | ✘ | ✔ | ~1B | 106 | | NVIDIA GTX 1660 Ti | 6 GB | Small | ✘ | ✔ | ~0.75B | 107 | | AMD RX 7900 XTX | 24 GB | Large | ✔ | ✔ | ~3.5B–4B | 108 | | AMD RX 7900 XT | 20 GB | Large | ✔ | ✔ | ~3B | 109 | | AMD RX 7800 XT | 16 GB | Medium | ✘ | ✔ | ~2B | 110 | | AMD RX 7700 XT | 12 GB | Medium | ✘ | ✔ | ~1.5B | 111 | | AMD RX 7600 | 8 GB | Small | ✘ | ✔ | ~1B | 112 | 113 | The 13M LLM training is the training of a 13+ million-parameter model, and the 2B LLM training is the training of a 2+ billion-parameter model. The data size is categorized as small, medium, and large. The small data size is around 1 GB, the medium data size is around 5 GB, and the large data size is around 10 GB. 114 | 115 | ## Code Structure 116 | 117 | The codebase is organized as follows: 118 | ```bash 119 | train-llm-from-scratch/ 120 | ├── src/ 121 | │ ├── models/ 122 | │ │ ├── mlp.py # Definition of the Multi-Layer Perceptron (MLP) module 123 | │ │ ├── attention.py # Definitions for attention mechanisms (single-head, multi-head) 124 | │ │ ├── transformer_block.py # Definition of a single Transformer block 125 | │ │ ├── transformer.py # Definition of the main Transformer model 126 | ├── config/ 127 | │ └── config.py # Contains default configurations (model parameters, file paths, etc.) 128 | ├── data_loader/ 129 | │ └── data_loader.py # Contains functions for creating data loaders/iterators 130 | ├── scripts/ 131 | │ ├── train_transformer.py # Script for training the Transformer model 132 | │ ├── data_download.py # Script for downloading the dataset 133 | │ ├── data_preprocess.py # Script for preprocessing the downloaded data 134 | │ ├── generate_text.py # Script for generating text using a trained model 135 | ├── data/ # Directory to store the dataset 136 | │ ├── train/ # Contains training data 137 | │ └── val/ # Contains validation data 138 | ├── models/ # Directory where trained models are saved 139 | ``` 140 | 141 | `scripts/` directory contains scripts for downloading the dataset, preprocessing the data, training the model, and generating text using the trained model. `src/models/` directory contains the implementation of the transformer model, multi-layer perceptron (MLP), attention mechanisms, and transformer blocks.`config/` directory contains the configuration file with default parameters. `data_loader/` directory contains functions for creating data loaders/iterators. 142 | 143 | ## Usage 144 | 145 | Clone the repository and navigate to the directory: 146 | ```bash 147 | git clone https://github.com/FareedKhan-dev/train-llm-from-scratch.git 148 | cd train-llm-from-scratch 149 | ``` 150 | 151 | if you encounter any issues regarding the imports, make sure to change pythonpath to the root directory of the project: 152 | ```bash 153 | export PYTHONPATH="${PYTHONPATH}:/path/to/train-llm-from-scratch" 154 | 155 | # or if you are already in the train-llm-from-scratch directory 156 | export PYTHONPATH="$PYTHONPATH:." 157 | ``` 158 | 159 | Install the required dependencies: 160 | ```bash 161 | pip install -r requirements.txt 162 | ``` 163 | 164 | You can modify the transformer architecture under `src/models/transformer.py` and the training configurations under `config/config.py`. 165 | 166 | 167 | To download the training data, run: 168 | ```bash 169 | python scripts/data_download.py 170 | ``` 171 | 172 | The script supports the following arguments: 173 | * `--train_max`: Maximum number of training files to download. Default is 1 (Max equal to 30) Each file is around 11 GB. 174 | * `--train_dir`: Directory for storing training data. Default is `data/train`. 175 | * `--val_dir`: Directory for storing validation data. Default is `data/val`. 176 | 177 | To preprocess the downloaded data, run: 178 | ```bash 179 | python scripts/data_preprocess.py 180 | ``` 181 | 182 | The script supports the following arguments: 183 | - `--train_dir`: Directory where the training data files are stored (default is `data/train`). 184 | - `--val_dir`: Directory where the validation data files are stored (default is `data/val`). 185 | - `--out_train_file`: Path to store the processed training data in HDF5 format (default is `data/train/pile_train.h5`). 186 | - `--out_val_file`: Path to store the processed validation data in HDF5 format (default is `data/val/pile_dev.h5`). 187 | - `--tokenizer_name`: Name of the tokenizer to use for processing the data (default is `r50k_base`). 188 | - `--max_data`: Maximum number of JSON objects ([lines](#training-data-info)) to process from each dataset (both train and validation). The default is 1000. 189 | 190 | Now that the data is preprocessed, you can train the 13 million parameter llm by changing the configuration in `config/config.py` to this: 191 | 192 | ```python 193 | # Define vocabulary size and transformer configuration (3 Billion) 194 | VOCAB_SIZE = 50304 # Number of unique tokens in the vocabulary 195 | CONTEXT_LENGTH = 128 # Maximum sequence length for the model 196 | N_EMBED = 128 # Dimension of the embedding space 197 | N_HEAD = 8 # Number of attention heads in each transformer block 198 | N_BLOCKS = 1 # Number of transformer blocks in the model 199 | ``` 200 | 201 | To train the model, run: 202 | ```bash 203 | python scripts/train_transformer.py 204 | ``` 205 | 206 | It will start training the model and save the trained model in the `models/` default directory or the directory specified in the configuration file. 207 | 208 | To generate text using the trained model, run: 209 | ```bash 210 | python scripts/generate_text.py --model_path models/your_model.pth --input_text hi 211 | ``` 212 | 213 | The script supports the following arguments: 214 | - `--model_path`: Path to the trained model. 215 | - `--input_text`: Initial text prompt for generating new text. 216 | - `--max_new_tokens`: Maximum number of tokens to generate (default is 100). 217 | 218 | It will generate text based on the input prompt using the trained model. 219 | 220 | ## Step by Step Code Explanation 221 | 222 | This section is for those who want to understand the code in detail. I will explain the code step by step, starting from importing the libraries to training the model and generating text. 223 | 224 | Previously, I wrote an article on Medium about creating a [2.3+ million-parameter](https://levelup.gitconnected.com/building-a-million-parameter-llm-from-scratch-using-python-f612398f06c2) LLM using the Tiny Shakespeare dataset, but the output didn’t make sense. Here is a sample output: 225 | 226 | ```bash 227 | # 2.3 Million Parameter LLM Output 228 | ZELBETH: 229 | Sey solmenter! tis tonguerered if 230 | Vurint as steolated have loven OID the queend refore 231 | Are been, good plmp: 232 | 233 | Proforne, wiftes swleen, was no blunderesd a a quain beath! 234 | Tybell is my gateer stalk smend as be matious dazest 235 | ``` 236 | 237 | I had a thought, what if I make the transformer architecture smaller and less complex, and the training data more diverse? Then, how big of a model could a single person, using their nearly dead GPU, create in terms of parameters that can speak proper grammar and generate text that makes some sense? 238 | 239 | I found that **13+ million-parameter** models are enough to start making sense in terms of proper grammar and punctuation, which is a positive point. This means we can use a very specific dataset to further fine-tune our previously trained model for a narrowed task. We might end up with a model under 1 billion parameters or even around 500 million parameters that is perfect for our specific use case, especially for running it on private data securely. 240 | 241 | I recommend you **first train a 13+ million-parameter** model using the script available in my GitHub repository. You will get results within one day, instead of waiting for a longer time, or if your local GPU might not be strong enough to train a billion-parameter model. 242 | 243 | ### Importing Libraries 244 | 245 | Let’s import the required libraries that will be used throughout this blog: 246 | 247 | ```python 248 | # PyTorch for deep learning functions and tensors 249 | import torch 250 | import torch.nn as nn 251 | import torch.nn.functional as F 252 | 253 | # Numerical operations and arrays handling 254 | import numpy as np 255 | 256 | # Handling HDF5 files 257 | import h5py 258 | 259 | # Operating system and file management 260 | import os 261 | 262 | # Command-line argument parsing 263 | import argparse 264 | 265 | # HTTP requests and interactions 266 | import requests 267 | 268 | # Progress bar for loops 269 | from tqdm import tqdm 270 | 271 | # JSON handling 272 | import json 273 | 274 | # Zstandard compression library 275 | import zstandard as zstd 276 | 277 | # Tokenization library for large language models 278 | import tiktoken 279 | 280 | # Math operations (used for advanced math functions) 281 | import math 282 | ``` 283 | 284 | ### Preparing the Training Data 285 | 286 | Our training dataset needs to be diverse, containing information from different domains, and The Pile is the right choice for it. Although it is 825 GB in size, we will stick to only a small portion of it, i.e., 5%–10%. Let’s first download the dataset and see how it works. I will be downloading the version available on [HuggingFace](https://huggingface.co/datasets/monology/pile-uncopyrighted). 287 | 288 | ```python 289 | # Download validation dataset 290 | !wget https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/val.jsonl.zst 291 | 292 | # Download the first part of the training dataset 293 | !wget https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/train/00.jsonl.zst 294 | 295 | # Download the second part of the training dataset 296 | !wget https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/train/01.jsonl.zst 297 | 298 | # Download the third part of the training dataset 299 | !wget https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main/train/02.jsonl.zst 300 | ``` 301 | 302 | It will take some time to download, but you can also limit the training dataset to just one file, `00.jsonl.zst`, instead of three. It is already split into train/val/test. Once it's done, make sure to place the files correctly in their respective directories. 303 | 304 | ```python 305 | import os 306 | import shutil 307 | import glob 308 | 309 | # Define directory structure 310 | train_dir = "data/train" 311 | val_dir = "data/val" 312 | 313 | # Create directories if they don't exist 314 | os.makedirs(train_dir, exist_ok=True) 315 | os.makedirs(val_dir, exist_ok=True) 316 | 317 | # Move all train files (e.g., 00.jsonl.zst, 01.jsonl.zst, ...) 318 | train_files = glob.glob("*.jsonl.zst") 319 | for file in train_files: 320 | if file.startswith("val"): 321 | # Move validation file 322 | dest = os.path.join(val_dir, file) 323 | else: 324 | # Move training file 325 | dest = os.path.join(train_dir, file) 326 | shutil.move(file, dest) 327 | 328 | Our dataset is in the .jsonl.zst format, which is a compressed file format commonly used for storing large datasets. It combines JSON Lines (.jsonl), where each line represents a valid JSON object, with Zstandard (.zst) compression. Let's read a sample of one of the downloaded files and see how it looks. 329 | 330 | in_file = "data/val/val.jsonl.zst" # Path to our validation file 331 | 332 | with zstd.open(in_file, 'r') as in_f: 333 | for i, line in tqdm(enumerate(in_f)): # Read first 5 lines 334 | data = json.loads(line) 335 | print(f"Line {i}: {data}") # Print the raw data for inspection 336 | if i == 2: 337 | break 338 | ``` 339 | 340 | The output of the above code is this: 341 | 342 | ```python 343 | #### OUTPUT #### 344 | Line: 0 345 | { 346 | "text": "Effect of sleep quality ... epilepsy.", 347 | "meta": { 348 | "pile_set_name": "PubMed Abstracts" 349 | } 350 | } 351 | 352 | Line: 1 353 | { 354 | "text": "LLMops a new GitHub Repository ...", 355 | "meta": { 356 | "pile_set_name": "Github" 357 | } 358 | } 359 | ``` 360 | 361 | Now we need to encode (tokenize) our dataset. Our goal is to have an LLM that can at least output proper words. For that, we need to use an already available tokenizer. We will use the tiktoken open-source tokenizer by OpenAI. We will use the r50k_base tokenizer, which is used for the ChatGPT (GPT-3) model, to tokenize our dataset. 362 | 363 | We need to create a function for this to avoid duplication, as we will be tokenizing both the train and validation datasets. 364 | 365 | ```python 366 | def process_files(input_dir, output_file): 367 | """ 368 | Process all .zst files in the specified input directory and save encoded tokens to an HDF5 file. 369 | 370 | Args: 371 | input_dir (str): Directory containing input .zst files. 372 | output_file (str): Path to the output HDF5 file. 373 | """ 374 | with h5py.File(output_file, 'w') as out_f: 375 | # Create an expandable dataset named 'tokens' in the HDF5 file 376 | dataset = out_f.create_dataset('tokens', (0,), maxshape=(None,), dtype='i') 377 | start_index = 0 378 | 379 | # Iterate through all .zst files in the input directory 380 | for filename in sorted(os.listdir(input_dir)): 381 | if filename.endswith(".jsonl.zst"): 382 | in_file = os.path.join(input_dir, filename) 383 | print(f"Processing: {in_file}") 384 | 385 | # Open the .zst file for reading 386 | with zstd.open(in_file, 'r') as in_f: 387 | # Iterate through each line in the compressed file 388 | for line in tqdm(in_f, desc=f"Processing {filename}"): 389 | # Load the line as JSON 390 | data = json.loads(line) 391 | 392 | # Append the end-of-text token to the text and encode it 393 | text = data['text'] + "<|endoftext|>" 394 | encoded = enc.encode(text, allowed_special={'<|endoftext|>'}) 395 | encoded_len = len(encoded) 396 | 397 | # Calculate the end index for the new tokens 398 | end_index = start_index + encoded_len 399 | 400 | # Expand the dataset size and store the encoded tokens 401 | dataset.resize(dataset.shape[0] + encoded_len, axis=0) 402 | dataset[start_index:end_index] = encoded 403 | 404 | # Update the start index for the next batch of tokens 405 | start_index = end_index 406 | ``` 407 | 408 | There are two important points regarding this function: 409 | 410 | 1. We are storing the tokenized data in an HDF5 file, which allows us flexibility for quicker data access while training the model. 411 | 412 | 2. Appending the `<|endoftext|>` token marks the end of each text sequence, signaling to the model that it has reached the end of a meaningful context, which helps in generating coherent outputs. 413 | 414 | Now we can simply encode our train and validation datasets using: 415 | 416 | ```python 417 | # Define tokenized data output directories 418 | out_train_file = "data/train/pile_train.h5" 419 | out_val_file = "data/val/pile_dev.h5" 420 | 421 | # Loading tokenizer of (GPT-3/GPT-2 Model) 422 | enc = tiktoken.get_encoding('r50k_base') 423 | 424 | # Process training data 425 | process_files(train_dir, out_train_file) 426 | 427 | # Process validation data 428 | process_files(val_dir, out_val_file) 429 | ``` 430 | 431 | Let’s take a look at the sample of our tokenized data: 432 | 433 | ```python 434 | with h5py.File(out_val_file, 'r') as file: 435 | # Access the 'tokens' dataset 436 | tokens_dataset = file['tokens'] 437 | 438 | # Print the dtype of the dataset 439 | print(f"Dtype of 'tokens' dataset: {tokens_dataset.dtype}") 440 | 441 | # load and print the first few elements of the dataset 442 | print("First few elements of the 'tokens' dataset:") 443 | print(tokens_dataset[:10]) # First 10 token 444 | ``` 445 | 446 | The output of the above code is this: 447 | 448 | ```python 449 | #### OUTPUT #### 450 | Dtype of 'tokens' dataset: int32 451 | 452 | First few elements of the 'tokens' dataset: 453 | [ 2725 6557 83 23105 157 119 229 77 5846 2429] 454 | ``` 455 | We have prepared our dataset for training. Now we will code the transformer architecture and look into its theory correspondingly. 456 | 457 | ### Transformer Overview 458 | 459 | Let’s have a quick look at how a transformer architecture is used to process and understand text. It works by breaking text into smaller pieces called tokens and predicting the next token in the sequence. A transformer has many layers, called transformer blocks, stacked on top of each other, with a final layer at the end to make the prediction. 460 | 461 | Each transformer block has two main components: 462 | 463 | * **Self-Attention Heads**: These figure out which parts of the input are most important for the model to focus on. For example, when processing a sentence, the attention heads can highlight relationships between words, such as how a pronoun relates to the noun it refers to. 464 | 465 | * **MLP (Multi-Layer Perceptron)**: This is a simple feed-forward neural network. It takes the information emphasized by the attention heads and processes it further. The MLP has an input layer that receives data from the attention heads, a hidden layer that adds complexity to the processing, and an output layer that passes the results to the next transformer block. 466 | 467 | Together, the attention heads act as the “what to think about” part, while the MLP is the “how to think about it” part. Stacking many transformer blocks allows the model to understand complex patterns and relationships in the text, but this is not always guaranteed. 468 | 469 | Instead of looking at the original paper diagram, let’s visualize a simpler and easier architecture diagram that we will be coding. 470 | 471 | ![Transformer Architecture by [Fareed Khan](undefined)](https://cdn-images-1.medium.com/max/11808/1*QXmeA-H52C-p82AwawslbQ.png) 472 | 473 | Let’s read through the flow of our architecture that we will be coding: 474 | 475 | 1. Input tokens are converted to embeddings and combined with position information. 476 | 477 | 2. The model has 64 identical transformer blocks that process data sequentially. 478 | 479 | 3. Each block first runs multi-head attention to look at relationships between tokens. 480 | 481 | 4. Each block then processes data through an MLP that expands and then compresses the data. 482 | 483 | 5. Each step uses residual connections (shortcuts) to help information flow. 484 | 485 | 6. Layer normalization is used throughout to stabilize training. 486 | 487 | 7. The attention mechanism calculates which tokens should pay attention to each other. 488 | 489 | 8. The MLP expands the data to 4x size, applies ReLU, and then compresses it back down. 490 | 491 | 9. The model uses 16 attention heads to capture different types of relationships. 492 | 493 | 10. The final layer converts the processed data into vocabulary-sized predictions. 494 | 495 | 11. The model generates text by repeatedly predicting the next most likely token. 496 | 497 | ### Multi Layer Perceptron (MLP) 498 | 499 | MLP is a fundamental building block within the transformer’s feed-forward network. Its role is to introduce non-linearity and learn complex relationships within the embedded representations. When defining an MLP module, an important parameter is n_embed, which defines the dimensionality of the input embedding. 500 | 501 | The MLP typically consists of a hidden linear layer that expands the input dimension by a factor (often 4, which we will use), followed by a non-linear activation function, commonly ReLU. This structure allows our network to learn more complex features. Finally, a projection linear layer maps the expanded representation back to the original embedding dimension. This sequence of transformations enables the MLP to refine the representations learned by the attention mechanism. 502 | 503 | ![MLP by [Fareed Khan](undefined)](https://cdn-images-1.medium.com/max/4866/1*GXxiLMW4kUXqOEimBA7g0A.png) 504 | 505 | ```python 506 | # --- MLP (Multi-Layer Perceptron) Class --- 507 | 508 | class MLP(nn.Module): 509 | """ 510 | A simple Multi-Layer Perceptron with one hidden layer. 511 | 512 | This module is used within the Transformer block for feed-forward processing. 513 | It expands the input embedding size, applies a ReLU activation, and then projects it back 514 | to the original embedding size. 515 | """ 516 | def __init__(self, n_embed): 517 | super().__init__() 518 | self.hidden = nn.Linear(n_embed, 4 * n_embed) # Linear layer to expand embedding size 519 | self.relu = nn.ReLU() # ReLU activation function 520 | self.proj = nn.Linear(4 * n_embed, n_embed) # Linear layer to project back to original size 521 | 522 | def forward(self, x): 523 | """ 524 | Forward pass through the MLP. 525 | 526 | Args: 527 | x (torch.Tensor): Input tensor of shape (B, T, C), where B is batch size, 528 | T is sequence length, and C is embedding size. 529 | 530 | Returns: 531 | torch.Tensor: Output tensor of the same shape as the input. 532 | """ 533 | x = self.forward_embedding(x) 534 | x = self.project_embedding(x) 535 | return x 536 | 537 | def forward_embedding(self, x): 538 | """ 539 | Applies the hidden linear layer followed by ReLU activation. 540 | 541 | Args: 542 | x (torch.Tensor): Input tensor. 543 | 544 | Returns: 545 | torch.Tensor: Output after the hidden layer and ReLU. 546 | """ 547 | x = self.relu(self.hidden(x)) 548 | return x 549 | 550 | def project_embedding(self, x): 551 | """ 552 | Applies the projection linear layer. 553 | 554 | Args: 555 | x (torch.Tensor): Input tensor. 556 | 557 | Returns: 558 | torch.Tensor: Output after the projection layer. 559 | """ 560 | x = self.proj(x) 561 | return x 562 | ``` 563 | 564 | We just coded our MLP part, where the __init__ method initializes a hidden linear layer that expands the input embedding size (n_embed) and a projection layer that reduces it back. ReLU activation is applied after the hidden layer. The forward method defines the data flow through these layers, applying the hidden layer and ReLU via forward_embedding, and the projection layer via project_embedding. 565 | 566 | ### Single Head Attention 567 | 568 | The attention head is the core part of our model. Its purpose is to focus on relevant parts of the input sequence. When defining a Head module, some important parameters are head_size, n_embed, and context_length. The head_size parameter determines the dimensionality of the key, query, and value projections, influencing the representational capacity of the attention mechanism. 569 | 570 | The input embedding dimension n_embed defines the size of the input to these projection layers. context_length is used to create a causal mask, ensuring that the model only attends to preceding tokens. 571 | 572 | Within the Head, linear layers (nn.Linear) for key, query, and value are initialized without bias. A lower triangular matrix (tril) of size context_length x context_length is registered as a buffer to implement causal masking, preventing the attention mechanism from attending to future tokens. 573 | 574 | ![Single Head Attention by [Fareed Khan](undefined)](https://cdn-images-1.medium.com/max/5470/1*teNwEhicq9ebVURiMS8WkA.png) 575 | 576 | ```python 577 | # --- Attention Head Class --- 578 | 579 | class Head(nn.Module): 580 | """ 581 | A single attention head. 582 | 583 | This module calculates attention scores and applies them to the values. 584 | It includes key, query, and value projections, and uses causal masking 585 | to prevent attending to future tokens. 586 | """ 587 | def __init__(self, head_size, n_embed, context_length): 588 | super().__init__() 589 | self.key = nn.Linear(n_embed, head_size, bias=False) # Key projection 590 | self.query = nn.Linear(n_embed, head_size, bias=False) # Query projection 591 | self.value = nn.Linear(n_embed, head_size, bias=False) # Value projection 592 | # Lower triangular matrix for causal masking 593 | self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length))) 594 | 595 | def forward(self, x): 596 | """ 597 | Forward pass through the attention head. 598 | 599 | Args: 600 | x (torch.Tensor): Input tensor of shape (B, T, C). 601 | 602 | Returns: 603 | torch.Tensor: Output tensor after applying attention. 604 | """ 605 | B, T, C = x.shape 606 | k = self.key(x) # (B, T, head_size) 607 | q = self.query(x) # (B, T, head_size) 608 | scale_factor = 1 / math.sqrt(C) 609 | # Calculate attention weights: (B, T, head_size) @ (B, head_size, T) -> (B, T, T) 610 | attn_weights = q @ k.transpose(-2, -1) * scale_factor 611 | # Apply causal masking 612 | attn_weights = attn_weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 613 | attn_weights = F.softmax(attn_weights, dim=-1) 614 | v = self.value(x) # (B, T, head_size) 615 | # Apply attention weights to values 616 | out = attn_weights @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size) 617 | return out 618 | ``` 619 | 620 | Our attention head class’s __init__ method initializes linear layers for key, query, and value projections, each projecting the input embedding (n_embed) to head_size. A lower triangular matrix based on context_length is used for causal masking. The forward method calculates attention weights by scaling the dot product of the query and key, applies the causal mask, normalizes the weights using softmax, and computes the weighted sum of the values to produce the attention output. 621 | 622 | ### Multi Head Attention 623 | 624 | To capture diverse relationships within the input sequence, we are going to use the concept of multi-head attention. The MultiHeadAttention module manages multiple independent attention heads operating in parallel. 625 | 626 | The key parameter here is n_head, which determines the number of parallel attention heads. The input embedding dimension (n_embed) and context_length are also necessary to instantiate the individual attention heads. Each head processes the input independently, projecting it into a lower-dimensional subspace of size n_embed // n_head. By having multiple heads, the model can attend to different aspects of the input simultaneously. 627 | 628 | ![Multi Head Attention by [Fareed Khan](undefined)](https://cdn-images-1.medium.com/max/6864/1*fa-YjrZdtbpuCLp7An99dg.png) 629 | 630 | ```python 631 | # --- Multi-Head Attention Class --- 632 | 633 | class MultiHeadAttention(nn.Module): 634 | """ 635 | Multi-Head Attention module. 636 | 637 | This module combines multiple attention heads in parallel. The outputs of each head 638 | are concatenated to form the final output. 639 | """ 640 | def __init__(self, n_head, n_embed, context_length): 641 | super().__init__() 642 | self.heads = nn.ModuleList([Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)]) 643 | 644 | def forward(self, x): 645 | """ 646 | Forward pass through the multi-head attention. 647 | 648 | Args: 649 | x (torch.Tensor): Input tensor of shape (B, T, C). 650 | 651 | Returns: 652 | torch.Tensor: Output tensor after concatenating the outputs of all heads. 653 | """ 654 | # Concatenate the output of each head along the last dimension (C) 655 | x = torch.cat([h(x) for h in self.heads], dim=-1) 656 | return x 657 | ``` 658 | 659 | Now that we have defined the MultiHeadAttention class, which combines multiple attention heads, the __init__ method initializes a list of Head instances (a total of n_head), each with a head_size of n_embed // n_head. The forward method applies each attention head to the input x and concatenates their outputs along the last dimension, merging the information learned by each head. 660 | 661 | ### Transformer Block 662 | 663 | To create a billion-parameter model, we definitely need a deep architecture. For that, we need to code a transformer block and stack them. The key parameters of a block are n_head, n_embed, and context_length. Each block comprises a multi-head attention layer and a feed-forward network (MLP), with layer normalization applied before each and residual connections after each. 664 | 665 | Layer normalization, parameterized by the embedding dimension n_embed, helps stabilize training. The multi-head attention mechanism, as described before, takes n_head, n_embed, and context_length. The MLP also utilizes the embedding dimension n_embed. These components work together to process the input and learn complex patterns. 666 | 667 | ![Transformer Block by [Fareed Khan](undefined)](https://cdn-images-1.medium.com/max/6942/1*uLWGajZc6StnQHfZjcb6eA.png) 668 | 669 | ```python 670 | # --- Transformer Block Class --- 671 | 672 | class Block(nn.Module): 673 | """ 674 | A single Transformer block. 675 | 676 | This block consists of a multi-head attention layer followed by an MLP, 677 | with layer normalization and residual connections. 678 | """ 679 | def __init__(self, n_head, n_embed, context_length): 680 | super().__init__() 681 | self.ln1 = nn.LayerNorm(n_embed) 682 | self.attn = MultiHeadAttention(n_head, n_embed, context_length) 683 | self.ln2 = nn.LayerNorm(n_embed) 684 | self.mlp = MLP(n_embed) 685 | 686 | def forward(self, x): 687 | """ 688 | Forward pass through the Transformer block. 689 | 690 | Args: 691 | x (torch.Tensor): Input tensor. 692 | 693 | Returns: 694 | torch.Tensor: Output tensor after the block. 695 | """ 696 | # Apply multi-head attention with residual connection 697 | x = x + self.attn(self.ln1(x)) 698 | # Apply MLP with residual connection 699 | x = x + self.mlp(self.ln2(x)) 700 | return x 701 | 702 | def forward_embedding(self, x): 703 | """ 704 | Forward pass focusing on the embedding and attention parts. 705 | 706 | Args: 707 | x (torch.Tensor): Input tensor. 708 | 709 | Returns: 710 | tuple: A tuple containing the output after MLP embedding and the residual. 711 | """ 712 | res = x + self.attn(self.ln1(x)) 713 | x = self.mlp.forward_embedding(self.ln2(res)) 714 | return x, res 715 | ``` 716 | 717 | Our Block class represents a single transformer block. The __init__ method initializes layer normalization layers (ln1, ln2), a MultiHeadAttention module, and an MLP module, all parameterized by n_head, n_embed, and context_length. 718 | 719 | The forward method implements the block's forward pass, applying layer normalization and multi-head attention with a residual connection, followed by another layer normalization and the MLP, again with a residual connection. The forward_embedding method provides an alternative forward pass focused on the attention and initial MLP embedding stages. 720 | 721 | ### The Final Model 722 | 723 | So far, we have coded small components of the transformer model. Next, we integrate token and position embeddings with a series of transformer blocks to perform sequence-to-sequence tasks. To do that, we need to code several key parameters: n_head, n_embed, context_length, vocab_size, and N_BLOCKS. 724 | 725 | vocab_size determines the size of the token embedding layer, mapping each token to a dense vector of size n_embed. The context_length parameter is important for the position embedding layer, which encodes the position of each token in the input sequence, also with dimension n_embed. The number of attention heads (n_head) and the number of blocks (N_BLOCKS) dictate the depth and complexity of the network. 726 | 727 | These parameters collectively define the architecture and capacity of the transformer model, so let’s code it. 728 | 729 | ![Transformer Class by [Fareed Khan](undefined)](https://cdn-images-1.medium.com/max/5418/1*0XXd_R2EOhkKCQDfqUQg0w.png) 730 | 731 | ```python 732 | # --- Transformer Model Class --- 733 | 734 | class Transformer(nn.Module): 735 | """ 736 | The main Transformer model. 737 | 738 | This class combines token and position embeddings with a sequence of Transformer blocks 739 | and a final linear layer for language modeling. 740 | """ 741 | def __init__(self, n_head, n_embed, context_length, vocab_size, N_BLOCKS): 742 | super().__init__() 743 | self.context_length = context_length 744 | self.N_BLOCKS = N_BLOCKS 745 | self.token_embed = nn.Embedding(vocab_size, n_embed) 746 | self.position_embed = nn.Embedding(context_length, n_embed) 747 | self.attn_blocks = nn.ModuleList([Block(n_head, n_embed, context_length) for _ in range(N_BLOCKS)]) 748 | self.layer_norm = nn.LayerNorm(n_embed) 749 | self.lm_head = nn.Linear(n_embed, vocab_size) 750 | self.register_buffer('pos_idxs', torch.arange(context_length)) 751 | 752 | def _pre_attn_pass(self, idx): 753 | """ 754 | Combines token and position embeddings. 755 | 756 | Args: 757 | idx (torch.Tensor): Input token indices. 758 | 759 | Returns: 760 | torch.Tensor: Sum of token and position embeddings. 761 | """ 762 | B, T = idx.shape 763 | tok_embedding = self.token_embed(idx) 764 | pos_embedding = self.position_embed(self.pos_idxs[:T]) 765 | return tok_embedding + pos_embedding 766 | 767 | def forward(self, idx, targets=None): 768 | """ 769 | Forward pass through the Transformer. 770 | 771 | Args: 772 | idx (torch.Tensor): Input token indices. 773 | targets (torch.Tensor, optional): Target token indices for loss calculation. Defaults to None. 774 | 775 | Returns: 776 | tuple: Logits and loss (if targets are provided). 777 | """ 778 | x = self._pre_attn_pass(idx) 779 | for block in self.attn_blocks: 780 | x = block(x) 781 | x = self.layer_norm(x) 782 | logits = self.lm_head(x) 783 | loss = None 784 | if targets is not None: 785 | B, T, C = logits.shape 786 | flat_logits = logits.view(B * T, C) 787 | targets = targets.view(B * T).long() 788 | loss = F.cross_entropy(flat_logits, targets) 789 | return logits, loss 790 | 791 | def forward_embedding(self, idx): 792 | """ 793 | Forward pass focusing on the embedding and attention blocks. 794 | 795 | Args: 796 | idx (torch.Tensor): Input token indices. 797 | 798 | Returns: 799 | tuple: Output after attention blocks and the residual. 800 | """ 801 | x = self._pre_attn_pass(idx) 802 | residual = x 803 | for block in self.attn_blocks: 804 | x, residual = block.forward_embedding(x) 805 | return x, residual 806 | 807 | def generate(self, idx, max_new_tokens): 808 | """ 809 | Generates new tokens given a starting sequence. 810 | 811 | Args: 812 | idx (torch.Tensor): Initial sequence of token indices. 813 | max_new_tokens (int): Number of tokens to generate. 814 | 815 | Returns: 816 | torch.Tensor: The extended sequence of tokens. 817 | """ 818 | for _ in range(max_new_tokens): 819 | idx_cond = idx[:, -self.context_length:] 820 | logits, _ = self(idx_cond) 821 | logits = logits[:, -1, :] 822 | probs = F.softmax(logits, dim=-1) 823 | idx_next = torch.multinomial(probs, num_samples=1) 824 | idx = torch.cat((idx, idx_next), dim=1) 825 | return idx 826 | ``` 827 | 828 | Our Transformer class `__init__` method initializes token and position embedding layers (token_embed, position_embed), a sequence of Block modules (attn_blocks), a final layer normalization layer (layer_norm), and a linear layer for language modeling (lm_head). 829 | 830 | The _pre_attn_pass method combines token and position embeddings. The forward method processes the input sequence through the embedding layers and the series of transformer blocks, applies final layer normalization, and generates logits. It also calculates the loss if targets are provided. The forward_embedding method provides an intermediate forward pass up to the output of the attention blocks, and the generate method implements token generation. 831 | 832 | ### Batch Processing 833 | 834 | When we train a deep learning model on big data, we process it in batches due to GPU availability. So, let’s create a get_batch_iterator function, taking the data_path to an HDF5 file, the desired batch_size, the context_length for each sequence, and the device to load the data onto. 835 | 836 | The batch_size determines how many sequences are processed in parallel during training, while the context_length specifies the length of each input sequence. The data_path points to the location of the training data. 837 | 838 | ```python 839 | # --- Data Loading Utility --- 840 | 841 | def get_batch_iterator(data_path, batch_size, context_length, device="gpu"): 842 | """ 843 | Creates an iterator for generating batches of data from an HDF5 file. 844 | 845 | Args: 846 | data_path (str): Path to the HDF5 file containing tokenized data. 847 | batch_size (int): Number of sequences in each batch. 848 | context_length (int): Length of each sequence. 849 | device (str, optional): Device to load the data onto ('cpu' or 'cuda'). Defaults to "cpu". 850 | 851 | Yields: 852 | tuple: A tuple containing input sequences (xb) and target sequences (yb). 853 | """ 854 | # Open the HDF5 file in read mode 855 | with h5py.File(data_path, 'r') as hdf5_file: 856 | 857 | # Extract the dataset of tokenized sequences 858 | dataset = hdf5_file['tokens'] 859 | 860 | # Get the total size of the dataset 861 | dataset_size = dataset.shape[0] 862 | 863 | # Calculate the number of examples (sequences) that can be made from the data 864 | n_examples = (dataset_size - 1) // context_length 865 | 866 | # Create an array of indices for examples and shuffle them for randomness 867 | example_idxs = np.arange(n_examples) 868 | np.random.shuffle(example_idxs) 869 | 870 | # Initialize epoch counter and example counter 871 | epochs = 0 872 | counter = 0 873 | 874 | while True: 875 | # Check if the current batch exceeds the number of available examples 876 | if counter + batch_size > n_examples: 877 | # Shuffle the indices again and reset the counter to 0 878 | np.random.shuffle(example_idxs) 879 | counter = 0 880 | print(f"Finished epoch {epochs}") # Print epoch number when an epoch finishes 881 | epochs += 1 # Increment the epoch counter 882 | 883 | # Select a batch of random indices to generate sequences 884 | random_indices = example_idxs[counter:counter+batch_size] * context_length 885 | 886 | # Retrieve sequences from the dataset based on the random indices 887 | random_samples = torch.tensor(np.array([dataset[idx:idx+context_length+1] for idx in random_indices])) 888 | 889 | # Separate the input sequences (xb) and target sequences (yb) 890 | xb = random_samples[:, :context_length].to(device) # Input sequence (first half of the random sample) 891 | yb = random_samples[:, 1:context_length+1].to(device) # Target sequence (second half of the random sample) 892 | 893 | # Increment the counter to move to the next batch 894 | counter += batch_size 895 | 896 | # Yield the input and target sequences as a tuple for the current batch 897 | yield xb, yb 898 | ``` 899 | Our get_batch_iterator function handles the loading and batching of training data. It takes data_path, batch_size, context_length, and device as input. The function opens the HDF5 file, shuffles the data, and then enters an infinite loop to generate batches. In each iteration, it selects a random subset of the data to form a batch of input sequences (xb) and their corresponding target sequences (yb). 900 | 901 | ### Training Parameters 902 | 903 | Now that we have coded our model, we need to define the training parameters, such as the number of heads, blocks, and more, along with the data path. 904 | 905 | ```python 906 | # --- Configuration --- 907 | 908 | # Define vocabulary size and transformer configuration 909 | VOCAB_SIZE = 50304 # Number of unique tokens in the vocabulary 910 | CONTEXT_LENGTH = 512 # Maximum sequence length for the model 911 | N_EMBED = 2048 # Dimension of the embedding space 912 | N_HEAD = 16 # Number of attention heads in each transformer block 913 | N_BLOCKS = 64 # Number of transformer blocks in the model 914 | 915 | # Paths to training and development datasets 916 | TRAIN_PATH = "data/train/pile_val.h5" # File path for the training dataset 917 | DEV_PATH = "data/val/pile_val.h5" # File path for the validation dataset 918 | 919 | # Transformer training parameters 920 | T_BATCH_SIZE = 32 # Number of samples per training batch 921 | T_CONTEXT_LENGTH = 16 # Context length for training batches 922 | T_TRAIN_STEPS = 200000 # Total number of training steps 923 | T_EVAL_STEPS = 1000 # Frequency (in steps) to perform evaluation 924 | T_EVAL_ITERS = 250 # Number of iterations to evaluate the model 925 | T_LR_DECAY_STEP = 50000 # Step at which to decay the learning rate 926 | T_LR = 5e-4 # Initial learning rate for training 927 | T_LR_DECAYED = 5e-5 # Learning rate after decay 928 | T_OUT_PATH = "models/transformer_B.pt" # Path to save the trained model 929 | 930 | # Device configuration 931 | DEVICE = 'cuda' 932 | 933 | # Store all configurations in a dictionary for easy access and modification 934 | default_config = { 935 | 'vocab_size': VOCAB_SIZE, 936 | 'context_length': CONTEXT_LENGTH, 937 | 'n_embed': N_EMBED, 938 | 'n_head': N_HEAD, 939 | 'n_blocks': N_BLOCKS, 940 | 'train_path': TRAIN_PATH, 941 | 'dev_path': DEV_PATH, 942 | 't_batch_size': T_BATCH_SIZE, 943 | 't_context_length': T_CONTEXT_LENGTH, 944 | 't_train_steps': T_TRAIN_STEPS, 945 | 't_eval_steps': T_EVAL_STEPS, 946 | 't_eval_iters': T_EVAL_ITERS, 947 | 't_lr_decay_step': T_LR_DECAY_STEP, 948 | 't_lr': T_LR, 949 | 't_lr_decayed': T_LR_DECAYED, 950 | 't_out_path': T_OUT_PATH, 951 | 'device': DEVICE, 952 | } 953 | ``` 954 | 955 | For most of the parameters, I have used the most common values and also stored them in a dictionary for easy access. Here, the parameters are for a billion-parameter model. If you want to train a model with millions of parameters, you can reduce the main parameters, which include CONTEXT_LENGTH, N_EMBED, N_HEAD, and N_BLOCKS. However, you can also run the million-parameter model script in my GitHub repository. 956 | 957 | ### Training the Model 958 | 959 | Let's initialize our transformer model and check its total number of parameters. 960 | ```python 961 | # --- Initialize the Model and Print Parameters --- 962 | 963 | model = Transformer( 964 | n_head=config['n_head'], 965 | n_embed=config['n_embed'], 966 | context_length=config['context_length'], 967 | vocab_size=config['vocab_size'], 968 | N_BLOCKS=config['n_blocks'] 969 | ).to(config['device']) 970 | 971 | 972 | # Print the total number of parameters 973 | total_params = sum(p.numel() for p in model.parameters()) 974 | print(f"Total number of parameters in the model: {total_params:,}") 975 | 976 | 977 | #### OUTPUT #### 978 | 2,141,346,251 979 | ``` 980 | 981 | Now that we have 2 Billion parameter model, we need to define our Adam optimizer and loss tracking function, which will help us track the progress of our model throughout the training. 982 | 983 | ```python 984 | # --- Optimizer Setup and Loss Tracking --- 985 | 986 | # Set up the AdamW optimizer with the specified learning rate. 987 | optimizer = torch.optim.AdamW(model.parameters(), lr=config['t_lr']) 988 | 989 | # List to track loss values during training. 990 | losses = [] 991 | 992 | # Define a window size for averaging recent losses in the training loop. 993 | AVG_WINDOW = 64 994 | 995 | # Helper function to estimate the average loss for training and development data. 996 | @torch.no_grad() 997 | def estimate_loss(steps): 998 | """ 999 | Evaluate the model on training and development datasets and calculate average loss. 1000 | 1001 | Args: 1002 | steps (int): Number of steps to evaluate. 1003 | 1004 | Returns: 1005 | dict: Dictionary containing average losses for 'train' and 'dev' splits. 1006 | """ 1007 | out = {} 1008 | model.eval() # Set the model to evaluation mode. 1009 | 1010 | for split in ['train', 'dev']: 1011 | # Select the appropriate data path for the current split. 1012 | data_path = config['train_path'] if split == 'train' else config['dev_path'] 1013 | 1014 | # Create a batch iterator for evaluation. 1015 | batch_iterator_eval = get_batch_iterator( 1016 | data_path, config['t_batch_size'], config['t_context_length'], device=config['device'] 1017 | ) 1018 | 1019 | # Initialize a tensor to track loss values for each evaluation step. 1020 | losses_eval = torch.zeros(steps) 1021 | for k in range(steps): 1022 | try: 1023 | # Fetch a batch and calculate the loss. 1024 | xb, yb = next(batch_iterator_eval) 1025 | _, loss = model(xb, yb) 1026 | losses_eval[k] = loss.item() 1027 | except StopIteration: 1028 | # Handle the case where the data iterator ends early. 1029 | print(f"Warning: Iterator for {split} ended early.") 1030 | break 1031 | 1032 | # Compute the mean loss for the current split. 1033 | out[split] = losses_eval[:k + 1].mean() 1034 | 1035 | model.train() # Restore the model to training mode. 1036 | return out 1037 | ``` 1038 | 1039 | We will now initialize our batch processing function and training loop, which will start our training. 1040 | 1041 | ```python 1042 | # --- Training Loop --- 1043 | 1044 | # Create a batch iterator for the training data. 1045 | batch_iterator = get_batch_iterator( 1046 | config['train_path'], 1047 | config['t_batch_size'], 1048 | config['t_context_length'], 1049 | device=config['device'] 1050 | ) 1051 | 1052 | # Create a progress bar to monitor training progress. 1053 | pbar = tqdm(range(config['t_train_steps'])) 1054 | for step in pbar: 1055 | try: 1056 | # Fetch a batch of input and target data. 1057 | xb, yb = next(batch_iterator) 1058 | 1059 | # Perform a forward pass and compute the loss. 1060 | _, loss = model(xb, yb) 1061 | 1062 | # Record the loss for tracking. 1063 | losses.append(loss.item()) 1064 | pbar.set_description(f"Train loss: {np.mean(losses[-AVG_WINDOW:]):.4f}") 1065 | 1066 | # Backpropagate the loss and update the model parameters. 1067 | optimizer.zero_grad(set_to_none=True) 1068 | loss.backward() 1069 | optimizer.step() 1070 | 1071 | # Periodically evaluate the model on training and development data. 1072 | if step % config['t_eval_steps'] == 0: 1073 | train_loss, dev_loss = estimate_loss(config['t_eval_iters']).values() 1074 | print(f"Step: {step}, Train loss: {train_loss:.4f}, Dev loss: {dev_loss:.4f}") 1075 | 1076 | # Decay the learning rate at the specified step. 1077 | if step == config['t_lr_decay_step']: 1078 | print('Decaying learning rate') 1079 | for g in optimizer.param_groups: 1080 | g['lr'] = config['t_lr_decayed'] 1081 | except StopIteration: 1082 | # Handle the case where the training data iterator ends early. 1083 | print("Training data iterator finished early.") 1084 | break 1085 | ``` 1086 | ### Saving the Trained Model 1087 | 1088 | Since our training loop has the ability to handle errors, in case the loop throws any error, it will save our partially trained model to avoid loss. Once the training is complete, we can save our trained model to use it later for inference. 1089 | 1090 | ```python 1091 | # --- Save Model and Final Evaluation --- 1092 | 1093 | # Perform a final evaluation of the model on training and development datasets. 1094 | train_loss, dev_loss = estimate_loss(200).values() 1095 | 1096 | # Ensure unique model save path in case the file already exists. 1097 | modified_model_out_path = config['t_out_path'] 1098 | save_tries = 0 1099 | while os.path.exists(modified_model_out_path): 1100 | save_tries += 1 1101 | model_out_name = os.path.splitext(config['t_out_path'])[0] 1102 | modified_model_out_path = model_out_name + f"_{save_tries}" + ".pt" 1103 | 1104 | # Save the model's state dictionary, optimizer state, and training metadata. 1105 | torch.save( 1106 | { 1107 | 'model_state_dict': model.state_dict(), 1108 | 'optimizer_state_dict': optimizer.state_dict(), 1109 | 'losses': losses, 1110 | 'train_loss': train_loss, 1111 | 'dev_loss': dev_loss, 1112 | 'steps': len(losses), 1113 | }, 1114 | modified_model_out_path 1115 | ) 1116 | print(f"Saved model to {modified_model_out_path}") 1117 | print(f"Finished training. Train loss: {train_loss:.4f}, Dev loss: {dev_loss:.4f}") 1118 | ``` 1119 | The final training loss for the billion-parameter model is 0.2314, and the dev loss is 0.643. 1120 | 1121 | ### Training Loss 1122 | 1123 | When I plot the loss of both the million- and billion-parameter models, they look very different. 1124 | 1125 | ![Training Loss Comparison](https://cdn-images-1.medium.com/max/6696/1*8Gl7cEbainB4GRVwL3cc7Q.png) 1126 | 1127 | The billion-parameter model starts with a much higher loss and fluctuates a lot at the beginning. It goes down quickly at first, but then wobbles before becoming smoother. This shows that the bigger model has a harder time finding the right way to learn at the start. It might need more data and careful settings. When the learning rate is lowered (the red line), the loss goes down more steadily, showing that this helps it fine-tune. 1128 | 1129 | The million-parameter model’s loss goes down more easily from the start. It doesn’t fluctuate as much as the bigger model. When the learning rate is lowered, it doesn’t change the curve as much. This is likely because the smaller model is simpler to train and finds a good solution faster. The big difference shows how much harder it is to train very large models. They need different methods and maybe more time to learn well. 1130 | 1131 | We now have our saved model. We can finally use it for inference and see how it generates text. 😓 1132 | 1133 | ### Generating Text 1134 | 1135 | Let’s create a function to generate text from our saved model, which takes the saved model path and the encoder as inputs and returns the generated text. 1136 | 1137 | ```python 1138 | def generate_text(model_path, input_text, max_length=512, device="gpu"): 1139 | """ 1140 | Generate text using a pre-trained model based on the given input text. 1141 | 1142 | Args: 1143 | - model_path (str): Path to the model checkpoint. 1144 | - device (torch.device): Device to load the model on (e.g., 'cpu' or 'cuda'). 1145 | - input_text (str): The input text to seed the generation. 1146 | - max_length (int, optional): Maximum length of generated text. Defaults to 512. 1147 | 1148 | Returns: 1149 | - str: The generated text. 1150 | """ 1151 | 1152 | # Load the model checkpoint 1153 | checkpoint = torch.load(model_path) 1154 | 1155 | # Initialize the model (you should ensure that the Transformer class is defined elsewhere) 1156 | model = Transformer().to(device) 1157 | 1158 | # Load the model's state dictionary 1159 | model.load_state_dict(checkpoint['model_state_dict']) 1160 | 1161 | # Load the tokenizer for the GPT model (we use 'r50k_base' for GPT models) 1162 | enc = tiktoken.get_encoding('r50k_base') 1163 | 1164 | # Encode the input text along with the end-of-text token 1165 | input_ids = torch.tensor( 1166 | enc.encode(input_text, allowed_special={'<|endoftext|>'}), 1167 | dtype=torch.long 1168 | )[None, :].to(device) # Add batch dimension and move to the specified device 1169 | 1170 | # Generate text with the model using the encoded input 1171 | with torch.no_grad(): 1172 | # Generate up to 'max_length' tokens of text 1173 | generated_output = model.generate(input_ids, max_length) 1174 | 1175 | # Decode the generated tokens back into text 1176 | generated_text = enc.decode(generated_output[0].tolist()) 1177 | 1178 | return generated_text 1179 | ``` 1180 | 1181 | The transformer we defined earlier needs to be called here to load the architecture, and then we load the saved model as the state in that architecture. 1182 | 1183 | Let’s first observe what both the million and billion-parameter models generate without providing any input, and see what they generate randomly. 1184 | 1185 | ```python 1186 | # Defining the file paths for the pre-trained models 1187 | Billion_model_path = 'models/transformer_B.pt' # Path to the Billion model 1188 | Million_model_path = 'models/transformer_M.pt' # Path to the Million model 1189 | 1190 | # Using '<|endoftext|>' as input to the models (acts as a prompt that allows the models to generate text freely) 1191 | input_text = "<|endoftext|>" 1192 | 1193 | # Call the function to generate text based on the input text using the Billion model 1194 | B_output = generate_text(Billion_model_path, input_text) 1195 | 1196 | # Call the function to generate text based on the input text using the Million model 1197 | M_output = generate_text(Million_model_path, input_text) 1198 | 1199 | # Print the output generated by both models 1200 | print(B_output) # Output from the Billion model 1201 | print(M_output) # Output from the Million model 1202 | ``` 1203 | 1204 | | **Million Parameter Output** | **Billion Parameter Output** | 1205 | |------------------------------|------------------------------| 1206 | | In 1978, The park was returned to the factory-plate that the public share to the lower of the electronic fence that follow from the Station's cities. The Canal of ancient Western nations were confined to the city spot. The villages were directly linked to cities in China that revolt that the US budget and in Odambinais is uncertain and fortune established in rural areas. | There are two miles east coast from 1037 and 73 million refugees (hypotetus) as the same men and defeated Harvard, and Croft. At right east and West Nile's Mediterranean Sea jets. It was found there a number of parties, blacksmith, musician and boutique hospitality and inspire the strain delivered Canadians have already killed, rural branches with coalition railholder against Abyssy. | 1207 | 1208 | 1209 | Both LLMs are able to generate clear and accurate words when the context is short and simple. For example, in the million-parameter output, the phrase **“The villages were directly linked to cities in China”** makes sense and conveys a clear idea. It is easy to understand and logically connects the villages to the cities. 1210 | 1211 | However, when the context becomes longer and more complex, the clarity begins to fade. In the billion-parameter output, sentences like **“There are two miles east coast from 1037 and 73 million refugees (hypotetus)”** and **“blacksmith, musician and boutique hospitality and inspire the strain delivered Canadians”** become harder to follow. The ideas seem disjointed, and the sentence structure doesn’t flow naturally. While the words used might still be correct, the overall meaning becomes confusing and unclear. 1212 | 1213 | The positive point is that the 13+ million-parameter LLM also starts generating some kind of meaningful content with correct word spelling. For instance, when I use the subject input text, it starts generating an email for me. Although, obviously, broader text doesn’t provide meaningful results, take a look at the output: 1214 | 1215 | ```python 1216 | # Input text 1217 | input_text "Subject: " 1218 | 1219 | # Call the Million parameter Mod 1220 | m_output = generate_text(Million_model_path, input_text) 1221 | 1222 | print(m_output) # Output from the Million model 1223 | ``` 1224 | | **Million Parameter LLM Output** | 1225 | |--------------------------------------------------------------------------------------------------| 1226 | | Subject: ClickPaper-summary Study for Interview
Good morning, I hope this message finds you well, as the sun gently peeks through the clouds, ... | 1227 | 1228 | Our million parameter model gives us the motivation that we can have a very narrow, goal-oriented LLM under 1B in size, while our 1B trained model shows us that the architecture needs to be coded in great depth with proper consideration. Otherwise, it won’t improve training or performance compared to the million-parameter model. It will just overfit the data unless you have a deep architecture for the billion-sized model. 1229 | 1230 | # What’s Next 1231 | 1232 | I recommend that you create the 13+ million-parameter model and then start scaling it by adding the next 100 parameters, improving its ability to handle shorter contexts. It’s up to you how many more parameters you want to train for specific tasks. Then, for the remaining parameters under 1B, try fine-tuning the model on domain-specific data, such as writing emails or essays, and see how it generates the text. 1233 | 1234 |
1235 | 1236 | Wanna chat on something? [My Linkedin](https://www.linkedin.com/in/fareed-khan-dev/) 1237 | 1238 | ## Star History 1239 | 1240 | [![](https://api.star-history.com/svg?repos=FareedKhan-dev/train-llm-from-scratch&type=Date)](https://star-history.com/#FareedKhan-dev/train-llm-from-scratch&Date) 1241 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | # --- Configuration --- 2 | 3 | # Define vocabulary size and transformer configuration (3 Billion) 4 | VOCAB_SIZE = 50304 # Number of unique tokens in the vocabulary 5 | CONTEXT_LENGTH = 512 # Maximum sequence length for the model 6 | N_EMBED = 2048 # Dimension of the embedding space 7 | N_HEAD = 16 # Number of attention heads in each transformer block 8 | N_BLOCKS = 64 # Number of transformer blocks in the model 9 | 10 | # Paths to training and development datasets 11 | TRAIN_PATH = "data/train/pile_train.h5" # File path for the training dataset 12 | DEV_PATH = "data/val/pile_dev.h5" # File path for the validation dataset 13 | 14 | # Transformer training parameters 15 | T_BATCH_SIZE = 32 # Number of samples per training batch 16 | T_CONTEXT_LENGTH = 16 # Context length for training batches 17 | T_TRAIN_STEPS = 200000 # Total number of training steps 18 | T_EVAL_STEPS = 1000 # Frequency (in steps) to perform evaluation 19 | T_EVAL_ITERS = 250 # Number of iterations to evaluate the model 20 | T_LR_DECAY_STEP = 50000 # Step at which to decay the learning rate 21 | T_LR = 5e-4 # Initial learning rate for training 22 | T_LR_DECAYED = 5e-5 # Learning rate after decay 23 | T_OUT_PATH = "models/transformer_B.pt" # Path to save the trained model 24 | 25 | # Device configuration 26 | DEVICE = 'cuda' 27 | 28 | # Store all configurations in a dictionary for easy access and modification 29 | default_config = { 30 | 'vocab_size': VOCAB_SIZE, 31 | 'context_length': CONTEXT_LENGTH, 32 | 'n_embed': N_EMBED, 33 | 'n_head': N_HEAD, 34 | 'n_blocks': N_BLOCKS, 35 | 'train_path': TRAIN_PATH, 36 | 'dev_path': DEV_PATH, 37 | 't_batch_size': T_BATCH_SIZE, 38 | 't_context_length': T_CONTEXT_LENGTH, 39 | 't_train_steps': T_TRAIN_STEPS, 40 | 't_eval_steps': T_EVAL_STEPS, 41 | 't_eval_iters': T_EVAL_ITERS, 42 | 't_lr_decay_step': T_LR_DECAY_STEP, 43 | 't_lr': T_LR, 44 | 't_lr_decayed': T_LR_DECAYED, 45 | 't_out_path': T_OUT_PATH, 46 | 'device': DEVICE, 47 | } -------------------------------------------------------------------------------- /data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import h5py 4 | from typing import Iterator, Tuple 5 | 6 | def get_batch_iterator(data_path: str, batch_size: int, context_length: int, device: str = "cpu") -> Iterator[Tuple[torch.Tensor, torch.Tensor]]: 7 | """ 8 | Creates an iterator for generating batches of data from an HDF5 file. 9 | 10 | Args: 11 | data_path (str): Path to the HDF5 file containing tokenized data. 12 | batch_size (int): Number of sequences in each batch. 13 | context_length (int): Length of each sequence. 14 | device (str, optional): Device to load the data onto ('cpu' or 'cuda'). Defaults to "cpu". 15 | 16 | Yields: 17 | tuple: A tuple containing input sequences (xb) and target sequences (yb). 18 | """ 19 | # Open the HDF5 file in read mode 20 | with h5py.File(data_path, 'r') as hdf5_file: 21 | 22 | # Extract the dataset of tokenized sequences 23 | dataset = hdf5_file['tokens'] 24 | 25 | # Get the total size of the dataset 26 | dataset_size = dataset.shape[0] 27 | 28 | # Calculate the number of examples (sequences) that can be made from the data 29 | n_examples = (dataset_size - 1) // context_length 30 | 31 | # Create an array of indices for examples and shuffle them for randomness 32 | example_idxs = np.arange(n_examples) 33 | np.random.shuffle(example_idxs) 34 | 35 | # Initialize epoch counter and example counter 36 | epochs = 0 37 | counter = 0 38 | 39 | while True: 40 | # Check if the current batch exceeds the number of available examples 41 | if counter + batch_size > n_examples: 42 | # Shuffle the indices again and reset the counter to 0 43 | np.random.shuffle(example_idxs) 44 | counter = 0 45 | print(f"Finished epoch {epochs}") # Print epoch number when an epoch finishes 46 | epochs += 1 # Increment the epoch counter 47 | 48 | # Select a batch of random indices to generate sequences 49 | random_indices = example_idxs[counter:counter+batch_size] * context_length 50 | 51 | # Retrieve sequences from the dataset based on the random indices 52 | random_samples = torch.tensor(np.array([dataset[idx:idx+context_length+1] for idx in random_indices])) 53 | 54 | # Separate the input sequences (xb) and target sequences (yb) 55 | xb = random_samples[:, :context_length].to(device) # Input sequence (first half of the random sample) 56 | yb = random_samples[:, 1:context_length+1].to(device) # Target sequence (second half of the random sample) 57 | 58 | # Increment the counter to move to the next batch 59 | counter += batch_size 60 | 61 | # Yield the input and target sequences as a tuple for the current batch 62 | yield xb, yb 63 | 64 | if __name__ == '__main__': 65 | # Example Usage (requires a dummy HDF5 file for testing) 66 | # Create a dummy HDF5 file 67 | import os 68 | dummy_data_path = "dummy_data.h5" 69 | if not os.path.exists(dummy_data_path): 70 | with h5py.File(dummy_data_path, 'w') as f: 71 | f.create_dataset('tokens', data=np.arange(1000)) 72 | 73 | batch_size = 4 74 | context_length = 10 75 | for xb, yb in get_batch_iterator(dummy_data_path, batch_size, context_length): 76 | print("Input Batch Shape:", xb.shape) 77 | print("Target Batch Shape:", yb.shape) 78 | break -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | torch 3 | torchvision 4 | torchaudio 5 | numpy 6 | h5py 7 | requests 8 | tqdm 9 | zstandard 10 | tiktoken -------------------------------------------------------------------------------- /scripts/data_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import requests 4 | from tqdm import tqdm 5 | from typing import List 6 | 7 | # Base URL for the dataset files 8 | BASE_URL = "https://huggingface.co/datasets/monology/pile-uncopyrighted/resolve/main" 9 | VAL_URL = f"{BASE_URL}/val.jsonl.zst" # URL for the validation dataset 10 | TRAIN_URLS = [f"{BASE_URL}/train/{i:02d}.jsonl.zst" for i in range(65)] # URLs for 65 training files (adjust the range if needed) 11 | 12 | def download_file(url: str, file_name: str) -> None: 13 | """ 14 | Downloads a file from the given URL and saves it with the specified file name. 15 | Displays a progress bar using tqdm. 16 | 17 | Args: 18 | url (str): The URL of the file to download. 19 | file_name (str): The local path where the file will be saved. 20 | """ 21 | print(f"Downloading: {file_name}...") 22 | response = requests.get(url, stream=True) # Stream the file content 23 | total_size = int(response.headers.get('content-length', 0)) # Get total file size if available 24 | block_size = 1024 # Size of each block for the progress bar 25 | with open(file_name, 'wb') as f: # Open file for writing in binary mode 26 | for chunk in tqdm(response.iter_content(block_size), total=total_size // block_size, desc="Downloading", leave=True): 27 | f.write(chunk) # Write each chunk to the file 28 | 29 | def download_dataset(val_url: str, train_urls: List[str], val_dir: str, train_dir: str, max_train_files: int) -> None: 30 | """ 31 | Manages downloading of the dataset, including both validation and training files. 32 | 33 | Args: 34 | val_url (str): URL for the validation dataset. 35 | train_urls (list): List of URLs for the training dataset files. 36 | val_dir (str): Directory where the validation file will be stored. 37 | train_dir (str): Directory where the training files will be stored. 38 | max_train_files (int): Maximum number of training files to download. 39 | """ 40 | # Define the path for the validation file 41 | val_file_path = os.path.join(val_dir, "val.jsonl.zst") 42 | if not os.path.exists(val_file_path): # Check if the validation file already exists 43 | print(f"Validation file not found. Downloading from {val_url}...") 44 | download_file(val_url, val_file_path) # Download the validation file 45 | else: 46 | print("Validation data already present. Skipping download.") 47 | 48 | # Loop through the training file URLs and download if not already present 49 | for idx, url in enumerate(train_urls[:max_train_files]): # Limit to max_train_files 50 | file_name = f"{idx:02d}.jsonl.zst" # Format file name (e.g., 00.jsonl.zst) 51 | file_path = os.path.join(train_dir, file_name) # Construct the full file path 52 | if not os.path.exists(file_path): # Check if the file already exists 53 | print(f"Training file {file_name} not found. Downloading...") 54 | download_file(url, file_path) # Download the training file 55 | else: 56 | print(f"Training file {file_name} already present. Skipping download.") 57 | 58 | def main() -> None: 59 | """ 60 | Main function to parse arguments and orchestrate the dataset download process. 61 | """ 62 | # Parse command-line arguments using argparse 63 | parser = argparse.ArgumentParser(description="Download PILE dataset.") # Description of the script 64 | parser.add_argument('--train_max', type=int, default=1, help="Max number of training files to download.") # Max training files 65 | parser.add_argument('--train_dir', default="data/train", help="Directory for storing training data.") # Training directory 66 | parser.add_argument('--val_dir', default="data/val", help="Directory for storing validation data.") # Validation directory 67 | 68 | args = parser.parse_args() # Parse the arguments provided by the user 69 | 70 | # Ensure directories for training and validation data exist 71 | os.makedirs(args.train_dir, exist_ok=True) # Create training directory if it doesn't exist 72 | os.makedirs(args.val_dir, exist_ok=True) # Create validation directory if it doesn't exist 73 | 74 | # Start downloading the dataset 75 | download_dataset(VAL_URL, TRAIN_URLS, args.val_dir, args.train_dir, args.train_max) 76 | 77 | print("Dataset downloaded successfully.") # Indicate successful download 78 | 79 | if __name__ == "__main__": 80 | # Entry point of the script 81 | main() -------------------------------------------------------------------------------- /scripts/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import zstandard as zstd 4 | import tiktoken 5 | import h5py 6 | from tqdm import tqdm 7 | import argparse 8 | from typing import Optional 9 | 10 | def process_files(input_dir: str, output_file: str, tokenizer_name: str, max_data: Optional[int] = None) -> None: 11 | """ 12 | Process a specified number of lines from each .jsonl.zst file in the input directory 13 | and save encoded tokens to an HDF5 file. 14 | 15 | Args: 16 | input_dir (str): Directory containing input .jsonl.zst files. 17 | output_file (str): Path to the output HDF5 file. 18 | tokenizer_name (str): Name of the tiktoken tokenizer to use (e.g., 'r50k_base'). 19 | max_data (int, optional): Maximum number of lines to process from each file. 20 | If None, process all lines. 21 | """ 22 | # Print processing strategy based on max_data 23 | if max_data is not None: 24 | print(f"You have chosen max_data = {max_data}. Processing only the top {max_data} JSON objects from each file.") 25 | else: 26 | print("Processing all available JSON objects from each file.") 27 | 28 | # Load the tokenizer using the provided tokenizer name 29 | enc = tiktoken.get_encoding(tokenizer_name) 30 | 31 | # Create an HDF5 file for output 32 | with h5py.File(output_file, 'w') as out_f: 33 | # Initialize the dataset for storing tokenized data 34 | dataset = out_f.create_dataset('tokens', (0,), maxshape=(None,), dtype='i') 35 | start_index = 0 # Track the starting index for the next batch of tokens 36 | 37 | # Process each .jsonl.zst file in the input directory 38 | for filename in sorted(os.listdir(input_dir)): 39 | if filename.endswith(".jsonl.zst"): # Only process .jsonl.zst files 40 | in_file = os.path.join(input_dir, filename) 41 | print(f"Processing: {in_file}") 42 | 43 | processed_lines = 0 # Counter for processed lines in the current file 44 | 45 | # Open the compressed .jsonl.zst file for reading 46 | with zstd.open(in_file, 'rt', encoding='utf-8') as in_f: 47 | # Iterate over each line in the file 48 | for line in tqdm(in_f, desc=f"Processing {filename}", total=max_data if max_data is not None else None): 49 | try: 50 | # Parse the line as JSON 51 | data = json.loads(line) 52 | text = data.get('text') # Extract the 'text' field from the JSON object 53 | 54 | if text: 55 | # Tokenize the text and append an end-of-text token 56 | encoded = enc.encode(text + "<|endoftext|>", allowed_special={'<|endoftext|>'}) 57 | encoded_len = len(encoded) 58 | 59 | # Resize the dataset to accommodate new tokens 60 | end_index = start_index + encoded_len 61 | dataset.resize(dataset.shape[0] + encoded_len, axis=0) 62 | 63 | # Store the encoded tokens in the dataset 64 | dataset[start_index:end_index] = encoded 65 | start_index = end_index # Update the start index 66 | else: 67 | # Warn if 'text' key is missing in the JSON object 68 | print(f"Warning: 'text' key missing in line from {filename}") 69 | except json.JSONDecodeError: 70 | # Handle JSON decoding errors 71 | print(f"Warning: Could not decode JSON from line in {filename}") 72 | except Exception as e: 73 | # Handle any other errors 74 | print(f"An error occurred while processing line in {filename}: {e}") 75 | 76 | processed_lines += 1 77 | # Stop processing if max_data limit is reached 78 | if max_data is not None and processed_lines >= max_data: 79 | break 80 | 81 | def main(): 82 | """ 83 | Main function to parse arguments, validate directories, and process files. 84 | """ 85 | # Parse command-line arguments 86 | parser = argparse.ArgumentParser(description="Preprocess PILE dataset files and save tokens to HDF5.") 87 | parser.add_argument("--train_dir", type=str, default="data/train", help="Directory containing training .jsonl.zst files.") 88 | parser.add_argument("--val_dir", type=str, default="data/val", help="Directory containing validation .jsonl.zst files.") 89 | parser.add_argument("--out_train_file", type=str, default="data/train/pile_train.h5", help="Path to the output training HDF5 file.") 90 | parser.add_argument("--out_val_file", type=str, default="data/val/pile_dev.h5", help="Path to the output validation HDF5 file.") 91 | parser.add_argument("--tokenizer_name", type=str, default="r50k_base", help="Name of the tiktoken tokenizer to use.") 92 | parser.add_argument("--max_data", type=int, default=1000, help="Maximum number of json objects to process from each file in both train and val datasets (default: 1000).") 93 | 94 | args = parser.parse_args() 95 | 96 | # Validate the existence of the training and validation directories 97 | if not os.path.isdir(args.train_dir): 98 | print(f"Error: Training directory not found: {args.train_dir}") 99 | return 100 | if not os.path.isdir(args.val_dir): 101 | print(f"Error: Validation directory not found: {args.val_dir}") 102 | return 103 | 104 | # Process training data 105 | print("Starting training data preprocessing...") 106 | process_files(args.train_dir, args.out_train_file, args.tokenizer_name, args.max_data) 107 | print("Training data preprocessing complete.") 108 | 109 | # Process validation data 110 | print("Starting validation data preprocessing...") 111 | process_files(args.val_dir, args.out_val_file, args.tokenizer_name, args.max_data) 112 | print("Validation data preprocessing complete.") 113 | 114 | # Entry point of the script 115 | if __name__ == "__main__": 116 | main() -------------------------------------------------------------------------------- /scripts/generate_text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tiktoken 3 | import argparse 4 | from config.config import default_config as config 5 | from src.models.transformer import Transformer # Assuming your Transformer class is in this module 6 | 7 | def generate_text(model_path: str, input_text: str, max_new_tokens: int = 100, device: str = 'cuda') -> str: 8 | """ 9 | Generates text using a pre-trained Transformer model. 10 | 11 | Args: 12 | model_path (str): Path to the saved model checkpoint. 13 | input_text (str): The initial text to start generation from. 14 | max_new_tokens (int): The maximum number of new tokens to generate. 15 | device (str): 'cuda' or 'cpu', the device to run the model on. 16 | 17 | Returns: 18 | str: The generated text. 19 | """ 20 | # Load the model checkpoint 21 | checkpoint = torch.load(model_path, map_location=torch.device(device)) 22 | 23 | # Initialize the model using the configuration from config.py 24 | model = Transformer( 25 | n_head=config['n_head'], 26 | n_embed=config['n_embed'], 27 | context_length=config['context_length'], 28 | vocab_size=config['vocab_size'], 29 | N_BLOCKS=config['n_blocks'] 30 | ) 31 | model.load_state_dict(checkpoint['model_state_dict']) 32 | model.eval().to(device) 33 | 34 | # Load the tokenizer 35 | enc = tiktoken.get_encoding("r50k_base") 36 | 37 | start_ids = enc.encode_ordinary(input_text) 38 | context = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0) 39 | 40 | # Generation process 41 | with torch.no_grad(): 42 | generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist() 43 | 44 | # Decode the generated tokens 45 | output_text = enc.decode(generated_tokens) 46 | 47 | return output_text 48 | 49 | def main() -> None: 50 | parser = argparse.ArgumentParser(description="Generate text using a pre-trained Transformer model.") 51 | parser.add_argument('--model_path', type=str, help='Path to the saved model checkpoint.') 52 | parser.add_argument('--input_text', type=str, help='The initial text to start generation from.') 53 | parser.add_argument('--max_new_tokens', type=int, default=100, help='Maximum number of new tokens to generate.') 54 | 55 | args = parser.parse_args() 56 | 57 | generated = generate_text(args.model_path, args.input_text, args.max_new_tokens) 58 | print(f"Generated text:\n{generated}") 59 | 60 | if __name__ == "__main__": 61 | main() -------------------------------------------------------------------------------- /scripts/train_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | from tqdm import tqdm 5 | import numpy as np 6 | from config.config import default_config as config 7 | from src.models.transformer import Transformer 8 | from data_loader.data_loader import get_batch_iterator 9 | from typing import Dict 10 | 11 | # --- Initialize the Model and Print Parameters --- 12 | 13 | model = Transformer( 14 | n_head=config['n_head'], 15 | n_embed=config['n_embed'], 16 | context_length=config['context_length'], 17 | vocab_size=config['vocab_size'], 18 | N_BLOCKS=config['n_blocks'] 19 | ).to(config['device']) 20 | 21 | # Print the total number of parameters 22 | total_params = sum(p.numel() for p in model.parameters()) 23 | print(f"Total number of parameters in the model: {total_params:,}") 24 | 25 | # --- Optimizer Setup and Loss Tracking --- 26 | 27 | # Set up the AdamW optimizer with the specified learning rate. 28 | optimizer = torch.optim.AdamW(model.parameters(), lr=config['t_lr']) 29 | 30 | # List to track loss values during training. 31 | losses = [] 32 | 33 | # Define a window size for averaging recent losses in the training loop. 34 | AVG_WINDOW = 64 35 | 36 | # Helper function to estimate the average loss for training and development data. 37 | @torch.no_grad() 38 | def estimate_loss(steps: int) -> Dict[str, float]: 39 | """ 40 | Evaluate the model on training and development datasets and calculate average loss. 41 | 42 | Args: 43 | steps (int): Number of steps to evaluate. 44 | 45 | Returns: 46 | dict: Dictionary containing average losses for 'train' and 'dev' splits. 47 | """ 48 | out = {} 49 | model.eval() # Set the model to evaluation mode. 50 | 51 | for split in ['train', 'dev']: 52 | # Select the appropriate data path for the current split. 53 | data_path = config['train_path'] if split == 'train' else config['dev_path'] 54 | 55 | # Create a batch iterator for evaluation. 56 | batch_iterator_eval = get_batch_iterator( 57 | data_path, config['t_batch_size'], config['t_context_length'], device=config['device'] 58 | ) 59 | 60 | # Initialize a tensor to track loss values for each evaluation step. 61 | losses_eval = torch.zeros(steps) 62 | for k in range(steps): 63 | try: 64 | # Fetch a batch and calculate the loss. 65 | xb, yb = next(batch_iterator_eval) 66 | _, loss = model(xb, yb) 67 | losses_eval[k] = loss.item() 68 | except StopIteration: 69 | # Handle the case where the data iterator ends early. 70 | print(f"Warning: Iterator for {split} ended early.") 71 | break 72 | 73 | # Compute the mean loss for the current split. 74 | out[split] = losses_eval[:k + 1].mean() 75 | 76 | model.train() # Restore the model to training mode. 77 | return out 78 | 79 | # --- Training Loop --- 80 | 81 | # Create a batch iterator for the training data. 82 | batch_iterator = get_batch_iterator( 83 | config['train_path'], 84 | config['t_batch_size'], 85 | config['t_context_length'], 86 | device=config['device'] 87 | ) 88 | 89 | # Create a progress bar to monitor training progress. 90 | pbar = tqdm(range(config['t_train_steps'])) 91 | for step in pbar: 92 | try: 93 | # Fetch a batch of input and target data. 94 | xb, yb = next(batch_iterator) 95 | 96 | # Perform a forward pass and compute the loss. 97 | _, loss = model(xb, yb) 98 | 99 | # Record the loss for tracking. 100 | losses.append(loss.item()) 101 | pbar.set_description(f"Train loss: {np.mean(losses[-AVG_WINDOW:]):.4f}") 102 | 103 | # Backpropagate the loss and update the model parameters. 104 | optimizer.zero_grad(set_to_none=True) 105 | loss.backward() 106 | optimizer.step() 107 | 108 | # Periodically evaluate the model on training and development data. 109 | if step % config['t_eval_steps'] == 0: 110 | evaluation_losses = estimate_loss(config['t_eval_iters']) 111 | train_loss = evaluation_losses['train'] 112 | dev_loss = evaluation_losses['dev'] 113 | print(f"Step: {step}, Train loss: {train_loss:.4f}, Dev loss: {dev_loss:.4f}") 114 | 115 | # Decay the learning rate at the specified step. 116 | if step == config['t_lr_decay_step']: 117 | print('Decaying learning rate') 118 | for g in optimizer.param_groups: 119 | g['lr'] = config['t_lr_decayed'] 120 | except StopIteration: 121 | # Handle the case where the training data iterator ends early. 122 | print("Training data iterator finished early.") 123 | break 124 | 125 | # --- Save Model and Final Evaluation --- 126 | 127 | # Create the output directory if it does not exist. 128 | os.makedirs(config['t_out_path'].split('/')[0], exist_ok=True) 129 | 130 | # Perform a final evaluation of the model on training and development datasets. 131 | evaluation_losses = estimate_loss(200) 132 | train_loss = evaluation_losses['train'] 133 | dev_loss = evaluation_losses['dev'] 134 | 135 | # Ensure unique model save path in case the file already exists. 136 | modified_model_out_path = config['t_out_path'] 137 | save_tries = 0 138 | while os.path.exists(modified_model_out_path): 139 | save_tries += 1 140 | model_out_name = os.path.splitext(config['t_out_path'])[0] 141 | modified_model_out_path = model_out_name + f"_{save_tries}" + ".pt" 142 | 143 | # Save the model's state dictionary, optimizer state, and training metadata. 144 | torch.save( 145 | { 146 | 'model_state_dict': model.state_dict(), 147 | 'optimizer_state_dict': optimizer.state_dict(), 148 | 'losses': losses, 149 | 'train_loss': train_loss, 150 | 'dev_loss': dev_loss, 151 | 'steps': len(losses), 152 | }, 153 | modified_model_out_path 154 | ) 155 | print(f"Saved model to {modified_model_out_path}") 156 | print(f"Finished training. Train loss: {train_loss:.4f}, Dev loss: {dev_loss:.4f}") -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FareedKhan-dev/train-llm-from-scratch/163ba68c408576174fa9b601f13f6b1c52fb3714/src/__init__.py -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # This file makes the 'src.models' directory a Python package. 2 | from .mlp import MLP 3 | from .attention import Head, MultiHeadAttention 4 | from .transformer_block import Block 5 | from .transformer import Transformer -------------------------------------------------------------------------------- /src/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class Head(nn.Module): 7 | """ 8 | A single attention head. 9 | 10 | This module calculates attention scores and applies them to the values. 11 | It includes key, query, and value projections, and uses causal masking 12 | to prevent attending to future tokens. 13 | 14 | Args: 15 | head_size (int): The dimensionality of the key, query, and value projections. 16 | n_embed (int): The dimensionality of the input embedding. 17 | context_length (int): The maximum length of the input sequence, used for causal masking. 18 | """ 19 | def __init__(self, head_size: int, n_embed: int, context_length: int) -> None: 20 | """ 21 | Initializes the attention head. 22 | 23 | Args: 24 | head_size (int): The dimensionality of the key, query, and value projections. 25 | n_embed (int): The dimensionality of the input embedding. 26 | context_length (int): The maximum length of the input sequence. 27 | """ 28 | super().__init__() 29 | self.key = nn.Linear(n_embed, head_size, bias=False) # Key projection 30 | self.query = nn.Linear(n_embed, head_size, bias=False) # Query projection 31 | self.value = nn.Linear(n_embed, head_size, bias=False) # Value projection 32 | # Lower triangular matrix for causal masking 33 | self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length))) 34 | 35 | def forward(self, x: torch.Tensor) -> torch.Tensor: 36 | """ 37 | Forward pass through the attention head. 38 | 39 | Args: 40 | x (torch.Tensor): Input tensor of shape (B, T, C). 41 | 42 | Returns: 43 | torch.Tensor: Output tensor after applying attention. 44 | """ 45 | B, T, C = x.shape 46 | head_size = self.key.out_features 47 | k = self.key(x) # (B, T, head_size) 48 | q = self.query(x) # (B, T, head_size) 49 | scale_factor = 1 / math.sqrt(head_size) 50 | # Calculate attention weights: (B, T, head_size) @ (B, head_size, T) -> (B, T, T) 51 | attn_weights = q @ k.transpose(-2, -1) * scale_factor 52 | # Apply causal masking 53 | attn_weights = attn_weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) 54 | attn_weights = F.softmax(attn_weights, dim=-1) 55 | v = self.value(x) # (B, T, head_size) 56 | # Apply attention weights to values 57 | out = attn_weights @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size) 58 | return out 59 | 60 | class MultiHeadAttention(nn.Module): 61 | """ 62 | Multi-Head Attention module. 63 | 64 | This module combines multiple attention heads in parallel. The outputs of each head 65 | are concatenated to form the final output. 66 | 67 | Args: 68 | n_head (int): The number of parallel attention heads. 69 | n_embed (int): The dimensionality of the input embedding. 70 | context_length (int): The maximum length of the input sequence. 71 | """ 72 | def __init__(self, n_head: int, n_embed: int, context_length: int) -> None: 73 | """ 74 | Initializes the multi-head attention module. 75 | 76 | Args: 77 | n_head (int): The number of parallel attention heads. 78 | n_embed (int): The dimensionality of the input embedding. 79 | context_length (int): The maximum length of the input sequence. 80 | """ 81 | super().__init__() 82 | self.heads = nn.ModuleList([Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)]) 83 | 84 | def forward(self, x: torch.Tensor) -> torch.Tensor: 85 | """ 86 | Forward pass through the multi-head attention. 87 | 88 | Args: 89 | x (torch.Tensor): Input tensor of shape (B, T, C). 90 | 91 | Returns: 92 | torch.Tensor: Output tensor after concatenating the outputs of all heads. 93 | """ 94 | # Concatenate the output of each head along the last dimension (C) 95 | x = torch.cat([h(x) for h in self.heads], dim=-1) 96 | return x 97 | 98 | if __name__ == '__main__': 99 | # Example Usage (optional, for testing the module independently) 100 | batch_size = 2 101 | sequence_length = 5 102 | embedding_dim = 32 103 | num_heads = 4 104 | context_len = 5 105 | input_tensor = torch.randn(batch_size, sequence_length, embedding_dim) 106 | 107 | multihead_attn = MultiHeadAttention(n_head=num_heads, n_embed=embedding_dim, context_length=context_len) 108 | output_tensor = multihead_attn(input_tensor) 109 | 110 | print("MultiHeadAttention Input Shape:", input_tensor.shape) 111 | print("MultiHeadAttention Output Shape:", output_tensor.shape) 112 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | class MLP(nn.Module): 6 | """ 7 | A simple Multi-Layer Perceptron with one hidden layer. 8 | 9 | This module is used within the Transformer block for feed-forward processing. 10 | It expands the input embedding size, applies a ReLU activation, and then projects it back 11 | to the original embedding size. 12 | 13 | Args: 14 | n_embed (int): The dimensionality of the input embedding. 15 | """ 16 | def __init__(self, n_embed: int) -> None: 17 | """ 18 | Initializes the MLP module. 19 | 20 | Args: 21 | n_embed (int): The dimensionality of the input embedding. 22 | """ 23 | super().__init__() 24 | self.hidden = nn.Linear(n_embed, 4 * n_embed) # Linear layer to expand embedding size 25 | self.relu = nn.ReLU() # ReLU activation function 26 | self.proj = nn.Linear(4 * n_embed, n_embed) # Linear layer to project back to original size 27 | 28 | def forward(self, x: Tensor) -> Tensor: 29 | """ 30 | Forward pass through the MLP. 31 | 32 | Args: 33 | x (torch.Tensor): Input tensor of shape (B, T, C), where B is batch size, 34 | T is sequence length, and C is embedding size. 35 | 36 | Returns: 37 | torch.Tensor: Output tensor of the same shape as the input. 38 | """ 39 | x = self.forward_embedding(x) 40 | x = self.project_embedding(x) 41 | return x 42 | 43 | def forward_embedding(self, x: Tensor) -> Tensor: 44 | """ 45 | Applies the hidden linear layer followed by ReLU activation. 46 | 47 | Args: 48 | x (torch.Tensor): Input tensor. 49 | 50 | Returns: 51 | torch.Tensor: Output after the hidden layer and ReLU. 52 | """ 53 | x = self.relu(self.hidden(x)) 54 | return x 55 | 56 | def project_embedding(self, x: Tensor) -> Tensor: 57 | """ 58 | Applies the projection linear layer. 59 | 60 | Args: 61 | x (torch.Tensor): Input tensor. 62 | 63 | Returns: 64 | torch.Tensor: Output after the projection layer. 65 | """ 66 | x = self.proj(x) 67 | return x 68 | 69 | if __name__ == '__main__': 70 | # Example Usage (optional, for testing the module independently) 71 | batch_size = 2 72 | sequence_length = 3 73 | embedding_dim = 16 74 | input_tensor = torch.randn(batch_size, sequence_length, embedding_dim) 75 | 76 | mlp_module = MLP(n_embed=embedding_dim) 77 | output_tensor = mlp_module(input_tensor) 78 | 79 | print("MLP Input Shape:", input_tensor.shape) 80 | print("MLP Output Shape:", output_tensor.shape) -------------------------------------------------------------------------------- /src/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from src.models.transformer_block import Block 5 | 6 | class Transformer(nn.Module): 7 | """ 8 | The main Transformer model. 9 | 10 | This class combines token and position embeddings with a sequence of Transformer blocks 11 | and a final linear layer for language modeling. 12 | 13 | Args: 14 | n_head (int): The number of attention heads in each transformer block. 15 | n_embed (int): The dimensionality of the embedding space. 16 | context_length (int): The maximum length of the input sequence. 17 | vocab_size (int): The size of the vocabulary. 18 | N_BLOCKS (int): The number of transformer blocks in the model. 19 | """ 20 | def __init__(self, n_head: int, n_embed: int, context_length: int, vocab_size: int, N_BLOCKS: int) -> None: 21 | """ 22 | Initializes the Transformer model. 23 | 24 | Args: 25 | n_head (int): Number of attention heads. 26 | n_embed (int): Embedding dimension. 27 | context_length (int): Maximum sequence length. 28 | vocab_size (int): Size of the vocabulary. 29 | N_BLOCKS (int): Number of transformer blocks. 30 | """ 31 | super().__init__() 32 | self.context_length = context_length 33 | self.N_BLOCKS = N_BLOCKS 34 | self.token_embed = nn.Embedding(vocab_size, n_embed) 35 | self.position_embed = nn.Embedding(context_length, n_embed) 36 | self.attn_blocks = nn.ModuleList([Block(n_head, n_embed, context_length) for _ in range(N_BLOCKS)]) 37 | self.layer_norm = nn.LayerNorm(n_embed) 38 | self.lm_head = nn.Linear(n_embed, vocab_size) 39 | self.register_buffer('pos_idxs', torch.arange(context_length)) 40 | 41 | def _pre_attn_pass(self, idx: torch.Tensor) -> torch.Tensor: 42 | """ 43 | Combines token and position embeddings. 44 | 45 | Args: 46 | idx (torch.Tensor): Input token indices. 47 | 48 | Returns: 49 | torch.Tensor: Sum of token and position embeddings. 50 | """ 51 | B, T = idx.shape 52 | tok_embedding = self.token_embed(idx) 53 | pos_embedding = self.position_embed(self.pos_idxs[:T]) 54 | return tok_embedding + pos_embedding 55 | 56 | def forward(self, idx: torch.Tensor, targets: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor | None]: 57 | """ 58 | Forward pass through the Transformer. 59 | 60 | Args: 61 | idx (torch.Tensor): Input token indices. 62 | targets (torch.Tensor, optional): Target token indices for loss calculation. Defaults to None. 63 | 64 | Returns: 65 | tuple: Logits and loss (if targets are provided). 66 | """ 67 | x = self._pre_attn_pass(idx) 68 | for block in self.attn_blocks: 69 | x = block(x) 70 | x = self.layer_norm(x) 71 | logits = self.lm_head(x) 72 | loss = None 73 | if targets is not None: 74 | B, T, C = logits.shape 75 | flat_logits = logits.view(B * T, C) 76 | targets = targets.view(B * T).long() 77 | loss = F.cross_entropy(flat_logits, targets) 78 | return logits, loss 79 | 80 | def forward_embedding(self, idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 81 | """ 82 | Forward pass focusing on the embedding and attention blocks. 83 | 84 | Args: 85 | idx (torch.Tensor): Input token indices. 86 | 87 | Returns: 88 | tuple: Output after attention blocks and the residual. 89 | """ 90 | x = self._pre_attn_pass(idx) 91 | residual = x 92 | for block in self.attn_blocks: 93 | x, residual = block.forward_embedding(x) 94 | return x, residual 95 | 96 | def generate(self, idx: torch.Tensor, max_new_tokens: int) -> torch.Tensor: 97 | """ 98 | Generates new tokens given a starting sequence. 99 | 100 | Args: 101 | idx (torch.Tensor): Initial sequence of token indices. 102 | max_new_tokens (int): Number of tokens to generate. 103 | 104 | Returns: 105 | torch.Tensor: The extended sequence of tokens. 106 | """ 107 | for _ in range(max_new_tokens): 108 | idx_cond = idx[:, -self.context_length:] 109 | logits, _ = self(idx_cond) 110 | logits = logits[:, -1, :] 111 | probs = F.softmax(logits, dim=-1) 112 | idx_next = torch.multinomial(probs, num_samples=1) 113 | idx = torch.cat((idx, idx_next), dim=1) 114 | return idx 115 | 116 | if __name__ == '__main__': 117 | # Example Usage (optional, for testing the module independently) 118 | batch_size = 2 119 | sequence_length = 5 120 | vocab_size = 100 121 | embedding_dim = 32 122 | num_heads = 4 123 | num_blocks = 2 124 | context_len = 5 125 | input_indices = torch.randint(0, vocab_size, (batch_size, sequence_length)) 126 | 127 | transformer_model = Transformer(n_head=num_heads, n_embed=embedding_dim, context_length=context_len, vocab_size=vocab_size, N_BLOCKS=num_blocks) 128 | logits, loss = transformer_model(input_indices, targets=input_indices) # Using input as target for simplicity 129 | 130 | print("Transformer Logits Shape:", logits.shape) 131 | print("Transformer Loss:", loss) 132 | 133 | # Example of generating tokens 134 | start_indices = input_indices[:, :1] # Take the first token of each sequence as start 135 | generated_tokens = transformer_model.generate(start_indices, max_new_tokens=5) 136 | print("Generated Tokens Shape:", generated_tokens.shape) -------------------------------------------------------------------------------- /src/models/transformer_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.models.attention import MultiHeadAttention 4 | from src.models.mlp import MLP 5 | 6 | class Block(nn.Module): 7 | """ 8 | A single Transformer block. 9 | 10 | This block consists of a multi-head attention layer followed by an MLP, 11 | with layer normalization and residual connections. 12 | 13 | Args: 14 | n_head (int): The number of attention heads in the multi-head attention layer. 15 | n_embed (int): The dimensionality of the input embedding. 16 | context_length (int): The maximum length of the input sequence. 17 | """ 18 | def __init__(self, n_head: int, n_embed: int, context_length: int) -> None: 19 | """ 20 | Initializes the Transformer block. 21 | 22 | Args: 23 | n_head (int): The number of attention heads. 24 | n_embed (int): The dimensionality of the embedding space. 25 | context_length (int): The maximum sequence length. 26 | """ 27 | super().__init__() 28 | self.ln1 = nn.LayerNorm(n_embed) 29 | self.attn = MultiHeadAttention(n_head, n_embed, context_length) 30 | self.ln2 = nn.LayerNorm(n_embed) 31 | self.mlp = MLP(n_embed) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Forward pass through the Transformer block. 36 | 37 | Args: 38 | x (torch.Tensor): Input tensor. 39 | 40 | Returns: 41 | torch.Tensor: Output tensor after the block. 42 | """ 43 | # Apply multi-head attention with residual connection 44 | x = x + self.attn(self.ln1(x)) 45 | # Apply MLP with residual connection 46 | x = x + self.mlp(self.ln2(x)) 47 | return x 48 | 49 | def forward_embedding(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 50 | """ 51 | Forward pass focusing on the embedding and attention parts. 52 | 53 | Args: 54 | x (torch.Tensor): Input tensor. 55 | 56 | Returns: 57 | tuple: A tuple containing the output after MLP embedding and the residual. 58 | """ 59 | res = x + self.attn(self.ln1(x)) 60 | x = self.mlp.forward_embedding(self.ln2(res)) 61 | return x, res 62 | 63 | if __name__ == '__main__': 64 | # Example Usage (optional, for testing the module independently) 65 | batch_size = 2 66 | sequence_length = 5 67 | embedding_dim = 32 68 | num_heads = 4 69 | context_len = 5 70 | input_tensor = torch.randn(batch_size, sequence_length, embedding_dim) 71 | 72 | transformer_block = Block(n_head=num_heads, n_embed=embedding_dim, context_length=context_len) 73 | output_tensor = transformer_block(input_tensor) 74 | 75 | print("Transformer Block Input Shape:", input_tensor.shape) 76 | print("Transformer Block Output Shape:", output_tensor.shape) --------------------------------------------------------------------------------