├── .gitignore ├── LICENSE ├── README.md ├── assets ├── loss_slimpajama.png └── pplx_slimpajama.png ├── requirements.txt ├── scripts └── train-baselines-example.sh └── src ├── config ├── __init__.py └── base.py ├── data ├── arxiv.py ├── benchmarks.py ├── openwebtext2.py ├── shakespeare.py ├── slimpajama.py ├── utils.py └── wikitext.py ├── distributed ├── __init__.py ├── backend.py ├── ddp.py └── single.py ├── main.py ├── models ├── base.py ├── llama.py └── utils.py └── optim ├── base.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Dataset folder 2 | src/data/datasets/ 3 | wandb/ 4 | exps/ 5 | scripts/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EPFL MLO Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM-baselines 2 | 3 | A modular codebase to experiment with transformers, inspired by NanoGPT. 4 | 5 | ## Quickstart 6 | 7 | Install dependencies: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | Run a simple training on the Slimpajama dataset ([6B subset](https://huggingface.co/datasets/DKYoon/SlimPajama-6B), 24GBs decompressed, takes a few minutes to download): 14 | 15 | ```sh 16 | python ./src/main.py --config_format base 17 | ``` 18 | 19 | The above command trains a 123.59M parameters model. It trains for 25k iterations with a batch size of 128=32x4 (4 gradient accumulation steps), using a cosine schedule with a maximum learning rate of 1e-3 that is reduced to 1e-4 at the end of training. The model is saved in the `./exps` folder. 20 | 21 | This training takes roughly ~3h on a single A100 (80GB) GPU. The plot of the training and validation loss should look roughly like this: 22 | 23 | Loss on SlimPajama 24 | Perplexity on SlimPajama 25 | 26 | You can check out the wandb run for yourself [here](https://wandb.ai/haeggee/llm-lauzhack/runs/lm2obqy9?nw=nwuserhaeggee). 27 | 28 | 29 | ## Less quick start 30 | 31 | Here are the possible parameters you can use (copypasted from `config/base.py`): 32 | 33 | ```python 34 | # General training params 35 | parser.add_argument('--batch_size', default=32, type=int) 36 | parser.add_argument('--acc_steps', default=4, type=int) 37 | parser.add_argument('--seed', default=0, type=int) # random seed for the parameters 38 | parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering 39 | parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs 40 | parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations 41 | parser.add_argument('--lr', default=1e-3, type=float) 42 | parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup steps is iterations * warmup_percent 43 | parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise 44 | parser.add_argument('--beta1', default=0.9, type=float) # adam parameter 45 | parser.add_argument('--beta2', default=0.95, type=float) # adam parameter 46 | parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) 47 | parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd']) 48 | parser.add_argument('--eval_freq', default=200, type=int) # in iterations 49 | parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved 50 | parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT 51 | # Dataset params 52 | parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2']) 53 | parser.add_argument('--vocab_size', default=50304, type=int) 54 | parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this 55 | # Model params 56 | parser.add_argument('--model', default='base', choices=['base', 'llama2']) 57 | parser.add_argument('--use_pretrained', default="none", type=str) # 'none', 'gpt-2' or a path to the pretraind model 58 | parser.add_argument('--dropout', default=0.0, type=float) # keep to 0 unless in low data regime (e.g. wikitext) 59 | parser.add_argument('--n_head', default=12, type=int) 60 | parser.add_argument('--n_layer', default=12, type=int) # depth in (att + ff) blocks 61 | parser.add_argument('--n_embd', default=768, type=int) # hidden size ... 62 | parser.add_argument('--sequence_length', default=512, type=int) 63 | parser.add_argument('--dtype', default=torch.bfloat16, type=torch.dtype) 64 | parser.add_argument('--bias', default=False, type=bool) 65 | parser.add_argument('--compile', action='store_true') # if true then model is compiled 66 | parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model 67 | parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2 68 | # logging params (WandB) 69 | parser.add_argument('--wandb', action='store_true') # whether to use wandb or not 70 | parser.add_argument('--wandb_project', default="my-project", type=str) 71 | parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name 72 | parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences 73 | # Distributed args 74 | parser.add_argument('--distributed_backend', default=None, type=str, required=False, 75 | choices=distributed.registered_backends()) # distributed backend type (e.g. nccl) 76 | parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) 77 | ``` 78 | 79 | ## Using WandB 80 | 81 | You need to give your wandb authorize key in order to send the data to your wandb account. If you start jobs on a server without access to prompt, then you can set the `WANDB_API_KEY` variable within your script: 82 | 83 | ```bash 84 | # this is a script that could be executed on a server 85 | pip install -r requirements.txt # install req. 86 | export WANDB_API_KEY="put your authorize key here, to find it: https://wandb.ai/authorize" 87 | python ./src/main.py --config_format base --wandb --wandb_project "my awesome project" --n_layer 7 --model base --seed 123 88 | ``` 89 | 90 | ## How to add your own transformer architecture? 91 | 92 | The structure of the project is the following: 93 | 94 | ```sh 95 | src/ 96 | main.py # pick the right data, model, and training function 97 | config/ 98 | __init__.py # contains CONFIG_FORMAT_TO_MODULE_MAP mapping the name given to the --config_format flag with a python conf file 99 | base.py # config for the base model 100 | data/ 101 | utils.py # contains the get_dataset function 102 | wikitext.py # load/process wikitext 103 | arxiv.py # load/process arxiv 104 | shakespeare.py # load/process the Shakespeare dataset 105 | slimpajama.py 106 | ... 107 | models/ 108 | utils.py # contains the get_model function 109 | base.py # contains the standard transformer base architecture 110 | llama.py # llama architecture 111 | optim/ 112 | utils.py # contains eval and get_batch functions 113 | base.py # training function for the base and llama models 114 | distributed/ 115 | # code to enable simple distributed training 116 | ``` 117 | 118 | Given the above structure, to add your own model, you can just fork the `./src/models/base.py` file, do your modifications, then if necessary fork the `./src/optim/base.py` in case you need some custom training loop or evaluation. You also need to fork the `./src/config/base.py` file to add your own parameters, which imply adding your new config to the mapping `CONFIG_FORMAT_TO_MODULE_MAP` in `./src/config/__init__.py`. To add a new dataset, create a new file in the `data` folder, check `wikitext.py` for the expected format. 119 | 120 | ## Multi-GPU training 121 | 122 | Given a multi-GPU machine with e.g. 4 GPUs, one can distribute the training using data-parallelism: 123 | 124 | ```sh 125 | torchrun --nproc_per_node=4 ./src/main.py --config_format base --distributed_backend nccl --dataset slimpajama --model base 126 | ``` 127 | 128 | When using multiple GPUs, the data will be distributed among the GPUs by dividing the number of accumulation steps by the number of nodes. For instance if we train with a batch size of 32 and 4 accumulation steps, then each GPU will process batches of 32 elements and do 1 accumulation steps. For this reason we require `acc_steps` to be a multiple of the number of GPUs. 129 | 130 | 131 | ## Experimenting locally on your device with CPU 132 | If do not have access to a GPU or just want to try the code locally on your device, you can try the Shakespeare dataset with character-level tokens: 133 | 134 | ```sh 135 | python ./src/main.py --n_layer=2 --n_head=4 --n_embd=128 --sequence_length=256 --dataset=shakespeare-char --device=cpu --vocab_size=96 136 | ``` 137 | -------------------------------------------------------------------------------- /assets/loss_slimpajama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/llm-baselines/2d172dfd06fb45dc379cae5f34abcb0872b2a93a/assets/loss_slimpajama.png -------------------------------------------------------------------------------- /assets/pplx_slimpajama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/llm-baselines/2d172dfd06fb45dc379cae5f34abcb0872b2a93a/assets/pplx_slimpajama.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | torch==2.0.0+cu118 4 | torchaudio==2.0.0+cu118 5 | torchvision==0.15.0+cu118 6 | tqdm==4.65.0 7 | transformers 8 | wandb 9 | datasets 10 | zstandard -------------------------------------------------------------------------------- /scripts/train-baselines-example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /scratch/homes/sfan/llm-baselines 4 | pip install -r requirements.txt 5 | 6 | export WANDB_API_KEY="put your wandb api key here" 7 | python ./src/main.py --model base --n_embd 768 --n_head 6 --wandb_run_prefix h768_nh12_nlyr24_sl512_d005 --n_layer 6 --batch_size 50 --sequence_length 512 --wandb --acc_steps 4 --dropout 0.05 8 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | 3 | CONFIG_FORMAT_TO_MODULE_MAP = { 4 | "base": base, 5 | } 6 | 7 | 8 | def parse_args_with_format(format, base_parser, args, namespace): 9 | return CONFIG_FORMAT_TO_MODULE_MAP[format].parse_args(base_parser, args, namespace) 10 | 11 | 12 | def registered_formats(): 13 | return CONFIG_FORMAT_TO_MODULE_MAP.keys() 14 | -------------------------------------------------------------------------------- /src/config/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import distributed 4 | 5 | def none_or_str(value): 6 | if value == 'None': 7 | return None 8 | return value 9 | 10 | def parse_args(base_parser, args, namespace): 11 | parser = base_parser 12 | # General training params 13 | parser.add_argument('--batch_size', default=32, type=int) 14 | parser.add_argument('--acc_steps', default=4, type=int) 15 | parser.add_argument('--seed', default=0, type=int) 16 | parser.add_argument('--data_seed', default=1337, type=int) 17 | parser.add_argument('--device', default='cuda:0', type=str) 18 | parser.add_argument('--iterations', default=25000, type=int) 19 | parser.add_argument('--lr', default=1e-3, type=float) 20 | parser.add_argument('--warmup_percent', default=0.05, type=float) 21 | parser.add_argument('--weight_decay', default=0.1, type=float) 22 | parser.add_argument('--beta1', default=0.9, type=float) 23 | parser.add_argument('--beta2', default=0.95, type=float) 24 | parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none']) 25 | parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd']) 26 | parser.add_argument('--eval_freq', default=200, type=int) # in iterations 27 | parser.add_argument('--results_base_folder', default="./exps", type=str) 28 | parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT 29 | # Dataset params 30 | parser.add_argument('--dataset', default='mathqa', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2', 'mathqa']) 31 | parser.add_argument('--vocab_size', default=50304, type=int) 32 | parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, mostly useless except for openwebtext2 33 | # Model params 34 | parser.add_argument('--model', default='base', choices=['base', 'llama2']) 35 | parser.add_argument('--use_pretrained', default="auto", type=none_or_str) # 'none', 'gpt-2' or a path to the pretraind model 36 | parser.add_argument('--dropout', default=0.0, type=float) 37 | parser.add_argument('--n_head', default=6, type=int) 38 | parser.add_argument('--n_layer', default=6, type=int) # depths in att + ff blocks 39 | parser.add_argument('--n_embd', default=768, type=int) # embedding size / hidden size ... 40 | parser.add_argument('--sequence_length', default=512, type=int) 41 | parser.add_argument('--dtype', default=torch.float16, type=torch.dtype) 42 | parser.add_argument('--bias', default=False, type=bool) 43 | parser.add_argument('--compile', action='store_true') # if true then model is compiled 44 | parser.add_argument("--rmsnorm_eps", default=1e-5, type=float) 45 | parser.add_argument( 46 | "--multiple_of", # make SwiGLU hidden layer size multiple of large power of 2 47 | default=256, 48 | type=int, 49 | ) 50 | parser.add_argument('--run_prefix', default=None, type=str, required=False) # is added before the autogenerated experiment name 51 | parser.add_argument('--exp_name', default=None, type=str, required=False) 52 | # logging params (WandB) 53 | parser.add_argument('--wandb', action='store_true') # whether to use wandb or not 54 | parser.add_argument('--wandb_project', default="my-project", type=str) 55 | parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name 56 | parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences 57 | # Distributed args 58 | parser.add_argument('--distributed_backend', default=None, type=str, required=False, 59 | choices=distributed.registered_backends()) # distributed backend type 60 | parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False) 61 | 62 | args = parser.parse_args(args, namespace) 63 | 64 | if args.exp_name is None: 65 | special_name_handle_fields = {"model", "lr", "batch_size", 66 | "acc_steps", "seed", "exp_name", 67 | "wandb", "wandb_project", "eval_seq_prefix", 68 | "run_prefix", "distributed_backend", "config_format", 69 | "sequence_length"} 70 | overriden_values = [] 71 | for key in vars(args): 72 | if key in special_name_handle_fields: 73 | continue 74 | if getattr(args, key) != parser.get_default(key): 75 | overriden_values.append((key, getattr(args, key))) 76 | chunk_len = 10 77 | overriden_values_str_parts = [] 78 | for chunk_id in range(0, len(overriden_values), chunk_len): 79 | overriden_values_str = "_".join(["{}={}".format(key, value) for key, value in overriden_values[chunk_id:chunk_id+chunk_len]]) 80 | overriden_values_str_parts.append(overriden_values_str) 81 | overriden_values_str = "/".join(overriden_values_str_parts) 82 | exp_name = "" 83 | if args.run_prefix is not None: 84 | exp_name += f"{args.run_prefix}_" 85 | exp_name += f"{args.model}_lr{args.lr}_bs{args.batch_size}x{args.acc_steps}_seqlen{args.sequence_length}/{overriden_values_str}_seed={args.seed}" 86 | args.exp_name = exp_name 87 | 88 | if args.dtype == "torch.bfloat16": 89 | args.dtype = torch.bfloat16 90 | elif args.dtype == "torch.float16": 91 | args.dtype = torch.float16 92 | elif args.dtype == "torch.float32": 93 | args.dtype = torch.float32 94 | 95 | return args 96 | -------------------------------------------------------------------------------- /src/data/arxiv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | import logging 4 | from pathlib import Path 5 | from typing import Optional 6 | from multiprocessing import Pool 7 | from tempfile import NamedTemporaryFile 8 | from subprocess import Popen, TimeoutExpired, PIPE 9 | from typing import Tuple, List 10 | 11 | import numpy as np 12 | import requests 13 | from tqdm.auto import tqdm 14 | import tiktoken 15 | 16 | 17 | def convert_to_markdown(args: Tuple[Path, Path]): 18 | texfile, mdroot = args 19 | mdfile = mdroot/f"{texfile.name}.md" 20 | with Popen(["pandoc", "--wrap=none", "--from", "latex", texfile, 21 | "--output", mdfile], stderr=PIPE) as proc: 22 | try: 23 | proc.communicate(timeout=1) 24 | except TimeoutExpired: 25 | proc.kill() 26 | 27 | 28 | 29 | def fetch_arxiv(root: Path, year: int): 30 | # download latex 31 | url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz" 32 | texroot = root/"tex" 33 | print("Downloading Arxiv year", year) 34 | req = requests.get(url, timeout=60) 35 | with NamedTemporaryFile(suffix=".tar.gz") as f: 36 | f.write(req.content) 37 | logging.debug("Tar saved in tempfile %s" % f.name) 38 | with tarfile.open(f.name) as tar: 39 | logging.debug("Extracting tarfile") 40 | tar.extractall(texroot) 41 | 42 | # convert to markdown 43 | mdroot = root/"md"/str(year) 44 | mdroot.mkdir(parents=True) 45 | files = list((texroot/str(year)).iterdir()) 46 | with Pool(os.cpu_count()) as p: 47 | args = [(texfile, mdroot) for texfile in files] 48 | for _ in tqdm(p.imap_unordered(convert_to_markdown, args), 49 | desc="Converting to markdown", total=len(files)): 50 | pass 51 | 52 | 53 | def tokenize_arxiv(root: Path, year: int): 54 | tokenizer = tiktoken.get_encoding("gpt2") 55 | tokens = [] 56 | tokens_val = [] 57 | tokens_test = [] 58 | mds = root/"md"/str(year) 59 | 60 | # tokenize 61 | desc = f"Tokenizing {year}" 62 | for i, mdpath in enumerate(tqdm(list(mds.iterdir()), desc=desc)): 63 | with open(mdpath, encoding="utf8") as f: 64 | text = "".join(f.readlines()) 65 | if i % 10 <= 6: # train split 66 | tokens += tokenizer.encode(text) 67 | elif i % 10 <= 8: # val split 68 | tokens_val += tokenizer.encode(text) 69 | else: # test split 70 | tokens_test += tokenizer.encode(text) 71 | 72 | # save to dir 73 | tpath = root/str(year) 74 | tpath.mkdir(parents=True) 75 | for x, name in zip([tokens, tokens_val, tokens_test], 76 | ["train", "val", "test"]): 77 | mem = np.memmap(tpath/f"{name}.npy", dtype=np.uint16, mode="w+", 78 | shape=len(x)) 79 | for i, v in enumerate(x): 80 | mem[i] = v 81 | 82 | 83 | def load_arxiv(cachedir: Path, years: Optional[List[int]] = None): 84 | all_years = list(range(1992, 2004)) 85 | if years is None: 86 | years = all_years 87 | assert set(years) <= set(all_years) 88 | root = cachedir/"arxiv" 89 | root.mkdir(exist_ok=True, parents=True) 90 | 91 | # download all years requested that are not present 92 | for year in years: 93 | if not (root/"md"/str(year)).exists(): 94 | fetch_arxiv(root, year) 95 | 96 | # tokenize all years not previously tokenized 97 | for year in years: 98 | if not (root/str(year)).exists(): 99 | tokenize_arxiv(root, year) 100 | 101 | # load meta 102 | ret = {} 103 | for split in ["train", "val"]: 104 | paths = [root/str(year)/f"{split}.npy" for year in years] 105 | x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths] 106 | ret[split] = np.concatenate(x) 107 | return ret 108 | 109 | 110 | def get_arxiv_2000(): 111 | return load_arxiv(Path(os.path.dirname(__file__))/"datasets", [2000]) 112 | 113 | 114 | def get_arxiv_full(): 115 | return load_arxiv(Path(os.path.dirname(__file__))/"datasets") 116 | -------------------------------------------------------------------------------- /src/data/benchmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import importlib 3 | 4 | ## load arc 5 | import os 6 | from tqdm import tqdm 7 | import numpy as np 8 | import tiktoken 9 | from datasets import load_dataset 10 | import torch 11 | 12 | # Initialize the tokenizer 13 | tknzr = tiktoken.get_encoding("gpt2") 14 | ARC_E_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/arc_easy/") 15 | ARC_C_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/arc_challenge/") 16 | HELLASWAG_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/hellaswag/") 17 | LOGIQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/logiqa/") 18 | PIQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/piqa/") 19 | SCIQ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/sciq/") 20 | HUMANEVAL_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/humaneval/") 21 | KODCODE_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/kodcode/") 22 | GSM8K_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/gsm8k/") 23 | MATHQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/mathqa/") 24 | MEDQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/medqa/") 25 | 26 | 27 | def tokenize_with_pad(text, pad_to_multiple=1024): 28 | ids = tknzr.encode_ordinary(text) 29 | ids.append(tknzr.eot_token) 30 | pad_token_id = tknzr.eot_token 31 | # Calculate padding length (next multiple of pad_multiple) 32 | padded_length = ((len(ids) + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple 33 | # Initialize the padded array with pad token (not zeros) 34 | padded_tokens = np.ones(padded_length, dtype=np.uint16) * pad_token_id 35 | padded_tokens[:len(ids)] = ids 36 | return padded_tokens 37 | 38 | 39 | def get_humaneval(num_proc=10, return_torch=False, pad_to_multiple=1024): 40 | """ 41 | Load and process the HumanEval (for coding) dataset. 42 | Tokenize the text and store it in binary format for efficient loading. 43 | """ 44 | if not os.path.exists(os.path.join(HUMANEVAL_DATA_PATH, 'val.bin')): 45 | os.makedirs(HUMANEVAL_DATA_PATH, exist_ok=True) 46 | 47 | # Load the HumanEval dataset from Hugging Face Datasets 48 | human_eval_path = "openai/openai_humaneval" 49 | dataset = load_dataset(human_eval_path, trust_remote_code=True) 50 | split_dataset = dataset["test"].train_test_split(test_size=0.5, seed=2357, shuffle=True) 51 | split_dataset['val'] = split_dataset.pop('test') 52 | data_dict = { 53 | 'train': split_dataset['train'], 54 | 'val': split_dataset['val'], 55 | } 56 | 57 | def process(example): 58 | """ 59 | Tokenize the example text by encoding it into token IDs. 60 | """ 61 | prompt = example['prompt'] 62 | completion = example['canonical_solution'] 63 | 64 | concatenated_text = f"{prompt} \n {completion}" 65 | # print(concatenated_text) 66 | ids = tokenize_with_pad(text=concatenated_text, 67 | pad_to_multiple=pad_to_multiple) 68 | return {'ids': ids, 'len': len(ids)} 69 | 70 | # Tokenize and map the dataset 71 | tokenized = {} 72 | for split, dset in data_dict.items(): 73 | tokenized[split] = dset.map( 74 | process, 75 | remove_columns=['task_id', 'prompt', 'canonical_solution', 'test', 'entry_point'], 76 | desc=f"Tokenizing {split} split", 77 | num_proc=num_proc 78 | ) 79 | 80 | # Concatenate all the token IDs into one large binary file per split 81 | for split, dset in tokenized.items(): 82 | # Save token IDs length 83 | len_arr = np.array(dset['len'], dtype=np.uint16) 84 | with open(os.path.join(HUMANEVAL_DATA_PATH, f'{split}.len'), 'wb') as f: 85 | np.save(f, len_arr) 86 | # Total number of tokens 87 | arr_len = np.sum(dset['len']) 88 | filename = os.path.join(HUMANEVAL_DATA_PATH, f'{split}.bin') 89 | dtype = np.uint16 90 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 91 | total_batches = 10 92 | 93 | idx = 0 94 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 95 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 96 | arr_batch = np.concatenate(batch['ids']) 97 | arr[idx: idx + len(arr_batch)] = arr_batch 98 | idx += len(arr_batch) 99 | arr.flush() 100 | 101 | # Load tokenized binary files for training, validation 102 | train_data = np.memmap(os.path.join(HUMANEVAL_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 103 | val_data = np.memmap(os.path.join(HUMANEVAL_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 104 | 105 | if return_torch: 106 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 107 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 108 | print(f'Benchmark HumanEval: train[{len(train_data)}] | val[{len(val_data)}]') 109 | return { 110 | 'train': train_data, 111 | 'train_len': np.load(os.path.join(HUMANEVAL_DATA_PATH, 'train.len')), 112 | 'val': val_data, 113 | 'val_len': np.load(os.path.join(HUMANEVAL_DATA_PATH, 'val.len')), 114 | } 115 | 116 | 117 | def get_kodcode(num_proc=10, return_torch=False, pad_to_multiple=1024): 118 | """ 119 | Load and process the KodCode (for code reasoning) dataset. 120 | Tokenize the text and store it in binary format for efficient loading. 121 | """ 122 | if not os.path.exists(os.path.join(KODCODE_DATA_PATH, 'val.bin')): 123 | os.makedirs(KODCODE_DATA_PATH, exist_ok=True) 124 | 125 | # Load the GSM8K dataset from Hugging Face Datasets 126 | dataset = load_dataset("KodCode/KodCode-V1-SFT-R1", trust_remote_code=True) 127 | dataset = dataset["train"].train_test_split(test_size=0.1, seed=2357, shuffle=True) 128 | data_dict = { 129 | 'train': dataset["train"], 130 | 'val': dataset["test"], 131 | } 132 | 133 | def process(example): 134 | """ 135 | Tokenize the example text by encoding it into token IDs. 136 | """ 137 | question = example['question'] 138 | answer = example['solution'] 139 | 140 | concatenated_text = f"{question}\n{answer}" 141 | # print(concatenated_text) 142 | ids = tokenize_with_pad(text=concatenated_text, 143 | pad_to_multiple=pad_to_multiple) 144 | return {'ids': ids, 'len': len(ids)} 145 | 146 | # Tokenize and map the dataset 147 | tokenized = {} 148 | for split, dset in data_dict.items(): 149 | tokenized[split] = dset.map( 150 | process, 151 | remove_columns=['style', 'question_id', 'subset', 'question', 'solution', 'test_code', 'test_info', 152 | 'gpt_pass_sequence', 'gpt_pass_trial_num', 'gpt_difficulty', 'gpt_pass_percentage', 153 | 'r1_pass_sequence', 'r1_pass_trial_num', 'r1_correctness', 'r1_solution', 'metadata', 'conversations'], 154 | desc=f"Tokenizing {split} split", 155 | num_proc=num_proc 156 | ) 157 | 158 | # Concatenate all the token IDs into one large binary file per split 159 | for split, dset in tokenized.items(): 160 | # Save token IDs length 161 | len_arr = np.array(dset['len'], dtype=np.uint16) 162 | with open(os.path.join(KODCODE_DATA_PATH, f'{split}.len'), 'wb') as f: 163 | np.save(f, len_arr) 164 | # Total number of tokens 165 | arr_len = np.sum(dset['len']) 166 | filename = os.path.join(KODCODE_DATA_PATH, f'{split}.bin') 167 | dtype = np.uint16 168 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 169 | total_batches = 10 170 | 171 | idx = 0 172 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 173 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 174 | arr_batch = np.concatenate(batch['ids']) 175 | arr[idx: idx + len(arr_batch)] = arr_batch 176 | idx += len(arr_batch) 177 | arr.flush() 178 | 179 | # Load tokenized binary files for training, validation 180 | train_data = np.memmap(os.path.join(KODCODE_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 181 | val_data = np.memmap(os.path.join(KODCODE_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 182 | 183 | if return_torch: 184 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 185 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 186 | print(f'Benchmark KodCode: train[{len(train_data)}] | val[{len(val_data)}]') 187 | return { 188 | 'train': train_data, 189 | 'train_len': np.load(os.path.join(KODCODE_DATA_PATH, 'train.len')), 190 | 'val': val_data, 191 | 'val_len': np.load(os.path.join(KODCODE_DATA_PATH, 'val.len')), 192 | } 193 | 194 | 195 | def get_gsm8k(num_proc=10, return_torch=False, pad_to_multiple=1024): 196 | """ 197 | Load and process the GSM8K (for math) dataset. 198 | Tokenize the text and store it in binary format for efficient loading. 199 | """ 200 | if not os.path.exists(os.path.join(GSM8K_DATA_PATH, 'val.bin')): 201 | os.makedirs(GSM8K_DATA_PATH, exist_ok=True) 202 | 203 | # Load the GSM8K dataset from Hugging Face Datasets 204 | dataset = load_dataset("openai/gsm8k", "main", trust_remote_code=True) 205 | data_dict = { 206 | 'train': dataset["train"], 207 | 'val': dataset["test"], 208 | } 209 | 210 | def process(example): 211 | """ 212 | Tokenize the example text by encoding it into token IDs. 213 | """ 214 | question = example['question'] 215 | answer = example['answer'] 216 | 217 | concatenated_text = f"{question}\n{answer}" 218 | # print(concatenated_text) 219 | ids = tokenize_with_pad(text=concatenated_text, 220 | pad_to_multiple=pad_to_multiple) 221 | return {'ids': ids, 'len': len(ids)} 222 | 223 | # Tokenize and map the dataset 224 | tokenized = {} 225 | for split, dset in data_dict.items(): 226 | tokenized[split] = dset.map( 227 | process, 228 | remove_columns=['question', 'answer'], 229 | desc=f"Tokenizing {split} split", 230 | num_proc=num_proc 231 | ) 232 | 233 | # Concatenate all the token IDs into one large binary file per split 234 | for split, dset in tokenized.items(): 235 | # Save token IDs length 236 | len_arr = np.array(dset['len'], dtype=np.uint16) 237 | with open(os.path.join(GSM8K_DATA_PATH, f'{split}.len'), 'wb') as f: 238 | np.save(f, len_arr) 239 | # Total number of tokens 240 | arr_len = np.sum(dset['len']) 241 | filename = os.path.join(GSM8K_DATA_PATH, f'{split}.bin') 242 | dtype = np.uint16 243 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 244 | total_batches = 10 245 | 246 | idx = 0 247 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 248 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 249 | arr_batch = np.concatenate(batch['ids']) 250 | arr[idx: idx + len(arr_batch)] = arr_batch 251 | idx += len(arr_batch) 252 | arr.flush() 253 | 254 | # Load tokenized binary files for training, validation 255 | train_data = np.memmap(os.path.join(GSM8K_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 256 | val_data = np.memmap(os.path.join(GSM8K_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 257 | 258 | if return_torch: 259 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 260 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 261 | print(f'Benchmark GSM8k: train[{len(train_data)}] | val[{len(val_data)}]') 262 | return { 263 | 'train': train_data, 264 | 'train_len': np.load(os.path.join(GSM8K_DATA_PATH, 'train.len')), 265 | 'val': val_data, 266 | 'val_len': np.load(os.path.join(GSM8K_DATA_PATH, 'val.len')), 267 | } 268 | 269 | def get_mathqa(num_proc=10, return_torch=False, pad_to_multiple=1024): 270 | """ 271 | Load and process the MATH-QA dataset. 272 | Tokenize the text and store it in binary format for efficient loading. 273 | """ 274 | if not os.path.exists(os.path.join(MATHQA_DATA_PATH, 'val.bin')): 275 | os.makedirs(MATHQA_DATA_PATH, exist_ok=True) 276 | 277 | # Load the MATH-QA dataset from Hugging Face Datasets 278 | dataset = load_dataset("allenai/math_qa",trust_remote_code=True) 279 | data_dict = { 280 | 'train': dataset["train"], 281 | 'val': dataset["test"], 282 | } 283 | 284 | def process(example): 285 | """ 286 | Tokenize the example text by encoding it into token IDs. 287 | """ 288 | question = example['Problem'] 289 | choices = {d.strip()[0]:d.split(")")[-1].strip() for d in example['options'].split(",")} 290 | answer = choices.get(example['correct']) 291 | 292 | concatenated_text = f"{question} {answer}" 293 | # print(concatenated_text) 294 | ids = tokenize_with_pad(text=concatenated_text, 295 | pad_to_multiple=pad_to_multiple) 296 | return {'ids': ids, 'len': len(ids)} 297 | 298 | # Tokenize and map the dataset 299 | tokenized = {} 300 | for split, dset in data_dict.items(): 301 | tokenized[split] = dset.map( 302 | process, 303 | remove_columns=['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'], 304 | desc=f"Tokenizing {split} split", 305 | num_proc=num_proc 306 | ) 307 | 308 | # Concatenate all the token IDs into one large binary file per split 309 | for split, dset in tokenized.items(): 310 | # Save token IDs length 311 | len_arr = np.array(dset['len'], dtype=np.uint16) 312 | with open(os.path.join(MATHQA_DATA_PATH, f'{split}.len'), 'wb') as f: 313 | np.save(f, len_arr) 314 | # Total number of tokens 315 | arr_len = np.sum(dset['len']) 316 | filename = os.path.join(MATHQA_DATA_PATH, f'{split}.bin') 317 | dtype = np.uint16 318 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 319 | total_batches = 10 320 | 321 | idx = 0 322 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 323 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 324 | arr_batch = np.concatenate(batch['ids']) 325 | arr[idx: idx + len(arr_batch)] = arr_batch 326 | idx += len(arr_batch) 327 | arr.flush() 328 | 329 | # Load tokenized binary files for training, validation 330 | train_data = np.memmap(os.path.join(MATHQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 331 | val_data = np.memmap(os.path.join(MATHQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 332 | 333 | if return_torch: 334 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 335 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 336 | print(f'Benchmark MATHQA: train[{len(train_data)}] | val[{len(val_data)}]') 337 | return { 338 | 'train': train_data, 339 | 'train_len': np.load(os.path.join(MATHQA_DATA_PATH, 'train.len')), 340 | 'val': val_data, 341 | 'val_len': np.load(os.path.join(MATHQA_DATA_PATH, 'val.len')), 342 | } 343 | 344 | def get_medqa(num_proc=10, return_torch=False, pad_to_multiple=1024): 345 | """ 346 | Load and process the MEDQA dataset. 347 | Tokenize the text and store it in binary format for efficient loading. 348 | """ 349 | if not os.path.exists(os.path.join(MEDQA_DATA_PATH, 'val.bin')): 350 | os.makedirs(MEDQA_DATA_PATH, exist_ok=True) 351 | 352 | # Load the MATH-QA dataset from Hugging Face Datasets 353 | dataset = load_dataset("bigbio/med_qa",trust_remote_code=True) 354 | data_dict = { 355 | 'train': dataset["train"], 356 | 'val': dataset["test"], 357 | } 358 | 359 | def process(example): 360 | """ 361 | Tokenize the example text by encoding it into token IDs. 362 | """ 363 | question = example["question"] 364 | answer = example["answer"] 365 | 366 | concatenated_text = f"{question} {answer}" 367 | # print(concatenated_text) 368 | ids = tokenize_with_pad(text=concatenated_text, 369 | pad_to_multiple=pad_to_multiple) 370 | return {'ids': ids, 'len': len(ids)} 371 | 372 | # Tokenize and map the dataset 373 | tokenized = {} 374 | for split, dset in data_dict.items(): 375 | tokenized[split] = dset.map( 376 | process, 377 | remove_columns=['meta_info', 'question', 'answer_idx', 'answer', 'options'], 378 | desc=f"Tokenizing {split} split", 379 | num_proc=num_proc 380 | ) 381 | 382 | # Concatenate all the token IDs into one large binary file per split 383 | for split, dset in tokenized.items(): 384 | # Save token IDs length 385 | len_arr = np.array(dset['len'], dtype=np.uint16) 386 | with open(os.path.join(MEDQA_DATA_PATH, f'{split}.len'), 'wb') as f: 387 | np.save(f, len_arr) 388 | # Total number of tokens 389 | arr_len = np.sum(dset['len']) 390 | filename = os.path.join(MEDQA_DATA_PATH, f'{split}.bin') 391 | dtype = np.uint16 392 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 393 | total_batches = 10 394 | 395 | idx = 0 396 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 397 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 398 | arr_batch = np.concatenate(batch['ids']) 399 | arr[idx: idx + len(arr_batch)] = arr_batch 400 | idx += len(arr_batch) 401 | arr.flush() 402 | 403 | # Load tokenized binary files for training, validation 404 | train_data = np.memmap(os.path.join(MEDQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 405 | val_data = np.memmap(os.path.join(MEDQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 406 | 407 | if return_torch: 408 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 409 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 410 | print(f'Benchmark MedQA: train[{len(train_data)}] | val[{len(val_data)}]') 411 | return { 412 | 'train': train_data, 413 | 'train_len': np.load(os.path.join(MEDQA_DATA_PATH, 'train.len')), 414 | 'val': val_data, 415 | 'val_len': np.load(os.path.join(MEDQA_DATA_PATH, 'val.len')), 416 | } 417 | 418 | def get_arc_easy(num_proc=10, return_torch=False, pad_to_multiple=1024): 419 | """ 420 | Load and process the AI2-ARC dataset. 421 | Tokenize the text and store it in binary format for efficient loading. 422 | """ 423 | if not os.path.exists(os.path.join(ARC_E_DATA_PATH, 'val.bin')): 424 | os.makedirs(ARC_E_DATA_PATH, exist_ok=True) 425 | 426 | # Load the AI2-ARC dataset from Hugging Face Datasets 427 | dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", split=['train', 'test']) 428 | data_dict = { 429 | 'train': dataset[0], 430 | 'val': dataset[1], 431 | } 432 | 433 | def process(example): 434 | """ 435 | Tokenize the example text by encoding it into token IDs. 436 | """ 437 | question = example['question'] 438 | choices = example['choices'] 439 | answer = dict(zip(choices["label"], choices["text"])).get(example['answerKey']) 440 | 441 | concatenated_text = f"{question} {answer}" 442 | # print(concatenated_text) 443 | ids = tokenize_with_pad(text=concatenated_text, 444 | pad_to_multiple=pad_to_multiple) 445 | return {'ids': ids, 'len': len(ids)} 446 | 447 | # Tokenize and map the dataset 448 | tokenized = {} 449 | for split, dset in data_dict.items(): 450 | tokenized[split] = dset.map( 451 | process, 452 | remove_columns=['question', 'choices', 'answerKey'], 453 | desc=f"Tokenizing {split} split", 454 | num_proc=num_proc 455 | ) 456 | 457 | # Concatenate all the token IDs into one large binary file per split 458 | for split, dset in tokenized.items(): 459 | # Save token IDs length 460 | len_arr = np.array(dset['len'], dtype=np.uint16) 461 | with open(os.path.join(ARC_E_DATA_PATH, f'{split}.len'), 'wb') as f: 462 | np.save(f, len_arr) 463 | # Total number of tokens 464 | arr_len = np.sum(dset['len']) 465 | filename = os.path.join(ARC_E_DATA_PATH, f'{split}.bin') 466 | dtype = np.uint16 467 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 468 | total_batches = 10 469 | 470 | idx = 0 471 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 472 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 473 | arr_batch = np.concatenate(batch['ids']) 474 | arr[idx: idx + len(arr_batch)] = arr_batch 475 | idx += len(arr_batch) 476 | arr.flush() 477 | 478 | # Load tokenized binary files for training, validation 479 | train_data = np.memmap(os.path.join(ARC_E_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 480 | val_data = np.memmap(os.path.join(ARC_E_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 481 | 482 | if return_torch: 483 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 484 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 485 | print(f'Benchmark ARC-Easy: train[{len(train_data)}] | val[{len(val_data)}]') 486 | return { 487 | 'train': train_data, 488 | 'train_len': np.load(os.path.join(ARC_E_DATA_PATH, 'train.len')), 489 | 'val': val_data, 490 | 'val_len': np.load(os.path.join(ARC_E_DATA_PATH, 'val.len')), 491 | } 492 | 493 | def get_arc_challenge(num_proc=10, return_torch=False, pad_to_multiple=1024): 494 | """ 495 | Load and process the AI2-ARC dataset. 496 | Tokenize the text and store it in binary format for efficient loading. 497 | """ 498 | if not os.path.exists(os.path.join(ARC_C_DATA_PATH, 'val.bin')): 499 | os.makedirs(ARC_C_DATA_PATH, exist_ok=True) 500 | 501 | # Load the AI2-ARC dataset from Hugging Face Datasets 502 | dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", split=['train', 'test']) 503 | data_dict = { 504 | 'train': dataset[0], 505 | 'val': dataset[1], 506 | } 507 | 508 | def process(example): 509 | """ 510 | Tokenize the example text by encoding it into token IDs. 511 | """ 512 | question = example['question'] 513 | choices = example['choices'] 514 | answer = dict(zip(choices["label"], choices["text"])).get(example['answerKey']) 515 | 516 | concatenated_text = f"{question} {answer}" 517 | # print(concatenated_text) 518 | ids = tokenize_with_pad(text=concatenated_text, 519 | pad_to_multiple=pad_to_multiple) 520 | return {'ids': ids, 'len': len(ids)} 521 | 522 | # Tokenize and map the dataset 523 | tokenized = {} 524 | for split, dset in data_dict.items(): 525 | tokenized[split] = dset.map( 526 | process, 527 | remove_columns=['question', 'choices', 'answerKey'], 528 | desc=f"Tokenizing {split} split", 529 | num_proc=num_proc 530 | ) 531 | 532 | # Concatenate all the token IDs into one large binary file per split 533 | for split, dset in tokenized.items(): 534 | # Save token IDs length 535 | len_arr = np.array(dset['len'], dtype=np.uint16) 536 | with open(os.path.join(ARC_C_DATA_PATH, f'{split}.len'), 'wb') as f: 537 | np.save(f, len_arr) 538 | # Total number of tokens 539 | arr_len = np.sum(dset['len']) 540 | filename = os.path.join(ARC_C_DATA_PATH, f'{split}.bin') 541 | dtype = np.uint16 542 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 543 | total_batches = 10 544 | 545 | idx = 0 546 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 547 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 548 | arr_batch = np.concatenate(batch['ids']) 549 | arr[idx: idx + len(arr_batch)] = arr_batch 550 | idx += len(arr_batch) 551 | arr.flush() 552 | 553 | # Load tokenized binary files for training, validation 554 | train_data = np.memmap(os.path.join(ARC_C_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 555 | val_data = np.memmap(os.path.join(ARC_C_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 556 | 557 | if return_torch: 558 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 559 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 560 | print(f'Benchmark ARC-Challenge: train[{len(train_data)}] | val[{len(val_data)}]') 561 | return { 562 | 'train': train_data, 563 | 'train_len': np.load(os.path.join(ARC_C_DATA_PATH, 'train.len')), 564 | 'val': val_data, 565 | 'val_len': np.load(os.path.join(ARC_C_DATA_PATH, 'val.len')), 566 | } 567 | 568 | def get_hellaswag(num_proc=10, return_torch=False, pad_to_multiple=1024): 569 | """ 570 | Load and process the HellaSwag dataset. 571 | Tokenize the text and store it in binary format for efficient loading. 572 | """ 573 | if not os.path.exists(os.path.join(HELLASWAG_DATA_PATH, 'val.bin')): 574 | os.makedirs(HELLASWAG_DATA_PATH, exist_ok=True) 575 | 576 | # Load the HellaSwag dataset from Hugging Face Datasets 577 | dataset = load_dataset("Rowan/hellaswag", split=['train', 'validation']) 578 | data_dict = { 579 | 'train': dataset[0], 580 | 'val': dataset[1], 581 | } 582 | 583 | def process(example): 584 | """ 585 | Tokenize the example text by encoding it into token IDs. 586 | """ 587 | context = example['ctx'] 588 | ending_options = example['endings'] 589 | answer = ending_options[int(example['label'])] 590 | concatenated_text = f"{context} {answer}" 591 | # print(concatenated_text) 592 | ids = tokenize_with_pad(text=concatenated_text, 593 | pad_to_multiple=pad_to_multiple) 594 | return {'ids': ids, 'len': len(ids)} 595 | 596 | # Tokenize and map the dataset 597 | tokenized = {} 598 | for split, dset in data_dict.items(): 599 | tokenized[split] = dset.map( 600 | process, 601 | remove_columns=['ctx', 'endings', 'label'], 602 | desc=f"Tokenizing {split} split", 603 | num_proc=num_proc 604 | ) 605 | 606 | # Concatenate all the token IDs into one large binary file per split 607 | for split, dset in tokenized.items(): 608 | # Save token IDs length 609 | len_arr = np.array(dset['len'], dtype=np.uint16) 610 | with open(os.path.join(HELLASWAG_DATA_PATH, f'{split}.len'), 'wb') as f: 611 | np.save(f, len_arr) 612 | # Total number of tokens 613 | arr_len = np.sum(dset['len']) 614 | filename = os.path.join(HELLASWAG_DATA_PATH, f'{split}.bin') 615 | dtype = np.uint16 616 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 617 | total_batches = 10 618 | 619 | idx = 0 620 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 621 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 622 | arr_batch = np.concatenate(batch['ids']) 623 | arr[idx: idx + len(arr_batch)] = arr_batch 624 | idx += len(arr_batch) 625 | arr.flush() 626 | 627 | # Load tokenized binary files for training, validation 628 | train_data = np.memmap(os.path.join(HELLASWAG_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 629 | val_data = np.memmap(os.path.join(HELLASWAG_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 630 | 631 | if return_torch: 632 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 633 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 634 | 635 | print(f'Benchmark Hellaswag: train[{len(train_data)}] | val[{len(val_data)}]') 636 | return { 637 | 'train': train_data, 638 | 'train_len': np.load(os.path.join(HELLASWAG_DATA_PATH, 'train.len')), 639 | 'val': val_data, 640 | 'val_len': np.load(os.path.join(HELLASWAG_DATA_PATH, 'val.len')), 641 | } 642 | 643 | def get_logiqa(num_proc=10, return_torch=False, pad_to_multiple=1024): 644 | """ 645 | Load and process the LogiQA dataset. 646 | Tokenize the text and store it in binary format for efficient loading. 647 | """ 648 | if not os.path.exists(os.path.join(LOGIQA_DATA_PATH, 'val.bin')): 649 | os.makedirs(LOGIQA_DATA_PATH, exist_ok=True) 650 | 651 | # Load the LogiQA dataset from HuggingFace Datasets 652 | dataset = load_dataset("lucasmccabe/logiqa", split=['train', 'test']) 653 | data_dict = { 654 | 'train': dataset[0], 655 | 'val': dataset[1], 656 | } 657 | 658 | def process(example): 659 | """ 660 | Tokenize the example text by encoding it into token IDs. 661 | """ 662 | context = example['context'] 663 | query = example['query'] 664 | correct_option = example['correct_option'] 665 | options = example['options'] 666 | answer = options[correct_option] 667 | concatenated_text = f"{context} {query} {answer}" 668 | # concatenated_text = f"Context: {context} Query: {query} Answer: {answer}" # TODO : try this ? 669 | ids = tokenize_with_pad(text=concatenated_text, 670 | pad_to_multiple=pad_to_multiple) 671 | return {'ids': ids, 'len': len(ids)} 672 | 673 | # Tokenize and map the dataset 674 | tokenized = {} 675 | for split, dset in data_dict.items(): 676 | tokenized[split] = dset.map( 677 | process, 678 | remove_columns=['context', 'query', 'correct_option', 'options'], 679 | desc=f"Tokenizing {split} split", 680 | num_proc=num_proc 681 | ) 682 | 683 | # Concatenate all the token IDs into one large binary file per split 684 | for split, dset in tokenized.items(): 685 | # Save token IDs length 686 | len_arr = np.array(dset['len'], dtype=np.uint16) 687 | with open(os.path.join(LOGIQA_DATA_PATH, f'{split}.len'), 'wb') as f: 688 | np.save(f, len_arr) 689 | # total number of tokens 690 | arr_len = np.sum(dset['len']) 691 | filename = os.path.join(LOGIQA_DATA_PATH, f'{split}.bin') 692 | dtype = np.uint16 693 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 694 | total_batches = 10 695 | 696 | idx = 0 697 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 698 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 699 | arr_batch = np.concatenate(batch['ids']) 700 | arr[idx: idx + len(arr_batch)] = arr_batch 701 | idx += len(arr_batch) 702 | arr.flush() 703 | 704 | # Load tokenized binary files for training, validation 705 | train_data = np.memmap(os.path.join(LOGIQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 706 | val_data = np.memmap(os.path.join(LOGIQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 707 | 708 | if return_torch: 709 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 710 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 711 | print(f'Benchmark LogiQA: train[{len(train_data)}] | val[{len(val_data)}]') 712 | return { 713 | 'train': train_data, 714 | 'train_len': np.load(os.path.join(LOGIQA_DATA_PATH, 'train.len')), 715 | 'val': val_data, 716 | 'val_len': np.load(os.path.join(LOGIQA_DATA_PATH, 'val.len')), 717 | } 718 | 719 | def get_piqa(num_proc=10, return_torch=False, pad_to_multiple=1024): 720 | """ 721 | Load and process the PIQA dataset. 722 | Tokenize the text and store it in binary format for efficient loading. 723 | """ 724 | if not os.path.exists(os.path.join(PIQA_DATA_PATH, 'val.bin')): 725 | os.makedirs(PIQA_DATA_PATH, exist_ok=True) 726 | 727 | # Load the PIQA dataset from HuggingFace Datasets 728 | dataset = load_dataset("piqa", split=['train', 'test']) 729 | 730 | data_dict = { 731 | 'train': dataset[0], 732 | 'val': dataset[1], 733 | } 734 | 735 | def process(example): 736 | """ 737 | Tokenize the example text by encoding it into token IDs. 738 | """ 739 | goal = example['goal'] 740 | sols = [example['sol1'], example['sol2']] 741 | label = example['label'] 742 | sol = sols[label] 743 | concatenated_text = f"{goal} {sol}" # only include the correct solution 744 | ids = tokenize_with_pad(text=concatenated_text, 745 | pad_to_multiple=pad_to_multiple) 746 | return {'ids': ids, 'len': len(ids)} 747 | 748 | # Tokenize and map the dataset 749 | tokenized = {} 750 | for split, dset in data_dict.items(): 751 | tokenized[split] = dset.map( 752 | process, 753 | remove_columns=['goal', 'sol1', 'sol2', 'label'], 754 | desc=f"Tokenizing {split} split", 755 | num_proc=num_proc 756 | ) 757 | 758 | # Concatenate all the token IDs into one large binary file per split 759 | for split, dset in tokenized.items(): 760 | # Save token IDs length 761 | len_arr = np.array(dset['len'], dtype=np.uint16) 762 | with open(os.path.join(PIQA_DATA_PATH, f'{split}.len'), 'wb') as f: 763 | np.save(f, len_arr) 764 | # total number of tokens 765 | arr_len = np.sum(dset['len']) 766 | filename = os.path.join(PIQA_DATA_PATH, f'{split}.bin') 767 | dtype = np.uint16 768 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 769 | total_batches = 10 770 | 771 | idx = 0 772 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 773 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 774 | arr_batch = np.concatenate(batch['ids']) 775 | arr[idx: idx + len(arr_batch)] = arr_batch 776 | idx += len(arr_batch) 777 | arr.flush() 778 | 779 | # Load tokenized binary files for training, validation 780 | train_data = np.memmap(os.path.join(PIQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 781 | val_data = np.memmap(os.path.join(PIQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 782 | print(f'Benchmark PIQA: train[{len(train_data)}] | val[{len(val_data)}]') 783 | if return_torch: 784 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 785 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 786 | 787 | return { 788 | 'train': train_data, 789 | 'train_len': np.load(os.path.join(PIQA_DATA_PATH, 'train.len')), 790 | 'val': val_data, 791 | 'val_len': np.load(os.path.join(PIQA_DATA_PATH, 'val.len')), 792 | } 793 | 794 | def get_sciq(num_proc=10, return_torch=False, pad_to_multiple=1024): 795 | """ 796 | Load and process the SciQ dataset. 797 | Tokenize the text and store it in binary format for efficient loading. 798 | """ 799 | if not os.path.exists(os.path.join(SCIQ_DATA_PATH, 'val.bin')): 800 | os.makedirs(SCIQ_DATA_PATH, exist_ok=True) 801 | 802 | # Load the SciQ dataset from HuggingFace Datasets 803 | dataset = load_dataset("sciq", split=['train', 'test']) 804 | 805 | data_dict = { 806 | 'train': dataset[0], 807 | 'val': dataset[1], 808 | } 809 | 810 | def process(example): 811 | """ 812 | Tokenize the example text by encoding it into token IDs. 813 | """ 814 | question = example['question'] 815 | answer = example['correct_answer'] 816 | # explanation = example.get('support', '') # Explanation text (optional) 817 | explantation = example['support'] 818 | # concatenated_text = f"Question: {question} Answer: {answer} {explantation}" #TODO: Try this? 819 | concatenated_text = f"{explantation}{question}{answer}" 820 | ids = tokenize_with_pad(text=concatenated_text, 821 | pad_to_multiple=pad_to_multiple) 822 | return {'ids': ids, 'len': len(ids)} 823 | 824 | # Tokenize and map the dataset 825 | tokenized = {} 826 | for split, dset in data_dict.items(): 827 | tokenized[split] = dset.map( 828 | process, 829 | remove_columns=['question', 'distractor1', 'distractor2', 'distractor3','correct_answer', 'support'], 830 | desc=f"Tokenizing {split} split", 831 | num_proc=num_proc 832 | ) 833 | 834 | # Concatenate all the token IDs into one large binary file per split 835 | for split, dset in tokenized.items(): 836 | # Save token IDs length 837 | len_arr = np.array(dset['len'], dtype=np.uint16) 838 | with open(os.path.join(SCIQ_DATA_PATH, f'{split}.len'), 'wb') as f: 839 | np.save(f, len_arr) 840 | # total number of tokens 841 | arr_len = np.sum(dset['len']) 842 | filename = os.path.join(SCIQ_DATA_PATH, f'{split}.bin') 843 | dtype = np.uint16 844 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 845 | total_batches = 10 846 | 847 | idx = 0 848 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'): 849 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict() 850 | arr_batch = np.concatenate(batch['ids']) 851 | arr[idx: idx + len(arr_batch)] = arr_batch 852 | idx += len(arr_batch) 853 | arr.flush() 854 | 855 | 856 | # Load tokenized binary files for training, validation, and test 857 | train_data = np.memmap(os.path.join(SCIQ_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 858 | val_data = np.memmap(os.path.join(SCIQ_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 859 | print(f'Benchmark SCIQ: train[{len(train_data)}] | val[{len(val_data)}]') 860 | 861 | if return_torch: 862 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16)) 863 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16)) 864 | 865 | return { 866 | 'train': train_data, 867 | 'train_len': np.load(os.path.join(SCIQ_DATA_PATH, 'train.len')), 868 | 'val': val_data, 869 | 'val_len': np.load(os.path.join(SCIQ_DATA_PATH, 'val.len')), 870 | } 871 | 872 | 873 | 874 | SUPPORTED_TASK_MAP = {"arc_easy": get_arc_easy, 875 | "arc_challenge": get_arc_challenge, 876 | "hellaswag":get_hellaswag, 877 | "logiqa": get_logiqa, 878 | "piqa": get_piqa, 879 | "sciq": get_sciq, 880 | "humaneval": get_humaneval, 881 | "gsm8k": get_gsm8k, 882 | "kodcode": get_kodcode, 883 | "mathqa": get_mathqa, 884 | "medqa": get_medqa} -------------------------------------------------------------------------------- /src/data/openwebtext2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import tiktoken 5 | from datasets import load_dataset 6 | 7 | 8 | OWT2_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/openwebtext2/") 9 | tknzr = tiktoken.get_encoding("gpt2") 10 | 11 | def get_openwebtext2_data(num_proc=40): 12 | """ https://openwebtext2.readthedocs.io/en/latest/ 13 | """ 14 | if not os.path.exists(os.path.join(OWT2_DATA_PATH, 'train.bin')): 15 | os.makedirs(OWT2_DATA_PATH, exist_ok=True) 16 | dataset = load_dataset("the_pile_openwebtext2") 17 | 18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 19 | split_dataset['val'] = split_dataset.pop('test') 20 | 21 | def process(example): 22 | ids = tknzr.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 23 | ids.append(tknzr.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 24 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 25 | out = {'ids': ids, 'len': len(ids)} 26 | return out 27 | 28 | # tokenize the dataset 29 | tokenized = split_dataset.map( 30 | process, 31 | remove_columns=['text'], 32 | desc="tokenizing the splits", 33 | num_proc=num_proc, 34 | ) 35 | 36 | # concatenate all the ids in each dataset into one large file we can use for training 37 | for split, dset in tokenized.items(): 38 | arr_len = np.sum(dset['len']) 39 | filename = os.path.join(OWT2_DATA_PATH, f'{split}.bin') 40 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 41 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 42 | total_batches = 1024 43 | 44 | idx = 0 45 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 46 | # Batch together samples for faster write 47 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 48 | arr_batch = np.concatenate(batch['ids']) 49 | # Write into mmap 50 | arr[idx : idx + len(arr_batch)] = arr_batch 51 | idx += len(arr_batch) 52 | arr.flush() 53 | 54 | train_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 55 | val_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 56 | 57 | return {'train': train_data, 'val': val_data} 58 | 59 | -------------------------------------------------------------------------------- /src/data/shakespeare.py: -------------------------------------------------------------------------------- 1 | import os 2 | from string import ascii_letters, digits, punctuation 3 | 4 | import numpy as np 5 | import requests 6 | 7 | 8 | _char_decode = dict(enumerate(sorted(set(ascii_letters + digits + punctuation + " \n")))) 9 | _char_encode = {char: i for i, char in _char_decode.items()} 10 | 11 | 12 | def char_tknzr(txt: str): 13 | return [_char_encode[char] for char in txt if char in _char_encode] 14 | 15 | 16 | DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets", "shakespeare") 17 | 18 | def get_shakespeare_data(): 19 | """Inspired from https://github.com/karpathy/nanoGPT/""" 20 | raw_path = os.path.join(DATA_PATH, "raw.txt") 21 | train_path = os.path.join(DATA_PATH, f"train.npy") 22 | test_path = os.path.join(DATA_PATH, f"test.npy") 23 | 24 | # if path is not even there, download all data 25 | if not os.path.exists(DATA_PATH): 26 | print("Downloading raw Shakespeare texts") 27 | url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 28 | os.makedirs(DATA_PATH, exist_ok=True) 29 | text = requests.get(url, timeout=60).text 30 | with open(raw_path, "w+", encoding="utf8") as f: 31 | f.write(text) 32 | 33 | # attempt to find cached version for current tokenizer 34 | if not os.path.exists(train_path) or not os.path.exists(test_path): 35 | print("Tokenizing Shakespeare texts") 36 | # load text 37 | with open(raw_path, encoding="utf8") as f: 38 | text = "".join(f.readlines()) 39 | i = int(0.8*len(text)) 40 | # encode text 41 | x = np.array(char_tknzr(text[:i]), dtype=np.uint16) 42 | x_test = np.array(char_tknzr(text[i:]), dtype=np.uint16) 43 | # map memory 44 | mem = np.memmap(train_path, dtype=np.uint16, mode="w+", shape=x.shape) 45 | mem[:] = x 46 | mem = np.memmap(test_path, dtype=np.uint16, mode="w+", shape=x_test.shape) 47 | mem[:] = x_test 48 | 49 | # at this point we know that the binfile was properly created so we load it 50 | return {"train": np.memmap(train_path, dtype=np.uint16, mode="r"), 51 | "val": np.memmap(test_path, dtype=np.uint16, mode="r")} 52 | -------------------------------------------------------------------------------- /src/data/slimpajama.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import tiktoken 4 | from datasets import load_dataset 5 | import os 6 | 7 | 8 | SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/") 9 | SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1") 10 | 11 | 12 | tknzr = tiktoken.get_encoding("gpt2") 13 | 14 | 15 | def get_slimpajama_data(num_proc=40): 16 | if not os.path.exists(os.path.join(SPJ_DATA_PATH, "train.bin")): 17 | os.makedirs(SPJ_DATA_PATH, exist_ok=True) 18 | dataset = load_dataset("DKYoon/SlimPajama-6B") 19 | 20 | split_dataset = dataset["train"].train_test_split( 21 | test_size=0.0005, seed=2357, shuffle=True 22 | ) 23 | split_dataset["val"] = split_dataset.pop("test") 24 | 25 | def process(example): 26 | ids = tknzr.encode_ordinary( 27 | example["text"] 28 | ) # encode_ordinary ignores any special tokens 29 | ids.append( 30 | tknzr.eot_token 31 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 32 | out = {"ids": ids, "len": len(ids)} 33 | return out 34 | 35 | # tokenize the dataset 36 | tokenized = split_dataset.map( 37 | process, 38 | remove_columns=["text"], 39 | desc="tokenizing the splits", 40 | num_proc=num_proc, 41 | ) 42 | 43 | # concatenate all the ids in each dataset into one large file we can use for training 44 | for split, dset in tokenized.items(): 45 | arr_len = np.sum(dset["len"]) 46 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin") 47 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 48 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 49 | total_batches = min(1024, len(dset)) 50 | 51 | idx = 0 52 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 53 | # Batch together samples for faster write 54 | batch = dset.shard( 55 | num_shards=total_batches, index=batch_idx, contiguous=True 56 | ).with_format("numpy") 57 | arr_batch = np.concatenate(batch["ids"]) 58 | # Write into mmap 59 | arr[idx : idx + len(arr_batch)] = arr_batch 60 | idx += len(arr_batch) 61 | arr.flush() 62 | 63 | train_data = np.memmap( 64 | os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" 65 | ) 66 | val_data = np.memmap( 67 | os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" 68 | ) 69 | 70 | return {"train": train_data, "val": val_data} 71 | 72 | 73 | def get_slimpajama_chunk1(num_proc=40): 74 | if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")): 75 | os.makedirs(SPJ_DATA_PATH, exist_ok=True) 76 | dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1") 77 | 78 | split_dataset = dataset["train"].train_test_split( 79 | test_size=0.0005, seed=2357, shuffle=True 80 | ) 81 | split_dataset["val"] = split_dataset.pop("test") 82 | 83 | def process(example): 84 | ids = tknzr.encode_ordinary( 85 | example["text"] 86 | ) # encode_ordinary ignores any special tokens 87 | ids.append( 88 | tknzr.eot_token 89 | ) # add the end of text token, e.g. 50256 for gpt2 bpe 90 | out = {"ids": ids, "len": len(ids)} 91 | return out 92 | 93 | # tokenize the dataset 94 | tokenized = split_dataset.map( 95 | process, 96 | remove_columns=["text"], 97 | desc="tokenizing the splits", 98 | num_proc=num_proc, 99 | ) 100 | 101 | # concatenate all the ids in each dataset into one large file we can use for training 102 | for split, dset in tokenized.items(): 103 | arr_len = np.sum(dset["len"]) 104 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin") 105 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 106 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,)) 107 | total_batches = min(1024, len(dset)) 108 | 109 | idx = 0 110 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): 111 | # Batch together samples for faster write 112 | batch = dset.shard( 113 | num_shards=total_batches, index=batch_idx, contiguous=True 114 | ).with_format("numpy") 115 | arr_batch = np.concatenate(batch["ids"]) 116 | # Write into mmap 117 | arr[idx : idx + len(arr_batch)] = arr_batch 118 | idx += len(arr_batch) 119 | arr.flush() 120 | 121 | train_data = np.memmap( 122 | os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r" 123 | ) 124 | val_data = np.memmap( 125 | os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r" 126 | ) 127 | 128 | return {"train": train_data, "val": val_data} 129 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | import torch 4 | 5 | from .shakespeare import get_shakespeare_data 6 | from .wikitext import get_wikitext_data 7 | from .arxiv import get_arxiv_2000, get_arxiv_full 8 | from .openwebtext2 import get_openwebtext2_data 9 | from .slimpajama import get_slimpajama_data 10 | from .benchmarks import get_mathqa 11 | 12 | 13 | def get_dataset(args) -> Dict[str, np.ndarray]: 14 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is 15 | contained in its own python file. The expected format at the moment is a dictionary of np.memmap 16 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """ 17 | if args.dataset == 'wikitext': 18 | return get_wikitext_data() 19 | if args.dataset == "shakespeare-char": 20 | return get_shakespeare_data() 21 | if args.dataset == "arxiv2000": 22 | return get_arxiv_2000() 23 | if args.dataset == "arxiv": 24 | return get_arxiv_full() 25 | if args.dataset == "arxiv+wiki": 26 | arxiv_data = get_arxiv_full() 27 | wiki_data = get_wikitext_data() 28 | train_data = np.concatenate((arxiv_data['train'], wiki_data['train'])) 29 | val_data = np.concatenate((arxiv_data['val'], wiki_data['val'])) 30 | return {'train': train_data, 'val': val_data} 31 | if args.dataset == 'openwebtext2': 32 | return get_openwebtext2_data() 33 | if args.dataset == "slimpajama": 34 | return get_slimpajama_data() 35 | if args.dataset == "mathqa": 36 | return get_mathqa() 37 | else: 38 | raise NotImplementedError(f"Unknow dataset key '{args.dataset}'") 39 | 40 | class Dataset(torch.utils.data.Dataset): 41 | def __init__(self, data, sequence_length): 42 | super().__init__() 43 | self.data = data 44 | self.sequence_length = sequence_length 45 | 46 | def __len__(self): 47 | total_length = len(self.data) 48 | # chunk the data into sequences of length `sequence_length` 49 | # NOTE: we discard the last remainding sequence if it's not of length `sequence_length` 50 | return (total_length - 1) // self.sequence_length 51 | 52 | def __getitem__(self, idx): 53 | seq_length = self.sequence_length 54 | idx = idx * seq_length 55 | x = torch.from_numpy((self.data[idx : idx + seq_length]).astype(np.int64)) 56 | 57 | y = torch.from_numpy( 58 | (self.data[idx + 1 : idx + 1 + seq_length]).astype(np.int64) 59 | ) 60 | return x, y 61 | 62 | 63 | def get_dataloader(data, sequence_length, batch_size, seed=0, distributed_backend=None): 64 | """Create a DataLoader for the given data. If distributed_backend is provided and is truly 65 | distributed (world size > 1), the DataLoader will be created with a DistributedSampler that 66 | splits the data across the processes (in conjunction with DDP). 67 | Otherwise, use a RandomSampler with the specified seed. 68 | 69 | Returns both the dataloader and the sampler. 70 | """ 71 | dataset = Dataset(data, sequence_length=sequence_length) 72 | if distributed_backend and distributed_backend.get_world_size() > 1: 73 | sampler = torch.utils.data.DistributedSampler( 74 | dataset, 75 | shuffle=True, 76 | seed=seed, 77 | ) 78 | else: 79 | g = torch.Generator() 80 | g.manual_seed(seed) 81 | sampler = torch.utils.data.RandomSampler( 82 | dataset, replacement=False, generator=g 83 | ) 84 | 85 | loader = torch.utils.data.DataLoader( 86 | dataset, 87 | sampler=sampler, 88 | batch_size=batch_size, 89 | num_workers=4, 90 | ) 91 | return loader, sampler 92 | -------------------------------------------------------------------------------- /src/data/wikitext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import urllib 4 | import numpy as np 5 | import tiktoken 6 | 7 | 8 | WIKITEXT_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/wikitext/") 9 | 10 | 11 | def get_wikitext_data(): 12 | """ Inspired from https://github.com/tysam-code/hlb-gpt """ 13 | if not os.path.exists(WIKITEXT_DATA_PATH): 14 | os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True) 15 | print("downloading data and tokenizing (1-2 min)") 16 | raw_data_source = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip' 17 | urllib.request.urlretrieve(raw_data_source, os.path.join(WIKITEXT_DATA_PATH,'data.zip')) 18 | 19 | with zipfile.ZipFile(os.path.join(WIKITEXT_DATA_PATH, "data.zip"), 'r') as zip_ref: 20 | zip_ref.extractall(WIKITEXT_DATA_PATH) 21 | 22 | with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), 'r') as data_file: 23 | raw_train_data = data_file.read() 24 | 25 | with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), 'r') as data_file: 26 | raw_eval_data = data_file.read() 27 | 28 | tokenizer = tiktoken.get_encoding("gpt2") 29 | raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data) 30 | raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data) 31 | 32 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16) 33 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16) 34 | 35 | train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'train.bin')) 36 | eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'val.bin')) 37 | print("completed the tokenization process!") 38 | 39 | train_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r') 40 | val_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r') 41 | 42 | return {'train': train_data, 'val': val_data} 43 | -------------------------------------------------------------------------------- /src/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import ddp 3 | from . import single 4 | 5 | BACKEND_TYPE_TO_MODULE_MAP = { 6 | "nccl": ddp.DataParallelDistributedBackend, 7 | None: single.SinlgeNodeBackend, 8 | } 9 | 10 | 11 | def make_backend_from_args(args): 12 | return BACKEND_TYPE_TO_MODULE_MAP[args.distributed_backend](args) 13 | 14 | 15 | def registered_backends(): 16 | return BACKEND_TYPE_TO_MODULE_MAP.keys() 17 | -------------------------------------------------------------------------------- /src/distributed/backend.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List 3 | 4 | 5 | class DistributedBackend(object): 6 | 7 | def __init__(self, args): 8 | pass 9 | 10 | def transform_model(self, model): 11 | raise NotImplementedError 12 | 13 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): 14 | raise NotImplementedError 15 | 16 | def is_master_process(self) -> bool: 17 | raise NotImplementedError 18 | 19 | def get_adjusted_args_for_process(self, args): 20 | raise NotImplementedError 21 | 22 | def get_raw_model(self, model): 23 | raise NotImplementedError 24 | 25 | def translate_model_parameter_name_for_node(self, parameter_name) -> List[str]: 26 | raise NotImplementedError 27 | 28 | def get_world_size(self): 29 | raise NotImplementedError 30 | 31 | def finalize(self): 32 | pass 33 | 34 | def sync(self): 35 | pass 36 | -------------------------------------------------------------------------------- /src/distributed/ddp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from contextlib import contextmanager 4 | 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | from torch.distributed import init_process_group, destroy_process_group, get_world_size, barrier 7 | 8 | from .backend import DistributedBackend 9 | 10 | 11 | class DataParallelDistributedBackend(DistributedBackend): 12 | 13 | def __init__(self, args): 14 | self.rank = int(os.environ.get('RANK', -1)) 15 | assert self.rank != -1, "DDP backend can not be used without rank" 16 | assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices" 17 | init_process_group(backend=args.distributed_backend) 18 | self.local_rank = int(os.environ['LOCAL_RANK']) 19 | 20 | def get_adjusted_args_for_process(self, args): 21 | effective_batch_size = args.batch_size * args.acc_steps 22 | world_size = self.get_world_size() 23 | if args.acc_steps % world_size != 0: 24 | raise ValueError(f"Number of accumulation steps " 25 | "{args.acc_steps} is not divisible " 26 | "by the world size {world_size}.") 27 | if effective_batch_size % world_size != 0: 28 | raise ValueError(f"Effective batch size " 29 | "{effective_batch_size} is not divisible " 30 | "by the world size {world_size}.") 31 | acc_steps_div = math.gcd(args.acc_steps, world_size) 32 | args.acc_steps = args.acc_steps // acc_steps_div 33 | args.batch_size = args.batch_size // (world_size // acc_steps_div) 34 | args.device = f'cuda:{self.local_rank}' 35 | args.seed = args.seed + self.local_rank 36 | return args 37 | 38 | def transform_model(self, model): 39 | return DDP(model, device_ids=[self.local_rank]) 40 | 41 | @contextmanager 42 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps): 43 | model.require_backward_grad_sync = ( 44 | microstep_idx == gradient_accumulation_steps - 1) 45 | yield 46 | 47 | def is_master_process(self) -> bool: 48 | return self.rank == 0 49 | 50 | def get_raw_model(self, model): 51 | return model.module 52 | 53 | def translate_model_parameter_name_for_node(self, parameter_name): 54 | return [f'module.{parameter_name}'] 55 | 56 | def get_world_size(self): 57 | return get_world_size() 58 | 59 | def finalize(self): 60 | destroy_process_group() 61 | 62 | def sync(self): 63 | barrier() 64 | -------------------------------------------------------------------------------- /src/distributed/single.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | 3 | from .backend import DistributedBackend 4 | 5 | 6 | class SinlgeNodeBackend(DistributedBackend): 7 | 8 | def transform_model(self, model): 9 | return model 10 | 11 | def get_context_for_microstep_forward(self, *args, **kwargs): 12 | return nullcontext() 13 | 14 | def get_adjusted_args_for_process(self, args): 15 | return args 16 | 17 | def is_master_process(self) -> bool: 18 | return True 19 | 20 | def get_raw_model(self, model): 21 | return model 22 | 23 | def get_world_size(self): 24 | return 1 25 | 26 | def translate_model_parameter_name_for_node(self, parameter_name): 27 | return [parameter_name] 28 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import inspect 6 | import json 7 | import copy 8 | import argparse 9 | import random 10 | import wandb 11 | 12 | import config 13 | from models.utils import get_model 14 | from data.utils import get_dataset 15 | from optim.base import train_base 16 | import distributed 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser(allow_abbrev=False) 21 | parser.add_argument('--config_format', default='base', choices=config.registered_formats()) 22 | 23 | args, rem_args = parser.parse_known_args() 24 | 25 | return config.parse_args_with_format(format=args.config_format, base_parser=parser, args=rem_args, namespace=args) 26 | 27 | 28 | def main(args): 29 | 30 | torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training 31 | torch.backends.cudnn.allow_tf32 = True 32 | 33 | distributed_backend = distributed.make_backend_from_args(args) 34 | args = distributed_backend.get_adjusted_args_for_process(args) 35 | 36 | args.device = torch.device(args.device) 37 | device_type = "cuda" if "cuda" in str(args.device) else "cpu" 38 | if device_type == "cuda": 39 | torch.cuda.set_device(args.device) 40 | 41 | torch.manual_seed(args.seed) 42 | random.seed(args.seed) 43 | np.random.seed(args.seed) 44 | 45 | print(f"Loading dataset '{args.dataset}'") 46 | 47 | data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized} 48 | if args.data_in_ram: 49 | data = {'train': np.array(data['train']), 'val': np.array(data['val'])} 50 | 51 | print(f"Num training tokens: {len(data['train'])}") 52 | print(f"Num validation tokens: {len(data['val'])}") 53 | 54 | model = get_model(args).to(args.device) # todo: take care of initializing the model if args.use_pretrained != 'none' 55 | 56 | model = distributed_backend.transform_model(model) 57 | 58 | group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs() 59 | param_name_mapping = {p_name: p for p_name, p in model.named_parameters()} 60 | optimized_params_cnt = 0 61 | for g in group_specs: 62 | params = [] 63 | for p_name in g["params"]: 64 | translated_p_names = distributed_backend.translate_model_parameter_name_for_node(p_name) 65 | params += [param_name_mapping[p_name] for p_name in translated_p_names] 66 | g["params"] = params 67 | optimized_params_cnt += sum([p.numel() for p in g["params"]]) 68 | print("number of optimized parameters: %.2fM" % (optimized_params_cnt/1e6,)) 69 | if args.opt == 'adamw': 70 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) 71 | print(f"using fused AdamW: {use_fused}") 72 | extra_args = dict(fused=True) if use_fused else dict() 73 | opt = torch.optim.AdamW(group_specs, lr=args.lr, betas=(args.beta1, args.beta2), 74 | weight_decay=args.weight_decay, **extra_args) 75 | else: 76 | opt = torch.optim.SGD(group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 77 | 78 | if args.scheduler != 'none': 79 | if args.scheduler in ['cos', 'linear']: 80 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=args.lr, total_steps=args.iterations, 81 | pct_start=args.warmup_percent, anneal_strategy=args.scheduler, 82 | cycle_momentum=False, div_factor=1e2, final_div_factor=.1) 83 | else: 84 | raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.") 85 | else: 86 | scheduler = None 87 | 88 | args.world_size = distributed_backend.get_world_size() 89 | exp_name = args.exp_name 90 | if distributed_backend.is_master_process() and args.wandb: 91 | params_copy = copy.deepcopy(vars(args)) 92 | del params_copy['device'] 93 | wandb.init(project=args.wandb_project, name=exp_name, config=params_copy) 94 | 95 | ckpt_path = os.path.join(args.results_base_folder, args.dataset, args.model, exp_name) 96 | if not os.path.exists(ckpt_path): 97 | if distributed_backend.is_master_process(): 98 | os.makedirs(ckpt_path) 99 | distributed_backend.sync() 100 | elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed 101 | print(f"Already found experiment '{ckpt_path}'.\nSkipping.") 102 | sys.exit(0) 103 | 104 | itr = 0 105 | rng_state_dict = None 106 | if args.use_pretrained == "auto": 107 | checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file] 108 | if checkpoints: 109 | args.use_pretrained = sorted(checkpoints)[-1] 110 | else: 111 | args.use_pretrained = None 112 | 113 | if args.use_pretrained is not None: 114 | last_ckpt_path = args.use_pretrained 115 | print(f"Resuming from {last_ckpt_path}") 116 | checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path)) 117 | model_state_dict = {distributed_backend.translate_model_parameter_name_for_node(k.replace("_orig_mod.", ""))[0]:v for k,v in checkpoint['model'].items()} 118 | # FIXME checkpoints from compiled model have _orig_mod keyword 119 | 120 | optimizer_state_dict = checkpoint['optimizer'] 121 | rng_state_dict = { 122 | module: checkpoint[module] for module in [ 123 | "cpu_rng_state", 124 | "gpu_rng_state", 125 | "numpy_rng_state", 126 | "py_rng_state", 127 | "train_sampler_state" 128 | ] 129 | } 130 | 131 | model.load_state_dict(model_state_dict) 132 | opt.load_state_dict(optimizer_state_dict) 133 | itr = checkpoint['itr'] 134 | if scheduler is not None: 135 | scheduler_state_dict = checkpoint['scheduler'] 136 | scheduler.load_state_dict(scheduler_state_dict) 137 | 138 | if args.model in ['base', 'llama2']: # all train functions have the same interface 139 | train = train_base 140 | else: 141 | raise NotImplementedError(f"No training method implemented for model type '{args.model}'.") 142 | 143 | print(f"\nTraining model={args.model} \n{vars(args)}\n") 144 | 145 | stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length, 146 | eval_freq=args.eval_freq, 147 | distributed_backend=distributed_backend, 148 | ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args) 149 | 150 | args.device = None 151 | args.dtype = None 152 | stats['args'] = vars(args) 153 | if distributed_backend.is_master_process(): 154 | with open(f"{ckpt_path}/summary.json", "w") as fs: 155 | json.dump(stats, fs) 156 | distributed_backend.finalize() 157 | 158 | 159 | if __name__ == "__main__": 160 | args = get_args() 161 | main(args) 162 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | 13 | import tiktoken 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 21 | 22 | def __init__(self, ndim, bias): 23 | super().__init__() 24 | self.weight = nn.Parameter(torch.ones(ndim)) 25 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 26 | 27 | def forward(self, input): 28 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 29 | 30 | 31 | class CausalSelfAttention(nn.Module): 32 | 33 | def __init__(self, config): 34 | super().__init__() 35 | assert config.n_embd % config.n_head == 0 36 | # key, query, value projections for all heads, but in a batch 37 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 38 | # output projection 39 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 40 | # regularization 41 | self.attn_dropout = nn.Dropout(config.dropout) 42 | self.resid_dropout = nn.Dropout(config.dropout) 43 | self.n_head = config.n_head 44 | self.n_embd = config.n_embd 45 | self.dropout = config.dropout 46 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 47 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 48 | if not self.flash: 49 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 50 | # causal mask to ensure that attention is only applied to the left in the input sequence 51 | self.register_buffer("bias", torch.tril(torch.ones(config.sequence_length, config.sequence_length)) 52 | .view(1, 1, config.sequence_length, config.sequence_length)) 53 | 54 | def forward(self, x): 55 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 56 | 57 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 58 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 59 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 61 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 62 | 63 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 64 | if self.flash: 65 | # efficient attention using Flash Attention CUDA kernels 66 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) 67 | else: 68 | # manual implementation of attention 69 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 70 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 71 | att = F.softmax(att, dim=-1) 72 | att = self.attn_dropout(att) 73 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 74 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 75 | 76 | # output projection 77 | y = self.resid_dropout(self.c_proj(y)) 78 | return y 79 | 80 | 81 | class MLP(nn.Module): 82 | 83 | def __init__(self, config): 84 | super().__init__() 85 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 86 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 87 | self.dropout = nn.Dropout(config.dropout) 88 | self.activation = nn.GELU() 89 | 90 | def forward(self, x): 91 | x = self.c_fc(x) 92 | x = self.activation(x) 93 | x = self.c_proj(x) 94 | x = self.dropout(x) 95 | return x 96 | 97 | 98 | class Block(nn.Module): 99 | 100 | def __init__(self, config): 101 | super().__init__() 102 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 103 | self.attn = CausalSelfAttention(config) 104 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 105 | self.mlp = MLP(config) 106 | 107 | def forward(self, x): 108 | x = x + self.attn(self.ln_1(x)) 109 | x = x + self.mlp(self.ln_2(x)) 110 | return x 111 | 112 | 113 | class GPTBase(nn.Module): 114 | 115 | def __init__(self, config): 116 | super().__init__() 117 | assert config.vocab_size is not None 118 | assert config.sequence_length is not None 119 | self.config = config 120 | self.tokenizer = tiktoken.get_encoding("gpt2") 121 | 122 | self.transformer = nn.ModuleDict(dict( 123 | wte = nn.Embedding(config.vocab_size, config.n_embd), 124 | wpe = nn.Embedding(config.sequence_length, config.n_embd), 125 | drop = nn.Dropout(config.dropout), 126 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 127 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 128 | )) 129 | 130 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 131 | # with weight tying when using torch.compile() some warnings get generated: 132 | # "UserWarning: functional_call was passed multiple values for tied weights. 133 | # This behavior is deprecated and will be an error in future versions" 134 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 135 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 136 | 137 | # init all weights 138 | self.apply(self._init_weights) 139 | # apply special scaled init to the residual projections, per GPT-2 paper 140 | for pn, p in self.named_parameters(): 141 | if pn.endswith('c_proj.weight'): 142 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 143 | 144 | # report number of parameters 145 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 146 | 147 | def get_num_params(self, non_embedding=True): 148 | """ 149 | Return the number of parameters in the model. 150 | For non-embedding count (default), the position embeddings get subtracted. 151 | The token embeddings would too, except due to the parameter sharing these 152 | params are actually used as weights in the final layer, so we include them. 153 | """ 154 | n_params = sum(p.numel() for p in self.parameters()) 155 | if non_embedding: 156 | n_params -= self.transformer.wpe.weight.numel() 157 | return n_params 158 | 159 | def _init_weights(self, module): 160 | if isinstance(module, nn.Linear): 161 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 162 | if module.bias is not None: 163 | torch.nn.init.zeros_(module.bias) 164 | elif isinstance(module, nn.Embedding): 165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 166 | 167 | def forward(self, idx, targets=None, get_logits=False): 168 | device = idx.device 169 | b, t = idx.size() 170 | assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" 171 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 172 | 173 | # forward the GPT model itself 174 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 175 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 176 | x = self.transformer.drop(tok_emb + pos_emb) 177 | for block in self.transformer.h: 178 | x = block(x) 179 | x = self.transformer.ln_f(x) 180 | 181 | if targets is not None: 182 | # if we are given some desired targets also calculate the loss 183 | logits = self.lm_head(x) 184 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 185 | else: 186 | # inference-time mini-optimization: only forward the lm_head on the very last position 187 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 188 | loss = None 189 | logits = logits if get_logits else None 190 | return {'logits': logits, 'loss': loss} 191 | 192 | def crop_sequence_length(self, sequence_length): 193 | # model surgery to decrease the block size if necessary 194 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 195 | # but want to use a smaller block size for some smaller, simpler model 196 | assert sequence_length <= self.config.sequence_length 197 | self.config.sequence_length = sequence_length 198 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:sequence_length]) 199 | for block in self.transformer.h: 200 | block.attn.bias = block.attn.bias[:,:,:sequence_length,:sequence_length] 201 | 202 | @classmethod 203 | def from_pretrained(cls, model_type, override_args=None): 204 | # TODO 205 | pass 206 | 207 | def get_parameter_group_specs(self): 208 | """ 209 | This long function is unfortunately doing something very simple and is being very defensive: 210 | We are separating out all parameters of the model into two buckets: those that will experience 211 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 212 | We are then returning the PyTorch optimizer object. 213 | """ 214 | 215 | # separate out all parameters to those that will and won't experience regularizing weight decay 216 | decay = set() 217 | no_decay = set() 218 | whitelist_weight_modules = (torch.nn.Linear,) 219 | # need to do import here to avoid circular import (since llama imports from base here) 220 | from .utils import BLACKLIST_WEIGHT_MODULES 221 | 222 | for mn, m in self.named_modules(): 223 | for pn, p in m.named_parameters(): 224 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 225 | # random note: because named_modules and named_parameters are recursive 226 | # we will see the same tensors p many many times. but doing it this way 227 | # allows us to know which parent module any tensor p belongs to... 228 | if pn.endswith("bias"): 229 | # all biases will not be decayed 230 | no_decay.add(fpn) 231 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 232 | # weights of whitelist modules will be weight decayed 233 | decay.add(fpn) 234 | elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES): 235 | # weights of blacklist modules will NOT be weight decayed 236 | no_decay.add(fpn) 237 | 238 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they 239 | # will appear in the no_decay and decay sets respectively after the above. 240 | # In addition, because named_parameters() doesn't return duplicates, it 241 | # will only return the first occurence, key'd by 'transformer.wte.weight', below. 242 | # so let's manually remove 'lm_head.weight' from decay set. This will include 243 | # this tensor into optimization via transformer.wte.weight only, and not decayed. 244 | decay.remove("lm_head.weight") 245 | 246 | # validate that we considered every parameter 247 | param_dict = {pn: p for pn, p in self.named_parameters()} 248 | inter_params = decay & no_decay 249 | union_params = decay | no_decay 250 | assert ( 251 | len(inter_params) == 0 252 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 253 | assert ( 254 | len(param_dict.keys() - union_params) == 0 255 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 256 | str(param_dict.keys() - union_params), 257 | ) 258 | 259 | # create the pytorch optimizer object 260 | return [ 261 | {"params": sorted(list(decay))}, 262 | {"params": sorted(list(no_decay)), "weight_decay": 0.0}, 263 | ] 264 | 265 | 266 | @torch.no_grad() 267 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 268 | """ 269 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 270 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 271 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 272 | """ 273 | for _ in range(max_new_tokens): 274 | # if the sequence context is growing too long we must crop it at sequence_length 275 | idx_cond = idx if idx.size(1) <= self.config.sequence_length else idx[:, -self.config.sequence_length:] 276 | # forward the model to get the logits for the index in the sequence 277 | logits = self(idx_cond, get_logits=True)['logits'] 278 | # pluck the logits at the final step and scale by desired temperature 279 | logits = logits[:, -1, :] / temperature 280 | # optionally crop the logits to only the top k options 281 | if top_k is not None: 282 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 283 | logits[logits < v[:, [-1]]] = -float('Inf') 284 | # apply softmax to convert logits to (normalized) probabilities 285 | probs = F.softmax(logits, dim=-1) 286 | # sample from the distribution 287 | idx_next = torch.multinomial(probs, num_samples=1) 288 | # append sampled index to the running sequence and continue 289 | idx = torch.cat((idx, idx_next), dim=1) 290 | 291 | return idx 292 | 293 | @torch.no_grad() 294 | def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None): 295 | idx = torch.tensor(self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})).view(1,-1).to(self.lm_head.weight.device) 296 | out_idx = self.generate(idx, max_new_tokens, temperature, top_k).view(-1).to('cpu').numpy() 297 | return self.tokenizer.decode(out_idx) 298 | -------------------------------------------------------------------------------- /src/models/llama.py: -------------------------------------------------------------------------------- 1 | """ 2 | Llama style Language Model. 3 | References: 4 | 1) Llama inference code: 5 | https://github.com/facebookresearch/llama/blob/main/llama/model.py 6 | 2) Mistral one file ref: 7 | https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py 8 | 3) Llama paper: 9 | https://arxiv.org/pdf/2302.13971.pdf 10 | 11 | Main differences from GPT2: 12 | * Uses RMSNorm instead of LayerNorm 13 | * Uses a slightly different MLP (SwiGLU) 14 | * rotary embeddings (RoPE) 15 | """ 16 | 17 | import math 18 | 19 | import tiktoken 20 | import torch 21 | import torch.nn as nn 22 | from torch.nn import functional as F 23 | from models.base import CausalSelfAttention, GPTBase 24 | 25 | 26 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: 27 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 28 | t = torch.arange(end, device=freqs.device) # type: ignore 29 | freqs = torch.outer(t, freqs).float() # type: ignore 30 | return torch.polar(torch.ones_like(freqs), freqs) # complex64 31 | 32 | 33 | def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 34 | """ 35 | freqs_cis: complex - (seq_len, head_dim / 2) 36 | x: complex - (bsz, seq_len, head_dim / 2) 37 | """ 38 | ndim = x.ndim 39 | assert 1 < ndim 40 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( 41 | freqs_cis.shape, 42 | (x.shape[1], x.shape[-1]), 43 | ) 44 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 45 | return freqs_cis.view(*shape) 46 | 47 | 48 | def apply_rotary_emb(q, k, freqs_cis): 49 | # q, k: (B, T, nh, hs) 50 | # freq_cis: (T, hs) 51 | # return: (B, T, nh, hs), (B, T, nh, hs) 52 | q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2)) 53 | k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2)) 54 | freqs_cis = _reshape_for_broadcast(freqs_cis, q_) 55 | xq_out = torch.view_as_real(q_ * freqs_cis).flatten(3) 56 | xk_out = torch.view_as_real(k_ * freqs_cis).flatten(3) 57 | return xq_out.type_as(q), xk_out.type_as(k) 58 | 59 | 60 | class RMSNorm(nn.Module): 61 | def __init__(self, dim: int, eps: float = 1e-6): 62 | super().__init__() 63 | self.eps = eps 64 | self.weight = nn.Parameter(torch.ones(dim)) 65 | 66 | def _norm(self, x): 67 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 68 | 69 | def forward(self, x): 70 | output = self._norm(x.float()).type_as(x) 71 | return output * self.weight 72 | 73 | 74 | class LlamaMLP(nn.Module): 75 | def __init__(self, config): 76 | super().__init__() 77 | 78 | hidden_dim = config.n_embd * 4 79 | hidden_dim = int(2 * hidden_dim / 3) 80 | hidden_dim = config.multiple_of * ( 81 | (hidden_dim + config.multiple_of - 1) // config.multiple_of 82 | ) 83 | 84 | self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False) 85 | self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False) 86 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False) 87 | 88 | def forward(self, x): 89 | return self.c_proj(nn.functional.silu(self.w1(x)) * self.w2(x)) 90 | 91 | 92 | class LlamaAttention(CausalSelfAttention): 93 | def forward(self, x, freqs_cis): 94 | # batch size, sequence length, embedding dimensionality (n_embd) 95 | ( 96 | B, 97 | T, 98 | C, 99 | ) = x.size() 100 | 101 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 102 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 103 | # (B, T, nh, hs) 104 | k = k.view(B, T, self.n_head, C // self.n_head) 105 | q = q.view(B, T, self.n_head, C // self.n_head) 106 | q, k = apply_rotary_emb(q, k, freqs_cis) 107 | # (B, nh, T, hs) 108 | q, k = q.transpose(1, 2), k.transpose(1, 2) 109 | 110 | # (B, nh, T, hs) 111 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 112 | 113 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 114 | if self.flash: 115 | # efficient attention using Flash Attention CUDA kernels 116 | y = torch.nn.functional.scaled_dot_product_attention( 117 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True 118 | ) 119 | else: 120 | # manual implementation of attention 121 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 122 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) 123 | att = F.softmax(att, dim=-1) 124 | att = self.attn_dropout(att) 125 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 126 | y = ( 127 | y.transpose(1, 2).contiguous().view(B, T, C) 128 | ) # re-assemble all head outputs side by side 129 | 130 | # output projection 131 | y = self.resid_dropout(self.c_proj(y)) 132 | return y 133 | 134 | 135 | class LlamaBlock(nn.Module): 136 | def __init__(self, config): 137 | super().__init__() 138 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) 139 | self.attn = LlamaAttention(config) 140 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) 141 | self.mlp = LlamaMLP(config) 142 | 143 | def forward(self, x, freqs_cis): 144 | x = x + self.attn(self.ln_1(x), freqs_cis) 145 | x = x + self.mlp(self.ln_2(x)) 146 | return x 147 | 148 | 149 | class Llama(GPTBase): 150 | def __init__(self, config): 151 | super().__init__(config) 152 | assert config.vocab_size is not None 153 | assert config.sequence_length is not None 154 | self.config = config 155 | self.tokenizer = tiktoken.get_encoding("gpt2") 156 | 157 | # create the token and position embeddings 158 | self.head_dim = config.n_embd // config.n_head 159 | self.freqs_cis = precompute_freqs_cis(self.head_dim, config.sequence_length) 160 | 161 | self.transformer = nn.ModuleDict( 162 | dict( 163 | wte=nn.Embedding(config.vocab_size, config.n_embd), 164 | drop=nn.Dropout(config.dropout), 165 | h=nn.ModuleList([LlamaBlock(config) for _ in range(config.n_layer)]), 166 | ln_f=RMSNorm(config.n_embd, eps=config.rmsnorm_eps), 167 | ) 168 | ) 169 | 170 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 171 | # with weight tying when using torch.compile() some warnings get generated: 172 | # "UserWarning: functional_call was passed multiple values for tied weights. 173 | # This behavior is deprecated and will be an error in future versions" 174 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 175 | self.transformer.wte.weight = ( 176 | self.lm_head.weight 177 | ) # https://paperswithcode.com/method/weight-tying 178 | 179 | # init all weights 180 | self.apply(self._init_weights) 181 | # apply special scaled init to the residual projections, per GPT-2 paper 182 | for pn, p in self.named_parameters(): 183 | if pn.endswith("c_proj.weight"): 184 | torch.nn.init.normal_( 185 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) 186 | ) 187 | 188 | 189 | def get_num_params(self, non_embedding=True): 190 | """ 191 | Return the number of parameters in the model. 192 | For non-embedding count (default) 193 | The token embeddings would too, except due to the parameter sharing these 194 | params are actually used as weights in the final layer, so we include them. 195 | """ 196 | n_params = sum(p.numel() for p in self.parameters()) 197 | return n_params 198 | 199 | def forward(self, idx, targets=None, get_logits=False): 200 | device = idx.device 201 | b, t = idx.size() 202 | assert ( 203 | t <= self.config.sequence_length 204 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}" 205 | # shape (1, t) 206 | pos = torch.arange(0, t, dtype=torch.long, device=device) 207 | 208 | # forward the GPT model itself 209 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 210 | 211 | x = self.transformer.drop(tok_emb) 212 | freqs_cis = self.freqs_cis.to(x.device)[pos] 213 | 214 | for block_idx, block in enumerate(self.transformer.h): 215 | x = block(x, freqs_cis=freqs_cis) 216 | x = self.transformer.ln_f(x) 217 | 218 | if targets is not None: 219 | # if we are given some desired targets also calculate the loss 220 | logits = self.lm_head(x) 221 | loss = F.cross_entropy( 222 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 223 | ) 224 | else: 225 | # inference-time mini-optimization: only forward the lm_head on the very last position 226 | logits = self.lm_head( 227 | x[:, [-1], :] 228 | ) # note: using list [-1] to preserve the time dim 229 | loss = None 230 | 231 | logits = logits if get_logits else None 232 | 233 | return { 234 | "logits": logits, 235 | "loss": loss, 236 | } 237 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .llama import Llama, RMSNorm 3 | from .base import GPTBase, LayerNorm 4 | 5 | 6 | BLACKLIST_WEIGHT_MODULES = ( 7 | torch.nn.LayerNorm, 8 | LayerNorm, 9 | RMSNorm, 10 | torch.nn.Embedding, 11 | ) 12 | 13 | 14 | def get_model(args): 15 | """ Return the right model """ 16 | if args.model == 'base': 17 | model = GPTBase(args) 18 | return model 19 | elif args.model == 'llama2': 20 | model = Llama(args) 21 | return model 22 | else: 23 | raise KeyError(f"Unknown model '{args.model}'.") 24 | -------------------------------------------------------------------------------- /src/optim/base.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from data.utils import get_dataloader 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import wandb 7 | import time 8 | import itertools 9 | import copy 10 | import random 11 | import os 12 | import numpy as np 13 | from .utils import eval, get_batch, save_checkpoint 14 | 15 | 16 | def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None): 17 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu' 18 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast( 19 | device_type=device_type, dtype=torch.float16) # extra_args.dtype) 20 | best_val_loss, text_table = float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible 21 | substep = itr * acc_steps 22 | data["train"], train_sampler = get_dataloader( 23 | data["train"], 24 | sequence_length=sequence_length, 25 | batch_size=batch_size, 26 | seed=data_seed, 27 | distributed_backend=distributed_backend, 28 | ) 29 | 30 | data["val"], val_sampler = get_dataloader( 31 | data["val"], 32 | sequence_length=sequence_length, 33 | batch_size=batch_size, 34 | seed=data_seed, 35 | ) 36 | 37 | num_substeps_per_epoch = len(data["train"]) 38 | train_epochs = substep//num_substeps_per_epoch 39 | 40 | if rng_state_dict is not None and rng_state_dict.get("train_sampler_state", None) is not None: 41 | train_sampler.generator.set_state(rng_state_dict["train_sampler_state"]) 42 | if hasattr(train_sampler, "set_epoch"): 43 | train_sampler.set_epoch(train_epochs) 44 | else: 45 | sampler_state_before_iter = train_sampler.generator.get_state() 46 | data_train_iter = iter(data["train"]) 47 | 48 | 49 | # for val data we don't care about epochs? just cycle through (no need to set_epoch to reshuffle) 50 | data_val_iter = itertools.cycle(data["val"]) 51 | 52 | stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []} 53 | 54 | 55 | 56 | if extra_args.compile: 57 | print(f"Compiling model ...") 58 | model = torch.compile(model) # requires pytorch 2.0+ 59 | 60 | model.train() 61 | 62 | t0 = time.time() 63 | 64 | if rng_state_dict is not None: 65 | torch.set_rng_state(rng_state_dict["cpu_rng_state"]) 66 | torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"]) 67 | np.random.set_state(rng_state_dict["numpy_rng_state"]) 68 | random.setstate(rng_state_dict["py_rng_state"]) 69 | for _ in range(substep % num_substeps_per_epoch): 70 | get_batch(data_train_iter, device=extra_args.device) 71 | 72 | 73 | while itr < iterations: 74 | 75 | for microstep_idx in range(acc_steps): # gradient accumulation 76 | x, y = get_batch(data_train_iter, device=extra_args.device) 77 | 78 | with type_ctx: 79 | with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps): 80 | outputs = model(x, targets=y) 81 | 82 | loss = outputs['loss'] / acc_steps 83 | loss.backward() 84 | substep += 1 85 | if substep % len(data["train"]) == 0: 86 | train_epochs += 1 87 | print(f"Train epoch {train_epochs} done (full pass over training data)") 88 | if hasattr(train_sampler, "set_epoch"): 89 | # set epoch for reshuffling between epochs 90 | train_sampler.set_epoch(train_epochs) 91 | sampler_state_before_iter = None 92 | else: 93 | sampler_state_before_iter = train_sampler.generator.get_state() 94 | data_train_iter = iter(data["train"]) 95 | 96 | 97 | if extra_args.grad_clip != 0.0: 98 | torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip) 99 | opt.step() 100 | scheduler.step() 101 | opt.zero_grad(set_to_none=True) 102 | itr += 1 103 | 104 | if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above 105 | if distributed_backend.is_master_process(): 106 | t1 = time.time() 107 | dt = t1 - t0 108 | epoch = substep//num_substeps_per_epoch 109 | 110 | model.eval() 111 | train_loss = loss.detach().cpu().item() * acc_steps 112 | current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr 113 | 114 | eval_steps = ( 115 | 24 if itr < iterations else len(data["val"]) 116 | ) 117 | val_acc, val_loss, val_perplexity = eval( 118 | model, 119 | data_val_iter, 120 | extra_args.device, 121 | max_num_batches=eval_steps, 122 | ctx=type_ctx, 123 | ) 124 | 125 | print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}" 126 | print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms" 127 | if scheduler is not None: 128 | print_string += f" [lr] {current_lr:.5f}" 129 | print(print_string) 130 | 131 | if extra_args.wandb: 132 | logs = { 133 | "iter": itr, 134 | "train/loss": train_loss, 135 | "val/loss": val_loss, 136 | "val/perplexity": val_perplexity, 137 | "val/acc": val_acc, 138 | "lr": current_lr, 139 | } 140 | 141 | if itr == iterations: 142 | logs["val/final-ppl"] = val_perplexity 143 | logs["val/final-acc"] = val_acc 144 | logs["val/final-loss"] = val_loss 145 | 146 | wandb.log(logs) 147 | 148 | if extra_args.eval_seq_prefix != 'none' and (itr % (eval_freq * 5) == 0 or itr == iterations): 149 | if text_table is None: 150 | text_table = wandb.Table(columns=["itr", "val-pp", "text"]) 151 | 152 | out_str = distributed_backend.get_raw_model(model).generate_from_string( 153 | extra_args.eval_seq_prefix, max_new_tokens=40, temperature=0.9, top_k=None) 154 | text_table.add_data(itr, val_perplexity, out_str) 155 | # why a copy? see github.com/wandb/wandb/issues/2981 156 | wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)}) 157 | 158 | model.train() 159 | t0 = time.time() 160 | if distributed_backend.is_master_process(): 161 | if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0: 162 | print(f"saving checkpoint to {os.path.dirname(ckpt_path)}/ckpt_{itr}.pt") 163 | save_checkpoint(distributed_backend=distributed_backend, 164 | model=model, 165 | opt=opt, 166 | scheduler=scheduler, 167 | itr=itr, 168 | cpu_rng_state=torch.get_rng_state(), 169 | gpu_rng_state=torch.cuda.get_rng_state(), 170 | numpy_rng_state=np.random.get_state(), 171 | py_rng_state=random.getstate(), 172 | train_sampler_state=sampler_state_before_iter, 173 | ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt")) 174 | 175 | if distributed_backend.is_master_process(): 176 | print(f"saving checkpoint to {ckpt_path}") 177 | save_checkpoint(distributed_backend=distributed_backend, 178 | model=model, 179 | opt=opt, 180 | scheduler=scheduler, 181 | itr=itr, 182 | ckpt_path=ckpt_path) 183 | 184 | return stats 185 | -------------------------------------------------------------------------------- /src/optim/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from contextlib import nullcontext, contextmanager, ExitStack 5 | 6 | 7 | def get_batch(dataloader, device="cpu"): 8 | x, y = next(dataloader) 9 | if "cuda" in torch.device(device).type: 10 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 11 | x = x.pin_memory().to(device, non_blocking=True) 12 | y = y.pin_memory().to(device, non_blocking=True) 13 | else: 14 | x = x.to(device) 15 | y = y.to(device) 16 | return x, y 17 | 18 | 19 | @torch.no_grad() 20 | def eval(model, data_val_iter, device='cpu', max_num_batches=24, ctx=nullcontext()): 21 | assert model.training == False 22 | 23 | loss_list_val, acc_list = [], [] 24 | 25 | for _ in range(max_num_batches): 26 | x, y = get_batch(data_val_iter, device=device) 27 | with ctx: 28 | outputs = model(x, targets=y, get_logits=True) 29 | val_loss = outputs['loss'] 30 | loss_list_val.append(val_loss) 31 | acc_list.append((outputs['logits'].argmax(-1) == y).float().mean()) 32 | 33 | val_acc = torch.stack(acc_list).mean().item() 34 | val_loss = torch.stack(loss_list_val).mean().item() 35 | val_perplexity = 2.71828 ** val_loss 36 | 37 | return val_acc, val_loss, val_perplexity 38 | 39 | 40 | def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args): 41 | 42 | checkpoint = dict({ 43 | 'model': distributed_backend.get_raw_model(model).state_dict(), 44 | 'optimizer': opt.state_dict(), 45 | 'scheduler': scheduler.state_dict(), 46 | 'itr': itr, 47 | }, **extra_args) 48 | 49 | torch.save(checkpoint, ckpt_path) 50 | --------------------------------------------------------------------------------