├── .gitignore ├── README.md ├── config ├── dataset_config.json ├── model_config.json └── training_config.json ├── main.py ├── model_checkpoint ├── hf_tokenizer │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json ├── tokenizer.json └── tokenizer_byte.json ├── notebooks └── test.ipynb ├── requirements.txt └── src ├── data ├── dataset.py └── load_data.py ├── model └── gpt2.py ├── tokenization ├── bpe_tokenizer.py └── bytelevel_bpe_tokenizer.py ├── training ├── pipeline.py └── train.py └── utils └── evaluation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore model checkpoints 2 | model_checkpoint/model.pt 3 | 4 | # Ignore Python cache 5 | **/__pycache__/ 6 | **/*.pyc 7 | 8 | # Ignore Jupyter notebook checkpoints 9 | **/.ipynb_checkpoints 10 | 11 | # Ignore OS-specific files 12 | .DS_Store 13 | 14 | # Ignore old scripts 15 | train.py 16 | utils.py 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT-2 Style LLM Training from Scratch 2 | 3 | This project implements a simplified GPT-2 style model (~120M parameters) from scratch and trains it on a large instruction dataset. The code is modular, uses configuration files, and aims to be educational — demonstrating the full pipeline without external large language model libraries. 4 | 5 | ## Overview 6 | 7 | This repository showcases: 8 | 9 | - **Custom BPE Tokenization:** A simplified implementation with options for Byte-Level or Character-Level encoding. 10 | - **GPT-2 Style Model Architecture:** Multi-head causal self-attention, residual connections, and layer norms. 11 | - **Config-Driven Pipeline:** All parameters are controlled via JSON configs in the `config/` directory. 12 | 13 | ## Dataset 14 | 15 | We train on the [Orca Agent Instruct Dataset (1M)](https://huggingface.co/datasets/microsoft/orca-agentinstruct-1M-v1) by Microsoft. This dataset: 16 | 17 | - Contains ~1 million user-assistant message pairs. 18 | - Covers a wide range of instructions and responses. 19 | - Provides realistic conversational patterns for better model adaptation. 20 | 21 | ## Requirements 22 | 23 | - **Python:** 3.11.10 24 | - Other dependencies listed in `requirements.txt`. 25 | 26 | ## Installation & Usage 27 | 28 | ```bash 29 | git clone https://github.com/timmzimm/Train-LLM-from-scratch.git 30 | cd Train-LLM-from-scratch 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Training 35 | 36 | We support both single-GPU (or CPU) and multi-GPU (DDP) training. 37 | 38 | ### Single-GPU or CPU 39 | 1. In `config/training_config.json`, set `"distributed": false`. 40 | 2. Run: 41 | ```bash 42 | python main.py 43 | ``` 44 | 45 | ### Multi-GPU (DDP) 46 | 1. In `config/training_config.json`, set `"distributed": true` and GPU indices (for example, `"gpu_ids": [0,1]`). 47 | 2. Run: 48 | ```bash 49 | torchrun --nproc_per_node=2 main.py 50 | ``` 51 | 52 | 53 | 54 | ## Acknowledgments 55 | - **Sebastian Raschka:** His transparent approach to LLMs was an inspiration. 56 | - **Microsoft & Hugging Face:** For providing and hosting the Orca dataset 57 | 58 | 59 | -------------------------------------------------------------------------------- /config/dataset_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "microsoft/orca-agentinstruct-1M-v1", 3 | "all_splits": [ 4 | "creative_content", 5 | "text_modification", 6 | "struct2text_flow", 7 | "rc", 8 | "rag", 9 | "text_extraction", 10 | "mcq", 11 | "follow_up", 12 | "analytical_reasoning", 13 | "fermi", 14 | "fs_cot_flow", 15 | "code_", 16 | "brain_teaser", 17 | "text_classification", 18 | "open_domain_qa" 19 | ], 20 | "train_ratio": 0.7, 21 | "val_ratio": 0.2, 22 | "test_ratio": 0.1 23 | } 24 | -------------------------------------------------------------------------------- /config/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_ctx": 1024, 3 | "n_embd": 768, 4 | "n_layer": 12, 5 | "n_head": 12 6 | } 7 | -------------------------------------------------------------------------------- /config/training_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "block_size": 1024, 3 | "batch_size": 4, 4 | "warmup_epochs": 1, 5 | "cosine_epochs": 0, 6 | "learning_rate": 0.0003, 7 | "eval_every_steps": 3500, 8 | "vocab_size_limit": 30000, 9 | "merges_count": 2, 10 | "special_tokens": ["<|endoftext|>"], 11 | "tokenization_type": "huggingface", 12 | "hf_tokenizer_name": "gpt2", 13 | "gpu_ids": [0, 1], 14 | "distributed": true 15 | } 16 | 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from src.training.pipeline import run_training_pipeline_single, run_training_pipeline_ddp 7 | 8 | def main(): 9 | """ 10 | Entry point. Reads train_config and decides if we do single-GPU/CPU or DDP. 11 | """ 12 | with open("config/training_config.json", "r") as f: 13 | train_config = json.load(f) 14 | 15 | distributed = train_config.get("distributed", False) 16 | if distributed: 17 | # Launch multi-GPU with PyTorch DDP 18 | world_size = len(train_config["gpu_ids"]) 19 | ddp_main(world_size, train_config) 20 | else: 21 | # Single-GPU or CPU fallback 22 | run_training_pipeline_single(train_config) 23 | 24 | def ddp_main(world_size, train_config): 25 | """ 26 | Initializes the process group, runs the pipeline on each process, then cleans up. 27 | We assume you're calling: 28 | torchrun --nproc_per_node={world_size} main.py 29 | """ 30 | dist.init_process_group(backend="nccl", init_method="env://") 31 | local_rank = int(os.environ["LOCAL_RANK"]) 32 | run_training_pipeline_ddp(train_config, local_rank) 33 | dist.destroy_process_group() 34 | 35 | if __name__ == "__main__": 36 | main() 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /model_checkpoint/hf_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "<|endoftext|>", 3 | "eos_token": "<|endoftext|>", 4 | "unk_token": "<|endoftext|>" 5 | } 6 | -------------------------------------------------------------------------------- /model_checkpoint/hf_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "50256": { 5 | "content": "<|endoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | } 12 | }, 13 | "bos_token": "<|endoftext|>", 14 | "clean_up_tokenization_spaces": false, 15 | "eos_token": "<|endoftext|>", 16 | "model_max_length": 1024, 17 | "tokenizer_class": "GPT2Tokenizer", 18 | "unk_token": "<|endoftext|>" 19 | } 20 | -------------------------------------------------------------------------------- /model_checkpoint/tokenizer.json: -------------------------------------------------------------------------------- 1 | {"special_tokens": ["<|endoftext|>"], "token2id": {"<|endoftext|>": 0, "\n": 1, " ": 2, "!": 3, "\"": 4, "#": 5, "$": 6, "%": 7, "&": 8, "'": 9, "(": 10, ")": 11, "*": 12, "+": 13, ",": 14, "-": 15, ".": 16, "/": 17, "0": 18, "1": 19, "2": 20, "3": 21, "4": 22, "5": 23, "6": 24, "7": 25, "8": 26, "9": 27, ":": 28, ";": 29, "<": 30, "=": 31, ">": 32, "?": 33, "@": 34, "A": 35, "B": 36, "C": 37, "D": 38, "E": 39, "F": 40, "G": 41, "H": 42, "I": 43, "J": 44, "K": 45, "L": 46, "M": 47, "N": 48, "O": 49, "P": 50, "Q": 51, "R": 52, "S": 53, "T": 54, "U": 55, "V": 56, "W": 57, "X": 58, "Y": 59, "Z": 60, "[": 61, "\\": 62, "]": 63, "^": 64, "_": 65, "`": 66, "a": 67, "b": 68, "c": 69, "d": 70, "e": 71, "f": 72, "g": 73, "h": 74, "i": 75, "j": 76, "k": 77, "l": 78, "m": 79, "n": 80, "o": 81, "p": 82, "q": 83, "r": 84, "s": 85, "t": 86, "u": 87, "v": 88, "w": 89, "x": 90, "y": 91, "z": 92, "{": 93, "|": 94, "}": 95, "~": 96, "\u00a1": 97, "\u00a3": 98, "\u00a5": 99, "\u00a7": 100, "\u00ac": 101, "\u00ae": 102, "\u00b0": 103, "\u00b2": 104, "\u00b3": 105, "\u00b5": 106, "\u00b7": 107, "\u00bd": 108, "\u00c1": 109, "\u00c3": 110, "\u00c5": 111, "\u00c9": 112, "\u00cb": 113, "\u00d0": 114, "\u00d6": 115, "\u00d7": 116, "\u00d8": 117, "\u00df": 118, "\u00e0": 119, "\u00e1": 120, "\u00e2": 121, "\u00e3": 122, "\u00e4": 123, "\u00e5": 124, "\u00e7": 125, "\u00e8": 126, "\u00e9": 127, "\u00ea": 128, "\u00eb": 129, "\u00ed": 130, "\u00ef": 131, "\u00f1": 132, "\u00f3": 133, "\u00f4": 134, "\u00f6": 135, "\u00f7": 136, "\u00f8": 137, "\u00fa": 138, "\u00fb": 139, "\u00fc": 140, "\u0101": 141, "\u0107": 142, "\u010d": 143, "\u0111": 144, "\u0119": 145, "\u011f": 146, "\u012b": 147, "\u0130": 148, "\u0131": 149, "\u0142": 150, "\u0144": 151, "\u0145": 152, "\u014c": 153, "\u014d": 154, "\u0156": 155, "\u0159": 156, "\u015f": 157, "\u016b": 158, "\u0175": 159, "\u0190": 160, "\u01de": 161, "\u02be": 162, "\u02bf": 163, "\u0394": 164, "\u039b": 165, "\u03a9": 166, "\u03b1": 167, "\u03b2": 168, "\u03b3": 169, "\u03b4": 170, "\u03b5": 171, "\u03b8": 172, "\u03ba": 173, "\u03bb": 174, "\u03bc": 175, "\u03bd": 176, "\u03be": 177, "\u03c0": 178, "\u03c1": 179, "\u03c3": 180, "\u03c4": 181, "\u03c8": 182, "\u03c9": 183, "\u0420": 184, "\u0437": 185, "\u0438": 186, "\u0439": 187, "\u043a": 188, "\u0441": 189, "\u0443": 190, "\u044b": 191, "\u044f": 192, "\u04e8": 193, "\u1e24": 194, "\u1e25": 195, "\u1e63": 196, "\u2006": 197, "\u2010": 198, "\u2013": 199, "\u2014": 200, "\u2018": 201, "\u2019": 202, "\u201d": 203, "\u2032": 204, "\u2033": 205, "\u20ac": 206, "\u2103": 207, "\u2122": 208, "\u2126": 209, "\u2192": 210, "\u2194": 211, "\u21d4": 212, "\u2206": 213, "\u2212": 214, "\u221e": 215, "\u2228": 216, "\u222b": 217, "\u2248": 218, "\u2260": 219, "\u2261": 220, "\u2264": 221, "\u2265": 222, "\u226a": 223, "\u2705": 224, "\u2713": 225, "\u2717": 226, "\u274c": 227, "\u4e00": 228, "\u4e1a": 229, "\u4e2d": 230, "\u4efd": 231, "\u521b": 232, "\u5238": 233, "\u56fd": 234, "\u592a": 235, "\u5e73": 236, "\u65b9": 237, "\u6b63": 238, "\u6c49": 239, "\u6cb3": 240, "\u6d0b": 241, "\u7b2c": 242, "\u80a1": 243, "\u8bc1": 244, "\u8bed": 245, "\u94f6": 246, "\u9526": 247, "\u9f99": 248, "e ": 249, "s ": 250, "th": 251, "t ": 252, "in": 253, "on": 254, ", ": 255, "an": 256, "er": 257, "d ": 258, "en": 259, "the ": 260, "ti": 261, "or": 262, "re": 263, "al": 264, "ar": 265, "y ": 266, " the ": 267, "tion": 268, "ing": 269, "at": 270, "st": 271, "is ": 272, "o ": 273, "ic": 274, "es ": 275, "ro": 276, "us": 277, "es": 278, "and ": 279, "ed ": 280, "wi": 281, "of": 282, "ing ": 283, "Th": 284, "ec": 285, "a ": 286, "ac": 287, "with": 288, ". ": 289, ") ": 290, "er ": 291, ".\n": 292, "no": 293, "al ": 294, "it": 295, "for": 296, "as": 297, "with ": 298, "to ": 299, "tion ": 300, "co": 301, "\n(": 302, "le": 303, "h ": 304, "si": 305, "of ": 306, "di": 307, "ur": 308, "ol": 309, "not ": 310, "el": 311, "ent": 312, "con": 313, "The ": 314, "mus": 315, "must ": 316, "ent ": 317, "ul": 318, "be ": 319, "in ": 320, "on ": 321, "at ": 322, "vi": 323, "am": 324, "op": 325, "s, ": 326, "em": 327, "et": 328, "- ": 329, "lu": 330, "lo": 331, "an ": 332, ": ": 333, "il": 334, "ch": 335, "ra": 336, "ri": 337, "un": 338, "ir": 339, "res": 340, "mp": 341, "ex": 342, "ed": 343, "up": 344, "li": 345, "as ": 346, "ad": 347, "pro": 348, "ve ": 349, "e, ": 350, "wh": 351, ". The ": 352, "'s ": 353, "m ": 354, "rec": 355, "that ": 356, "clu": 357, "ce ": 358, "qu": 359, "ation": 360, "ation ": 361, "This ": 362, "all": 363, "is": 364, "ig": 365, "are ": 366, "ag": 367, "int": 368, "en ": 369, "ab": 370, "for ": 371, "pl": 372, "be": 373, "or ": 374, "of the ": 375, "cor": 376, "to": 377, "se": 378, "per": 379, ":\n": 380, "ate ": 381, "ich ": 382, "y, ": 383, "sp": 384, ".\n\n": 385, "the": 386, "ev": 387, "le ": 388, "can": 389, "ati": 390, "ly ": 391, "me": 392, "must be ": 393, "comp": 394, "ou": 395, "ran": 396, ", and ": 397, "os": 398, "uc": 399, "ap": 400, "cannot ": 401, " with ": 402, "ach ": 403, ".\n\n(": 404, "correc": 405, "use ": 406, "ect": 407, "ity ": 408, "do": 409, "ff": 410, "low": 411, "which ": 412, "sh": 413, "k ": 414, "ver": 415, "up ": 416, "in the ": 417, "pres": 418, "A ": 419, "ment": 420, "te": 421, "tic": 422, "ay": 423, "one ": 424, "bu": 425, "ow": 426, "fore ": 427, "before ": 428, "oc": 429, "if": 430, "om": 431, "ers ": 432, "A) ": 433, "on the ": 434, "ha": 435, "im": 436, "ir ": 437, "es, ": 438, "end": 439, "ic ": 440, "com": 441, "ne": 442, "fro": 443, "id": 444, "gro": 445, "vid": 446, "th ": 447, "su": 448, "l ": 449, "it ": 450, "E) ": 451, "C) ": 452, "bec": 453, "option ": 454, "D) ": 455, "ause ": 456, "; ": 457, "B) ": 458, "ter": 459, "because ": 460, "gi": 461, "str": 462, "ates ": 463, "stra": 464, "ai": 465, "\n(A) ": 466, "ak": 467, "their ": 468, "od": 469, "each ": 470, "\n- ": 471, ".\n- ": 472, "par": 473, "ence ": 474, "correct ": 475, " is ": 476, ", C": 477, "to the ": 478, "enc": 479, "af": 480, "rul": 481, "um": 482, "has ": 483, "st ": 484, "bo": 485, " and ": 486, "ear": 487, "ing the ": 488, "sc": 489, "does ": 490, "\nThis ": 491, "ure ": 492, "der": 493, "can ": 494, "constra": 495, "ment ": 496, "ay ": 497, "ts ": 498, "does not ": 499, "w ": 500, "form": 501, "her": 502, "from ": 503, "viol": 504, "iz": 505, "de": 506, "In": 507, "have ": 508, "wor": 509, "tions ": 510, "with the ": 511, "cannot be ": 512, "ased ": 513, "spec": 514, "min": 515, "present": 516, "ical ": 517, "2 ": 518, "fol": 519, "ip": 520, "ot": 521, "so ": 522, "by ": 523, ", B": 524, "ud": 525, "tr": 526, "lex": 527, "fir": 528, "follow": 529, "ther ": 530, "enti": 531, "ther": 532, "clud": 533, "ed to ": 534, "igh": 535, "1 ": 536, "ast ": 537, "e the ": 538, ", D": 539, "fi": 540, "ame ": 541, "and": 542, "is not ": 543, "uld ": 544, "eg": 545, "ble ": 546, "aly": 547, "ter ": 548, "anc": 549, "but ": 550, "than ": 551, "der ": 552, "red ": 553, "n ": 554, "rang": 555, "sec": 556, "constraint": 557, "Ch": 558, "ess": 559, ", which ": 560, "s and ": 561, "act": 562, "por": 563, "est": 564, "any ": 565, "sequ": 566, "Clu": 567, "diff": 568, "att": 569, "mor": 570, "we ": 571, "ist": 572, "es:\n": 573, "ph": 574, "ement ": 575, "imp": 576, "duc": 577, "est ": 578, "ated ": 579, "cl": 580, "dition": 581, "ma": 582, "am ": 583, "differ": 584, "provid": 585, "ar ": 586, "el ": 587, "Alex": 588, "\n\n": 589, "all the ": 590, "3 ": 591, "includ": 592, "ari": 593, "pre": 594, "adher": 595, "tw": 596, ", E": 597, "Al": 598, "clue ": 599, "' ": 600, "consi": 601, "violates ": 602, "cont": 603, "comm": 604, "we": 605, "ess ": 606, "sion": 607, "analy": 608, "\n(C) ": 609, "set": 610, "0 ": 611, "ut": 612, "gn": 613, "\n(B) ": 614, "\n(D) ": 615, "\n(E) ": 616, "alu": 617, "ure": 618, "er, ": 619, "different ": 620, "out ": 621, "same ": 622, "ens": 623, "la": 624, "iti": 625, "and the ": 626, "arrang": 627, "all ": 628, "\n- This ": 629, "St": 630, "more ": 631, "they ": 632, "wil": 633, "po": 634, "et ": 635, "inv": 636, "our": 637, "ance ": 638, "ep": 639, "sion ": 640, "ice ": 641, "gh": 642, "tim": 643, "ext ": 644, "Clues:\n": 645, "also ": 646, "pos": 647, "og": 648, "ser": 649, "Q: ": 650, "cur": 651, "), ": 652, "correct": 653, "based ": 654, ".\n\nClues:\n": 655, "ang": 656, "s the ": 657, "day": 658, "ents ": 659, "age ": 660, "will ": 661, "ial ": 662, "sy": 663, "no ": 664, "exper": 665, "ey ": 666, "?\n(A) ": 667, "giv": 668, "ain": 669, "iv": 670, "for the ": 671, "inter": 672, "ations ": 673, "clues ": 674, "ech": 675, "cho": 676, "ass": 677, "tiv": 678, "following ": 679, "sign": 680, "eng": 681, "requ": 682, "next ": 683, "end ": 684, "det": 685, "work": 686, "s. ": 687, "group": 688, "only ": 689, "ely ": 690, "ect ": 691, "group ": 692, "ary ": 693, "conf": 694, "two ": 695, "are": 696, "av": 697, "first ": 698, "fri": 699, "r. ": 700, ":\n\n(A) ": 701, "its ": 702, "If": 703, "ant ": 704, "ing a ": 705, "s. The ": 706, "\n\nQ: ": 707, "stud": 708, "ies ": 709, "5 ": 710, "medi": 711, "fin": 712, "after ": 713, "aile": 714, "let": 715, "ob": 716, ", S": 717, "ous ": 718, "order ": 719, "sub": 720, "ron": 721, "sequence ": 722, "able ": 723, "tra": 724, "sible ": 725, "ese ": 726, "cre": 727, "resp": 728, "yp": 729, "evalu": 730, "tive ": 731, "oun": 732, "of the following ": 733, "ase ": 734, "adheres ": 735, "ould ": 736, "Ther": 737, "specif": 738, "4 ": 739, "ed, ": 740, "pa": 741, "Wh": 742, "Each ": 743, "ree ": 744, "coun": 745, "Let": 746, "a, ": 747, "atis": 748}, "merges": [["e", " "], ["s", " "], ["t", "h"], ["t", " "], ["i", "n"], ["o", "n"], [",", " "], ["a", "n"], ["e", "r"], ["d", " "], ["e", "n"], ["th", "e "], ["t", "i"], ["o", "r"], ["r", "e"], ["a", "l"], ["a", "r"], ["y", " "], [" ", "the "], ["ti", "on"], ["in", "g"], ["a", "t"], ["s", "t"], ["i", "s "], ["o", " "], ["i", "c"], ["e", "s "], ["r", "o"], ["u", "s"], ["e", "s"], ["an", "d "], ["e", "d "], ["w", "i"], ["o", "f"], ["ing", " "], ["T", "h"], ["e", "c"], ["a", " "], ["a", "c"], ["wi", "th"], [".", " "], [")", " "], ["er", " "], [".", "\n"], ["n", "o"], ["al", " "], ["i", "t"], ["f", "or"], ["a", "s"], ["with", " "], ["t", "o "], ["tion", " "], ["c", "o"], ["\n", "("], ["l", "e"], ["h", " "], ["s", "i"], ["of", " "], ["d", "i"], ["u", "r"], ["o", "l"], ["no", "t "], ["e", "l"], ["en", "t"], ["c", "on"], ["Th", "e "], ["m", "us"], ["mus", "t "], ["en", "t "], ["u", "l"], ["b", "e "], ["in", " "], ["on", " "], ["a", "t "], ["v", "i"], ["a", "m"], ["o", "p"], ["s", ", "], ["e", "m"], ["e", "t"], ["-", " "], ["l", "u"], ["l", "o"], ["an", " "], [":", " "], ["i", "l"], ["c", "h"], ["r", "a"], ["r", "i"], ["u", "n"], ["i", "r"], ["re", "s"], ["m", "p"], ["e", "x"], ["e", "d"], ["u", "p"], ["l", "i"], ["a", "s "], ["a", "d"], ["p", "ro"], ["v", "e "], ["e", ", "], ["w", "h"], [". ", "The "], ["'", "s "], ["m", " "], ["re", "c"], ["th", "at "], ["c", "lu"], ["c", "e "], ["q", "u"], ["a", "tion"], ["a", "tion "], ["Th", "is "], ["al", "l"], ["i", "s"], ["i", "g"], ["ar", "e "], ["a", "g"], ["in", "t"], ["en", " "], ["a", "b"], ["for", " "], ["p", "l"], ["b", "e"], ["or", " "], ["of", " the "], ["c", "or"], ["t", "o"], ["s", "e"], ["p", "er"], [":", "\n"], ["at", "e "], ["ic", "h "], ["y", ", "], ["s", "p"], [".\n", "\n"], ["th", "e"], ["e", "v"], ["l", "e "], ["c", "an"], ["a", "ti"], ["l", "y "], ["m", "e"], ["must ", "be "], ["co", "mp"], ["o", "u"], ["r", "an"], [", ", "and "], ["o", "s"], ["u", "c"], ["a", "p"], ["can", "not "], [" ", "with "], ["ac", "h "], [".\n", "\n("], ["cor", "rec"], ["us", "e "], ["ec", "t"], ["it", "y "], ["d", "o"], ["f", "f"], ["lo", "w"], ["wh", "ich "], ["s", "h"], ["k", " "], ["v", "er"], ["up", " "], ["in", " the "], ["p", "res"], ["A", " "], ["m", "ent"], ["t", "e"], ["ti", "c"], ["a", "y"], ["on", "e "], ["b", "u"], ["o", "w"], ["for", "e "], ["be", "fore "], ["o", "c"], ["i", "f"], ["o", "m"], ["er", "s "], ["A", ") "], ["on", " the "], ["h", "a"], ["i", "m"], ["ir", " "], ["es", ", "], ["en", "d"], ["ic", " "], ["co", "m"], ["n", "e"], ["f", "ro"], ["i", "d"], ["g", "ro"], ["vi", "d"], ["th", " "], ["s", "u"], ["l", " "], ["i", "t "], ["E", ") "], ["C", ") "], ["b", "ec"], ["op", "tion "], ["D", ") "], ["a", "use "], [";", " "], ["B", ") "], ["t", "er"], ["bec", "ause "], ["g", "i"], ["st", "r"], ["at", "es "], ["st", "ra"], ["a", "i"], ["\n(", "A) "], ["a", "k"], ["the", "ir "], ["o", "d"], ["e", "ach "], ["\n", "- "], [".\n", "- "], ["p", "ar"], ["en", "ce "], ["correc", "t "], [" ", "is "], [", ", "C"], ["to", " the "], ["en", "c"], ["a", "f"], ["r", "ul"], ["u", "m"], ["h", "as "], ["s", "t "], ["b", "o"], [" ", "and "], ["e", "ar"], ["ing", " the "], ["s", "c"], ["do", "es "], ["\n", "This "], ["ur", "e "], ["d", "er"], ["c", "an "], ["con", "stra"], ["m", "ent "], ["a", "y "], ["t", "s "], ["does ", "not "], ["w", " "], ["for", "m"], ["h", "er"], ["fro", "m "], ["vi", "ol"], ["i", "z"], ["d", "e"], ["I", "n"], ["ha", "ve "], ["w", "or"], ["tion", "s "], ["with", " the "], ["cannot ", "be "], ["as", "ed "], ["sp", "ec"], ["m", "in"], ["pres", "ent"], ["ic", "al "], ["2", " "], ["f", "ol"], ["i", "p"], ["o", "t"], ["s", "o "], ["b", "y "], [", ", "B"], ["u", "d"], ["t", "r"], ["le", "x"], ["f", "ir"], ["fol", "low"], ["th", "er "], ["en", "ti"], ["th", "er"], ["clu", "d"], ["ed ", "to "], ["ig", "h"], ["1", " "], ["as", "t "], ["e ", "the "], [", ", "D"], ["f", "i"], ["am", "e "], ["an", "d"], ["is ", "not "], ["ul", "d "], ["e", "g"], ["b", "le "], ["al", "y"], ["t", "er "], ["an", "c"], ["bu", "t "], ["th", "an "], ["d", "er "], ["re", "d "], ["n", " "], ["ran", "g"], ["s", "ec"], ["constra", "int"], ["C", "h"], ["es", "s"], [", ", "which "], ["s ", "and "], ["ac", "t"], ["p", "or"], ["e", "st"], ["an", "y "], ["se", "qu"], ["C", "lu"], ["di", "ff"], ["at", "t"], ["m", "or"], ["w", "e "], ["i", "st"], ["es", ":\n"], ["p", "h"], ["em", "ent "], ["i", "mp"], ["d", "uc"], ["es", "t "], ["at", "ed "], ["c", "l"], ["di", "tion"], ["m", "a"], ["am", " "], ["diff", "er"], ["pro", "vid"], ["ar", " "], ["el", " "], ["A", "lex"], ["\n", "\n"], ["all", " the "], ["3", " "], ["in", "clud"], ["ar", "i"], ["p", "re"], ["ad", "her"], ["t", "w"], [", ", "E"], ["A", "l"], ["clu", "e "], ["'", " "], ["con", "si"], ["viol", "ates "], ["con", "t"], ["com", "m"], ["w", "e"], ["es", "s "], ["si", "on"], ["an", "aly"], ["\n(", "C) "], ["s", "et"], ["0", " "], ["u", "t"], ["g", "n"], ["\n(", "B) "], ["\n(", "D) "], ["\n(", "E) "], ["al", "u"], ["u", "re"], ["er", ", "], ["differ", "ent "], ["ou", "t "], ["s", "ame "], ["en", "s"], ["l", "a"], ["i", "ti"], ["and ", "the "], ["ar", "rang"], ["all", " "], ["\n- ", "This "], ["S", "t"], ["mor", "e "], ["the", "y "], ["wi", "l"], ["p", "o"], ["e", "t "], ["in", "v"], ["o", "ur"], ["an", "ce "], ["e", "p"], ["si", "on "], ["ic", "e "], ["g", "h"], ["ti", "m"], ["ex", "t "], ["Clu", "es:\n"], ["al", "so "], ["p", "os"], ["o", "g"], ["s", "er"], ["Q", ": "], ["c", "ur"], [")", ", "], ["correc", "t"], ["b", "ased "], [".\n\n", "Clues:\n"], ["an", "g"], ["s ", "the "], ["d", "ay"], ["ent", "s "], ["ag", "e "], ["wil", "l "], ["i", "al "], ["s", "y"], ["n", "o "], ["ex", "per"], ["e", "y "], ["?", "\n(A) "], ["gi", "v"], ["a", "in"], ["i", "v"], ["for", " the "], ["int", "er"], ["ation", "s "], ["clu", "es "], ["ec", "h"], ["ch", "o"], ["as", "s"], ["ti", "v"], ["follow", "ing "], ["si", "gn"], ["en", "g"], ["re", "qu"], ["n", "ext "], ["en", "d "], ["d", "et"], ["wor", "k"], ["s", ". "], ["gro", "up"], ["on", "ly "], ["el", "y "], ["ec", "t "], ["gro", "up "], ["ar", "y "], ["con", "f"], ["tw", "o "], ["a", "re"], ["a", "v"], ["fir", "st "], ["f", "ri"], ["r", ". "], [":\n", "\n(A) "], ["it", "s "], ["I", "f"], ["an", "t "], ["ing ", "a "], ["s", ". The "], ["\n\n", "Q: "], ["st", "ud"], ["i", "es "], ["5", " "], ["me", "di"], ["f", "in"], ["af", "ter "], ["ai", "le"], ["le", "t"], ["o", "b"], [", ", "S"], ["ou", "s "], ["or", "der "], ["su", "b"], ["r", "on"], ["sequ", "ence "], ["ab", "le "], ["t", "ra"], ["si", "ble "], ["es", "e "], ["c", "re"], ["res", "p"], ["y", "p"], ["ev", "alu"], ["ti", "ve "], ["o", "un"], ["of the ", "following "], ["as", "e "], ["adher", "es "], ["o", "uld "], ["Th", "er"], ["spec", "if"], ["4", " "], ["ed", ", "], ["p", "a"], ["W", "h"], ["E", "ach "], ["re", "e "], ["co", "un"], ["L", "et"], ["a", ", "], ["ati", "s"]]} -------------------------------------------------------------------------------- /model_checkpoint/tokenizer_byte.json: -------------------------------------------------------------------------------- 1 | {"special_tokens": ["<|endoftext|>"], "token2id": {"<|endoftext|>": 0, "<0x00>": 1, "<0x01>": 2, "<0x02>": 3, "<0x03>": 4, "<0x04>": 5, "<0x05>": 6, "<0x06>": 7, "<0x07>": 8, "<0x08>": 9, "<0x09>": 10, "<0x0A>": 11, "<0x0B>": 12, "<0x0C>": 13, "<0x0D>": 14, "<0x0E>": 15, "<0x0F>": 16, "<0x10>": 17, "<0x11>": 18, "<0x12>": 19, "<0x13>": 20, "<0x14>": 21, "<0x15>": 22, "<0x16>": 23, "<0x17>": 24, "<0x18>": 25, "<0x19>": 26, "<0x1A>": 27, "<0x1B>": 28, "<0x1C>": 29, "<0x1D>": 30, "<0x1E>": 31, "<0x1F>": 32, "<0x20>": 33, "<0x21>": 34, "<0x22>": 35, "<0x23>": 36, "<0x24>": 37, "<0x25>": 38, "<0x26>": 39, "<0x27>": 40, "<0x28>": 41, "<0x29>": 42, "<0x2A>": 43, "<0x2B>": 44, "<0x2C>": 45, "<0x2D>": 46, "<0x2E>": 47, "<0x2F>": 48, "<0x30>": 49, "<0x31>": 50, "<0x32>": 51, "<0x33>": 52, "<0x34>": 53, "<0x35>": 54, "<0x36>": 55, "<0x37>": 56, "<0x38>": 57, "<0x39>": 58, "<0x3A>": 59, "<0x3B>": 60, "<0x3C>": 61, "<0x3D>": 62, "<0x3E>": 63, "<0x3F>": 64, "<0x40>": 65, "<0x41>": 66, "<0x42>": 67, "<0x43>": 68, "<0x44>": 69, "<0x45>": 70, "<0x46>": 71, "<0x47>": 72, "<0x48>": 73, "<0x49>": 74, "<0x4A>": 75, "<0x4B>": 76, "<0x4C>": 77, "<0x4D>": 78, "<0x4E>": 79, "<0x4F>": 80, "<0x50>": 81, "<0x51>": 82, "<0x52>": 83, "<0x53>": 84, "<0x54>": 85, "<0x55>": 86, "<0x56>": 87, "<0x57>": 88, "<0x58>": 89, "<0x59>": 90, "<0x5A>": 91, "<0x5B>": 92, "<0x5C>": 93, "<0x5D>": 94, "<0x5E>": 95, "<0x5F>": 96, "<0x60>": 97, "<0x61>": 98, "<0x62>": 99, "<0x63>": 100, "<0x64>": 101, "<0x65>": 102, "<0x66>": 103, "<0x67>": 104, "<0x68>": 105, "<0x69>": 106, "<0x6A>": 107, "<0x6B>": 108, "<0x6C>": 109, "<0x6D>": 110, "<0x6E>": 111, "<0x6F>": 112, "<0x70>": 113, "<0x71>": 114, "<0x72>": 115, "<0x73>": 116, "<0x74>": 117, "<0x75>": 118, "<0x76>": 119, "<0x77>": 120, "<0x78>": 121, "<0x79>": 122, "<0x7A>": 123, "<0x7B>": 124, "<0x7C>": 125, "<0x7D>": 126, "<0x7E>": 127, "<0x7F>": 128, "<0x80>": 129, "<0x81>": 130, "<0x82>": 131, "<0x83>": 132, "<0x84>": 133, "<0x85>": 134, "<0x86>": 135, "<0x87>": 136, "<0x88>": 137, "<0x89>": 138, "<0x8A>": 139, "<0x8B>": 140, "<0x8C>": 141, "<0x8D>": 142, "<0x8E>": 143, "<0x8F>": 144, "<0x90>": 145, "<0x91>": 146, "<0x92>": 147, "<0x93>": 148, "<0x94>": 149, "<0x95>": 150, "<0x96>": 151, "<0x97>": 152, "<0x98>": 153, "<0x99>": 154, "<0x9A>": 155, "<0x9B>": 156, "<0x9C>": 157, "<0x9D>": 158, "<0x9E>": 159, "<0x9F>": 160, "<0xA0>": 161, "<0xA1>": 162, "<0xA2>": 163, "<0xA3>": 164, "<0xA4>": 165, "<0xA5>": 166, "<0xA6>": 167, "<0xA7>": 168, "<0xA8>": 169, "<0xA9>": 170, "<0xAA>": 171, "<0xAB>": 172, "<0xAC>": 173, "<0xAD>": 174, "<0xAE>": 175, "<0xAF>": 176, "<0xB0>": 177, "<0xB1>": 178, "<0xB2>": 179, "<0xB3>": 180, "<0xB4>": 181, "<0xB5>": 182, "<0xB6>": 183, "<0xB7>": 184, "<0xB8>": 185, "<0xB9>": 186, "<0xBA>": 187, "<0xBB>": 188, "<0xBC>": 189, "<0xBD>": 190, "<0xBE>": 191, "<0xBF>": 192, "<0xC0>": 193, "<0xC1>": 194, "<0xC2>": 195, "<0xC3>": 196, "<0xC4>": 197, "<0xC5>": 198, "<0xC6>": 199, "<0xC7>": 200, "<0xC8>": 201, "<0xC9>": 202, "<0xCA>": 203, "<0xCB>": 204, "<0xCC>": 205, "<0xCD>": 206, "<0xCE>": 207, "<0xCF>": 208, "<0xD0>": 209, "<0xD1>": 210, "<0xD2>": 211, "<0xD3>": 212, "<0xD4>": 213, "<0xD5>": 214, "<0xD6>": 215, "<0xD7>": 216, "<0xD8>": 217, "<0xD9>": 218, "<0xDA>": 219, "<0xDB>": 220, "<0xDC>": 221, "<0xDD>": 222, "<0xDE>": 223, "<0xDF>": 224, "<0xE0>": 225, "<0xE1>": 226, "<0xE2>": 227, "<0xE3>": 228, "<0xE4>": 229, "<0xE5>": 230, "<0xE6>": 231, "<0xE7>": 232, "<0xE8>": 233, "<0xE9>": 234, "<0xEA>": 235, "<0xEB>": 236, "<0xEC>": 237, "<0xED>": 238, "<0xEE>": 239, "<0xEF>": 240, "<0xF0>": 241, "<0xF1>": 242, "<0xF2>": 243, "<0xF3>": 244, "<0xF4>": 245, "<0xF5>": 246, "<0xF6>": 247, "<0xF7>": 248, "<0xF8>": 249, "<0xF9>": 250, "<0xFA>": 251, "<0xFB>": 252, "<0xFC>": 253, "<0xFD>": 254, "<0xFE>": 255, "<0xFF>": 256, "<0x65><0x20>": 257, "<0x69><0x6E>": 258}, "merges": [["<0x65>", "<0x20>"], ["<0x69>", "<0x6E>"]]} -------------------------------------------------------------------------------- /notebooks/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/zimin_timur/miniconda3/envs/timmzimm_nlp/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "Available splits: dict_keys(['creative_content', 'text_modification', 'struct2text_flow', 'rc', 'rag', 'text_extraction', 'mcq', 'follow_up', 'analytical_reasoning', 'fermi', 'fs_cot_flow', 'code_', 'brain_teaser', 'text_classification', 'open_domain_qa'])\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "from datasets import load_dataset\n", 26 | "dataset_name = \"microsoft/orca-agentinstruct-1M-v1\"\n", 27 | "dataset_info = load_dataset(dataset_name, split=None)\n", 28 | "print(\"Available splits:\", dataset_info.keys())" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "timmzimm_nlp", 42 | "language": "python", 43 | "name": "python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.11.10" 56 | } 57 | }, 58 | "nbformat": 4, 59 | "nbformat_minor": 2 60 | } 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.11.0 3 | aiosignal==1.3.1 4 | asttokens==2.4.1 5 | attrs==24.2.0 6 | bzip2==1.0.8 7 | ca-certificates==2024.9.24 8 | certifi==2024.8.30 9 | charset-normalizer==3.4.0 10 | comm==0.2.2 11 | datasets==3.1.0 12 | debugpy==1.8.7 13 | decorator==5.1.1 14 | dill==0.3.8 15 | exceptiongroup==1.2.2 16 | executing==2.1.0 17 | filelock==3.16.1 18 | frozenlist==1.5.0 19 | fsspec==2024.9.0 20 | huggingface-hub==0.26.2 21 | idna==3.10 22 | importlib-metadata==8.5.0 23 | ipykernel==6.29.5 24 | ipython==8.29.0 25 | jedi==0.19.1 26 | jinja2==3.1.3 27 | joblib==1.4.2 28 | jupyter_client==8.6.3 29 | jupyter_core==5.7.2 30 | keyutils==1.6.1 31 | krb5==1.21.3 32 | ld_impl_linux-64==2.43 33 | libedit==3.1.20191231 34 | libexpat==2.6.4 35 | libffi==3.4.4 36 | libgcc==14.2.0 37 | libgcc-ng==14.2.0 38 | libgomp==14.2.0 39 | libnsl==2.0.1 40 | libsodium==1.0.20 41 | libsqlite==3.47.0 42 | libstdcxx==14.2.0 43 | libstdcxx-ng==14.2.0 44 | libuuid==2.38.1 45 | libxcrypt==4.4.36 46 | libzlib==1.3.1 47 | markupsafe==2.1.5 48 | matplotlib-inline==0.1.7 49 | mpmath==1.3.0 50 | multidict==6.1.0 51 | multiprocess==0.70.16 52 | ncurses==6.5 53 | nest-asyncio==1.6.0 54 | networkx==3.2.1 55 | numpy==2.1.3 56 | nvidia-cublas-cu12==12.1.3.1 57 | nvidia-cuda-cupti-cu12==12.1.105 58 | nvidia-cuda-nvrtc-cu12==12.1.105 59 | nvidia-cuda-runtime-cu12==12.1.105 60 | nvidia-cudnn-cu12==8.9.2.26 61 | nvidia-cufft-cu12==11.0.2.54 62 | nvidia-curand-cu12==10.3.2.106 63 | nvidia-cusolver-cu12==11.4.5.107 64 | nvidia-cusparse-cu12==12.1.0.106 65 | nvidia-nccl-cu12==2.20.5 66 | nvidia-nvjitlink-cu12==12.1.105 67 | nvidia-nvtx-cu12==12.1.105 68 | openssl==3.3.2 69 | packaging==24.1 70 | pandas==2.2.3 71 | parso==0.8.4 72 | pexpect==4.9.0 73 | pickleshare==0.7.5 74 | pillow==10.2.0 75 | pip==24.3.1 76 | platformdirs==4.3.6 77 | prompt-toolkit==3.0.48 78 | propcache==0.2.0 79 | protobuf==5.28.3 80 | psutil==6.1.0 81 | ptyprocess==0.7.0 82 | pure_eval==0.2.3 83 | pyarrow==18.0.0 84 | pygments==2.18.0 85 | python==3.11.10 86 | python-dateutil==2.9.0 87 | python_abi==3.11 88 | pytz==2024.2 89 | pyyaml==6.0.2 90 | pyzmq==26.2.0 91 | readline==8.2 92 | regex==2024.11.6 93 | requests==2.32.3 94 | safetensors==0.4.5 95 | scikit-learn==1.5.2 96 | scipy==1.14.1 97 | sentencepiece==0.2.0 98 | setuptools==75.3.0 99 | six==1.16.0 100 | sqlite==3.47.0 101 | stack_data==0.6.2 102 | sympy==1.13.1 103 | threadpoolctl==3.5.0 104 | tiktoken==0.8.0 105 | tk==8.6.13 106 | tokenizers==0.20.3 107 | torch==2.3.1+cu121 108 | torchaudio==2.3.1+cu121 109 | torchvision==0.18.1+cu121 110 | tornado==6.4.1 111 | tqdm==4.67.0 112 | traitlets==5.14.3 113 | transformers==4.46.2 114 | triton==2.3.1 115 | typing_extensions==4.12.2 116 | tzdata==2024.2 117 | urllib3==2.2.3 118 | wcwidth==0.2.13 119 | wheel==0.44.0 120 | xxhash==3.5.0 121 | xz==5.4.6 122 | yarl==1.17.1 123 | zeromq==4.3.5 124 | zipp==3.20.2 125 | zlib==1.3.1 126 | 127 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from torch.utils.data import Dataset 4 | from typing import List 5 | 6 | class TextDataset(Dataset): 7 | """ 8 | A dataset that splits a long sequence of token IDs into smaller blocks of fixed length. 9 | 10 | For each block of size N, the inputs are the first N tokens, and the targets are 11 | the next N tokens (shifted by one). 12 | """ 13 | def __init__(self, token_ids: List[int], block_size: int): 14 | self.block_size = block_size 15 | self.data = torch.tensor(token_ids, dtype=torch.long) 16 | self.num_sequences = max((len(self.data) - 1) // block_size, 0) 17 | 18 | def __len__(self): 19 | return self.num_sequences 20 | 21 | def __getitem__(self, idx): 22 | start = idx * self.block_size 23 | end = start + self.block_size 24 | x = self.data[start:end] 25 | y = self.data[start+1:end+1] 26 | return x, y 27 | 28 | def extract_texts(dataset) -> List[str]: 29 | """ 30 | Extracts all user-assistant interaction texts from the given dataset split. 31 | Each extracted text combines the user's instruction and the assistant's response. 32 | """ 33 | texts = [] 34 | for ex in dataset: 35 | if 'messages' in ex: 36 | try: 37 | messages = json.loads(ex['messages']) 38 | instruction = next((m['content'] for m in messages if m.get('role')=='user'), "") 39 | output = next((m['content'] for m in messages if m.get('role')=='assistant'), "") 40 | text = f"{instruction} {output}".strip() 41 | if text: 42 | texts.append(text) 43 | except json.JSONDecodeError: 44 | # Skip if messages can't be decoded 45 | pass 46 | return texts 47 | 48 | -------------------------------------------------------------------------------- /src/data/load_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from datasets import load_dataset, concatenate_datasets 4 | 5 | def load_and_merge_splits(config_path: str): 6 | """ 7 | Loads the orca-agentinstruct dataset from Hugging Face. 8 | Splits each subset (from 'all_splits') into train/val/test 9 | in the ratio (train_ratio, val_ratio, test_ratio), 10 | then concatenates them across all subsets. 11 | 12 | Returns: 13 | train_ds, val_ds, test_ds (huggingface Datasets) 14 | """ 15 | with open(config_path, "r") as f: 16 | dataset_config = json.load(f) 17 | 18 | dataset_name = dataset_config["dataset_name"] 19 | all_splits = dataset_config["all_splits"] 20 | train_ratio = dataset_config["train_ratio"] 21 | val_ratio = dataset_config["val_ratio"] 22 | test_ratio = dataset_config["test_ratio"] 23 | 24 | # Sanity check 25 | if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-7: 26 | raise ValueError("Sum of train_ratio, val_ratio, and test_ratio must be 1.0") 27 | 28 | # Load the entire dataset 29 | print(f"Loading entire dataset: {dataset_name}") 30 | dataset = load_dataset(dataset_name) 31 | 32 | train_sets = [] 33 | val_sets = [] 34 | test_sets = [] 35 | 36 | for split_name in all_splits: 37 | if split_name not in dataset: 38 | print(f"Warning: split '{split_name}' not found. Skipping.") 39 | continue 40 | 41 | ds_current = dataset[split_name] 42 | # 1) first "cut" test portion 43 | test_split = ds_current.train_test_split(test_size=test_ratio, seed=42) 44 | test_ds = test_split["test"] 45 | rest_ds = test_split["train"] 46 | 47 | # 2) now from rest, cut val portion 48 | # val_ratio is fraction of total, so relative fraction from rest is val_ratio / (train_ratio + val_ratio) 49 | total_tv = train_ratio + val_ratio 50 | if abs(total_tv) < 1e-9: 51 | raise ValueError("train_ratio + val_ratio cannot be 0.") 52 | relative_val_ratio = val_ratio / total_tv 53 | 54 | tv_split = rest_ds.train_test_split(test_size=relative_val_ratio, seed=42) 55 | cur_train = tv_split["train"] 56 | cur_val = tv_split["test"] 57 | 58 | train_sets.append(cur_train) 59 | val_sets.append(cur_val) 60 | test_sets.append(test_ds) 61 | 62 | if not train_sets: 63 | raise ValueError("No valid splits found. Check your all_splits in dataset_config.") 64 | 65 | # Combine 66 | train_ds = concatenate_datasets(train_sets) 67 | val_ds = concatenate_datasets(val_sets) 68 | test_ds = concatenate_datasets(test_sets) 69 | 70 | print(f"Final dataset sizes: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}") 71 | return train_ds, val_ds, test_ds 72 | -------------------------------------------------------------------------------- /src/model/gpt2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class GPT2Config: 6 | """ 7 | Configuration class for the GPT-2 style model. 8 | Defines model hyperparameters. 9 | """ 10 | def __init__(self, 11 | vocab_size: int, 12 | n_ctx: int = 1024, 13 | n_embd: int = 768, 14 | n_layer: int = 12, 15 | n_head: int = 12, 16 | embedding_dropout: float = 0.1, 17 | resid_dropout: float = 0.1, 18 | attn_dropout: float = 0.1): 19 | self.vocab_size = vocab_size 20 | self.n_ctx = n_ctx 21 | self.n_embd = n_embd 22 | self.n_layer = n_layer 23 | self.n_head = n_head 24 | self.embedding_dropout = embedding_dropout 25 | self.resid_dropout = resid_dropout 26 | self.attn_dropout = attn_dropout 27 | 28 | class CausalSelfAttention(nn.Module): 29 | """ 30 | Causal self-attention mechanism for GPT-2 style models. 31 | Uses a lower-triangular mask to ensure causal (autoregressive) behavior. 32 | """ 33 | def __init__(self, config: GPT2Config): 34 | super().__init__() 35 | assert config.n_embd % config.n_head == 0 36 | self.n_head = config.n_head 37 | self.d_head = config.n_embd // config.n_head 38 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 39 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 40 | self.register_buffer("bias", torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1,1,config.n_ctx,config.n_ctx)) 41 | self.attn_dropout = nn.Dropout(config.attn_dropout) 42 | self.resid_dropout = nn.Dropout(config.resid_dropout) 43 | 44 | def forward(self, x): 45 | B,T,C = x.size() 46 | qkv = self.c_attn(x) 47 | q, k, v = qkv.split(C, dim=2) 48 | q = q.view(B,T,self.n_head,self.d_head).transpose(1,2) 49 | k = k.view(B,T,self.n_head,self.d_head).transpose(1,2) 50 | v = v.view(B,T,self.n_head,self.d_head).transpose(1,2) 51 | 52 | att = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_head) 53 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 54 | att = torch.softmax(att, dim=-1) 55 | att = self.attn_dropout(att) 56 | y = att @ v 57 | y = y.transpose(1,2).contiguous().view(B,T,C) 58 | y = self.resid_dropout(self.c_proj(y)) 59 | return y 60 | 61 | class MLP(nn.Module): 62 | """ 63 | Multi-layer perceptron block in the GPT-2 model. 64 | Expands the embedding dimension, applies a non-linearity, and projects back. 65 | """ 66 | def __init__(self, config: GPT2Config): 67 | super().__init__() 68 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) 69 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) 70 | self.act = nn.GELU() 71 | self.dropout = nn.Dropout(config.resid_dropout) 72 | 73 | def forward(self, x): 74 | x = self.c_fc(x) 75 | x = self.act(x) 76 | x = self.c_proj(x) 77 | x = self.dropout(x) 78 | return x 79 | 80 | class Block(nn.Module): 81 | """ 82 | A single Transformer block consisting of: 83 | - LayerNorm 84 | - Causal self-attention 85 | - Another LayerNorm 86 | - MLP 87 | Residual connections are applied around both the attention and MLP sub-blocks. 88 | """ 89 | def __init__(self, config: GPT2Config): 90 | super().__init__() 91 | self.ln_1 = nn.LayerNorm(config.n_embd, eps=1e-5) 92 | self.attn = CausalSelfAttention(config) 93 | self.ln_2 = nn.LayerNorm(config.n_embd, eps=1e-5) 94 | self.mlp = MLP(config) 95 | 96 | def forward(self, x): 97 | x = x + self.attn(self.ln_1(x)) 98 | x = x + self.mlp(self.ln_2(x)) 99 | return x 100 | 101 | class GPT2Model(nn.Module): 102 | """ 103 | GPT-2 style model: 104 | - Token embedding + positional embedding 105 | - N Transformer blocks 106 | - Layer norm + linear head at the end 107 | """ 108 | def __init__(self, config: GPT2Config): 109 | super().__init__() 110 | self.config = config 111 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 112 | self.wpe = nn.Embedding(config.n_ctx, config.n_embd) 113 | self.drop = nn.Dropout(config.embedding_dropout) 114 | self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) 115 | self.ln_f = nn.LayerNorm(config.n_embd, eps=1e-5) 116 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 117 | 118 | self.apply(self._init_weights) 119 | 120 | def _init_weights(self, module): 121 | """ 122 | Initialize weights in a manner similar to the original GPT-2 initialization. 123 | """ 124 | if isinstance(module, (nn.Linear, nn.Embedding)): 125 | nn.init.normal_(module.weight, mean=0.0, std=0.02) 126 | if isinstance(module, nn.Linear) and module.bias is not None: 127 | nn.init.zeros_(module.bias) 128 | 129 | def forward(self, idx, targets=None): 130 | """ 131 | Forward pass of the GPT-2 model. 132 | 133 | Args: 134 | idx: (B,T) tensor of token IDs 135 | targets: (B,T) tensor of token IDs to compute cross-entropy loss 136 | 137 | Returns: 138 | logits: (B,T,vocab_size) predictions of the next token 139 | loss: scalar cross-entropy loss if targets are provided 140 | """ 141 | B,T = idx.size() 142 | if T > self.config.n_ctx: 143 | raise ValueError("Sequence length exceeds model context length") 144 | pos = torch.arange(0,T,dtype=torch.long,device=idx.device).unsqueeze(0) 145 | x = self.wte(idx) + self.wpe(pos) 146 | x = self.drop(x) 147 | 148 | for block in self.h: 149 | x = block(x) 150 | 151 | x = self.ln_f(x) 152 | logits = self.lm_head(x) 153 | 154 | loss = None 155 | if targets is not None: 156 | loss_fct = nn.CrossEntropyLoss() 157 | loss = loss_fct(logits.view(-1, self.config.vocab_size), targets.view(-1)) 158 | 159 | return logits, loss 160 | 161 | -------------------------------------------------------------------------------- /src/tokenization/bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | from collections import Counter 4 | from tqdm import tqdm 5 | 6 | class BPEVocabulary: 7 | """ 8 | Maintains the vocabulary for BPE tokenization, including: 9 | - A mapping from token to ID and ID to token. 10 | - A list of merges (pairs of tokens combined into a new token). 11 | """ 12 | def __init__(self, special_tokens: List[str] = ["<|endoftext|>"]): 13 | self.special_tokens = special_tokens 14 | self.token2id = {} 15 | self.id2token = {} 16 | self.merges = [] 17 | self.vocab_size = 0 18 | 19 | def build_initial_vocab(self, texts: List[str]): 20 | """ 21 | Builds the initial character-level vocabulary from the given texts. 22 | Special tokens are also added at the beginning. 23 | """ 24 | char_counter = Counter() 25 | for text in texts: 26 | char_counter.update(list(text)) 27 | chars = sorted(char_counter.keys()) 28 | idx = 0 29 | for st in self.special_tokens: 30 | self.token2id[st] = idx 31 | idx += 1 32 | for ch in chars: 33 | if ch not in self.token2id: 34 | self.token2id[ch] = idx 35 | idx += 1 36 | self.id2token = {v:k for k,v in self.token2id.items()} 37 | self.vocab_size = len(self.id2token) 38 | 39 | def add_merge(self, new_token: str, t1: str, t2: str): 40 | """ 41 | Adds a new merged token to the vocabulary. 42 | """ 43 | if new_token not in self.token2id: 44 | idx = len(self.token2id) 45 | self.token2id[new_token] = idx 46 | self.id2token[idx] = new_token 47 | self.vocab_size += 1 48 | self.merges.append((t1, t2)) 49 | 50 | def encode_tokens_to_ids(self, tokens: List[str]) -> List[int]: 51 | """ 52 | Encodes a list of tokens into their corresponding IDs. 53 | """ 54 | return [self.token2id[t] for t in tokens if t in self.token2id] 55 | 56 | def decode_ids_to_tokens(self, ids: List[int]) -> List[str]: 57 | """ 58 | Decodes a list of token IDs back into tokens. 59 | """ 60 | return [self.id2token[i] for i in ids] 61 | 62 | class BPETokenizer: 63 | """ 64 | A simplified BPE tokenizer (not byte-level) that: 65 | - Learns merges from a corpus. 66 | - Encodes and decodes text using the learned merges. 67 | """ 68 | def __init__(self, special_tokens=["<|endoftext|>"], vocab_size_limit=30000, merges_count=10000): 69 | self.vocab = BPEVocabulary(special_tokens) 70 | self.vocab_size_limit = vocab_size_limit 71 | self.merges_count = merges_count 72 | 73 | def _get_token_sequences(self, texts: List[str]) -> List[List[str]]: 74 | """ 75 | Converts each text into a list of characters plus the end-of-text token. 76 | """ 77 | sequences = [] 78 | for t in texts: 79 | seq = list(t) 80 | seq.append("<|endoftext|>") 81 | sequences.append(seq) 82 | return sequences 83 | 84 | def train(self, texts: List[str]): 85 | """ 86 | Trains the BPE tokenizer on the given corpus. 87 | This involves iteratively merging the most frequent pairs of tokens. 88 | """ 89 | print("Training BPE tokenizer on the corpus...") 90 | self.vocab.build_initial_vocab(texts) 91 | sequences = self._get_token_sequences(texts) 92 | 93 | for i in tqdm(range(self.merges_count), desc="BPE merges"): 94 | pair_counts = Counter() 95 | # Count pair frequencies 96 | for seq in sequences: 97 | for j in range(len(seq)-1): 98 | pair = (seq[j], seq[j+1]) 99 | pair_counts[pair] += 1 100 | 101 | if not pair_counts: 102 | break 103 | best_pair, best_count = pair_counts.most_common(1)[0] 104 | new_token = best_pair[0] + best_pair[1] 105 | self.vocab.add_merge(new_token, best_pair[0], best_pair[1]) 106 | 107 | # Replace all occurrences of the best pair with the new token 108 | new_sequences = [] 109 | for seq in sequences: 110 | j = 0 111 | new_seq = [] 112 | while j < len(seq): 113 | if j < len(seq)-1 and (seq[j], seq[j+1]) == best_pair: 114 | new_seq.append(new_token) 115 | j += 2 116 | else: 117 | new_seq.append(seq[j]) 118 | j += 1 119 | new_sequences.append(new_seq) 120 | sequences = new_sequences 121 | 122 | if self.vocab.vocab_size >= self.vocab_size_limit: 123 | print("Reached vocab size limit.") 124 | break 125 | 126 | print("BPE training complete. Final vocab size:", self.vocab.vocab_size) 127 | 128 | def encode(self, text: str) -> List[int]: 129 | """ 130 | Encodes a given text into a sequence of token IDs using the learned BPE merges. 131 | """ 132 | chars = list(text) + ["<|endoftext|>"] 133 | merges_set = set(self.vocab.merges) 134 | 135 | # Apply merges until no more merges can be made 136 | while True: 137 | merged = False 138 | new_seq = [] 139 | j = 0 140 | while j < len(chars): 141 | if j < len(chars)-1 and (chars[j], chars[j+1]) in merges_set: 142 | new_token = chars[j] + chars[j+1] 143 | new_seq.append(new_token) 144 | j += 2 145 | merged = True 146 | else: 147 | new_seq.append(chars[j]) 148 | j += 1 149 | chars = new_seq 150 | if not merged: 151 | break 152 | 153 | return self.vocab.encode_tokens_to_ids(chars) 154 | 155 | def decode(self, ids: List[int]) -> str: 156 | """ 157 | Decodes a sequence of token IDs back into a text string. 158 | """ 159 | tokens = self.vocab.decode_ids_to_tokens(ids) 160 | if tokens and tokens[-1] == "<|endoftext|>": 161 | tokens = tokens[:-1] 162 | return "".join(tokens) -------------------------------------------------------------------------------- /src/tokenization/bytelevel_bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | from collections import Counter 4 | from tqdm import tqdm 5 | 6 | class ByteLevelBPEVocabulary: 7 | """ 8 | Maintains the vocabulary for Byte-Level BPE tokenization, 9 | including mappings from token to ID, ID to token, and merges. 10 | """ 11 | def __init__(self, special_tokens: List[str] = ["<|endoftext|>"]): 12 | self.special_tokens = special_tokens 13 | self.token2id = {} 14 | self.id2token = {} 15 | self.merges = [] 16 | self.vocab_size = 0 17 | 18 | def build_initial_vocab(self): 19 | """ 20 | Build an initial byte-level vocab: 256 possible bytes plus special tokens. 21 | """ 22 | idx = 0 23 | # Add special tokens first 24 | for st in self.special_tokens: 25 | self.token2id[st] = idx 26 | idx += 1 27 | 28 | # Add all 256 bytes 29 | for b in range(256): 30 | byte_char = f"<0x{b:02X}>" # a representation for each byte 31 | self.token2id[byte_char] = idx 32 | idx += 1 33 | 34 | self.id2token = {v: k for k, v in self.token2id.items()} 35 | self.vocab_size = len(self.id2token) 36 | 37 | def add_merge(self, new_token: str, t1: str, t2: str): 38 | """ 39 | Adds a new merged token to the vocabulary. 40 | """ 41 | if new_token not in self.token2id: 42 | idx = len(self.token2id) 43 | self.token2id[new_token] = idx 44 | self.id2token[idx] = new_token 45 | self.vocab_size += 1 46 | self.merges.append((t1, t2)) 47 | 48 | def encode_tokens_to_ids(self, tokens: List[str]) -> List[int]: 49 | """ 50 | Encodes a list of tokens into IDs. 51 | """ 52 | return [self.token2id[t] for t in tokens if t in self.token2id] 53 | 54 | def decode_ids_to_tokens(self, ids: List[int]) -> List[str]: 55 | """ 56 | Decodes token IDs to the corresponding tokens. 57 | """ 58 | return [self.id2token[i] for i in ids] 59 | 60 | class ByteLevelBPETokenizer: 61 | """ 62 | A Byte-Level BPE tokenizer: 63 | - Processes text at the byte level. 64 | - Learns merges similarly to a standard BPE approach. 65 | """ 66 | def __init__(self, special_tokens=["<|endoftext|>"], vocab_size_limit=30000, merges_count=10000): 67 | self.vocab = ByteLevelBPEVocabulary(special_tokens) 68 | self.vocab_size_limit = vocab_size_limit 69 | self.merges_count = merges_count 70 | 71 | def _text_to_bytes(self, text: str) -> List[str]: 72 | """ 73 | Converts a string into a list of byte-representations, plus end-of-text token. 74 | """ 75 | byte_list = [] 76 | for b in text.encode("utf-8"): 77 | byte_list.append(f"<0x{b:02X}>") 78 | byte_list.append("<|endoftext|>") 79 | return byte_list 80 | 81 | def train(self, texts: List[str]): 82 | """ 83 | Train the byte-level BPE tokenizer on the given corpus. 84 | """ 85 | print("Training Byte-Level BPE tokenizer on the corpus...") 86 | # Build the initial vocab (256 bytes + special tokens) 87 | self.vocab.build_initial_vocab() 88 | 89 | # Convert texts to sequences of byte-level tokens 90 | sequences = [self._text_to_bytes(t) for t in texts] 91 | 92 | for i in tqdm(range(self.merges_count), desc="Byte BPE merges"): 93 | pair_counts = Counter() 94 | # Count pair frequencies 95 | for seq in sequences: 96 | for j in range(len(seq)-1): 97 | pair = (seq[j], seq[j+1]) 98 | pair_counts[pair] += 1 99 | 100 | if not pair_counts: 101 | break 102 | best_pair, best_count = pair_counts.most_common(1)[0] 103 | new_token = best_pair[0] + best_pair[1] 104 | self.vocab.add_merge(new_token, best_pair[0], best_pair[1]) 105 | 106 | # Replace all occurrences of the best pair with the new token 107 | new_sequences = [] 108 | for seq in sequences: 109 | j = 0 110 | new_seq = [] 111 | while j < len(seq): 112 | if j < len(seq)-1 and (seq[j], seq[j+1]) == best_pair: 113 | new_seq.append(new_token) 114 | j += 2 115 | else: 116 | new_seq.append(seq[j]) 117 | j += 1 118 | new_sequences.append(new_seq) 119 | sequences = new_sequences 120 | 121 | if self.vocab.vocab_size >= self.vocab_size_limit: 122 | print("Reached vocab size limit.") 123 | break 124 | 125 | print("Byte-Level BPE training complete. Final vocab size:", self.vocab.vocab_size) 126 | 127 | def encode(self, text: str) -> List[int]: 128 | """ 129 | Encodes a given text into a sequence of byte-level BPE token IDs. 130 | """ 131 | tokens = self._text_to_bytes(text) 132 | merges_set = set(self.vocab.merges) 133 | 134 | while True: 135 | merged = False 136 | new_seq = [] 137 | j = 0 138 | while j < len(tokens): 139 | if j < len(tokens)-1 and (tokens[j], tokens[j+1]) in merges_set: 140 | new_token = tokens[j] + tokens[j+1] 141 | new_seq.append(new_token) 142 | j += 2 143 | merged = True 144 | else: 145 | new_seq.append(tokens[j]) 146 | j += 1 147 | tokens = new_seq 148 | if not merged: 149 | break 150 | 151 | return self.vocab.encode_tokens_to_ids(tokens) 152 | 153 | def decode(self, ids: List[int]) -> str: 154 | """ 155 | Decodes a sequence of byte-level BPE IDs back into a text string. 156 | """ 157 | tokens = self.vocab.decode_ids_to_tokens(ids) 158 | # Remove trailing end-of-text if present 159 | if tokens and tokens[-1] == "<|endoftext|>": 160 | tokens = tokens[:-1] 161 | 162 | # Reconstruct text from byte tokens 163 | byte_values = [] 164 | for tok in tokens: 165 | if tok.startswith("<0x") and tok.endswith(">"): 166 | # convert hex to int then to bytes 167 | hex_val = tok[3:-1] # e.g., "FF" 168 | byte_values.append(int(hex_val, 16)) 169 | 170 | return bytearray(byte_values).decode("utf-8", errors="replace") 171 | -------------------------------------------------------------------------------- /src/training/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.distributed import DistributedSampler 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | 10 | from transformers import GPT2TokenizerFast 11 | from src.tokenization.bpe_tokenizer import BPETokenizer 12 | from src.tokenization.bytelevel_bpe_tokenizer import ByteLevelBPETokenizer 13 | from src.data.dataset import TextDataset, extract_texts 14 | from src.data.load_data import load_and_merge_splits 15 | from src.model.gpt2 import GPT2Config, GPT2Model 16 | from src.training.train import train 17 | 18 | def run_training_pipeline_single(train_config): 19 | """ 20 | Single-GPU or CPU training pipeline. 21 | """ 22 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 23 | local_rank = 0 24 | _run_pipeline_core(train_config, device, local_rank, ddp=False) 25 | 26 | def run_training_pipeline_ddp(train_config, local_rank): 27 | """ 28 | Multi-GPU DDP training pipeline for a single process (GPU). 29 | local_rank is set by PyTorch (0..N-1). 30 | """ 31 | gpu_ids = train_config["gpu_ids"] 32 | device_id = gpu_ids[local_rank] 33 | device = torch.device(f"cuda:{device_id}") 34 | torch.cuda.set_device(device) 35 | 36 | _run_pipeline_core(train_config, device, local_rank, ddp=True) 37 | 38 | def _run_pipeline_core(train_config, device, local_rank, ddp=False): 39 | """ 40 | Core logic: 41 | 1) Load dataset config and merge all splits. 42 | 2) Extract texts from train/val (test is optional). 43 | 3) Tokenize texts (choose huggingface/char/byte). 44 | 4) Create DataLoader (distributed if ddp). 45 | 5) Create GPT-2 model, DDP if needed. 46 | 6) Run training. 47 | 7) Save model (on rank=0). 48 | """ 49 | 50 | # 1) Load & merge splits 51 | train_ds_full, val_ds_full, test_ds_full = load_and_merge_splits("config/dataset_config.json") 52 | train_texts = extract_texts(train_ds_full) 53 | val_texts = extract_texts(val_ds_full) 54 | 55 | # Limit data for quick test (optional) 56 | max_train = 1000 # or any number you want 57 | max_val = 200 58 | train_texts = train_texts[:max_train] 59 | val_texts = val_texts[:max_val] 60 | 61 | # 2) Initialize tokenizer 62 | tokenization_type = train_config["tokenization_type"] 63 | if tokenization_type == "huggingface": 64 | tokenizer = GPT2TokenizerFast.from_pretrained(train_config["hf_tokenizer_name"]) 65 | vocab_size = tokenizer.vocab_size 66 | 67 | def encode_func(txt): 68 | return tokenizer.encode(txt) 69 | 70 | elif tokenization_type == "char": 71 | tokenizer = BPETokenizer( 72 | special_tokens=train_config["special_tokens"], 73 | vocab_size_limit=train_config["vocab_size_limit"], 74 | merges_count=train_config["merges_count"] 75 | ) 76 | tokenizer.train(train_texts) 77 | vocab_size = tokenizer.vocab.vocab_size 78 | 79 | def encode_func(txt): 80 | return tokenizer.encode(txt) 81 | 82 | elif tokenization_type == "byte": 83 | tokenizer = ByteLevelBPETokenizer( 84 | special_tokens=train_config["special_tokens"], 85 | vocab_size_limit=train_config["vocab_size_limit"], 86 | merges_count=train_config["merges_count"] 87 | ) 88 | tokenizer.train(train_texts) 89 | vocab_size = tokenizer.vocab.vocab_size 90 | 91 | def encode_func(txt): 92 | return tokenizer.encode(txt) 93 | else: 94 | raise ValueError(f"Unknown tokenization_type: {tokenization_type}") 95 | 96 | # 3) Encode data with tqdm progress 97 | 98 | train_ids = [] 99 | for t in tqdm(train_texts, desc="Encoding train texts"): 100 | train_ids.extend(encode_func(t)) 101 | 102 | val_ids = [] 103 | for t in tqdm(val_texts, desc="Encoding val texts"): 104 | val_ids.extend(encode_func(t)) 105 | 106 | # 4) Create PyTorch datasets & (optionally) distributed sampler 107 | block_size = train_config["block_size"] 108 | train_dataset = TextDataset(train_ids, block_size) 109 | val_dataset = TextDataset(val_ids, block_size) 110 | 111 | if ddp: 112 | train_sampler = DistributedSampler(train_dataset) 113 | val_sampler = DistributedSampler(val_dataset, shuffle=False) 114 | train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config["batch_size"]) 115 | val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=train_config["batch_size"]) 116 | else: 117 | train_loader = DataLoader(train_dataset, batch_size=train_config["batch_size"], shuffle=True) 118 | val_loader = DataLoader(val_dataset, batch_size=train_config["batch_size"], shuffle=False) 119 | 120 | # 5) Build model 121 | with open("config/model_config.json", "r") as f: 122 | model_config_data = json.load(f) 123 | 124 | gpt2_cfg = GPT2Config( 125 | vocab_size=vocab_size, 126 | n_ctx=model_config_data["n_ctx"], 127 | n_embd=model_config_data["n_embd"], 128 | n_layer=model_config_data["n_layer"], 129 | n_head=model_config_data["n_head"] 130 | ) 131 | model = GPT2Model(gpt2_cfg).to(device) 132 | 133 | if ddp: 134 | model = DDP(model, device_ids=[device.index], output_device=device.index) 135 | 136 | # 6) Train 137 | train(model, train_loader, val_loader, device, train_config) 138 | 139 | # 7) Save model on rank=0 140 | if (not ddp) or (ddp and local_rank == 0): 141 | _save_model_and_tokenizer(model, tokenizer, tokenization_type) 142 | 143 | def _save_model_and_tokenizer(model, tokenizer, tokenization_type): 144 | """ 145 | Saves the model checkpoint and tokenizer (if any). 146 | """ 147 | os.makedirs("model_checkpoint", exist_ok=True) 148 | torch.save(model.state_dict(), "model_checkpoint/model.pt") 149 | 150 | if tokenization_type == "huggingface": 151 | # GPT2TokenizerFast can save_pretrained 152 | tokenizer.save_pretrained("model_checkpoint/hf_tokenizer") 153 | else: 154 | # For char/byte custom BPE 155 | # If you want merges & token2id, add them here 156 | pass 157 | 158 | print("[Pipeline] Model saved.") 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from tqdm import tqdm 4 | from src.utils.evaluation import evaluate 5 | import math 6 | 7 | def train(model, train_loader, val_loader, device, train_config): 8 | """ 9 | Training loop for GPT-2 with linear warmup + cosine decay. 10 | 11 | Args: 12 | model: GPT-2 model 13 | train_loader: DataLoader for training 14 | val_loader: DataLoader for validation (can be empty) 15 | device: 'cuda' or 'cpu' 16 | train_config: dictionary with training parameters, e.g.: 17 | { 18 | "warmup_epochs": 1, 19 | "cosine_epochs": 5, 20 | "learning_rate": 0.0003, 21 | "eval_every_steps": 1000, 22 | "block_size": 1024, 23 | "batch_size": 2, 24 | ... 25 | } 26 | """ 27 | warmup_epochs = train_config["warmup_epochs"] 28 | cosine_epochs = train_config["cosine_epochs"] 29 | lr = train_config["learning_rate"] 30 | eval_every_steps = train_config["eval_every_steps"] 31 | 32 | epochs = warmup_epochs + cosine_epochs 33 | optimizer = optim.AdamW(model.parameters(), lr=lr) 34 | 35 | def get_lr_for_epoch(epoch): 36 | """ 37 | Compute learning rate for the given epoch 38 | based on linear warmup and cosine decay. 39 | """ 40 | if epoch < warmup_epochs: 41 | # Linear warmup from 0 to lr 42 | return lr * float(epoch + 1) / float(warmup_epochs) 43 | else: 44 | # Cosine decay from lr down to 0 45 | progress = float(epoch - warmup_epochs) / float(cosine_epochs) 46 | return lr * 0.5 * (1.0 + math.cos(math.pi * progress)) 47 | 48 | global_step = 0 49 | 50 | for epoch in range(epochs): 51 | model.train() 52 | current_lr = get_lr_for_epoch(epoch) 53 | for param_group in optimizer.param_groups: 54 | param_group['lr'] = current_lr 55 | 56 | pbar = tqdm(train_loader, desc=f"Epoch {epoch} (lr={current_lr:.6f})", leave=True) 57 | for i, (x, y) in enumerate(pbar): 58 | x = x.to(device) 59 | y = y.to(device) 60 | optimizer.zero_grad() 61 | _, loss = model(x, y) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | global_step += 1 66 | 67 | # Evaluate only every 'eval_every_steps' steps 68 | if len(val_loader) > 0 and global_step % eval_every_steps == 0: 69 | val_loss = evaluate(model, val_loader, device) 70 | pbar.set_postfix({"train_loss": loss.item(), "val_loss": val_loss}) 71 | 72 | # Optionally evaluate at the end of each epoch 73 | if len(val_loader) > 0: 74 | val_loss = evaluate(model, val_loader, device) 75 | print(f"End of epoch {epoch}, validation loss: {val_loss}") 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /src/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import numpy as np 4 | 5 | def evaluate(model, dataloader, device): 6 | """ 7 | Evaluate the model on a given dataloader. 8 | 9 | Computes the average cross-entropy loss over the dataset. 10 | """ 11 | model.eval() 12 | losses = [] 13 | with torch.no_grad(): 14 | for x, y in tqdm(dataloader, desc="Evaluating", leave=False): 15 | x = x.to(device) 16 | y = y.to(device) 17 | _, loss = model(x, y) 18 | losses.append(loss.item()) 19 | model.train() 20 | return float(np.mean(losses)) 21 | --------------------------------------------------------------------------------