├── .gitignore
├── LICENSE
├── README.md
├── assets
├── loss_slimpajama.png
└── pplx_slimpajama.png
├── requirements.txt
├── scripts
└── train-baselines-example.sh
└── src
├── config
├── __init__.py
└── base.py
├── data
├── arxiv.py
├── benchmarks.py
├── openwebtext2.py
├── shakespeare.py
├── slimpajama.py
├── utils.py
└── wikitext.py
├── distributed
├── __init__.py
├── backend.py
├── ddp.py
└── single.py
├── main.py
├── models
├── base.py
├── llama.py
└── utils.py
└── optim
├── base.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Dataset folder
2 | src/data/datasets/
3 | wandb/
4 | exps/
5 | scripts/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 EPFL MLO Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LLM-baselines
2 |
3 | A modular codebase to experiment with transformers, inspired by NanoGPT.
4 |
5 | ## Quickstart
6 |
7 | Install dependencies:
8 |
9 | ```
10 | pip install -r requirements.txt
11 | ```
12 |
13 | Run a simple training on the Slimpajama dataset ([6B subset](https://huggingface.co/datasets/DKYoon/SlimPajama-6B), 24GBs decompressed, takes a few minutes to download):
14 |
15 | ```sh
16 | python ./src/main.py --config_format base
17 | ```
18 |
19 | The above command trains a 123.59M parameters model. It trains for 25k iterations with a batch size of 128=32x4 (4 gradient accumulation steps), using a cosine schedule with a maximum learning rate of 1e-3 that is reduced to 1e-4 at the end of training. The model is saved in the `./exps` folder.
20 |
21 | This training takes roughly ~3h on a single A100 (80GB) GPU. The plot of the training and validation loss should look roughly like this:
22 |
23 |
24 |
25 |
26 | You can check out the wandb run for yourself [here](https://wandb.ai/haeggee/llm-lauzhack/runs/lm2obqy9?nw=nwuserhaeggee).
27 |
28 |
29 | ## Less quick start
30 |
31 | Here are the possible parameters you can use (copypasted from `config/base.py`):
32 |
33 | ```python
34 | # General training params
35 | parser.add_argument('--batch_size', default=32, type=int)
36 | parser.add_argument('--acc_steps', default=4, type=int)
37 | parser.add_argument('--seed', default=0, type=int) # random seed for the parameters
38 | parser.add_argument('--data_seed', default=1337, type=int) # random seed defining the data ordering
39 | parser.add_argument('--device', default='cuda:0', type=str) # see below to run on multiple GPUs
40 | parser.add_argument('--iterations', default=25000, type=int) # total number of training iterations
41 | parser.add_argument('--lr', default=1e-3, type=float)
42 | parser.add_argument('--warmup_percent', default=0.05, type=float) # the total number of warmup steps is iterations * warmup_percent
43 | parser.add_argument('--weight_decay', default=0.1, type=float) # I recommend you keep this value, else instabilities might arise
44 | parser.add_argument('--beta1', default=0.9, type=float) # adam parameter
45 | parser.add_argument('--beta2', default=0.95, type=float) # adam parameter
46 | parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none'])
47 | parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd'])
48 | parser.add_argument('--eval_freq', default=200, type=int) # in iterations
49 | parser.add_argument('--results_base_folder', default="./exps", type=str) # where the checkpoints will be saved
50 | parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT
51 | # Dataset params
52 | parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2'])
53 | parser.add_argument('--vocab_size', default=50304, type=int)
54 | parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, you most likely do not need this
55 | # Model params
56 | parser.add_argument('--model', default='base', choices=['base', 'llama2'])
57 | parser.add_argument('--use_pretrained', default="none", type=str) # 'none', 'gpt-2' or a path to the pretraind model
58 | parser.add_argument('--dropout', default=0.0, type=float) # keep to 0 unless in low data regime (e.g. wikitext)
59 | parser.add_argument('--n_head', default=12, type=int)
60 | parser.add_argument('--n_layer', default=12, type=int) # depth in (att + ff) blocks
61 | parser.add_argument('--n_embd', default=768, type=int) # hidden size ...
62 | parser.add_argument('--sequence_length', default=512, type=int)
63 | parser.add_argument('--dtype', default=torch.bfloat16, type=torch.dtype)
64 | parser.add_argument('--bias', default=False, type=bool)
65 | parser.add_argument('--compile', action='store_true') # if true then model is compiled
66 | parser.add_argument('--rmsnorm_eps', default=1e-5, type=float) # used by the llama model
67 | parser.add_argument('--multiple_of', default=256, type=int) # used by the llama model make SwiGLU hidden layer size multiple of large power of 2
68 | # logging params (WandB)
69 | parser.add_argument('--wandb', action='store_true') # whether to use wandb or not
70 | parser.add_argument('--wandb_project', default="my-project", type=str)
71 | parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name
72 | parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences
73 | # Distributed args
74 | parser.add_argument('--distributed_backend', default=None, type=str, required=False,
75 | choices=distributed.registered_backends()) # distributed backend type (e.g. nccl)
76 | parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)
77 | ```
78 |
79 | ## Using WandB
80 |
81 | You need to give your wandb authorize key in order to send the data to your wandb account. If you start jobs on a server without access to prompt, then you can set the `WANDB_API_KEY` variable within your script:
82 |
83 | ```bash
84 | # this is a script that could be executed on a server
85 | pip install -r requirements.txt # install req.
86 | export WANDB_API_KEY="put your authorize key here, to find it: https://wandb.ai/authorize"
87 | python ./src/main.py --config_format base --wandb --wandb_project "my awesome project" --n_layer 7 --model base --seed 123
88 | ```
89 |
90 | ## How to add your own transformer architecture?
91 |
92 | The structure of the project is the following:
93 |
94 | ```sh
95 | src/
96 | main.py # pick the right data, model, and training function
97 | config/
98 | __init__.py # contains CONFIG_FORMAT_TO_MODULE_MAP mapping the name given to the --config_format flag with a python conf file
99 | base.py # config for the base model
100 | data/
101 | utils.py # contains the get_dataset function
102 | wikitext.py # load/process wikitext
103 | arxiv.py # load/process arxiv
104 | shakespeare.py # load/process the Shakespeare dataset
105 | slimpajama.py
106 | ...
107 | models/
108 | utils.py # contains the get_model function
109 | base.py # contains the standard transformer base architecture
110 | llama.py # llama architecture
111 | optim/
112 | utils.py # contains eval and get_batch functions
113 | base.py # training function for the base and llama models
114 | distributed/
115 | # code to enable simple distributed training
116 | ```
117 |
118 | Given the above structure, to add your own model, you can just fork the `./src/models/base.py` file, do your modifications, then if necessary fork the `./src/optim/base.py` in case you need some custom training loop or evaluation. You also need to fork the `./src/config/base.py` file to add your own parameters, which imply adding your new config to the mapping `CONFIG_FORMAT_TO_MODULE_MAP` in `./src/config/__init__.py`. To add a new dataset, create a new file in the `data` folder, check `wikitext.py` for the expected format.
119 |
120 | ## Multi-GPU training
121 |
122 | Given a multi-GPU machine with e.g. 4 GPUs, one can distribute the training using data-parallelism:
123 |
124 | ```sh
125 | torchrun --nproc_per_node=4 ./src/main.py --config_format base --distributed_backend nccl --dataset slimpajama --model base
126 | ```
127 |
128 | When using multiple GPUs, the data will be distributed among the GPUs by dividing the number of accumulation steps by the number of nodes. For instance if we train with a batch size of 32 and 4 accumulation steps, then each GPU will process batches of 32 elements and do 1 accumulation steps. For this reason we require `acc_steps` to be a multiple of the number of GPUs.
129 |
130 |
131 | ## Experimenting locally on your device with CPU
132 | If do not have access to a GPU or just want to try the code locally on your device, you can try the Shakespeare dataset with character-level tokens:
133 |
134 | ```sh
135 | python ./src/main.py --n_layer=2 --n_head=4 --n_embd=128 --sequence_length=256 --dataset=shakespeare-char --device=cpu --vocab_size=96
136 | ```
137 |
--------------------------------------------------------------------------------
/assets/loss_slimpajama.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfml/llm-baselines/2d172dfd06fb45dc379cae5f34abcb0872b2a93a/assets/loss_slimpajama.png
--------------------------------------------------------------------------------
/assets/pplx_slimpajama.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfml/llm-baselines/2d172dfd06fb45dc379cae5f34abcb0872b2a93a/assets/pplx_slimpajama.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tiktoken
2 | --find-links https://download.pytorch.org/whl/torch_stable.html
3 | torch==2.0.0+cu118
4 | torchaudio==2.0.0+cu118
5 | torchvision==0.15.0+cu118
6 | tqdm==4.65.0
7 | transformers
8 | wandb
9 | datasets
10 | zstandard
--------------------------------------------------------------------------------
/scripts/train-baselines-example.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd /scratch/homes/sfan/llm-baselines
4 | pip install -r requirements.txt
5 |
6 | export WANDB_API_KEY="put your wandb api key here"
7 | python ./src/main.py --model base --n_embd 768 --n_head 6 --wandb_run_prefix h768_nh12_nlyr24_sl512_d005 --n_layer 6 --batch_size 50 --sequence_length 512 --wandb --acc_steps 4 --dropout 0.05
8 |
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
1 | from . import base
2 |
3 | CONFIG_FORMAT_TO_MODULE_MAP = {
4 | "base": base,
5 | }
6 |
7 |
8 | def parse_args_with_format(format, base_parser, args, namespace):
9 | return CONFIG_FORMAT_TO_MODULE_MAP[format].parse_args(base_parser, args, namespace)
10 |
11 |
12 | def registered_formats():
13 | return CONFIG_FORMAT_TO_MODULE_MAP.keys()
14 |
--------------------------------------------------------------------------------
/src/config/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import distributed
4 |
5 | def none_or_str(value):
6 | if value == 'None':
7 | return None
8 | return value
9 |
10 | def parse_args(base_parser, args, namespace):
11 | parser = base_parser
12 | # General training params
13 | parser.add_argument('--batch_size', default=32, type=int)
14 | parser.add_argument('--acc_steps', default=4, type=int)
15 | parser.add_argument('--seed', default=0, type=int)
16 | parser.add_argument('--data_seed', default=1337, type=int)
17 | parser.add_argument('--device', default='cuda:0', type=str)
18 | parser.add_argument('--iterations', default=25000, type=int)
19 | parser.add_argument('--lr', default=1e-3, type=float)
20 | parser.add_argument('--warmup_percent', default=0.05, type=float)
21 | parser.add_argument('--weight_decay', default=0.1, type=float)
22 | parser.add_argument('--beta1', default=0.9, type=float)
23 | parser.add_argument('--beta2', default=0.95, type=float)
24 | parser.add_argument('--scheduler', default='cos', choices=['linear', 'cos', 'none'])
25 | parser.add_argument('--opt', default='adamw', choices=['adamw', 'sgd'])
26 | parser.add_argument('--eval_freq', default=200, type=int) # in iterations
27 | parser.add_argument('--results_base_folder', default="./exps", type=str)
28 | parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT
29 | # Dataset params
30 | parser.add_argument('--dataset', default='mathqa', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2', 'mathqa'])
31 | parser.add_argument('--vocab_size', default=50304, type=int)
32 | parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, mostly useless except for openwebtext2
33 | # Model params
34 | parser.add_argument('--model', default='base', choices=['base', 'llama2'])
35 | parser.add_argument('--use_pretrained', default="auto", type=none_or_str) # 'none', 'gpt-2' or a path to the pretraind model
36 | parser.add_argument('--dropout', default=0.0, type=float)
37 | parser.add_argument('--n_head', default=6, type=int)
38 | parser.add_argument('--n_layer', default=6, type=int) # depths in att + ff blocks
39 | parser.add_argument('--n_embd', default=768, type=int) # embedding size / hidden size ...
40 | parser.add_argument('--sequence_length', default=512, type=int)
41 | parser.add_argument('--dtype', default=torch.float16, type=torch.dtype)
42 | parser.add_argument('--bias', default=False, type=bool)
43 | parser.add_argument('--compile', action='store_true') # if true then model is compiled
44 | parser.add_argument("--rmsnorm_eps", default=1e-5, type=float)
45 | parser.add_argument(
46 | "--multiple_of", # make SwiGLU hidden layer size multiple of large power of 2
47 | default=256,
48 | type=int,
49 | )
50 | parser.add_argument('--run_prefix', default=None, type=str, required=False) # is added before the autogenerated experiment name
51 | parser.add_argument('--exp_name', default=None, type=str, required=False)
52 | # logging params (WandB)
53 | parser.add_argument('--wandb', action='store_true') # whether to use wandb or not
54 | parser.add_argument('--wandb_project', default="my-project", type=str)
55 | parser.add_argument('--wandb_run_prefix', default="none", type=str) # is added before the autogenerated experiment name
56 | parser.add_argument('--eval_seq_prefix', default="Once upon a time", type=str) # prefix used to generate sequences
57 | # Distributed args
58 | parser.add_argument('--distributed_backend', default=None, type=str, required=False,
59 | choices=distributed.registered_backends()) # distributed backend type
60 | parser.add_argument('--save_checkpoint_freq', default=None, type=int, required=False)
61 |
62 | args = parser.parse_args(args, namespace)
63 |
64 | if args.exp_name is None:
65 | special_name_handle_fields = {"model", "lr", "batch_size",
66 | "acc_steps", "seed", "exp_name",
67 | "wandb", "wandb_project", "eval_seq_prefix",
68 | "run_prefix", "distributed_backend", "config_format",
69 | "sequence_length"}
70 | overriden_values = []
71 | for key in vars(args):
72 | if key in special_name_handle_fields:
73 | continue
74 | if getattr(args, key) != parser.get_default(key):
75 | overriden_values.append((key, getattr(args, key)))
76 | chunk_len = 10
77 | overriden_values_str_parts = []
78 | for chunk_id in range(0, len(overriden_values), chunk_len):
79 | overriden_values_str = "_".join(["{}={}".format(key, value) for key, value in overriden_values[chunk_id:chunk_id+chunk_len]])
80 | overriden_values_str_parts.append(overriden_values_str)
81 | overriden_values_str = "/".join(overriden_values_str_parts)
82 | exp_name = ""
83 | if args.run_prefix is not None:
84 | exp_name += f"{args.run_prefix}_"
85 | exp_name += f"{args.model}_lr{args.lr}_bs{args.batch_size}x{args.acc_steps}_seqlen{args.sequence_length}/{overriden_values_str}_seed={args.seed}"
86 | args.exp_name = exp_name
87 |
88 | if args.dtype == "torch.bfloat16":
89 | args.dtype = torch.bfloat16
90 | elif args.dtype == "torch.float16":
91 | args.dtype = torch.float16
92 | elif args.dtype == "torch.float32":
93 | args.dtype = torch.float32
94 |
95 | return args
96 |
--------------------------------------------------------------------------------
/src/data/arxiv.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tarfile
3 | import logging
4 | from pathlib import Path
5 | from typing import Optional
6 | from multiprocessing import Pool
7 | from tempfile import NamedTemporaryFile
8 | from subprocess import Popen, TimeoutExpired, PIPE
9 | from typing import Tuple, List
10 |
11 | import numpy as np
12 | import requests
13 | from tqdm.auto import tqdm
14 | import tiktoken
15 |
16 |
17 | def convert_to_markdown(args: Tuple[Path, Path]):
18 | texfile, mdroot = args
19 | mdfile = mdroot/f"{texfile.name}.md"
20 | with Popen(["pandoc", "--wrap=none", "--from", "latex", texfile,
21 | "--output", mdfile], stderr=PIPE) as proc:
22 | try:
23 | proc.communicate(timeout=1)
24 | except TimeoutExpired:
25 | proc.kill()
26 |
27 |
28 |
29 | def fetch_arxiv(root: Path, year: int):
30 | # download latex
31 | url = f"https://www.cs.cornell.edu/projects/kddcup/download/hep-th-{year}.tar.gz"
32 | texroot = root/"tex"
33 | print("Downloading Arxiv year", year)
34 | req = requests.get(url, timeout=60)
35 | with NamedTemporaryFile(suffix=".tar.gz") as f:
36 | f.write(req.content)
37 | logging.debug("Tar saved in tempfile %s" % f.name)
38 | with tarfile.open(f.name) as tar:
39 | logging.debug("Extracting tarfile")
40 | tar.extractall(texroot)
41 |
42 | # convert to markdown
43 | mdroot = root/"md"/str(year)
44 | mdroot.mkdir(parents=True)
45 | files = list((texroot/str(year)).iterdir())
46 | with Pool(os.cpu_count()) as p:
47 | args = [(texfile, mdroot) for texfile in files]
48 | for _ in tqdm(p.imap_unordered(convert_to_markdown, args),
49 | desc="Converting to markdown", total=len(files)):
50 | pass
51 |
52 |
53 | def tokenize_arxiv(root: Path, year: int):
54 | tokenizer = tiktoken.get_encoding("gpt2")
55 | tokens = []
56 | tokens_val = []
57 | tokens_test = []
58 | mds = root/"md"/str(year)
59 |
60 | # tokenize
61 | desc = f"Tokenizing {year}"
62 | for i, mdpath in enumerate(tqdm(list(mds.iterdir()), desc=desc)):
63 | with open(mdpath, encoding="utf8") as f:
64 | text = "".join(f.readlines())
65 | if i % 10 <= 6: # train split
66 | tokens += tokenizer.encode(text)
67 | elif i % 10 <= 8: # val split
68 | tokens_val += tokenizer.encode(text)
69 | else: # test split
70 | tokens_test += tokenizer.encode(text)
71 |
72 | # save to dir
73 | tpath = root/str(year)
74 | tpath.mkdir(parents=True)
75 | for x, name in zip([tokens, tokens_val, tokens_test],
76 | ["train", "val", "test"]):
77 | mem = np.memmap(tpath/f"{name}.npy", dtype=np.uint16, mode="w+",
78 | shape=len(x))
79 | for i, v in enumerate(x):
80 | mem[i] = v
81 |
82 |
83 | def load_arxiv(cachedir: Path, years: Optional[List[int]] = None):
84 | all_years = list(range(1992, 2004))
85 | if years is None:
86 | years = all_years
87 | assert set(years) <= set(all_years)
88 | root = cachedir/"arxiv"
89 | root.mkdir(exist_ok=True, parents=True)
90 |
91 | # download all years requested that are not present
92 | for year in years:
93 | if not (root/"md"/str(year)).exists():
94 | fetch_arxiv(root, year)
95 |
96 | # tokenize all years not previously tokenized
97 | for year in years:
98 | if not (root/str(year)).exists():
99 | tokenize_arxiv(root, year)
100 |
101 | # load meta
102 | ret = {}
103 | for split in ["train", "val"]:
104 | paths = [root/str(year)/f"{split}.npy" for year in years]
105 | x = [np.memmap(path, dtype=np.uint16, mode="r") for path in paths]
106 | ret[split] = np.concatenate(x)
107 | return ret
108 |
109 |
110 | def get_arxiv_2000():
111 | return load_arxiv(Path(os.path.dirname(__file__))/"datasets", [2000])
112 |
113 |
114 | def get_arxiv_full():
115 | return load_arxiv(Path(os.path.dirname(__file__))/"datasets")
116 |
--------------------------------------------------------------------------------
/src/data/benchmarks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import importlib
3 |
4 | ## load arc
5 | import os
6 | from tqdm import tqdm
7 | import numpy as np
8 | import tiktoken
9 | from datasets import load_dataset
10 | import torch
11 |
12 | # Initialize the tokenizer
13 | tknzr = tiktoken.get_encoding("gpt2")
14 | ARC_E_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/arc_easy/")
15 | ARC_C_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/arc_challenge/")
16 | HELLASWAG_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/hellaswag/")
17 | LOGIQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/logiqa/")
18 | PIQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/piqa/")
19 | SCIQ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/sciq/")
20 | HUMANEVAL_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/humaneval/")
21 | KODCODE_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/kodcode/")
22 | GSM8K_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/gsm8k/")
23 | MATHQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/mathqa/")
24 | MEDQA_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/medqa/")
25 |
26 |
27 | def tokenize_with_pad(text, pad_to_multiple=1024):
28 | ids = tknzr.encode_ordinary(text)
29 | ids.append(tknzr.eot_token)
30 | pad_token_id = tknzr.eot_token
31 | # Calculate padding length (next multiple of pad_multiple)
32 | padded_length = ((len(ids) + pad_to_multiple - 1) // pad_to_multiple) * pad_to_multiple
33 | # Initialize the padded array with pad token (not zeros)
34 | padded_tokens = np.ones(padded_length, dtype=np.uint16) * pad_token_id
35 | padded_tokens[:len(ids)] = ids
36 | return padded_tokens
37 |
38 |
39 | def get_humaneval(num_proc=10, return_torch=False, pad_to_multiple=1024):
40 | """
41 | Load and process the HumanEval (for coding) dataset.
42 | Tokenize the text and store it in binary format for efficient loading.
43 | """
44 | if not os.path.exists(os.path.join(HUMANEVAL_DATA_PATH, 'val.bin')):
45 | os.makedirs(HUMANEVAL_DATA_PATH, exist_ok=True)
46 |
47 | # Load the HumanEval dataset from Hugging Face Datasets
48 | human_eval_path = "openai/openai_humaneval"
49 | dataset = load_dataset(human_eval_path, trust_remote_code=True)
50 | split_dataset = dataset["test"].train_test_split(test_size=0.5, seed=2357, shuffle=True)
51 | split_dataset['val'] = split_dataset.pop('test')
52 | data_dict = {
53 | 'train': split_dataset['train'],
54 | 'val': split_dataset['val'],
55 | }
56 |
57 | def process(example):
58 | """
59 | Tokenize the example text by encoding it into token IDs.
60 | """
61 | prompt = example['prompt']
62 | completion = example['canonical_solution']
63 |
64 | concatenated_text = f"{prompt} \n {completion}"
65 | # print(concatenated_text)
66 | ids = tokenize_with_pad(text=concatenated_text,
67 | pad_to_multiple=pad_to_multiple)
68 | return {'ids': ids, 'len': len(ids)}
69 |
70 | # Tokenize and map the dataset
71 | tokenized = {}
72 | for split, dset in data_dict.items():
73 | tokenized[split] = dset.map(
74 | process,
75 | remove_columns=['task_id', 'prompt', 'canonical_solution', 'test', 'entry_point'],
76 | desc=f"Tokenizing {split} split",
77 | num_proc=num_proc
78 | )
79 |
80 | # Concatenate all the token IDs into one large binary file per split
81 | for split, dset in tokenized.items():
82 | # Save token IDs length
83 | len_arr = np.array(dset['len'], dtype=np.uint16)
84 | with open(os.path.join(HUMANEVAL_DATA_PATH, f'{split}.len'), 'wb') as f:
85 | np.save(f, len_arr)
86 | # Total number of tokens
87 | arr_len = np.sum(dset['len'])
88 | filename = os.path.join(HUMANEVAL_DATA_PATH, f'{split}.bin')
89 | dtype = np.uint16
90 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
91 | total_batches = 10
92 |
93 | idx = 0
94 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
95 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
96 | arr_batch = np.concatenate(batch['ids'])
97 | arr[idx: idx + len(arr_batch)] = arr_batch
98 | idx += len(arr_batch)
99 | arr.flush()
100 |
101 | # Load tokenized binary files for training, validation
102 | train_data = np.memmap(os.path.join(HUMANEVAL_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
103 | val_data = np.memmap(os.path.join(HUMANEVAL_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
104 |
105 | if return_torch:
106 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
107 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
108 | print(f'Benchmark HumanEval: train[{len(train_data)}] | val[{len(val_data)}]')
109 | return {
110 | 'train': train_data,
111 | 'train_len': np.load(os.path.join(HUMANEVAL_DATA_PATH, 'train.len')),
112 | 'val': val_data,
113 | 'val_len': np.load(os.path.join(HUMANEVAL_DATA_PATH, 'val.len')),
114 | }
115 |
116 |
117 | def get_kodcode(num_proc=10, return_torch=False, pad_to_multiple=1024):
118 | """
119 | Load and process the KodCode (for code reasoning) dataset.
120 | Tokenize the text and store it in binary format for efficient loading.
121 | """
122 | if not os.path.exists(os.path.join(KODCODE_DATA_PATH, 'val.bin')):
123 | os.makedirs(KODCODE_DATA_PATH, exist_ok=True)
124 |
125 | # Load the GSM8K dataset from Hugging Face Datasets
126 | dataset = load_dataset("KodCode/KodCode-V1-SFT-R1", trust_remote_code=True)
127 | dataset = dataset["train"].train_test_split(test_size=0.1, seed=2357, shuffle=True)
128 | data_dict = {
129 | 'train': dataset["train"],
130 | 'val': dataset["test"],
131 | }
132 |
133 | def process(example):
134 | """
135 | Tokenize the example text by encoding it into token IDs.
136 | """
137 | question = example['question']
138 | answer = example['solution']
139 |
140 | concatenated_text = f"{question}\n{answer}"
141 | # print(concatenated_text)
142 | ids = tokenize_with_pad(text=concatenated_text,
143 | pad_to_multiple=pad_to_multiple)
144 | return {'ids': ids, 'len': len(ids)}
145 |
146 | # Tokenize and map the dataset
147 | tokenized = {}
148 | for split, dset in data_dict.items():
149 | tokenized[split] = dset.map(
150 | process,
151 | remove_columns=['style', 'question_id', 'subset', 'question', 'solution', 'test_code', 'test_info',
152 | 'gpt_pass_sequence', 'gpt_pass_trial_num', 'gpt_difficulty', 'gpt_pass_percentage',
153 | 'r1_pass_sequence', 'r1_pass_trial_num', 'r1_correctness', 'r1_solution', 'metadata', 'conversations'],
154 | desc=f"Tokenizing {split} split",
155 | num_proc=num_proc
156 | )
157 |
158 | # Concatenate all the token IDs into one large binary file per split
159 | for split, dset in tokenized.items():
160 | # Save token IDs length
161 | len_arr = np.array(dset['len'], dtype=np.uint16)
162 | with open(os.path.join(KODCODE_DATA_PATH, f'{split}.len'), 'wb') as f:
163 | np.save(f, len_arr)
164 | # Total number of tokens
165 | arr_len = np.sum(dset['len'])
166 | filename = os.path.join(KODCODE_DATA_PATH, f'{split}.bin')
167 | dtype = np.uint16
168 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
169 | total_batches = 10
170 |
171 | idx = 0
172 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
173 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
174 | arr_batch = np.concatenate(batch['ids'])
175 | arr[idx: idx + len(arr_batch)] = arr_batch
176 | idx += len(arr_batch)
177 | arr.flush()
178 |
179 | # Load tokenized binary files for training, validation
180 | train_data = np.memmap(os.path.join(KODCODE_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
181 | val_data = np.memmap(os.path.join(KODCODE_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
182 |
183 | if return_torch:
184 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
185 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
186 | print(f'Benchmark KodCode: train[{len(train_data)}] | val[{len(val_data)}]')
187 | return {
188 | 'train': train_data,
189 | 'train_len': np.load(os.path.join(KODCODE_DATA_PATH, 'train.len')),
190 | 'val': val_data,
191 | 'val_len': np.load(os.path.join(KODCODE_DATA_PATH, 'val.len')),
192 | }
193 |
194 |
195 | def get_gsm8k(num_proc=10, return_torch=False, pad_to_multiple=1024):
196 | """
197 | Load and process the GSM8K (for math) dataset.
198 | Tokenize the text and store it in binary format for efficient loading.
199 | """
200 | if not os.path.exists(os.path.join(GSM8K_DATA_PATH, 'val.bin')):
201 | os.makedirs(GSM8K_DATA_PATH, exist_ok=True)
202 |
203 | # Load the GSM8K dataset from Hugging Face Datasets
204 | dataset = load_dataset("openai/gsm8k", "main", trust_remote_code=True)
205 | data_dict = {
206 | 'train': dataset["train"],
207 | 'val': dataset["test"],
208 | }
209 |
210 | def process(example):
211 | """
212 | Tokenize the example text by encoding it into token IDs.
213 | """
214 | question = example['question']
215 | answer = example['answer']
216 |
217 | concatenated_text = f"{question}\n{answer}"
218 | # print(concatenated_text)
219 | ids = tokenize_with_pad(text=concatenated_text,
220 | pad_to_multiple=pad_to_multiple)
221 | return {'ids': ids, 'len': len(ids)}
222 |
223 | # Tokenize and map the dataset
224 | tokenized = {}
225 | for split, dset in data_dict.items():
226 | tokenized[split] = dset.map(
227 | process,
228 | remove_columns=['question', 'answer'],
229 | desc=f"Tokenizing {split} split",
230 | num_proc=num_proc
231 | )
232 |
233 | # Concatenate all the token IDs into one large binary file per split
234 | for split, dset in tokenized.items():
235 | # Save token IDs length
236 | len_arr = np.array(dset['len'], dtype=np.uint16)
237 | with open(os.path.join(GSM8K_DATA_PATH, f'{split}.len'), 'wb') as f:
238 | np.save(f, len_arr)
239 | # Total number of tokens
240 | arr_len = np.sum(dset['len'])
241 | filename = os.path.join(GSM8K_DATA_PATH, f'{split}.bin')
242 | dtype = np.uint16
243 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
244 | total_batches = 10
245 |
246 | idx = 0
247 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
248 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
249 | arr_batch = np.concatenate(batch['ids'])
250 | arr[idx: idx + len(arr_batch)] = arr_batch
251 | idx += len(arr_batch)
252 | arr.flush()
253 |
254 | # Load tokenized binary files for training, validation
255 | train_data = np.memmap(os.path.join(GSM8K_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
256 | val_data = np.memmap(os.path.join(GSM8K_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
257 |
258 | if return_torch:
259 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
260 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
261 | print(f'Benchmark GSM8k: train[{len(train_data)}] | val[{len(val_data)}]')
262 | return {
263 | 'train': train_data,
264 | 'train_len': np.load(os.path.join(GSM8K_DATA_PATH, 'train.len')),
265 | 'val': val_data,
266 | 'val_len': np.load(os.path.join(GSM8K_DATA_PATH, 'val.len')),
267 | }
268 |
269 | def get_mathqa(num_proc=10, return_torch=False, pad_to_multiple=1024):
270 | """
271 | Load and process the MATH-QA dataset.
272 | Tokenize the text and store it in binary format for efficient loading.
273 | """
274 | if not os.path.exists(os.path.join(MATHQA_DATA_PATH, 'val.bin')):
275 | os.makedirs(MATHQA_DATA_PATH, exist_ok=True)
276 |
277 | # Load the MATH-QA dataset from Hugging Face Datasets
278 | dataset = load_dataset("allenai/math_qa",trust_remote_code=True)
279 | data_dict = {
280 | 'train': dataset["train"],
281 | 'val': dataset["test"],
282 | }
283 |
284 | def process(example):
285 | """
286 | Tokenize the example text by encoding it into token IDs.
287 | """
288 | question = example['Problem']
289 | choices = {d.strip()[0]:d.split(")")[-1].strip() for d in example['options'].split(",")}
290 | answer = choices.get(example['correct'])
291 |
292 | concatenated_text = f"{question} {answer}"
293 | # print(concatenated_text)
294 | ids = tokenize_with_pad(text=concatenated_text,
295 | pad_to_multiple=pad_to_multiple)
296 | return {'ids': ids, 'len': len(ids)}
297 |
298 | # Tokenize and map the dataset
299 | tokenized = {}
300 | for split, dset in data_dict.items():
301 | tokenized[split] = dset.map(
302 | process,
303 | remove_columns=['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
304 | desc=f"Tokenizing {split} split",
305 | num_proc=num_proc
306 | )
307 |
308 | # Concatenate all the token IDs into one large binary file per split
309 | for split, dset in tokenized.items():
310 | # Save token IDs length
311 | len_arr = np.array(dset['len'], dtype=np.uint16)
312 | with open(os.path.join(MATHQA_DATA_PATH, f'{split}.len'), 'wb') as f:
313 | np.save(f, len_arr)
314 | # Total number of tokens
315 | arr_len = np.sum(dset['len'])
316 | filename = os.path.join(MATHQA_DATA_PATH, f'{split}.bin')
317 | dtype = np.uint16
318 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
319 | total_batches = 10
320 |
321 | idx = 0
322 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
323 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
324 | arr_batch = np.concatenate(batch['ids'])
325 | arr[idx: idx + len(arr_batch)] = arr_batch
326 | idx += len(arr_batch)
327 | arr.flush()
328 |
329 | # Load tokenized binary files for training, validation
330 | train_data = np.memmap(os.path.join(MATHQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
331 | val_data = np.memmap(os.path.join(MATHQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
332 |
333 | if return_torch:
334 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
335 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
336 | print(f'Benchmark MATHQA: train[{len(train_data)}] | val[{len(val_data)}]')
337 | return {
338 | 'train': train_data,
339 | 'train_len': np.load(os.path.join(MATHQA_DATA_PATH, 'train.len')),
340 | 'val': val_data,
341 | 'val_len': np.load(os.path.join(MATHQA_DATA_PATH, 'val.len')),
342 | }
343 |
344 | def get_medqa(num_proc=10, return_torch=False, pad_to_multiple=1024):
345 | """
346 | Load and process the MEDQA dataset.
347 | Tokenize the text and store it in binary format for efficient loading.
348 | """
349 | if not os.path.exists(os.path.join(MEDQA_DATA_PATH, 'val.bin')):
350 | os.makedirs(MEDQA_DATA_PATH, exist_ok=True)
351 |
352 | # Load the MATH-QA dataset from Hugging Face Datasets
353 | dataset = load_dataset("bigbio/med_qa",trust_remote_code=True)
354 | data_dict = {
355 | 'train': dataset["train"],
356 | 'val': dataset["test"],
357 | }
358 |
359 | def process(example):
360 | """
361 | Tokenize the example text by encoding it into token IDs.
362 | """
363 | question = example["question"]
364 | answer = example["answer"]
365 |
366 | concatenated_text = f"{question} {answer}"
367 | # print(concatenated_text)
368 | ids = tokenize_with_pad(text=concatenated_text,
369 | pad_to_multiple=pad_to_multiple)
370 | return {'ids': ids, 'len': len(ids)}
371 |
372 | # Tokenize and map the dataset
373 | tokenized = {}
374 | for split, dset in data_dict.items():
375 | tokenized[split] = dset.map(
376 | process,
377 | remove_columns=['meta_info', 'question', 'answer_idx', 'answer', 'options'],
378 | desc=f"Tokenizing {split} split",
379 | num_proc=num_proc
380 | )
381 |
382 | # Concatenate all the token IDs into one large binary file per split
383 | for split, dset in tokenized.items():
384 | # Save token IDs length
385 | len_arr = np.array(dset['len'], dtype=np.uint16)
386 | with open(os.path.join(MEDQA_DATA_PATH, f'{split}.len'), 'wb') as f:
387 | np.save(f, len_arr)
388 | # Total number of tokens
389 | arr_len = np.sum(dset['len'])
390 | filename = os.path.join(MEDQA_DATA_PATH, f'{split}.bin')
391 | dtype = np.uint16
392 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
393 | total_batches = 10
394 |
395 | idx = 0
396 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
397 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
398 | arr_batch = np.concatenate(batch['ids'])
399 | arr[idx: idx + len(arr_batch)] = arr_batch
400 | idx += len(arr_batch)
401 | arr.flush()
402 |
403 | # Load tokenized binary files for training, validation
404 | train_data = np.memmap(os.path.join(MEDQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
405 | val_data = np.memmap(os.path.join(MEDQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
406 |
407 | if return_torch:
408 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
409 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
410 | print(f'Benchmark MedQA: train[{len(train_data)}] | val[{len(val_data)}]')
411 | return {
412 | 'train': train_data,
413 | 'train_len': np.load(os.path.join(MEDQA_DATA_PATH, 'train.len')),
414 | 'val': val_data,
415 | 'val_len': np.load(os.path.join(MEDQA_DATA_PATH, 'val.len')),
416 | }
417 |
418 | def get_arc_easy(num_proc=10, return_torch=False, pad_to_multiple=1024):
419 | """
420 | Load and process the AI2-ARC dataset.
421 | Tokenize the text and store it in binary format for efficient loading.
422 | """
423 | if not os.path.exists(os.path.join(ARC_E_DATA_PATH, 'val.bin')):
424 | os.makedirs(ARC_E_DATA_PATH, exist_ok=True)
425 |
426 | # Load the AI2-ARC dataset from Hugging Face Datasets
427 | dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", split=['train', 'test'])
428 | data_dict = {
429 | 'train': dataset[0],
430 | 'val': dataset[1],
431 | }
432 |
433 | def process(example):
434 | """
435 | Tokenize the example text by encoding it into token IDs.
436 | """
437 | question = example['question']
438 | choices = example['choices']
439 | answer = dict(zip(choices["label"], choices["text"])).get(example['answerKey'])
440 |
441 | concatenated_text = f"{question} {answer}"
442 | # print(concatenated_text)
443 | ids = tokenize_with_pad(text=concatenated_text,
444 | pad_to_multiple=pad_to_multiple)
445 | return {'ids': ids, 'len': len(ids)}
446 |
447 | # Tokenize and map the dataset
448 | tokenized = {}
449 | for split, dset in data_dict.items():
450 | tokenized[split] = dset.map(
451 | process,
452 | remove_columns=['question', 'choices', 'answerKey'],
453 | desc=f"Tokenizing {split} split",
454 | num_proc=num_proc
455 | )
456 |
457 | # Concatenate all the token IDs into one large binary file per split
458 | for split, dset in tokenized.items():
459 | # Save token IDs length
460 | len_arr = np.array(dset['len'], dtype=np.uint16)
461 | with open(os.path.join(ARC_E_DATA_PATH, f'{split}.len'), 'wb') as f:
462 | np.save(f, len_arr)
463 | # Total number of tokens
464 | arr_len = np.sum(dset['len'])
465 | filename = os.path.join(ARC_E_DATA_PATH, f'{split}.bin')
466 | dtype = np.uint16
467 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
468 | total_batches = 10
469 |
470 | idx = 0
471 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
472 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
473 | arr_batch = np.concatenate(batch['ids'])
474 | arr[idx: idx + len(arr_batch)] = arr_batch
475 | idx += len(arr_batch)
476 | arr.flush()
477 |
478 | # Load tokenized binary files for training, validation
479 | train_data = np.memmap(os.path.join(ARC_E_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
480 | val_data = np.memmap(os.path.join(ARC_E_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
481 |
482 | if return_torch:
483 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
484 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
485 | print(f'Benchmark ARC-Easy: train[{len(train_data)}] | val[{len(val_data)}]')
486 | return {
487 | 'train': train_data,
488 | 'train_len': np.load(os.path.join(ARC_E_DATA_PATH, 'train.len')),
489 | 'val': val_data,
490 | 'val_len': np.load(os.path.join(ARC_E_DATA_PATH, 'val.len')),
491 | }
492 |
493 | def get_arc_challenge(num_proc=10, return_torch=False, pad_to_multiple=1024):
494 | """
495 | Load and process the AI2-ARC dataset.
496 | Tokenize the text and store it in binary format for efficient loading.
497 | """
498 | if not os.path.exists(os.path.join(ARC_C_DATA_PATH, 'val.bin')):
499 | os.makedirs(ARC_C_DATA_PATH, exist_ok=True)
500 |
501 | # Load the AI2-ARC dataset from Hugging Face Datasets
502 | dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", split=['train', 'test'])
503 | data_dict = {
504 | 'train': dataset[0],
505 | 'val': dataset[1],
506 | }
507 |
508 | def process(example):
509 | """
510 | Tokenize the example text by encoding it into token IDs.
511 | """
512 | question = example['question']
513 | choices = example['choices']
514 | answer = dict(zip(choices["label"], choices["text"])).get(example['answerKey'])
515 |
516 | concatenated_text = f"{question} {answer}"
517 | # print(concatenated_text)
518 | ids = tokenize_with_pad(text=concatenated_text,
519 | pad_to_multiple=pad_to_multiple)
520 | return {'ids': ids, 'len': len(ids)}
521 |
522 | # Tokenize and map the dataset
523 | tokenized = {}
524 | for split, dset in data_dict.items():
525 | tokenized[split] = dset.map(
526 | process,
527 | remove_columns=['question', 'choices', 'answerKey'],
528 | desc=f"Tokenizing {split} split",
529 | num_proc=num_proc
530 | )
531 |
532 | # Concatenate all the token IDs into one large binary file per split
533 | for split, dset in tokenized.items():
534 | # Save token IDs length
535 | len_arr = np.array(dset['len'], dtype=np.uint16)
536 | with open(os.path.join(ARC_C_DATA_PATH, f'{split}.len'), 'wb') as f:
537 | np.save(f, len_arr)
538 | # Total number of tokens
539 | arr_len = np.sum(dset['len'])
540 | filename = os.path.join(ARC_C_DATA_PATH, f'{split}.bin')
541 | dtype = np.uint16
542 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
543 | total_batches = 10
544 |
545 | idx = 0
546 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
547 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
548 | arr_batch = np.concatenate(batch['ids'])
549 | arr[idx: idx + len(arr_batch)] = arr_batch
550 | idx += len(arr_batch)
551 | arr.flush()
552 |
553 | # Load tokenized binary files for training, validation
554 | train_data = np.memmap(os.path.join(ARC_C_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
555 | val_data = np.memmap(os.path.join(ARC_C_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
556 |
557 | if return_torch:
558 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
559 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
560 | print(f'Benchmark ARC-Challenge: train[{len(train_data)}] | val[{len(val_data)}]')
561 | return {
562 | 'train': train_data,
563 | 'train_len': np.load(os.path.join(ARC_C_DATA_PATH, 'train.len')),
564 | 'val': val_data,
565 | 'val_len': np.load(os.path.join(ARC_C_DATA_PATH, 'val.len')),
566 | }
567 |
568 | def get_hellaswag(num_proc=10, return_torch=False, pad_to_multiple=1024):
569 | """
570 | Load and process the HellaSwag dataset.
571 | Tokenize the text and store it in binary format for efficient loading.
572 | """
573 | if not os.path.exists(os.path.join(HELLASWAG_DATA_PATH, 'val.bin')):
574 | os.makedirs(HELLASWAG_DATA_PATH, exist_ok=True)
575 |
576 | # Load the HellaSwag dataset from Hugging Face Datasets
577 | dataset = load_dataset("Rowan/hellaswag", split=['train', 'validation'])
578 | data_dict = {
579 | 'train': dataset[0],
580 | 'val': dataset[1],
581 | }
582 |
583 | def process(example):
584 | """
585 | Tokenize the example text by encoding it into token IDs.
586 | """
587 | context = example['ctx']
588 | ending_options = example['endings']
589 | answer = ending_options[int(example['label'])]
590 | concatenated_text = f"{context} {answer}"
591 | # print(concatenated_text)
592 | ids = tokenize_with_pad(text=concatenated_text,
593 | pad_to_multiple=pad_to_multiple)
594 | return {'ids': ids, 'len': len(ids)}
595 |
596 | # Tokenize and map the dataset
597 | tokenized = {}
598 | for split, dset in data_dict.items():
599 | tokenized[split] = dset.map(
600 | process,
601 | remove_columns=['ctx', 'endings', 'label'],
602 | desc=f"Tokenizing {split} split",
603 | num_proc=num_proc
604 | )
605 |
606 | # Concatenate all the token IDs into one large binary file per split
607 | for split, dset in tokenized.items():
608 | # Save token IDs length
609 | len_arr = np.array(dset['len'], dtype=np.uint16)
610 | with open(os.path.join(HELLASWAG_DATA_PATH, f'{split}.len'), 'wb') as f:
611 | np.save(f, len_arr)
612 | # Total number of tokens
613 | arr_len = np.sum(dset['len'])
614 | filename = os.path.join(HELLASWAG_DATA_PATH, f'{split}.bin')
615 | dtype = np.uint16
616 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
617 | total_batches = 10
618 |
619 | idx = 0
620 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
621 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
622 | arr_batch = np.concatenate(batch['ids'])
623 | arr[idx: idx + len(arr_batch)] = arr_batch
624 | idx += len(arr_batch)
625 | arr.flush()
626 |
627 | # Load tokenized binary files for training, validation
628 | train_data = np.memmap(os.path.join(HELLASWAG_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
629 | val_data = np.memmap(os.path.join(HELLASWAG_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
630 |
631 | if return_torch:
632 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
633 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
634 |
635 | print(f'Benchmark Hellaswag: train[{len(train_data)}] | val[{len(val_data)}]')
636 | return {
637 | 'train': train_data,
638 | 'train_len': np.load(os.path.join(HELLASWAG_DATA_PATH, 'train.len')),
639 | 'val': val_data,
640 | 'val_len': np.load(os.path.join(HELLASWAG_DATA_PATH, 'val.len')),
641 | }
642 |
643 | def get_logiqa(num_proc=10, return_torch=False, pad_to_multiple=1024):
644 | """
645 | Load and process the LogiQA dataset.
646 | Tokenize the text and store it in binary format for efficient loading.
647 | """
648 | if not os.path.exists(os.path.join(LOGIQA_DATA_PATH, 'val.bin')):
649 | os.makedirs(LOGIQA_DATA_PATH, exist_ok=True)
650 |
651 | # Load the LogiQA dataset from HuggingFace Datasets
652 | dataset = load_dataset("lucasmccabe/logiqa", split=['train', 'test'])
653 | data_dict = {
654 | 'train': dataset[0],
655 | 'val': dataset[1],
656 | }
657 |
658 | def process(example):
659 | """
660 | Tokenize the example text by encoding it into token IDs.
661 | """
662 | context = example['context']
663 | query = example['query']
664 | correct_option = example['correct_option']
665 | options = example['options']
666 | answer = options[correct_option]
667 | concatenated_text = f"{context} {query} {answer}"
668 | # concatenated_text = f"Context: {context} Query: {query} Answer: {answer}" # TODO : try this ?
669 | ids = tokenize_with_pad(text=concatenated_text,
670 | pad_to_multiple=pad_to_multiple)
671 | return {'ids': ids, 'len': len(ids)}
672 |
673 | # Tokenize and map the dataset
674 | tokenized = {}
675 | for split, dset in data_dict.items():
676 | tokenized[split] = dset.map(
677 | process,
678 | remove_columns=['context', 'query', 'correct_option', 'options'],
679 | desc=f"Tokenizing {split} split",
680 | num_proc=num_proc
681 | )
682 |
683 | # Concatenate all the token IDs into one large binary file per split
684 | for split, dset in tokenized.items():
685 | # Save token IDs length
686 | len_arr = np.array(dset['len'], dtype=np.uint16)
687 | with open(os.path.join(LOGIQA_DATA_PATH, f'{split}.len'), 'wb') as f:
688 | np.save(f, len_arr)
689 | # total number of tokens
690 | arr_len = np.sum(dset['len'])
691 | filename = os.path.join(LOGIQA_DATA_PATH, f'{split}.bin')
692 | dtype = np.uint16
693 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
694 | total_batches = 10
695 |
696 | idx = 0
697 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
698 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
699 | arr_batch = np.concatenate(batch['ids'])
700 | arr[idx: idx + len(arr_batch)] = arr_batch
701 | idx += len(arr_batch)
702 | arr.flush()
703 |
704 | # Load tokenized binary files for training, validation
705 | train_data = np.memmap(os.path.join(LOGIQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
706 | val_data = np.memmap(os.path.join(LOGIQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
707 |
708 | if return_torch:
709 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
710 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
711 | print(f'Benchmark LogiQA: train[{len(train_data)}] | val[{len(val_data)}]')
712 | return {
713 | 'train': train_data,
714 | 'train_len': np.load(os.path.join(LOGIQA_DATA_PATH, 'train.len')),
715 | 'val': val_data,
716 | 'val_len': np.load(os.path.join(LOGIQA_DATA_PATH, 'val.len')),
717 | }
718 |
719 | def get_piqa(num_proc=10, return_torch=False, pad_to_multiple=1024):
720 | """
721 | Load and process the PIQA dataset.
722 | Tokenize the text and store it in binary format for efficient loading.
723 | """
724 | if not os.path.exists(os.path.join(PIQA_DATA_PATH, 'val.bin')):
725 | os.makedirs(PIQA_DATA_PATH, exist_ok=True)
726 |
727 | # Load the PIQA dataset from HuggingFace Datasets
728 | dataset = load_dataset("piqa", split=['train', 'test'])
729 |
730 | data_dict = {
731 | 'train': dataset[0],
732 | 'val': dataset[1],
733 | }
734 |
735 | def process(example):
736 | """
737 | Tokenize the example text by encoding it into token IDs.
738 | """
739 | goal = example['goal']
740 | sols = [example['sol1'], example['sol2']]
741 | label = example['label']
742 | sol = sols[label]
743 | concatenated_text = f"{goal} {sol}" # only include the correct solution
744 | ids = tokenize_with_pad(text=concatenated_text,
745 | pad_to_multiple=pad_to_multiple)
746 | return {'ids': ids, 'len': len(ids)}
747 |
748 | # Tokenize and map the dataset
749 | tokenized = {}
750 | for split, dset in data_dict.items():
751 | tokenized[split] = dset.map(
752 | process,
753 | remove_columns=['goal', 'sol1', 'sol2', 'label'],
754 | desc=f"Tokenizing {split} split",
755 | num_proc=num_proc
756 | )
757 |
758 | # Concatenate all the token IDs into one large binary file per split
759 | for split, dset in tokenized.items():
760 | # Save token IDs length
761 | len_arr = np.array(dset['len'], dtype=np.uint16)
762 | with open(os.path.join(PIQA_DATA_PATH, f'{split}.len'), 'wb') as f:
763 | np.save(f, len_arr)
764 | # total number of tokens
765 | arr_len = np.sum(dset['len'])
766 | filename = os.path.join(PIQA_DATA_PATH, f'{split}.bin')
767 | dtype = np.uint16
768 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
769 | total_batches = 10
770 |
771 | idx = 0
772 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
773 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
774 | arr_batch = np.concatenate(batch['ids'])
775 | arr[idx: idx + len(arr_batch)] = arr_batch
776 | idx += len(arr_batch)
777 | arr.flush()
778 |
779 | # Load tokenized binary files for training, validation
780 | train_data = np.memmap(os.path.join(PIQA_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
781 | val_data = np.memmap(os.path.join(PIQA_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
782 | print(f'Benchmark PIQA: train[{len(train_data)}] | val[{len(val_data)}]')
783 | if return_torch:
784 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
785 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
786 |
787 | return {
788 | 'train': train_data,
789 | 'train_len': np.load(os.path.join(PIQA_DATA_PATH, 'train.len')),
790 | 'val': val_data,
791 | 'val_len': np.load(os.path.join(PIQA_DATA_PATH, 'val.len')),
792 | }
793 |
794 | def get_sciq(num_proc=10, return_torch=False, pad_to_multiple=1024):
795 | """
796 | Load and process the SciQ dataset.
797 | Tokenize the text and store it in binary format for efficient loading.
798 | """
799 | if not os.path.exists(os.path.join(SCIQ_DATA_PATH, 'val.bin')):
800 | os.makedirs(SCIQ_DATA_PATH, exist_ok=True)
801 |
802 | # Load the SciQ dataset from HuggingFace Datasets
803 | dataset = load_dataset("sciq", split=['train', 'test'])
804 |
805 | data_dict = {
806 | 'train': dataset[0],
807 | 'val': dataset[1],
808 | }
809 |
810 | def process(example):
811 | """
812 | Tokenize the example text by encoding it into token IDs.
813 | """
814 | question = example['question']
815 | answer = example['correct_answer']
816 | # explanation = example.get('support', '') # Explanation text (optional)
817 | explantation = example['support']
818 | # concatenated_text = f"Question: {question} Answer: {answer} {explantation}" #TODO: Try this?
819 | concatenated_text = f"{explantation}{question}{answer}"
820 | ids = tokenize_with_pad(text=concatenated_text,
821 | pad_to_multiple=pad_to_multiple)
822 | return {'ids': ids, 'len': len(ids)}
823 |
824 | # Tokenize and map the dataset
825 | tokenized = {}
826 | for split, dset in data_dict.items():
827 | tokenized[split] = dset.map(
828 | process,
829 | remove_columns=['question', 'distractor1', 'distractor2', 'distractor3','correct_answer', 'support'],
830 | desc=f"Tokenizing {split} split",
831 | num_proc=num_proc
832 | )
833 |
834 | # Concatenate all the token IDs into one large binary file per split
835 | for split, dset in tokenized.items():
836 | # Save token IDs length
837 | len_arr = np.array(dset['len'], dtype=np.uint16)
838 | with open(os.path.join(SCIQ_DATA_PATH, f'{split}.len'), 'wb') as f:
839 | np.save(f, len_arr)
840 | # total number of tokens
841 | arr_len = np.sum(dset['len'])
842 | filename = os.path.join(SCIQ_DATA_PATH, f'{split}.bin')
843 | dtype = np.uint16
844 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
845 | total_batches = 10
846 |
847 | idx = 0
848 | for batch_idx in tqdm(range(total_batches), desc=f'Writing {filename}'):
849 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy').to_dict()
850 | arr_batch = np.concatenate(batch['ids'])
851 | arr[idx: idx + len(arr_batch)] = arr_batch
852 | idx += len(arr_batch)
853 | arr.flush()
854 |
855 |
856 | # Load tokenized binary files for training, validation, and test
857 | train_data = np.memmap(os.path.join(SCIQ_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
858 | val_data = np.memmap(os.path.join(SCIQ_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
859 | print(f'Benchmark SCIQ: train[{len(train_data)}] | val[{len(val_data)}]')
860 |
861 | if return_torch:
862 | train_data = torch.tensor(np.array(train_data, dtype=np.uint16))
863 | val_data = torch.tensor(np.array(val_data, dtype=np.uint16))
864 |
865 | return {
866 | 'train': train_data,
867 | 'train_len': np.load(os.path.join(SCIQ_DATA_PATH, 'train.len')),
868 | 'val': val_data,
869 | 'val_len': np.load(os.path.join(SCIQ_DATA_PATH, 'val.len')),
870 | }
871 |
872 |
873 |
874 | SUPPORTED_TASK_MAP = {"arc_easy": get_arc_easy,
875 | "arc_challenge": get_arc_challenge,
876 | "hellaswag":get_hellaswag,
877 | "logiqa": get_logiqa,
878 | "piqa": get_piqa,
879 | "sciq": get_sciq,
880 | "humaneval": get_humaneval,
881 | "gsm8k": get_gsm8k,
882 | "kodcode": get_kodcode,
883 | "mathqa": get_mathqa,
884 | "medqa": get_medqa}
--------------------------------------------------------------------------------
/src/data/openwebtext2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import numpy as np
4 | import tiktoken
5 | from datasets import load_dataset
6 |
7 |
8 | OWT2_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/openwebtext2/")
9 | tknzr = tiktoken.get_encoding("gpt2")
10 |
11 | def get_openwebtext2_data(num_proc=40):
12 | """ https://openwebtext2.readthedocs.io/en/latest/
13 | """
14 | if not os.path.exists(os.path.join(OWT2_DATA_PATH, 'train.bin')):
15 | os.makedirs(OWT2_DATA_PATH, exist_ok=True)
16 | dataset = load_dataset("the_pile_openwebtext2")
17 |
18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
19 | split_dataset['val'] = split_dataset.pop('test')
20 |
21 | def process(example):
22 | ids = tknzr.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
23 | ids.append(tknzr.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
24 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
25 | out = {'ids': ids, 'len': len(ids)}
26 | return out
27 |
28 | # tokenize the dataset
29 | tokenized = split_dataset.map(
30 | process,
31 | remove_columns=['text'],
32 | desc="tokenizing the splits",
33 | num_proc=num_proc,
34 | )
35 |
36 | # concatenate all the ids in each dataset into one large file we can use for training
37 | for split, dset in tokenized.items():
38 | arr_len = np.sum(dset['len'])
39 | filename = os.path.join(OWT2_DATA_PATH, f'{split}.bin')
40 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
41 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
42 | total_batches = 1024
43 |
44 | idx = 0
45 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
46 | # Batch together samples for faster write
47 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
48 | arr_batch = np.concatenate(batch['ids'])
49 | # Write into mmap
50 | arr[idx : idx + len(arr_batch)] = arr_batch
51 | idx += len(arr_batch)
52 | arr.flush()
53 |
54 | train_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
55 | val_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
56 |
57 | return {'train': train_data, 'val': val_data}
58 |
59 |
--------------------------------------------------------------------------------
/src/data/shakespeare.py:
--------------------------------------------------------------------------------
1 | import os
2 | from string import ascii_letters, digits, punctuation
3 |
4 | import numpy as np
5 | import requests
6 |
7 |
8 | _char_decode = dict(enumerate(sorted(set(ascii_letters + digits + punctuation + " \n"))))
9 | _char_encode = {char: i for i, char in _char_decode.items()}
10 |
11 |
12 | def char_tknzr(txt: str):
13 | return [_char_encode[char] for char in txt if char in _char_encode]
14 |
15 |
16 | DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets", "shakespeare")
17 |
18 | def get_shakespeare_data():
19 | """Inspired from https://github.com/karpathy/nanoGPT/"""
20 | raw_path = os.path.join(DATA_PATH, "raw.txt")
21 | train_path = os.path.join(DATA_PATH, f"train.npy")
22 | test_path = os.path.join(DATA_PATH, f"test.npy")
23 |
24 | # if path is not even there, download all data
25 | if not os.path.exists(DATA_PATH):
26 | print("Downloading raw Shakespeare texts")
27 | url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
28 | os.makedirs(DATA_PATH, exist_ok=True)
29 | text = requests.get(url, timeout=60).text
30 | with open(raw_path, "w+", encoding="utf8") as f:
31 | f.write(text)
32 |
33 | # attempt to find cached version for current tokenizer
34 | if not os.path.exists(train_path) or not os.path.exists(test_path):
35 | print("Tokenizing Shakespeare texts")
36 | # load text
37 | with open(raw_path, encoding="utf8") as f:
38 | text = "".join(f.readlines())
39 | i = int(0.8*len(text))
40 | # encode text
41 | x = np.array(char_tknzr(text[:i]), dtype=np.uint16)
42 | x_test = np.array(char_tknzr(text[i:]), dtype=np.uint16)
43 | # map memory
44 | mem = np.memmap(train_path, dtype=np.uint16, mode="w+", shape=x.shape)
45 | mem[:] = x
46 | mem = np.memmap(test_path, dtype=np.uint16, mode="w+", shape=x_test.shape)
47 | mem[:] = x_test
48 |
49 | # at this point we know that the binfile was properly created so we load it
50 | return {"train": np.memmap(train_path, dtype=np.uint16, mode="r"),
51 | "val": np.memmap(test_path, dtype=np.uint16, mode="r")}
52 |
--------------------------------------------------------------------------------
/src/data/slimpajama.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import tiktoken
4 | from datasets import load_dataset
5 | import os
6 |
7 |
8 | SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/")
9 | SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1")
10 |
11 |
12 | tknzr = tiktoken.get_encoding("gpt2")
13 |
14 |
15 | def get_slimpajama_data(num_proc=40):
16 | if not os.path.exists(os.path.join(SPJ_DATA_PATH, "train.bin")):
17 | os.makedirs(SPJ_DATA_PATH, exist_ok=True)
18 | dataset = load_dataset("DKYoon/SlimPajama-6B")
19 |
20 | split_dataset = dataset["train"].train_test_split(
21 | test_size=0.0005, seed=2357, shuffle=True
22 | )
23 | split_dataset["val"] = split_dataset.pop("test")
24 |
25 | def process(example):
26 | ids = tknzr.encode_ordinary(
27 | example["text"]
28 | ) # encode_ordinary ignores any special tokens
29 | ids.append(
30 | tknzr.eot_token
31 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
32 | out = {"ids": ids, "len": len(ids)}
33 | return out
34 |
35 | # tokenize the dataset
36 | tokenized = split_dataset.map(
37 | process,
38 | remove_columns=["text"],
39 | desc="tokenizing the splits",
40 | num_proc=num_proc,
41 | )
42 |
43 | # concatenate all the ids in each dataset into one large file we can use for training
44 | for split, dset in tokenized.items():
45 | arr_len = np.sum(dset["len"])
46 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin")
47 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
48 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
49 | total_batches = min(1024, len(dset))
50 |
51 | idx = 0
52 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
53 | # Batch together samples for faster write
54 | batch = dset.shard(
55 | num_shards=total_batches, index=batch_idx, contiguous=True
56 | ).with_format("numpy")
57 | arr_batch = np.concatenate(batch["ids"])
58 | # Write into mmap
59 | arr[idx : idx + len(arr_batch)] = arr_batch
60 | idx += len(arr_batch)
61 | arr.flush()
62 |
63 | train_data = np.memmap(
64 | os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r"
65 | )
66 | val_data = np.memmap(
67 | os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r"
68 | )
69 |
70 | return {"train": train_data, "val": val_data}
71 |
72 |
73 | def get_slimpajama_chunk1(num_proc=40):
74 | if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")):
75 | os.makedirs(SPJ_DATA_PATH, exist_ok=True)
76 | dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1")
77 |
78 | split_dataset = dataset["train"].train_test_split(
79 | test_size=0.0005, seed=2357, shuffle=True
80 | )
81 | split_dataset["val"] = split_dataset.pop("test")
82 |
83 | def process(example):
84 | ids = tknzr.encode_ordinary(
85 | example["text"]
86 | ) # encode_ordinary ignores any special tokens
87 | ids.append(
88 | tknzr.eot_token
89 | ) # add the end of text token, e.g. 50256 for gpt2 bpe
90 | out = {"ids": ids, "len": len(ids)}
91 | return out
92 |
93 | # tokenize the dataset
94 | tokenized = split_dataset.map(
95 | process,
96 | remove_columns=["text"],
97 | desc="tokenizing the splits",
98 | num_proc=num_proc,
99 | )
100 |
101 | # concatenate all the ids in each dataset into one large file we can use for training
102 | for split, dset in tokenized.items():
103 | arr_len = np.sum(dset["len"])
104 | filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin")
105 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
106 | arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
107 | total_batches = min(1024, len(dset))
108 |
109 | idx = 0
110 | for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
111 | # Batch together samples for faster write
112 | batch = dset.shard(
113 | num_shards=total_batches, index=batch_idx, contiguous=True
114 | ).with_format("numpy")
115 | arr_batch = np.concatenate(batch["ids"])
116 | # Write into mmap
117 | arr[idx : idx + len(arr_batch)] = arr_batch
118 | idx += len(arr_batch)
119 | arr.flush()
120 |
121 | train_data = np.memmap(
122 | os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r"
123 | )
124 | val_data = np.memmap(
125 | os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r"
126 | )
127 |
128 | return {"train": train_data, "val": val_data}
129 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Dict
3 | import torch
4 |
5 | from .shakespeare import get_shakespeare_data
6 | from .wikitext import get_wikitext_data
7 | from .arxiv import get_arxiv_2000, get_arxiv_full
8 | from .openwebtext2 import get_openwebtext2_data
9 | from .slimpajama import get_slimpajama_data
10 | from .benchmarks import get_mathqa
11 |
12 |
13 | def get_dataset(args) -> Dict[str, np.ndarray]:
14 | """ Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is
15 | contained in its own python file. The expected format at the moment is a dictionary of np.memmap
16 | containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """
17 | if args.dataset == 'wikitext':
18 | return get_wikitext_data()
19 | if args.dataset == "shakespeare-char":
20 | return get_shakespeare_data()
21 | if args.dataset == "arxiv2000":
22 | return get_arxiv_2000()
23 | if args.dataset == "arxiv":
24 | return get_arxiv_full()
25 | if args.dataset == "arxiv+wiki":
26 | arxiv_data = get_arxiv_full()
27 | wiki_data = get_wikitext_data()
28 | train_data = np.concatenate((arxiv_data['train'], wiki_data['train']))
29 | val_data = np.concatenate((arxiv_data['val'], wiki_data['val']))
30 | return {'train': train_data, 'val': val_data}
31 | if args.dataset == 'openwebtext2':
32 | return get_openwebtext2_data()
33 | if args.dataset == "slimpajama":
34 | return get_slimpajama_data()
35 | if args.dataset == "mathqa":
36 | return get_mathqa()
37 | else:
38 | raise NotImplementedError(f"Unknow dataset key '{args.dataset}'")
39 |
40 | class Dataset(torch.utils.data.Dataset):
41 | def __init__(self, data, sequence_length):
42 | super().__init__()
43 | self.data = data
44 | self.sequence_length = sequence_length
45 |
46 | def __len__(self):
47 | total_length = len(self.data)
48 | # chunk the data into sequences of length `sequence_length`
49 | # NOTE: we discard the last remainding sequence if it's not of length `sequence_length`
50 | return (total_length - 1) // self.sequence_length
51 |
52 | def __getitem__(self, idx):
53 | seq_length = self.sequence_length
54 | idx = idx * seq_length
55 | x = torch.from_numpy((self.data[idx : idx + seq_length]).astype(np.int64))
56 |
57 | y = torch.from_numpy(
58 | (self.data[idx + 1 : idx + 1 + seq_length]).astype(np.int64)
59 | )
60 | return x, y
61 |
62 |
63 | def get_dataloader(data, sequence_length, batch_size, seed=0, distributed_backend=None):
64 | """Create a DataLoader for the given data. If distributed_backend is provided and is truly
65 | distributed (world size > 1), the DataLoader will be created with a DistributedSampler that
66 | splits the data across the processes (in conjunction with DDP).
67 | Otherwise, use a RandomSampler with the specified seed.
68 |
69 | Returns both the dataloader and the sampler.
70 | """
71 | dataset = Dataset(data, sequence_length=sequence_length)
72 | if distributed_backend and distributed_backend.get_world_size() > 1:
73 | sampler = torch.utils.data.DistributedSampler(
74 | dataset,
75 | shuffle=True,
76 | seed=seed,
77 | )
78 | else:
79 | g = torch.Generator()
80 | g.manual_seed(seed)
81 | sampler = torch.utils.data.RandomSampler(
82 | dataset, replacement=False, generator=g
83 | )
84 |
85 | loader = torch.utils.data.DataLoader(
86 | dataset,
87 | sampler=sampler,
88 | batch_size=batch_size,
89 | num_workers=4,
90 | )
91 | return loader, sampler
92 |
--------------------------------------------------------------------------------
/src/data/wikitext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import urllib
4 | import numpy as np
5 | import tiktoken
6 |
7 |
8 | WIKITEXT_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/wikitext/")
9 |
10 |
11 | def get_wikitext_data():
12 | """ Inspired from https://github.com/tysam-code/hlb-gpt """
13 | if not os.path.exists(WIKITEXT_DATA_PATH):
14 | os.makedirs(WIKITEXT_DATA_PATH, exist_ok=True)
15 | print("downloading data and tokenizing (1-2 min)")
16 | raw_data_source = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip'
17 | urllib.request.urlretrieve(raw_data_source, os.path.join(WIKITEXT_DATA_PATH,'data.zip'))
18 |
19 | with zipfile.ZipFile(os.path.join(WIKITEXT_DATA_PATH, "data.zip"), 'r') as zip_ref:
20 | zip_ref.extractall(WIKITEXT_DATA_PATH)
21 |
22 | with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.train.raw"), 'r') as data_file:
23 | raw_train_data = data_file.read()
24 |
25 | with open(os.path.join(WIKITEXT_DATA_PATH, "wikitext-103-raw/wiki.valid.raw"), 'r') as data_file:
26 | raw_eval_data = data_file.read()
27 |
28 | tokenizer = tiktoken.get_encoding("gpt2")
29 | raw_tokenized_train = tokenizer.encode_ordinary(raw_train_data)
30 | raw_tokenized_eval = tokenizer.encode_ordinary(raw_eval_data)
31 |
32 | train_tokenized = np.array(raw_tokenized_train, dtype=np.uint16)
33 | eval_tokenized = np.array(raw_tokenized_eval, dtype=np.uint16)
34 |
35 | train_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'))
36 | eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'))
37 | print("completed the tokenization process!")
38 |
39 | train_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
40 | val_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')
41 |
42 | return {'train': train_data, 'val': val_data}
43 |
--------------------------------------------------------------------------------
/src/distributed/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from . import ddp
3 | from . import single
4 |
5 | BACKEND_TYPE_TO_MODULE_MAP = {
6 | "nccl": ddp.DataParallelDistributedBackend,
7 | None: single.SinlgeNodeBackend,
8 | }
9 |
10 |
11 | def make_backend_from_args(args):
12 | return BACKEND_TYPE_TO_MODULE_MAP[args.distributed_backend](args)
13 |
14 |
15 | def registered_backends():
16 | return BACKEND_TYPE_TO_MODULE_MAP.keys()
17 |
--------------------------------------------------------------------------------
/src/distributed/backend.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import List
3 |
4 |
5 | class DistributedBackend(object):
6 |
7 | def __init__(self, args):
8 | pass
9 |
10 | def transform_model(self, model):
11 | raise NotImplementedError
12 |
13 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps):
14 | raise NotImplementedError
15 |
16 | def is_master_process(self) -> bool:
17 | raise NotImplementedError
18 |
19 | def get_adjusted_args_for_process(self, args):
20 | raise NotImplementedError
21 |
22 | def get_raw_model(self, model):
23 | raise NotImplementedError
24 |
25 | def translate_model_parameter_name_for_node(self, parameter_name) -> List[str]:
26 | raise NotImplementedError
27 |
28 | def get_world_size(self):
29 | raise NotImplementedError
30 |
31 | def finalize(self):
32 | pass
33 |
34 | def sync(self):
35 | pass
36 |
--------------------------------------------------------------------------------
/src/distributed/ddp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from contextlib import contextmanager
4 |
5 | from torch.nn.parallel import DistributedDataParallel as DDP
6 | from torch.distributed import init_process_group, destroy_process_group, get_world_size, barrier
7 |
8 | from .backend import DistributedBackend
9 |
10 |
11 | class DataParallelDistributedBackend(DistributedBackend):
12 |
13 | def __init__(self, args):
14 | self.rank = int(os.environ.get('RANK', -1))
15 | assert self.rank != -1, "DDP backend can not be used without rank"
16 | assert "cuda" in args.device, "DDP backend can not be used on non-CUDA devices"
17 | init_process_group(backend=args.distributed_backend)
18 | self.local_rank = int(os.environ['LOCAL_RANK'])
19 |
20 | def get_adjusted_args_for_process(self, args):
21 | effective_batch_size = args.batch_size * args.acc_steps
22 | world_size = self.get_world_size()
23 | if args.acc_steps % world_size != 0:
24 | raise ValueError(f"Number of accumulation steps "
25 | "{args.acc_steps} is not divisible "
26 | "by the world size {world_size}.")
27 | if effective_batch_size % world_size != 0:
28 | raise ValueError(f"Effective batch size "
29 | "{effective_batch_size} is not divisible "
30 | "by the world size {world_size}.")
31 | acc_steps_div = math.gcd(args.acc_steps, world_size)
32 | args.acc_steps = args.acc_steps // acc_steps_div
33 | args.batch_size = args.batch_size // (world_size // acc_steps_div)
34 | args.device = f'cuda:{self.local_rank}'
35 | args.seed = args.seed + self.local_rank
36 | return args
37 |
38 | def transform_model(self, model):
39 | return DDP(model, device_ids=[self.local_rank])
40 |
41 | @contextmanager
42 | def get_context_for_microstep_forward(self, model, microstep_idx, gradient_accumulation_steps):
43 | model.require_backward_grad_sync = (
44 | microstep_idx == gradient_accumulation_steps - 1)
45 | yield
46 |
47 | def is_master_process(self) -> bool:
48 | return self.rank == 0
49 |
50 | def get_raw_model(self, model):
51 | return model.module
52 |
53 | def translate_model_parameter_name_for_node(self, parameter_name):
54 | return [f'module.{parameter_name}']
55 |
56 | def get_world_size(self):
57 | return get_world_size()
58 |
59 | def finalize(self):
60 | destroy_process_group()
61 |
62 | def sync(self):
63 | barrier()
64 |
--------------------------------------------------------------------------------
/src/distributed/single.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 |
3 | from .backend import DistributedBackend
4 |
5 |
6 | class SinlgeNodeBackend(DistributedBackend):
7 |
8 | def transform_model(self, model):
9 | return model
10 |
11 | def get_context_for_microstep_forward(self, *args, **kwargs):
12 | return nullcontext()
13 |
14 | def get_adjusted_args_for_process(self, args):
15 | return args
16 |
17 | def is_master_process(self) -> bool:
18 | return True
19 |
20 | def get_raw_model(self, model):
21 | return model
22 |
23 | def get_world_size(self):
24 | return 1
25 |
26 | def translate_model_parameter_name_for_node(self, parameter_name):
27 | return [parameter_name]
28 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 | import inspect
6 | import json
7 | import copy
8 | import argparse
9 | import random
10 | import wandb
11 |
12 | import config
13 | from models.utils import get_model
14 | from data.utils import get_dataset
15 | from optim.base import train_base
16 | import distributed
17 |
18 |
19 | def get_args():
20 | parser = argparse.ArgumentParser(allow_abbrev=False)
21 | parser.add_argument('--config_format', default='base', choices=config.registered_formats())
22 |
23 | args, rem_args = parser.parse_known_args()
24 |
25 | return config.parse_args_with_format(format=args.config_format, base_parser=parser, args=rem_args, namespace=args)
26 |
27 |
28 | def main(args):
29 |
30 | torch.backends.cuda.matmul.allow_tf32 = True # allows us to make sure we're able to use tensorfloat32 during training
31 | torch.backends.cudnn.allow_tf32 = True
32 |
33 | distributed_backend = distributed.make_backend_from_args(args)
34 | args = distributed_backend.get_adjusted_args_for_process(args)
35 |
36 | args.device = torch.device(args.device)
37 | device_type = "cuda" if "cuda" in str(args.device) else "cpu"
38 | if device_type == "cuda":
39 | torch.cuda.set_device(args.device)
40 |
41 | torch.manual_seed(args.seed)
42 | random.seed(args.seed)
43 | np.random.seed(args.seed)
44 |
45 | print(f"Loading dataset '{args.dataset}'")
46 |
47 | data = get_dataset(args) # data is a dict: {'train': train_tokenized, 'val': eval_tokenized}
48 | if args.data_in_ram:
49 | data = {'train': np.array(data['train']), 'val': np.array(data['val'])}
50 |
51 | print(f"Num training tokens: {len(data['train'])}")
52 | print(f"Num validation tokens: {len(data['val'])}")
53 |
54 | model = get_model(args).to(args.device) # todo: take care of initializing the model if args.use_pretrained != 'none'
55 |
56 | model = distributed_backend.transform_model(model)
57 |
58 | group_specs = distributed_backend.get_raw_model(model).get_parameter_group_specs()
59 | param_name_mapping = {p_name: p for p_name, p in model.named_parameters()}
60 | optimized_params_cnt = 0
61 | for g in group_specs:
62 | params = []
63 | for p_name in g["params"]:
64 | translated_p_names = distributed_backend.translate_model_parameter_name_for_node(p_name)
65 | params += [param_name_mapping[p_name] for p_name in translated_p_names]
66 | g["params"] = params
67 | optimized_params_cnt += sum([p.numel() for p in g["params"]])
68 | print("number of optimized parameters: %.2fM" % (optimized_params_cnt/1e6,))
69 | if args.opt == 'adamw':
70 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
71 | print(f"using fused AdamW: {use_fused}")
72 | extra_args = dict(fused=True) if use_fused else dict()
73 | opt = torch.optim.AdamW(group_specs, lr=args.lr, betas=(args.beta1, args.beta2),
74 | weight_decay=args.weight_decay, **extra_args)
75 | else:
76 | opt = torch.optim.SGD(group_specs, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
77 |
78 | if args.scheduler != 'none':
79 | if args.scheduler in ['cos', 'linear']:
80 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=args.lr, total_steps=args.iterations,
81 | pct_start=args.warmup_percent, anneal_strategy=args.scheduler,
82 | cycle_momentum=False, div_factor=1e2, final_div_factor=.1)
83 | else:
84 | raise NotImplementedError(f"Unknown scheduler type: {args.scheduler}.")
85 | else:
86 | scheduler = None
87 |
88 | args.world_size = distributed_backend.get_world_size()
89 | exp_name = args.exp_name
90 | if distributed_backend.is_master_process() and args.wandb:
91 | params_copy = copy.deepcopy(vars(args))
92 | del params_copy['device']
93 | wandb.init(project=args.wandb_project, name=exp_name, config=params_copy)
94 |
95 | ckpt_path = os.path.join(args.results_base_folder, args.dataset, args.model, exp_name)
96 | if not os.path.exists(ckpt_path):
97 | if distributed_backend.is_master_process():
98 | os.makedirs(ckpt_path)
99 | distributed_backend.sync()
100 | elif os.path.isfile(os.path.join(ckpt_path, "summary.json")): # the experiment was already completed
101 | print(f"Already found experiment '{ckpt_path}'.\nSkipping.")
102 | sys.exit(0)
103 |
104 | itr = 0
105 | rng_state_dict = None
106 | if args.use_pretrained == "auto":
107 | checkpoints = [file for file in os.listdir(ckpt_path) if 'ckpt_' in file]
108 | if checkpoints:
109 | args.use_pretrained = sorted(checkpoints)[-1]
110 | else:
111 | args.use_pretrained = None
112 |
113 | if args.use_pretrained is not None:
114 | last_ckpt_path = args.use_pretrained
115 | print(f"Resuming from {last_ckpt_path}")
116 | checkpoint = torch.load(os.path.join(ckpt_path, last_ckpt_path))
117 | model_state_dict = {distributed_backend.translate_model_parameter_name_for_node(k.replace("_orig_mod.", ""))[0]:v for k,v in checkpoint['model'].items()}
118 | # FIXME checkpoints from compiled model have _orig_mod keyword
119 |
120 | optimizer_state_dict = checkpoint['optimizer']
121 | rng_state_dict = {
122 | module: checkpoint[module] for module in [
123 | "cpu_rng_state",
124 | "gpu_rng_state",
125 | "numpy_rng_state",
126 | "py_rng_state",
127 | "train_sampler_state"
128 | ]
129 | }
130 |
131 | model.load_state_dict(model_state_dict)
132 | opt.load_state_dict(optimizer_state_dict)
133 | itr = checkpoint['itr']
134 | if scheduler is not None:
135 | scheduler_state_dict = checkpoint['scheduler']
136 | scheduler.load_state_dict(scheduler_state_dict)
137 |
138 | if args.model in ['base', 'llama2']: # all train functions have the same interface
139 | train = train_base
140 | else:
141 | raise NotImplementedError(f"No training method implemented for model type '{args.model}'.")
142 |
143 | print(f"\nTraining model={args.model} \n{vars(args)}\n")
144 |
145 | stats = train(model, opt, data, args.data_seed, scheduler, args.iterations, args.acc_steps, args.batch_size, args.sequence_length,
146 | eval_freq=args.eval_freq,
147 | distributed_backend=distributed_backend,
148 | ckpt_path=f"{ckpt_path}/ckpt.pt", itr=itr, rng_state_dict=rng_state_dict, extra_args=args)
149 |
150 | args.device = None
151 | args.dtype = None
152 | stats['args'] = vars(args)
153 | if distributed_backend.is_master_process():
154 | with open(f"{ckpt_path}/summary.json", "w") as fs:
155 | json.dump(stats, fs)
156 | distributed_backend.finalize()
157 |
158 |
159 | if __name__ == "__main__":
160 | args = get_args()
161 | main(args)
162 |
--------------------------------------------------------------------------------
/src/models/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Full definition of a GPT Language Model, all of it in this single file.
3 | References:
4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5 | https://github.com/openai/gpt-2/blob/master/src/model.py
6 | 2) huggingface/transformers PyTorch implementation:
7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8 | """
9 |
10 | import math
11 | import inspect
12 |
13 | import tiktoken
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
21 |
22 | def __init__(self, ndim, bias):
23 | super().__init__()
24 | self.weight = nn.Parameter(torch.ones(ndim))
25 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
26 |
27 | def forward(self, input):
28 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
29 |
30 |
31 | class CausalSelfAttention(nn.Module):
32 |
33 | def __init__(self, config):
34 | super().__init__()
35 | assert config.n_embd % config.n_head == 0
36 | # key, query, value projections for all heads, but in a batch
37 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
38 | # output projection
39 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
40 | # regularization
41 | self.attn_dropout = nn.Dropout(config.dropout)
42 | self.resid_dropout = nn.Dropout(config.dropout)
43 | self.n_head = config.n_head
44 | self.n_embd = config.n_embd
45 | self.dropout = config.dropout
46 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
47 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
48 | if not self.flash:
49 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
50 | # causal mask to ensure that attention is only applied to the left in the input sequence
51 | self.register_buffer("bias", torch.tril(torch.ones(config.sequence_length, config.sequence_length))
52 | .view(1, 1, config.sequence_length, config.sequence_length))
53 |
54 | def forward(self, x):
55 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
56 |
57 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
58 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
59 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
61 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
62 |
63 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
64 | if self.flash:
65 | # efficient attention using Flash Attention CUDA kernels
66 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
67 | else:
68 | # manual implementation of attention
69 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
70 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
71 | att = F.softmax(att, dim=-1)
72 | att = self.attn_dropout(att)
73 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
74 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
75 |
76 | # output projection
77 | y = self.resid_dropout(self.c_proj(y))
78 | return y
79 |
80 |
81 | class MLP(nn.Module):
82 |
83 | def __init__(self, config):
84 | super().__init__()
85 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
86 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
87 | self.dropout = nn.Dropout(config.dropout)
88 | self.activation = nn.GELU()
89 |
90 | def forward(self, x):
91 | x = self.c_fc(x)
92 | x = self.activation(x)
93 | x = self.c_proj(x)
94 | x = self.dropout(x)
95 | return x
96 |
97 |
98 | class Block(nn.Module):
99 |
100 | def __init__(self, config):
101 | super().__init__()
102 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
103 | self.attn = CausalSelfAttention(config)
104 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
105 | self.mlp = MLP(config)
106 |
107 | def forward(self, x):
108 | x = x + self.attn(self.ln_1(x))
109 | x = x + self.mlp(self.ln_2(x))
110 | return x
111 |
112 |
113 | class GPTBase(nn.Module):
114 |
115 | def __init__(self, config):
116 | super().__init__()
117 | assert config.vocab_size is not None
118 | assert config.sequence_length is not None
119 | self.config = config
120 | self.tokenizer = tiktoken.get_encoding("gpt2")
121 |
122 | self.transformer = nn.ModuleDict(dict(
123 | wte = nn.Embedding(config.vocab_size, config.n_embd),
124 | wpe = nn.Embedding(config.sequence_length, config.n_embd),
125 | drop = nn.Dropout(config.dropout),
126 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
127 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
128 | ))
129 |
130 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
131 | # with weight tying when using torch.compile() some warnings get generated:
132 | # "UserWarning: functional_call was passed multiple values for tied weights.
133 | # This behavior is deprecated and will be an error in future versions"
134 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
135 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
136 |
137 | # init all weights
138 | self.apply(self._init_weights)
139 | # apply special scaled init to the residual projections, per GPT-2 paper
140 | for pn, p in self.named_parameters():
141 | if pn.endswith('c_proj.weight'):
142 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
143 |
144 | # report number of parameters
145 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
146 |
147 | def get_num_params(self, non_embedding=True):
148 | """
149 | Return the number of parameters in the model.
150 | For non-embedding count (default), the position embeddings get subtracted.
151 | The token embeddings would too, except due to the parameter sharing these
152 | params are actually used as weights in the final layer, so we include them.
153 | """
154 | n_params = sum(p.numel() for p in self.parameters())
155 | if non_embedding:
156 | n_params -= self.transformer.wpe.weight.numel()
157 | return n_params
158 |
159 | def _init_weights(self, module):
160 | if isinstance(module, nn.Linear):
161 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
162 | if module.bias is not None:
163 | torch.nn.init.zeros_(module.bias)
164 | elif isinstance(module, nn.Embedding):
165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
166 |
167 | def forward(self, idx, targets=None, get_logits=False):
168 | device = idx.device
169 | b, t = idx.size()
170 | assert t <= self.config.sequence_length, f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
171 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
172 |
173 | # forward the GPT model itself
174 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
175 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
176 | x = self.transformer.drop(tok_emb + pos_emb)
177 | for block in self.transformer.h:
178 | x = block(x)
179 | x = self.transformer.ln_f(x)
180 |
181 | if targets is not None:
182 | # if we are given some desired targets also calculate the loss
183 | logits = self.lm_head(x)
184 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
185 | else:
186 | # inference-time mini-optimization: only forward the lm_head on the very last position
187 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
188 | loss = None
189 | logits = logits if get_logits else None
190 | return {'logits': logits, 'loss': loss}
191 |
192 | def crop_sequence_length(self, sequence_length):
193 | # model surgery to decrease the block size if necessary
194 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
195 | # but want to use a smaller block size for some smaller, simpler model
196 | assert sequence_length <= self.config.sequence_length
197 | self.config.sequence_length = sequence_length
198 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:sequence_length])
199 | for block in self.transformer.h:
200 | block.attn.bias = block.attn.bias[:,:,:sequence_length,:sequence_length]
201 |
202 | @classmethod
203 | def from_pretrained(cls, model_type, override_args=None):
204 | # TODO
205 | pass
206 |
207 | def get_parameter_group_specs(self):
208 | """
209 | This long function is unfortunately doing something very simple and is being very defensive:
210 | We are separating out all parameters of the model into two buckets: those that will experience
211 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
212 | We are then returning the PyTorch optimizer object.
213 | """
214 |
215 | # separate out all parameters to those that will and won't experience regularizing weight decay
216 | decay = set()
217 | no_decay = set()
218 | whitelist_weight_modules = (torch.nn.Linear,)
219 | # need to do import here to avoid circular import (since llama imports from base here)
220 | from .utils import BLACKLIST_WEIGHT_MODULES
221 |
222 | for mn, m in self.named_modules():
223 | for pn, p in m.named_parameters():
224 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
225 | # random note: because named_modules and named_parameters are recursive
226 | # we will see the same tensors p many many times. but doing it this way
227 | # allows us to know which parent module any tensor p belongs to...
228 | if pn.endswith("bias"):
229 | # all biases will not be decayed
230 | no_decay.add(fpn)
231 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
232 | # weights of whitelist modules will be weight decayed
233 | decay.add(fpn)
234 | elif pn.endswith("weight") and isinstance(m, BLACKLIST_WEIGHT_MODULES):
235 | # weights of blacklist modules will NOT be weight decayed
236 | no_decay.add(fpn)
237 |
238 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
239 | # will appear in the no_decay and decay sets respectively after the above.
240 | # In addition, because named_parameters() doesn't return duplicates, it
241 | # will only return the first occurence, key'd by 'transformer.wte.weight', below.
242 | # so let's manually remove 'lm_head.weight' from decay set. This will include
243 | # this tensor into optimization via transformer.wte.weight only, and not decayed.
244 | decay.remove("lm_head.weight")
245 |
246 | # validate that we considered every parameter
247 | param_dict = {pn: p for pn, p in self.named_parameters()}
248 | inter_params = decay & no_decay
249 | union_params = decay | no_decay
250 | assert (
251 | len(inter_params) == 0
252 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
253 | assert (
254 | len(param_dict.keys() - union_params) == 0
255 | ), "parameters %s were not separated into either decay/no_decay set!" % (
256 | str(param_dict.keys() - union_params),
257 | )
258 |
259 | # create the pytorch optimizer object
260 | return [
261 | {"params": sorted(list(decay))},
262 | {"params": sorted(list(no_decay)), "weight_decay": 0.0},
263 | ]
264 |
265 |
266 | @torch.no_grad()
267 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
268 | """
269 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
270 | the sequence max_new_tokens times, feeding the predictions back into the model each time.
271 | Most likely you'll want to make sure to be in model.eval() mode of operation for this.
272 | """
273 | for _ in range(max_new_tokens):
274 | # if the sequence context is growing too long we must crop it at sequence_length
275 | idx_cond = idx if idx.size(1) <= self.config.sequence_length else idx[:, -self.config.sequence_length:]
276 | # forward the model to get the logits for the index in the sequence
277 | logits = self(idx_cond, get_logits=True)['logits']
278 | # pluck the logits at the final step and scale by desired temperature
279 | logits = logits[:, -1, :] / temperature
280 | # optionally crop the logits to only the top k options
281 | if top_k is not None:
282 | v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
283 | logits[logits < v[:, [-1]]] = -float('Inf')
284 | # apply softmax to convert logits to (normalized) probabilities
285 | probs = F.softmax(logits, dim=-1)
286 | # sample from the distribution
287 | idx_next = torch.multinomial(probs, num_samples=1)
288 | # append sampled index to the running sequence and continue
289 | idx = torch.cat((idx, idx_next), dim=1)
290 |
291 | return idx
292 |
293 | @torch.no_grad()
294 | def generate_from_string(self, in_str, max_new_tokens, temperature=1.0, top_k=None):
295 | idx = torch.tensor(self.tokenizer.encode(in_str, allowed_special={"<|endoftext|>"})).view(1,-1).to(self.lm_head.weight.device)
296 | out_idx = self.generate(idx, max_new_tokens, temperature, top_k).view(-1).to('cpu').numpy()
297 | return self.tokenizer.decode(out_idx)
298 |
--------------------------------------------------------------------------------
/src/models/llama.py:
--------------------------------------------------------------------------------
1 | """
2 | Llama style Language Model.
3 | References:
4 | 1) Llama inference code:
5 | https://github.com/facebookresearch/llama/blob/main/llama/model.py
6 | 2) Mistral one file ref:
7 | https://github.com/mistralai/mistral-src/blob/main/one_file_ref.py
8 | 3) Llama paper:
9 | https://arxiv.org/pdf/2302.13971.pdf
10 |
11 | Main differences from GPT2:
12 | * Uses RMSNorm instead of LayerNorm
13 | * Uses a slightly different MLP (SwiGLU)
14 | * rotary embeddings (RoPE)
15 | """
16 |
17 | import math
18 |
19 | import tiktoken
20 | import torch
21 | import torch.nn as nn
22 | from torch.nn import functional as F
23 | from models.base import CausalSelfAttention, GPTBase
24 |
25 |
26 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
27 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
28 | t = torch.arange(end, device=freqs.device) # type: ignore
29 | freqs = torch.outer(t, freqs).float() # type: ignore
30 | return torch.polar(torch.ones_like(freqs), freqs) # complex64
31 |
32 |
33 | def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
34 | """
35 | freqs_cis: complex - (seq_len, head_dim / 2)
36 | x: complex - (bsz, seq_len, head_dim / 2)
37 | """
38 | ndim = x.ndim
39 | assert 1 < ndim
40 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
41 | freqs_cis.shape,
42 | (x.shape[1], x.shape[-1]),
43 | )
44 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
45 | return freqs_cis.view(*shape)
46 |
47 |
48 | def apply_rotary_emb(q, k, freqs_cis):
49 | # q, k: (B, T, nh, hs)
50 | # freq_cis: (T, hs)
51 | # return: (B, T, nh, hs), (B, T, nh, hs)
52 | q_ = torch.view_as_complex(q.float().reshape(*q.shape[:-1], -1, 2))
53 | k_ = torch.view_as_complex(k.float().reshape(*k.shape[:-1], -1, 2))
54 | freqs_cis = _reshape_for_broadcast(freqs_cis, q_)
55 | xq_out = torch.view_as_real(q_ * freqs_cis).flatten(3)
56 | xk_out = torch.view_as_real(k_ * freqs_cis).flatten(3)
57 | return xq_out.type_as(q), xk_out.type_as(k)
58 |
59 |
60 | class RMSNorm(nn.Module):
61 | def __init__(self, dim: int, eps: float = 1e-6):
62 | super().__init__()
63 | self.eps = eps
64 | self.weight = nn.Parameter(torch.ones(dim))
65 |
66 | def _norm(self, x):
67 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
68 |
69 | def forward(self, x):
70 | output = self._norm(x.float()).type_as(x)
71 | return output * self.weight
72 |
73 |
74 | class LlamaMLP(nn.Module):
75 | def __init__(self, config):
76 | super().__init__()
77 |
78 | hidden_dim = config.n_embd * 4
79 | hidden_dim = int(2 * hidden_dim / 3)
80 | hidden_dim = config.multiple_of * (
81 | (hidden_dim + config.multiple_of - 1) // config.multiple_of
82 | )
83 |
84 | self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=False)
85 | self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
86 | self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
87 |
88 | def forward(self, x):
89 | return self.c_proj(nn.functional.silu(self.w1(x)) * self.w2(x))
90 |
91 |
92 | class LlamaAttention(CausalSelfAttention):
93 | def forward(self, x, freqs_cis):
94 | # batch size, sequence length, embedding dimensionality (n_embd)
95 | (
96 | B,
97 | T,
98 | C,
99 | ) = x.size()
100 |
101 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
102 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
103 | # (B, T, nh, hs)
104 | k = k.view(B, T, self.n_head, C // self.n_head)
105 | q = q.view(B, T, self.n_head, C // self.n_head)
106 | q, k = apply_rotary_emb(q, k, freqs_cis)
107 | # (B, nh, T, hs)
108 | q, k = q.transpose(1, 2), k.transpose(1, 2)
109 |
110 | # (B, nh, T, hs)
111 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
112 |
113 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
114 | if self.flash:
115 | # efficient attention using Flash Attention CUDA kernels
116 | y = torch.nn.functional.scaled_dot_product_attention(
117 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
118 | )
119 | else:
120 | # manual implementation of attention
121 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
122 | att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
123 | att = F.softmax(att, dim=-1)
124 | att = self.attn_dropout(att)
125 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
126 | y = (
127 | y.transpose(1, 2).contiguous().view(B, T, C)
128 | ) # re-assemble all head outputs side by side
129 |
130 | # output projection
131 | y = self.resid_dropout(self.c_proj(y))
132 | return y
133 |
134 |
135 | class LlamaBlock(nn.Module):
136 | def __init__(self, config):
137 | super().__init__()
138 | self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
139 | self.attn = LlamaAttention(config)
140 | self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps)
141 | self.mlp = LlamaMLP(config)
142 |
143 | def forward(self, x, freqs_cis):
144 | x = x + self.attn(self.ln_1(x), freqs_cis)
145 | x = x + self.mlp(self.ln_2(x))
146 | return x
147 |
148 |
149 | class Llama(GPTBase):
150 | def __init__(self, config):
151 | super().__init__(config)
152 | assert config.vocab_size is not None
153 | assert config.sequence_length is not None
154 | self.config = config
155 | self.tokenizer = tiktoken.get_encoding("gpt2")
156 |
157 | # create the token and position embeddings
158 | self.head_dim = config.n_embd // config.n_head
159 | self.freqs_cis = precompute_freqs_cis(self.head_dim, config.sequence_length)
160 |
161 | self.transformer = nn.ModuleDict(
162 | dict(
163 | wte=nn.Embedding(config.vocab_size, config.n_embd),
164 | drop=nn.Dropout(config.dropout),
165 | h=nn.ModuleList([LlamaBlock(config) for _ in range(config.n_layer)]),
166 | ln_f=RMSNorm(config.n_embd, eps=config.rmsnorm_eps),
167 | )
168 | )
169 |
170 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
171 | # with weight tying when using torch.compile() some warnings get generated:
172 | # "UserWarning: functional_call was passed multiple values for tied weights.
173 | # This behavior is deprecated and will be an error in future versions"
174 | # not 100% sure what this is, so far seems to be harmless. TODO investigate
175 | self.transformer.wte.weight = (
176 | self.lm_head.weight
177 | ) # https://paperswithcode.com/method/weight-tying
178 |
179 | # init all weights
180 | self.apply(self._init_weights)
181 | # apply special scaled init to the residual projections, per GPT-2 paper
182 | for pn, p in self.named_parameters():
183 | if pn.endswith("c_proj.weight"):
184 | torch.nn.init.normal_(
185 | p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
186 | )
187 |
188 |
189 | def get_num_params(self, non_embedding=True):
190 | """
191 | Return the number of parameters in the model.
192 | For non-embedding count (default)
193 | The token embeddings would too, except due to the parameter sharing these
194 | params are actually used as weights in the final layer, so we include them.
195 | """
196 | n_params = sum(p.numel() for p in self.parameters())
197 | return n_params
198 |
199 | def forward(self, idx, targets=None, get_logits=False):
200 | device = idx.device
201 | b, t = idx.size()
202 | assert (
203 | t <= self.config.sequence_length
204 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.sequence_length}"
205 | # shape (1, t)
206 | pos = torch.arange(0, t, dtype=torch.long, device=device)
207 |
208 | # forward the GPT model itself
209 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
210 |
211 | x = self.transformer.drop(tok_emb)
212 | freqs_cis = self.freqs_cis.to(x.device)[pos]
213 |
214 | for block_idx, block in enumerate(self.transformer.h):
215 | x = block(x, freqs_cis=freqs_cis)
216 | x = self.transformer.ln_f(x)
217 |
218 | if targets is not None:
219 | # if we are given some desired targets also calculate the loss
220 | logits = self.lm_head(x)
221 | loss = F.cross_entropy(
222 | logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
223 | )
224 | else:
225 | # inference-time mini-optimization: only forward the lm_head on the very last position
226 | logits = self.lm_head(
227 | x[:, [-1], :]
228 | ) # note: using list [-1] to preserve the time dim
229 | loss = None
230 |
231 | logits = logits if get_logits else None
232 |
233 | return {
234 | "logits": logits,
235 | "loss": loss,
236 | }
237 |
--------------------------------------------------------------------------------
/src/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .llama import Llama, RMSNorm
3 | from .base import GPTBase, LayerNorm
4 |
5 |
6 | BLACKLIST_WEIGHT_MODULES = (
7 | torch.nn.LayerNorm,
8 | LayerNorm,
9 | RMSNorm,
10 | torch.nn.Embedding,
11 | )
12 |
13 |
14 | def get_model(args):
15 | """ Return the right model """
16 | if args.model == 'base':
17 | model = GPTBase(args)
18 | return model
19 | elif args.model == 'llama2':
20 | model = Llama(args)
21 | return model
22 | else:
23 | raise KeyError(f"Unknown model '{args.model}'.")
24 |
--------------------------------------------------------------------------------
/src/optim/base.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from data.utils import get_dataloader
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import wandb
7 | import time
8 | import itertools
9 | import copy
10 | import random
11 | import os
12 | import numpy as np
13 | from .utils import eval, get_batch, save_checkpoint
14 |
15 |
16 | def train_base(model, opt, data, data_seed, scheduler, iterations, acc_steps, batch_size, sequence_length, eval_freq, ckpt_path, distributed_backend,extra_args, itr=0,rng_state_dict=None):
17 | device_type = 'cuda' if 'cuda' in str(extra_args.device) else 'cpu'
18 | type_ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(
19 | device_type=device_type, dtype=torch.float16) # extra_args.dtype)
20 | best_val_loss, text_table = float('inf'), None # best_val_loss not used atm, early stopping not recommended but possible
21 | substep = itr * acc_steps
22 | data["train"], train_sampler = get_dataloader(
23 | data["train"],
24 | sequence_length=sequence_length,
25 | batch_size=batch_size,
26 | seed=data_seed,
27 | distributed_backend=distributed_backend,
28 | )
29 |
30 | data["val"], val_sampler = get_dataloader(
31 | data["val"],
32 | sequence_length=sequence_length,
33 | batch_size=batch_size,
34 | seed=data_seed,
35 | )
36 |
37 | num_substeps_per_epoch = len(data["train"])
38 | train_epochs = substep//num_substeps_per_epoch
39 |
40 | if rng_state_dict is not None and rng_state_dict.get("train_sampler_state", None) is not None:
41 | train_sampler.generator.set_state(rng_state_dict["train_sampler_state"])
42 | if hasattr(train_sampler, "set_epoch"):
43 | train_sampler.set_epoch(train_epochs)
44 | else:
45 | sampler_state_before_iter = train_sampler.generator.get_state()
46 | data_train_iter = iter(data["train"])
47 |
48 |
49 | # for val data we don't care about epochs? just cycle through (no need to set_epoch to reshuffle)
50 | data_val_iter = itertools.cycle(data["val"])
51 |
52 | stats = {"train_loss": [], "val_loss": [], "val_pp": [], "val_acc": []}
53 |
54 |
55 |
56 | if extra_args.compile:
57 | print(f"Compiling model ...")
58 | model = torch.compile(model) # requires pytorch 2.0+
59 |
60 | model.train()
61 |
62 | t0 = time.time()
63 |
64 | if rng_state_dict is not None:
65 | torch.set_rng_state(rng_state_dict["cpu_rng_state"])
66 | torch.cuda.set_rng_state(rng_state_dict["gpu_rng_state"])
67 | np.random.set_state(rng_state_dict["numpy_rng_state"])
68 | random.setstate(rng_state_dict["py_rng_state"])
69 | for _ in range(substep % num_substeps_per_epoch):
70 | get_batch(data_train_iter, device=extra_args.device)
71 |
72 |
73 | while itr < iterations:
74 |
75 | for microstep_idx in range(acc_steps): # gradient accumulation
76 | x, y = get_batch(data_train_iter, device=extra_args.device)
77 |
78 | with type_ctx:
79 | with distributed_backend.get_context_for_microstep_forward(model=model, microstep_idx=microstep_idx, gradient_accumulation_steps=acc_steps):
80 | outputs = model(x, targets=y)
81 |
82 | loss = outputs['loss'] / acc_steps
83 | loss.backward()
84 | substep += 1
85 | if substep % len(data["train"]) == 0:
86 | train_epochs += 1
87 | print(f"Train epoch {train_epochs} done (full pass over training data)")
88 | if hasattr(train_sampler, "set_epoch"):
89 | # set epoch for reshuffling between epochs
90 | train_sampler.set_epoch(train_epochs)
91 | sampler_state_before_iter = None
92 | else:
93 | sampler_state_before_iter = train_sampler.generator.get_state()
94 | data_train_iter = iter(data["train"])
95 |
96 |
97 | if extra_args.grad_clip != 0.0:
98 | torch.nn.utils.clip_grad_norm_(model.parameters(), extra_args.grad_clip)
99 | opt.step()
100 | scheduler.step()
101 | opt.zero_grad(set_to_none=True)
102 | itr += 1
103 |
104 | if itr % eval_freq == 0 or itr == iterations: # from here it's only evaluation code, all the training is above
105 | if distributed_backend.is_master_process():
106 | t1 = time.time()
107 | dt = t1 - t0
108 | epoch = substep//num_substeps_per_epoch
109 |
110 | model.eval()
111 | train_loss = loss.detach().cpu().item() * acc_steps
112 | current_lr = scheduler.get_last_lr()[0] if scheduler is not None else extra_args.lr
113 |
114 | eval_steps = (
115 | 24 if itr < iterations else len(data["val"])
116 | )
117 | val_acc, val_loss, val_perplexity = eval(
118 | model,
119 | data_val_iter,
120 | extra_args.device,
121 | max_num_batches=eval_steps,
122 | ctx=type_ctx,
123 | )
124 |
125 | print_string = f"{epoch}/{itr} [train] loss={train_loss:.3f} [val] loss={val_loss:.3f}, pp={val_perplexity:.2f}, acc={val_acc:3f}"
126 | print_string += f" [time per itr] {dt*1000/eval_freq:.2f}ms"
127 | if scheduler is not None:
128 | print_string += f" [lr] {current_lr:.5f}"
129 | print(print_string)
130 |
131 | if extra_args.wandb:
132 | logs = {
133 | "iter": itr,
134 | "train/loss": train_loss,
135 | "val/loss": val_loss,
136 | "val/perplexity": val_perplexity,
137 | "val/acc": val_acc,
138 | "lr": current_lr,
139 | }
140 |
141 | if itr == iterations:
142 | logs["val/final-ppl"] = val_perplexity
143 | logs["val/final-acc"] = val_acc
144 | logs["val/final-loss"] = val_loss
145 |
146 | wandb.log(logs)
147 |
148 | if extra_args.eval_seq_prefix != 'none' and (itr % (eval_freq * 5) == 0 or itr == iterations):
149 | if text_table is None:
150 | text_table = wandb.Table(columns=["itr", "val-pp", "text"])
151 |
152 | out_str = distributed_backend.get_raw_model(model).generate_from_string(
153 | extra_args.eval_seq_prefix, max_new_tokens=40, temperature=0.9, top_k=None)
154 | text_table.add_data(itr, val_perplexity, out_str)
155 | # why a copy? see github.com/wandb/wandb/issues/2981
156 | wandb.log({f"generated-text-{wandb.run.name}": copy.copy(text_table)})
157 |
158 | model.train()
159 | t0 = time.time()
160 | if distributed_backend.is_master_process():
161 | if extra_args.save_checkpoint_freq is not None and itr % extra_args.save_checkpoint_freq == 0:
162 | print(f"saving checkpoint to {os.path.dirname(ckpt_path)}/ckpt_{itr}.pt")
163 | save_checkpoint(distributed_backend=distributed_backend,
164 | model=model,
165 | opt=opt,
166 | scheduler=scheduler,
167 | itr=itr,
168 | cpu_rng_state=torch.get_rng_state(),
169 | gpu_rng_state=torch.cuda.get_rng_state(),
170 | numpy_rng_state=np.random.get_state(),
171 | py_rng_state=random.getstate(),
172 | train_sampler_state=sampler_state_before_iter,
173 | ckpt_path=os.path.join(os.path.dirname(ckpt_path), f"ckpt_{itr}.pt"))
174 |
175 | if distributed_backend.is_master_process():
176 | print(f"saving checkpoint to {ckpt_path}")
177 | save_checkpoint(distributed_backend=distributed_backend,
178 | model=model,
179 | opt=opt,
180 | scheduler=scheduler,
181 | itr=itr,
182 | ckpt_path=ckpt_path)
183 |
184 | return stats
185 |
--------------------------------------------------------------------------------
/src/optim/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from contextlib import nullcontext, contextmanager, ExitStack
5 |
6 |
7 | def get_batch(dataloader, device="cpu"):
8 | x, y = next(dataloader)
9 | if "cuda" in torch.device(device).type:
10 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
11 | x = x.pin_memory().to(device, non_blocking=True)
12 | y = y.pin_memory().to(device, non_blocking=True)
13 | else:
14 | x = x.to(device)
15 | y = y.to(device)
16 | return x, y
17 |
18 |
19 | @torch.no_grad()
20 | def eval(model, data_val_iter, device='cpu', max_num_batches=24, ctx=nullcontext()):
21 | assert model.training == False
22 |
23 | loss_list_val, acc_list = [], []
24 |
25 | for _ in range(max_num_batches):
26 | x, y = get_batch(data_val_iter, device=device)
27 | with ctx:
28 | outputs = model(x, targets=y, get_logits=True)
29 | val_loss = outputs['loss']
30 | loss_list_val.append(val_loss)
31 | acc_list.append((outputs['logits'].argmax(-1) == y).float().mean())
32 |
33 | val_acc = torch.stack(acc_list).mean().item()
34 | val_loss = torch.stack(loss_list_val).mean().item()
35 | val_perplexity = 2.71828 ** val_loss
36 |
37 | return val_acc, val_loss, val_perplexity
38 |
39 |
40 | def save_checkpoint(distributed_backend, model, opt, scheduler, itr, ckpt_path, **extra_args):
41 |
42 | checkpoint = dict({
43 | 'model': distributed_backend.get_raw_model(model).state_dict(),
44 | 'optimizer': opt.state_dict(),
45 | 'scheduler': scheduler.state_dict(),
46 | 'itr': itr,
47 | }, **extra_args)
48 |
49 | torch.save(checkpoint, ckpt_path)
50 |
--------------------------------------------------------------------------------