├── .gitignore ├── .gitmodules ├── README.md ├── encode.py ├── notebooks ├── model_scalar.ipynb └── superbpe_encoding_efficiency_analysis.ipynb ├── requirements.txt ├── scripts ├── extend_tokenizer.sh └── train_tokenizer.sh ├── train_tokenizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | olmo-mix-1124-subset-p99/ 2 | *.egg-info/ 3 | .DS_Store 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tokenizers-superbpe"] 2 | path = tokenizers_superbpe 3 | url = https://github.com/alisawuffles/tokenizers-superbpe.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SuperBPE: Space Travel for Language Models 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2503.13423-b31b1b.svg)](https://arxiv.org/pdf/2503.13423) [![website](https://img.shields.io/badge/Website-superbpe.github.io-C16C8A)](https://superbpe.github.io/) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Collection-FFD21E)](https://huggingface.co/collections/UW/superbpe-67db2338062faa07c7473ffa) 4 | 5 | This repository contains the tokenizer training code. Code for other aspects of the project (e.g. evals, model scaling, data processing, wandb, train configs) will be added soon! 6 | 7 | ## Setup 8 | First, clone the project with: 9 | ```bash 10 | git clone --recurse-submodules https://github.com/PythonNut/superbpe.git 11 | ``` 12 | We use a custom [fork](https://github.com/alisawuffles/tokenizers-superbpe) of [huggingface/tokenizers](https://github.com/huggingface/tokenizers) which conflicts with the original. 13 | Because of this, we recommend *always installing this project in its own virtual environment.* 14 | 15 | ### Setup virtual environment 16 | 17 | #### Using `conda` 18 | ```bash 19 | conda create -n superbpe python=3.12 rust 20 | conda activate superbpe 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | #### Using `venv` 25 | You will need to [install rust](https://www.rust-lang.org/tools/install) and Python 3.12. 26 | Then, you can do: 27 | ``` 28 | python3.12 -m venv .venv 29 | source .venv/bin/activate 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Data download 34 | Our tokenizer training data is available [here](https://huggingface.co/datasets/UW/olmo-mix-1124-subset-p99). 35 | You can download it using [`huggingface-cli`](https://huggingface.co/docs/huggingface_hub/en/guides/cli) (after logging into your HuggingFace account) using: 36 | ``` 37 | mkdir olmo-mix-1124-subset-p99 38 | cd olmo-mix-1124-subset-p99 39 | huggingface-cli download UW/olmo-mix-1124-subset-p99 --repo-type dataset --local-dir . 40 | ``` 41 | 42 | ## Tokenizer training 43 | Training a SuperBPE tokenizer involves two stages: 44 | 45 | 1. **Stage 1:** Learn subwords by enforcing whitespace pretokenization (equivalent to regular BPE training). 46 | 47 | ```bash 48 | python -m train_tokenizer \ 49 | --output_dir tokenizers/olmo2_bpe \ 50 | --corpus_dir olmo-mix-1124-subset-p99/train \ 51 | --num_bytes $((10**10)) \ 52 | --vocab_size 200000 \ 53 | --do_whitespace_pretokenization true 54 | ``` 55 | 56 | 2. **Stage 2:** Learn superwords by resuming tokenizer training, but this time skip the whitespace pretokenization step. 57 | 58 | ```bash 59 | orig_tokenizer_dir=tokenizers/olmo2_bpe 60 | num_inherit_merges=180000 61 | output_dir=tokenizers/olmo2_superbpe 62 | 63 | mkdir -p $output_dir 64 | 65 | # inherit the first num_inherit_merges from the BPE tokenizer 66 | head -n $num_inherit_merges $orig_tokenizer_dir/merges.txt > $output_dir/merges.txt 67 | 68 | # specifies the same training files used in stage 1 69 | cp $orig_tokenizer_dir/meta.json $output_dir/meta.json 70 | 71 | python -m train_tokenizer \ 72 | --output_dir $output_dir \ 73 | --vocab_size 200000 \ 74 | --do_whitespace_pretokenization false 75 | ``` 76 | 77 | After tokenizer training, you need to update the `decoder` field in the `tokenizer.json` to make sure it looks like this. 78 | 79 | ``` 80 | "decoder": { 81 | "type": "ByteLevel", 82 | "add_prefix_space": true, 83 | "trim_offsets": true, 84 | "use_regex": true 85 | } 86 | ``` 87 | 88 | ## Citation 89 | 90 | If you found this codebase helpful, please cite 91 | 92 | ``` 93 | @article{liu2025superbpe, 94 | title={{SuperBPE}: Space travel for language models}, 95 | author={Alisa Liu and Jonathan Hayase and Valentin Hofmann and Sewoong Oh and Noah A Smith and Yejin Choi}, 96 | journal={arXiv preprint arXiv:2503.13423}, 97 | year={2025}, 98 | url={https://arxiv.org/abs/2503.13423} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /encode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate encoding efficiency over a corpus of text. 3 | """ 4 | 5 | import json 6 | from pathlib import Path 7 | from tokenizers import Tokenizer 8 | import click 9 | import random 10 | import regex as re 11 | 12 | import os 13 | from tqdm import tqdm 14 | from collections import Counter 15 | from utils import ( 16 | get_files_with_num_bytes, 17 | read_json, 18 | ensure_dir, 19 | get_pretokenization_regex, 20 | ) 21 | 22 | RANDOM_SEED = 5 23 | NUM_BYTES = 10**9 24 | 25 | 26 | @click.command() 27 | @click.option( 28 | "--tokenizer_path", 29 | type=str, 30 | help="Path to tokenizer.json", 31 | ) 32 | @click.option( 33 | "--corpus_dir", 34 | type=str, 35 | default=None, 36 | help="Directory of text files to encode.", 37 | ) 38 | @click.option( 39 | "--file_path", 40 | type=str, 41 | default=None, 42 | help="Path to a single text file to encode. One of corpus_dir or file_path must be provided.", 43 | ) 44 | @click.option("--output_dir", type=str, default=None) 45 | @click.option( 46 | "--num_bytes", 47 | type=int, 48 | default=NUM_BYTES, 49 | help="Size of text (in bytes) to encode. If -1, will encode all files in corpus_dir.", 50 | ) 51 | @click.option( 52 | "--vocab_size", 53 | type=int, 54 | default=None, 55 | ) 56 | @click.option( 57 | "--dropout", type=float, help="Dropout rate for the tokenizer.", default=None 58 | ) 59 | @click.option( 60 | "--save_token_stats", 61 | is_flag=True, 62 | help="Save token counts for each file.", 63 | default=False, 64 | ) 65 | @click.option( 66 | "--save_bytes_per_token", 67 | is_flag=True, 68 | help="Save bytes per token stats.", 69 | default=False, 70 | ) 71 | def main( 72 | tokenizer_path: str, 73 | corpus_dir: str, 74 | file_path: str, 75 | output_dir: str, 76 | num_bytes: int, 77 | vocab_size: int, 78 | dropout: float, 79 | save_token_stats: bool, 80 | save_bytes_per_token: bool, 81 | ): 82 | random.seed(RANDOM_SEED) 83 | if corpus_dir: 84 | corpus_dir = Path(corpus_dir) 85 | tokenizer = Tokenizer.from_file(tokenizer_path) 86 | tokenizer_name = os.path.basename(os.path.dirname(tokenizer_path)) 87 | tokenizer_json = read_json(tokenizer_path) 88 | 89 | # if vocab_size is given, construct tokenizer with the desired vocab_size 90 | if vocab_size and vocab_size <= tokenizer.get_vocab_size(): 91 | print(f"We will only use the top {vocab_size} merges for encoding.", flush=True) 92 | merges = tokenizer_json["model"]["merges"] 93 | tokenizer_json["model"]["merges"] = merges[:vocab_size] 94 | tokenizer_json["model"]["ignore_merges"] = False 95 | 96 | # create new tokenizer file with truncated vocabulary (hacky) 97 | tokenizer_path = tokenizer_path.replace( 98 | "tokenizer.json", f"tokenizer_{vocab_size}.json" 99 | ) 100 | with open(tokenizer_path, "w") as fout: 101 | json.dump(tokenizer_json, fout, indent=4, ensure_ascii=False) 102 | 103 | tokenizer = Tokenizer.from_file(tokenizer_path) 104 | count_pretokens = False 105 | elif vocab_size: 106 | raise ValueError( 107 | f"Vocab size ({vocab_size}) > tokenizer vocab size ({tokenizer.get_vocab_size()})." 108 | ) 109 | else: 110 | count_pretokens = True 111 | 112 | print(f"Using tokenizer from {tokenizer_path}", flush=True) 113 | 114 | if dropout: 115 | print(f"Setting dropout to {dropout}", flush=True) 116 | tokenizer.model.dropout = dropout 117 | 118 | pretok_regex = get_pretokenization_regex(tokenizer_json) 119 | 120 | def encode_file(file, count_pretokens=False): 121 | """ 122 | Encode file and return the number of tokens. 123 | """ 124 | with open(file, "r") as fin: 125 | text = fin.read() 126 | 127 | # Split into chunks so we don't OOM 128 | # This is ok bc tokenizer training splits on newline 129 | tokens = [] 130 | num_pretokens = 0 131 | pps = text.split("\n\n") 132 | chunk_size = max(len(pps) // 20, 100) 133 | for i in tqdm(range(0, len(pps), chunk_size), desc=os.path.basename(file)): 134 | chunk = "\n\n".join(pps[i : i + chunk_size]) + "\n\n" 135 | encoded = tokenizer.encode(chunk) 136 | tokens.extend(encoded.ids) 137 | if count_pretokens: 138 | num_pretokens += len( 139 | [match for match in re.finditer(pretok_regex, text)] 140 | ) 141 | # Note to self: num_pretokens will not be completely accurate for superword tokenizers because 142 | # the tokenizers training library splits on newline (separately from pretokenization). However, 143 | # the upper bound calculation is mainly for pretok tokenizers anyway, so we won't worry too 144 | # much about this case. 145 | 146 | return tokens, num_pretokens 147 | 148 | # Collect list of files to be encoded 149 | if corpus_dir: 150 | if num_bytes == -1: 151 | num_bytes = None 152 | file_list, byte_count = get_files_with_num_bytes( 153 | corpus_dir, num_bytes, loop_around=False 154 | ) 155 | elif file_path: 156 | file_list = [file_path] 157 | byte_count = os.path.getsize(file_path) 158 | else: 159 | raise ValueError("Either corpus_dir or file_path must be provided.") 160 | 161 | # Count tokens in files 162 | token_count = 0 163 | pretoken_count = 0 164 | for file in file_list: 165 | tokens, num_pretokens = encode_file(file, count_pretokens=count_pretokens) 166 | token_count += len(tokens) 167 | pretoken_count += num_pretokens 168 | if save_token_stats: 169 | filename = os.path.basename(file).split(".txt")[0] 170 | ensure_dir(f"encoded/{tokenizer_name}") 171 | with open(f"encoded/{tokenizer_name}/{filename}.json", "w") as fout: 172 | token_counter = Counter(tokens) 173 | json.dump(token_counter, fout, indent=5) 174 | 175 | # Save encoding efficiency stats to output_dir 176 | if save_bytes_per_token: 177 | assert output_dir is not None 178 | output_dir = Path(output_dir) 179 | ensure_dir(output_dir) 180 | 181 | if vocab_size: 182 | out_filename = f"token_byte_counts_{vocab_size}.json" 183 | else: 184 | out_filename = "token_byte_counts.json" 185 | 186 | with open(output_dir / out_filename, "w") as fout: 187 | d = { 188 | "test_files": file_list, 189 | "token_count": token_count, 190 | "byte_count": byte_count, 191 | } 192 | json.dump(d, fout, indent=5) 193 | 194 | if count_pretokens: 195 | with open(output_dir / "pretoken_byte_counts.json", "w") as fout: 196 | d = { 197 | "test_files": file_list, 198 | "pretoken_count": pretoken_count, 199 | "byte_count": byte_count, 200 | } 201 | json.dump(d, fout, indent=5) 202 | 203 | print(f"Saved to {output_dir / out_filename}", flush=True) 204 | 205 | if vocab_size and vocab_size <= tokenizer.get_vocab_size(): 206 | os.remove(tokenizer_path) 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /notebooks/model_scalar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "2714fa7b-bfc4-4bcd-979a-bf7c15707a22", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import functools\n", 11 | "import itertools as it\n", 12 | "import os\n", 13 | "from copy import deepcopy\n", 14 | "\n", 15 | "import numpy as np\n", 16 | "import torch\n", 17 | "import warnings\n", 18 | "import math\n", 19 | "\n", 20 | "import olmo\n", 21 | "\n", 22 | "os.environ[\"SCRATCH_DIR\"] = \"no_exist\"" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 4, 28 | "id": "0a03f529-6fcc-4b56-a975-ee8ea3cb4008", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def scale_config(\n", 33 | " config,\n", 34 | " flops_ratio,\n", 35 | " axis_divisor=128,\n", 36 | " ceil=False,\n", 37 | " mode=\"inference-flops\",\n", 38 | " other_updates={},\n", 39 | "):\n", 40 | " assert flops_ratio\n", 41 | " config = deepcopy(config)\n", 42 | " config.init_device = \"meta\"\n", 43 | " head_dim = config.d_model // config.n_heads\n", 44 | " if not config.mlp_hidden_size:\n", 45 | " config.mlp_hidden_size = config.d_model * config.mlp_ratio\n", 46 | "\n", 47 | " # estimate the flops of a config\n", 48 | " def f(C):\n", 49 | " C.init_device = \"meta\"\n", 50 | " model = olmo.model.OLMo(C)\n", 51 | " if mode == \"inference-flops\":\n", 52 | " return model.num_fwd_flops\n", 53 | " elif mode == \"params\":\n", 54 | " return model.num_params()\n", 55 | " elif mode == \"params-non-embedding\":\n", 56 | " return model.num_params(include_embedding=False)\n", 57 | " elif mode == \"train-flops\":\n", 58 | " return model.num_fwd_flops + model.num_bck_flops\n", 59 | " else:\n", 60 | " raise NotImplementedError(f\"Unknown mode {mode}\")\n", 61 | "\n", 62 | " def make_config(d_model, n_layers, mlp_hidden_size, do_updates=True):\n", 63 | " C = deepcopy(config)\n", 64 | " C.d_model = d_model\n", 65 | " C.n_heads = C.d_model // head_dim\n", 66 | " C.mlp_hidden_size = mlp_hidden_size\n", 67 | " C.n_layers = n_layers\n", 68 | "\n", 69 | " if do_updates:\n", 70 | " for key, val in other_updates.items():\n", 71 | " setattr(C, key, val)\n", 72 | " return C\n", 73 | "\n", 74 | " # reparameterize so only valid configs are reachable\n", 75 | " def param(d, n, h):\n", 76 | " return (\n", 77 | " config.d_model + head_dim * d,\n", 78 | " config.n_layers + n,\n", 79 | " config.mlp_hidden_size + axis_divisor * h,\n", 80 | " )\n", 81 | "\n", 82 | " def r(d, n, h):\n", 83 | " ratios = np.array(\n", 84 | " [\n", 85 | " 1 + head_dim * d / config.d_model,\n", 86 | " 1 + n / config.n_layers,\n", 87 | " 1 + axis_divisor * h / config.mlp_hidden_size,\n", 88 | " ]\n", 89 | " )\n", 90 | " with warnings.catch_warnings():\n", 91 | " warnings.simplefilter(\"ignore\")\n", 92 | " return np.log(ratios)\n", 93 | "\n", 94 | " def g(d, n, h, do_updates=True):\n", 95 | " return f(make_config(*param(d, n, h), do_updates=do_updates))\n", 96 | "\n", 97 | " base_flops = g(0, 0, 0, do_updates=False)\n", 98 | " target_flops = base_flops * flops_ratio\n", 99 | "\n", 100 | " # fit a polynomial to g\n", 101 | " Q = np.array(list(it.product(*[[0, 1, 2]] * 3)))\n", 102 | " one = np.ones_like(Q[:, 0])\n", 103 | " QQ = np.vstack([Q[:, 0] ** a * Q[:, 1] ** b * Q[:, 2] ** c for a, b, c in Q]).T\n", 104 | " Qg = np.array([g(*row) / base_flops for row in Q])\n", 105 | " coeff = np.linalg.lstsq(QQ, Qg, rcond=None)[0]\n", 106 | "\n", 107 | " def g2(d, n, h):\n", 108 | " return (\n", 109 | " np.array([d**a * n**b * h**c for a, b, c in Q]).dot(coeff)\n", 110 | " * base_flops\n", 111 | " )\n", 112 | "\n", 113 | " # double check the predictions are matching\n", 114 | " assert round(g2(3, 4, 5)) == g(3, 4, 5)\n", 115 | " assert round(g2(5, 4, 6)) == g(5, 4, 6)\n", 116 | " assert round(g2(2, 7, 3)) == g(2, 7, 3)\n", 117 | "\n", 118 | " # given d and n, solve for h\n", 119 | " def solve_h(d, n):\n", 120 | " f0, f1 = g2(d, n, 0), g2(d, n, 1)\n", 121 | " slope = f1 - f0\n", 122 | " rounder = np.ceil if ceil else np.floor\n", 123 | " return int(rounder((target_flops - f0) / slope))\n", 124 | "\n", 125 | " # enumerate all viable d and n\n", 126 | " best, best_l = None, float(\"inf\")\n", 127 | " \n", 128 | " d2 = 0\n", 129 | " while True:\n", 130 | " if g2(d2, 0, 0) > target_flops:\n", 131 | " break\n", 132 | "\n", 133 | " n2 = 0\n", 134 | " while True:\n", 135 | " if g2(d2, n2, 0) > target_flops:\n", 136 | " break\n", 137 | "\n", 138 | " h2 = solve_h(d2, n2)\n", 139 | " r2 = r(d2, n2, h2)\n", 140 | " l2 = r2.std()\n", 141 | " if l2 < best_l:\n", 142 | " best_l, best = l2, (d2, n2, h2)\n", 143 | "\n", 144 | " n2 += 1\n", 145 | " d2 += 1\n", 146 | " \n", 147 | " d2 = 0\n", 148 | " while True:\n", 149 | " if g2(d2 - 1, 0, 0) < target_flops:\n", 150 | " break\n", 151 | "\n", 152 | " n2 = 0\n", 153 | " while True:\n", 154 | " if g2(d2, n2 - 1, 0) < target_flops:\n", 155 | " break\n", 156 | " \n", 157 | " h2 = solve_h(d2, n2)\n", 158 | " r2 = r(d2, n2, h2)\n", 159 | " if not np.isinf(r2).any():\n", 160 | " l2 = r2.std()\n", 161 | " if l2 < best_l:\n", 162 | " best_l, best = l2, (d2, n2, h2)\n", 163 | "\n", 164 | " n2 -= 1\n", 165 | " d2 -= 1\n", 166 | "\n", 167 | " opt_d, opt_n, opt_hsize = param(*best)\n", 168 | " ratios = tuple(r(*best).tolist())\n", 169 | " rel_flops = (g(*best) - target_flops) / target_flops\n", 170 | " return (\n", 171 | " (opt_d, opt_d // head_dim, opt_n, opt_hsize),\n", 172 | " ratios,\n", 173 | " rel_flops,\n", 174 | " make_config(*param(*best)),\n", 175 | " )\n", 176 | "\n", 177 | "# scale_config(\n", 178 | "# BASE_CONFIG.model, 0.7, mode=\"train-flops\", other_updates=dict(max_sequence_length=1376)\n", 179 | "# )" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 5, 185 | "id": "a7786f35-5c11-453a-8275-d95c99dd599d", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "def tokens_per_step(config):\n", 190 | " return config.global_train_batch_size * config.model.max_sequence_length\n", 191 | "\n", 192 | "def test_flops_per_step(config):\n", 193 | " model = olmo.model.OLMo(config.model)\n", 194 | " return tokens_per_step(config) * model.num_fwd_flops\n", 195 | " \n", 196 | "def train_flops_per_step(config):\n", 197 | " model = olmo.model.OLMo(config.model)\n", 198 | " return tokens_per_step(config) * (model.num_fwd_flops + model.num_bck_flops)\n", 199 | " \n", 200 | "def bytes_per_step(config, encoding_efficiency):\n", 201 | " result = config.global_train_batch_size * (\n", 202 | " config.model.max_sequence_length\n", 203 | " * encoding_efficiency\n", 204 | " )\n", 205 | " return float(result)\n", 206 | " \n", 207 | "def num_fwd_flops(model):\n", 208 | " # embedding table is just a lookup in the forward pass\n", 209 | " n_params = model.num_params(include_embedding=False)\n", 210 | " # the number of parameters is approximately the number of multiply-accumulates (MAC) in the network\n", 211 | " # each MAC has 2 FLOPs - we multiply by 2 ie 2 * n_param\n", 212 | " # this gets us FLOPs / token\n", 213 | " params_flops_per_token = 2 * n_params\n", 214 | " # there are 2 FLOPS per mac; there is A=Q*K^T and out=A*V ops (ie mult by 2)\n", 215 | " attn_flops_per_token = (\n", 216 | " model.config.n_layers * 2 * 2 * (model.config.d_model * model.config.max_sequence_length)\n", 217 | " )\n", 218 | " return params_flops_per_token, attn_flops_per_token\n", 219 | "\n", 220 | "def num_bck_flops(model):\n", 221 | " n_params = model.num_params()\n", 222 | " params_flops_per_token = 4 * n_params\n", 223 | " attn_flops_per_token = model.config.n_layers * 8 * (model.config.d_model * model.config.max_sequence_length)\n", 224 | " return params_flops_per_token, attn_flops_per_token\n", 225 | " \n", 226 | "def model_num_params(config):\n", 227 | " model = olmo.model.OLMo(config.model)\n", 228 | " return model.num_params()\n", 229 | "\n", 230 | "def max_steps(config):\n", 231 | " if isinstance(config.max_duration, int):\n", 232 | " return config.max_duration\n", 233 | " elif isinstance(config.max_duration, str):\n", 234 | " if config.max_duration.endswith(\"T\"):\n", 235 | " # convert to float *first* to handle scientific notation\n", 236 | " max_tokens = int(float(config.max_duration[:-1].strip()))\n", 237 | " return math.ceil(max_tokens / (config.global_train_batch_size * config.model.max_sequence_length))\n", 238 | " elif config.max_duration.endswith(\"ep\"):\n", 239 | " raise NotImplementedError\n", 240 | " else:\n", 241 | " # convert to float *first* to handle scientific notation\n", 242 | " return int(float(config.max_duration))\n", 243 | " else:\n", 244 | " raise TypeError(f\"expected int or str for 'max_duration', found {type(config.max_duration)}\")\n" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 38, 250 | "id": "3c1dd33e-a9fc-4bd8-bdf6-e5a6c79d4203", 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "Model config:\n", 258 | "TARGET_CONFIG.model.d_model=1280\n", 259 | "TARGET_CONFIG.model.n_heads=20\n", 260 | "TARGET_CONFIG.model.n_layers=19\n", 261 | "TARGET_CONFIG.model.mlp_hidden_size=10240\n", 262 | "TARGET_CONFIG.model.weight_tying=False\n", 263 | "TARGET_CONFIG.model.max_sequence_length=2048\n", 264 | "TARGET_CONFIG.model.vocab_size=200005\n", 265 | "TARGET_CONFIG.model.embedding_size=200064\n", 266 | "TARGET_CONFIG.max_duration=10572\n", 267 | "\n", 268 | "Model ratios: (0.22314355131420974, 0.17185025692665923, 0.22314355131420974)\n", 269 | "Tokens: 22,171,090,944\n", 270 | "Params: 1,010,336,000\n", 271 | "T/P ratio: 21.9443 (0.99747x Chinchilla)\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "# Define the \"base\" config we are going to scale\n", 277 | "BASE_CONFIG = olmo.config.TrainConfig.load(\n", 278 | " \"pretokenization-scripts/nairr/configs/OM2-300M-generic200k.yaml\"\n", 279 | " # \"pretokenization-scripts/nairr/configs/OM2-1B-pt200k.yaml\"\n", 280 | " # \"../OLMo-alisa/configs/alisa/OLMo2-7B-generic200k.yaml\"\n", 281 | " # \"configs/official-1124/OLMo2-7B-stage1.yaml\"\n", 282 | ")\n", 283 | "# BASE_CONFIG.max_duration = 76543\n", 284 | "BASE_TRAIN_FLOPS = max_steps(BASE_CONFIG) * train_flops_per_step(BASE_CONFIG)\n", 285 | "BASE_ENCODING_EFFICIENCY = 4.458679110036542\n", 286 | "\n", 287 | "# Specify the desired scaling parameters\n", 288 | "TOKENIZER_VOCAB_SIZE = BASE_CONFIG.model.vocab_size\n", 289 | "TOKENIZER_ENCODING_EFFICIENCY = 6.0887434010717465\n", 290 | "TOKENIZER_ENCODING_EFFICIENCY = BASE_ENCODING_EFFICIENCY\n", 291 | "# TOKENIZER_ENCODING_EFFICIENCY = 6.6421079664426035\n", 292 | "\n", 293 | "TARGET_TRAIN_FLOPS = BASE_TRAIN_FLOPS \n", 294 | "# TARGET_MODEL_SCALE = TOKENIZER_ENCODING_EFFICIENCY/BASE_ENCODING_EFFICIENCY\n", 295 | "TARGET_MODEL_SCALE = 1.58\n", 296 | "\n", 297 | "\n", 298 | "# Do the calculations\n", 299 | "TARGET_CONFIG = deepcopy(BASE_CONFIG)\n", 300 | "TARGET_CONFIG.model.vocab_size = TOKENIZER_VOCAB_SIZE\n", 301 | "TARGET_CONFIG.model.embedding_size = (\n", 302 | " TOKENIZER_VOCAB_SIZE + (-TOKENIZER_VOCAB_SIZE) % 128\n", 303 | ")\n", 304 | "TARGET_CONFIG.model.max_sequence_length = int(\n", 305 | " np.ceil(\n", 306 | " BASE_CONFIG.model.max_sequence_length\n", 307 | " * BASE_ENCODING_EFFICIENCY\n", 308 | " / TOKENIZER_ENCODING_EFFICIENCY\n", 309 | " )\n", 310 | ")\n", 311 | "\n", 312 | "if TARGET_MODEL_SCALE:\n", 313 | " params, scales, error, new_model_config = scale_config(\n", 314 | " BASE_CONFIG.model,\n", 315 | " TARGET_MODEL_SCALE,\n", 316 | " mode=\"inference-flops\",\n", 317 | " other_updates=dict(max_sequence_length=TARGET_CONFIG.model.max_sequence_length),\n", 318 | " )\n", 319 | " TARGET_CONFIG.model = new_model_config\n", 320 | "\n", 321 | "TARGET_CONFIG.max_duration = int(\n", 322 | " np.floor(TARGET_TRAIN_FLOPS / train_flops_per_step(TARGET_CONFIG))\n", 323 | ")\n", 324 | "TARGET_NUM_PARAMS = model_num_params(TARGET_CONFIG)\n", 325 | "TARGET_TOKENS = TARGET_CONFIG.max_duration * tokens_per_step(TARGET_CONFIG)\n", 326 | "TARGET_TOKEN_PARAM_RATIO = TARGET_TOKENS / TARGET_NUM_PARAMS\n", 327 | "TARGET_TOTAL_BYTES = TARGET_CONFIG.max_duration * bytes_per_step(TARGET_CONFIG, TOKENIZER_ENCODING_EFFICIENCY)\n", 328 | "# hardcoded for OLMo Mix 2\n", 329 | "if TARGET_TOTAL_BYTES > 1748032475185 * 0.99:\n", 330 | " print(\"Warning: subset does not have enough training bytes!\")\n", 331 | "print(\"Model config:\")\n", 332 | "print(f\"{TARGET_CONFIG.model.d_model=}\")\n", 333 | "print(f\"{TARGET_CONFIG.model.n_heads=}\")\n", 334 | "print(f\"{TARGET_CONFIG.model.n_layers=}\")\n", 335 | "if TARGET_CONFIG.model.mlp_hidden_size:\n", 336 | " print(f\"{TARGET_CONFIG.model.mlp_hidden_size=}\")\n", 337 | "else:\n", 338 | " print(f\"{TARGET_CONFIG.model.mlp_ratio=}\")\n", 339 | "\n", 340 | "if not TARGET_MODEL_SCALE:\n", 341 | " print(\"[The above should be unchanged from the baseline.]\")\n", 342 | "print(f\"{TARGET_CONFIG.model.weight_tying=}\")\n", 343 | "print(f\"{TARGET_CONFIG.model.max_sequence_length=}\")\n", 344 | "print(f\"{TARGET_CONFIG.model.vocab_size=}\")\n", 345 | "print(f\"{TARGET_CONFIG.model.embedding_size=}\")\n", 346 | "print(f\"{TARGET_CONFIG.max_duration=}\")\n", 347 | "print()\n", 348 | "if TARGET_MODEL_SCALE:\n", 349 | " print(f\"Model ratios: {scales}\")\n", 350 | "print(f\"Tokens: {TARGET_TOKENS:,}\")\n", 351 | "print(f\"Params: {TARGET_NUM_PARAMS:,}\")\n", 352 | "print(\n", 353 | " f\"T/P ratio: {TARGET_TOKEN_PARAM_RATIO:.06} ({TARGET_TOKEN_PARAM_RATIO/22:.05}x Chinchilla)\"\n", 354 | ")" 355 | ] 356 | } 357 | ], 358 | "metadata": { 359 | "kernelspec": { 360 | "display_name": "Python 3 (ipykernel)", 361 | "language": "python", 362 | "name": "python3" 363 | }, 364 | "language_info": { 365 | "codemirror_mode": { 366 | "name": "ipython", 367 | "version": 3 368 | }, 369 | "file_extension": ".py", 370 | "mimetype": "text/x-python", 371 | "name": "python", 372 | "nbconvert_exporter": "python", 373 | "pygments_lexer": "ipython3", 374 | "version": "3.12.7" 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 5 379 | } 380 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt 2 | 3 | # Core dependencies 4 | click 5 | filelock 6 | huggingface-hub 7 | pysimdjson; python_version >= "3.9" and python_version < "3.13" 8 | ai2-olmo 9 | 10 | # Editable install for tokenizer 11 | -e tokenizers_superbpe/bindings/python/ 12 | -------------------------------------------------------------------------------- /scripts/extend_tokenizer.sh: -------------------------------------------------------------------------------- 1 | dataset_name=olmo2_p99_truncate 2 | orig_tokenizer_dir=tokenizer_json/${dataset_name}_pretok_10G_200K 3 | num_inherit_merges=180000 4 | vocab_size=200000 5 | 6 | # create a str called num_inherit_merges_str, which turns 100000 into 100K 7 | if [ $num_inherit_merges -ge 1000 ]; then 8 | num_inherit_merges_str=$(($num_inherit_merges / 1000))K 9 | else 10 | num_inherit_merges_str=${num_inherit_merges} 11 | fi 12 | 13 | # convert vocab_size to something like 100K, depending on the value 14 | if [ $vocab_size -ge 1000 ]; then 15 | vocab_size_str=$(($vocab_size / 1000))K 16 | else 17 | vocab_size_str=${vocab_size} 18 | fi 19 | 20 | output_dir=tokenizer_json/${dataset_name}_10G_${num_inherit_merges_str}_extend_${vocab_size_str}_mw4_colon 21 | echo "output_dir: $output_dir" 22 | 23 | mkdir -p $output_dir 24 | head -n $num_inherit_merges $orig_tokenizer_dir/merges.txt > $output_dir/merges.txt 25 | cp $orig_tokenizer_dir/meta.json $output_dir/meta.json 26 | 27 | python -m train_tokenizer \ 28 | --output_dir $output_dir \ 29 | --vocab_size $vocab_size \ 30 | --do_whitespace_pretokenization false 31 | -------------------------------------------------------------------------------- /scripts/train_tokenizer.sh: -------------------------------------------------------------------------------- 1 | dataset_name=olmo2_p99_truncate 2 | do_whitespace_pretokenization=true 3 | vocab_size=200000 4 | num_bytes=$((10**10)) 5 | corpus_dir=/gscratch/xlab/alisaliu/pretokenization/data/${dataset_name}/train # a directory containing txt files for tokenizer training 6 | 7 | # convert num_bytes to something like 10G or 100M, depending on the value 8 | if [ $num_bytes -ge $((10**9)) ]; then 9 | num_bytes_str=$(($num_bytes / 10**9))G 10 | elif [ $num_bytes -ge $((10**6)) ]; then 11 | num_bytes_str=$(($num_bytes / 10**6))M 12 | elif [ $num_bytes -ge $((10**3)) ]; then 13 | num_bytes_str=$(($num_bytes / 10**3))K 14 | else 15 | num_bytes_str=${num_bytes} 16 | fi 17 | 18 | # convert vocab_size to something like 100K, depending on the value 19 | if [ $vocab_size -ge 1000 ]; then 20 | vocab_size_str=$(($vocab_size / 1000))K 21 | else 22 | vocab_size_str=${vocab_size} 23 | fi 24 | 25 | # if do_whitespace_pretokenization is true, set pretok_str to "pretok", else "nopretok" 26 | if [ $do_whitespace_pretokenization == true ]; then 27 | pretok_str=pretok 28 | else 29 | pretok_str=nopretok 30 | fi 31 | 32 | output_dir=tokenizer_json/${dataset_name}_${pretok_str}_${num_bytes_str}_${vocab_size_str} 33 | echo "output_dir: $output_dir" 34 | 35 | python -m train_tokenizer \ 36 | --output_dir $output_dir \ 37 | --corpus_dir $corpus_dir \ 38 | --num_bytes $num_bytes \ 39 | --vocab_size $vocab_size \ 40 | --do_whitespace_pretokenization $do_whitespace_pretokenization 41 | -------------------------------------------------------------------------------- /train_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train tokenizer on a single data category. 3 | """ 4 | 5 | import os 6 | from pathlib import Path 7 | import time 8 | import json 9 | import re 10 | import random 11 | import click 12 | from utils import ( 13 | ensure_dir, 14 | get_files_with_num_bytes, 15 | get_truncated_file, 16 | train_or_extend_tokenizer, 17 | ) 18 | 19 | random.seed(0) 20 | 21 | 22 | @click.command() 23 | @click.option( 24 | "--output_dir", 25 | type=str, 26 | help="Where to save the trained tokenizer.", 27 | ) 28 | @click.option( 29 | "--num_bytes", 30 | type=int, 31 | default=None, 32 | help="The maximum number of bytes to use for tokenizer training.", 33 | ) 34 | @click.option( 35 | "--corpus_dir", 36 | type=str, 37 | default=None, 38 | help="Directory containing text files to use for training the tokenizer.", 39 | ) 40 | @click.option( 41 | "--vocab_size", 42 | type=int, 43 | default=100000, 44 | help="The number of tokens in the vocabulary.", 45 | ) 46 | @click.option( 47 | "--do_whitespace_pretokenization", 48 | type=bool, 49 | default=True, 50 | help="Whether to do whitespace pretokenization.", 51 | ) 52 | def main( 53 | output_dir: str, 54 | num_bytes: int, 55 | corpus_dir: str, 56 | vocab_size: int, 57 | do_whitespace_pretokenization: bool, 58 | ): 59 | output_dir = Path(output_dir) 60 | ensure_dir(output_dir) 61 | print(f"We are training a tokenizer for {output_dir}", flush=True) 62 | 63 | # We look for merges.txt in the current dir to determine whether we are extending 64 | # the tokenizer or training from scratch, so we need to cd into the output directory. 65 | os.chdir(output_dir) 66 | 67 | if os.path.exists("meta.json"): 68 | print( 69 | "Output directory contains meta.json, so we will use the files from there." 70 | ) 71 | meta = json.load(open("meta.json")) 72 | train_files, actual_num_bytes = meta["train_files"], meta["total_bytes"] 73 | for file in train_files: 74 | if not os.path.exists(file): 75 | assert "truncated" in file, f"{file} not found" 76 | wanted_filesize = int(re.search(r"_truncated_(\d+)", file).group(1)) 77 | file = re.sub(r"_truncated_\d+", "", file) 78 | get_truncated_file(file, wanted_filesize) 79 | else: 80 | train_files, actual_num_bytes = get_files_with_num_bytes(corpus_dir, num_bytes) 81 | 82 | # Write metadata 83 | with open("meta.json", "w") as fo: 84 | meta = {} 85 | meta["total_bytes"] = actual_num_bytes 86 | meta["train_files"] = train_files 87 | if os.path.exists("merges.txt"): 88 | os.system("cp merges.txt initial_merges.txt") 89 | meta["num_initial_merges"] = ( 90 | sum(1 for line in open("initial_merges.txt")) - 1 91 | ) 92 | json.dump(meta, fo, indent=5) 93 | 94 | # Train tokenizer 95 | start_time = time.time() 96 | 97 | print("Training with HF tokenizers...") 98 | tokenizer = train_or_extend_tokenizer( 99 | train_files, 100 | vocab_size=vocab_size, 101 | do_whitespace_pretokenization=do_whitespace_pretokenization, 102 | ) 103 | tokenizer.model.save(".") # saves merges.txt and vocab.json 104 | tokenizer.save("tokenizer.json") 105 | 106 | print(f"Train time: {time.time() - start_time}", flush=True) 107 | print("Tokenizer info saved to " + str(output_dir), flush=True) 108 | 109 | # Delete files that were constructed just for this 110 | # for f in train_files: 111 | # if "truncated" in f: 112 | # os.remove(f) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import random 5 | from pathlib import Path 6 | from filelock import FileLock 7 | 8 | import simdjson as json 9 | from tqdm import tqdm 10 | from tokenizers.models import BPE, Unigram 11 | 12 | from tokenizers import Tokenizer, pre_tokenizers, Regex 13 | from tokenizers.pre_tokenizers import ByteLevel, Split, Digits 14 | from tokenizers.trainers import BpeTrainer, UnigramTrainer 15 | 16 | 17 | def ensure_dir(d): 18 | if not os.path.exists(d): 19 | os.makedirs(d, exist_ok=True) 20 | 21 | 22 | def read_json(file): 23 | return json.load(open(file)) 24 | 25 | 26 | def read_merges_txt(path_to_txt): 27 | with open(path_to_txt) as fin: 28 | merges = fin.readlines()[1:] 29 | merges = [m.rsplit("\n", 1)[0] for m in merges] 30 | return merges 31 | 32 | 33 | def get_pretokenization_regex(tokenizer_json): 34 | if isinstance(tokenizer_json, str): 35 | tokenizer_json = read_json(tokenizer_json) 36 | 37 | split_pretokenizer = [ 38 | p 39 | for p in tokenizer_json["pre_tokenizer"]["pretokenizers"] 40 | if p["type"] == "Split" 41 | ][0] 42 | pretok_regex = split_pretokenizer["pattern"]["Regex"] 43 | return pretok_regex 44 | 45 | 46 | def train_or_extend_tokenizer( 47 | text_files: str, 48 | vocab_size: int = 100000, 49 | do_whitespace_pretokenization: bool = True, 50 | regex_string: str = None, 51 | tokenizer_type: str = "bpe", 52 | ): 53 | if tokenizer_type == "bpe": 54 | tokenizer = Tokenizer(BPE()) 55 | trainer = BpeTrainer(show_progress=True, vocab_size=vocab_size) 56 | elif tokenizer_type == "unigram": 57 | tokenizer = Tokenizer(Unigram()) 58 | trainer = UnigramTrainer(show_progress=True, vocab_size=vocab_size) 59 | 60 | if not regex_string: 61 | regex_string = "(?=(\d{3})+(?!\d))" # pretokenize digits in groups of 3 from right to left (from Luca) 62 | 63 | if do_whitespace_pretokenization: 64 | if regex_string: 65 | regex_string += "|" 66 | regex_string += ( 67 | " ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" # GPT-2 pretokenization 68 | ) 69 | 70 | pretokenizers = [ 71 | Digits(individual_digits=False), 72 | Split( 73 | pattern=Regex(regex_string), 74 | behavior="isolated", 75 | invert=False, 76 | ), 77 | ByteLevel( 78 | add_prefix_space=False, 79 | trim_offsets=True, 80 | use_regex=False, 81 | ), 82 | ] 83 | tokenizer.pre_tokenizer = pre_tokenizers.Sequence(pretokenizers) 84 | tokenizer.train(text_files, trainer) 85 | 86 | return tokenizer 87 | 88 | 89 | def bytes_to_unicode(): 90 | """ 91 | MJ: STOLEN DIRECTLY FROM https://github.com/openai/gpt-2/blob/master/src/encoder.py#L9 92 | -------------- 93 | Returns list of utf-8 byte and a corresponding list of unicode strings. 94 | The reversible bpe codes work on unicode strings. 95 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 96 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 97 | This is a signficant percentage of your normal, say, 32K bpe vocab. 98 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 99 | And avoids mapping to whitespace/control characters the bpe code barfs on. 100 | """ 101 | bs = ( 102 | list(range(ord("!"), ord("~") + 1)) 103 | + list(range(ord("¡"), ord("¬") + 1)) 104 | + list(range(ord("®"), ord("ÿ") + 1)) 105 | ) 106 | cs = bs[:] 107 | n = 0 108 | for b in range(2**8): 109 | if b not in bs: 110 | bs.append(b) 111 | cs.append(2**8 + n) 112 | n += 1 113 | cs = [chr(n) for n in cs] 114 | return dict(zip(bs, cs)) 115 | 116 | 117 | def is_valid_unicode(data): 118 | try: 119 | data.decode("utf-8") 120 | return True 121 | except UnicodeDecodeError: 122 | return False 123 | 124 | 125 | def get_truncated_file(filepath, wanted_filesize): 126 | """ 127 | Create a copy of the given file and truncates it to the desired size. 128 | """ 129 | if os.path.getsize(filepath) < wanted_filesize: 130 | raise ValueError("File is already smaller than desired filesize") 131 | 132 | filename, ext = os.path.splitext(filepath) 133 | truncated_filepath = Path(os.path.dirname(filepath)) / ( 134 | f"{filename}_truncated_{wanted_filesize}{ext}" 135 | ) 136 | 137 | # we want to make sure that multiple scripts don't create a truncated file at the same time 138 | lock = FileLock(str(truncated_filepath) + ".lock") 139 | with lock: 140 | if not os.path.exists(truncated_filepath): 141 | print(f"Truncating {filepath} to {wanted_filesize} bytes") 142 | 143 | os.system(f"cp {filepath} {truncated_filepath}") 144 | 145 | # adjust wanted_filesize to the next valid unicode character 146 | with open(truncated_filepath, "rb") as f: 147 | f.seek(wanted_filesize) 148 | data = f.read(1) 149 | while data and not is_valid_unicode(data): 150 | data = f.read(1) 151 | wanted_filesize += 1 152 | 153 | with open(truncated_filepath, "r+", encoding="utf-8") as fin: 154 | fin.truncate(wanted_filesize) 155 | else: 156 | print(f"Truncated file already exists: {truncated_filepath}") 157 | 158 | return str(truncated_filepath), wanted_filesize 159 | 160 | 161 | def get_files_with_num_bytes(data_dir, num_bytes=None, loop_around=True): 162 | """Return a list of files inside data_dir that contain num_bytes worth of data.""" 163 | file_list, byte_count = [], 0 164 | data_dir = Path(data_dir) 165 | 166 | all_files = [ 167 | f 168 | for f in os.listdir(data_dir) 169 | if f.endswith(".txt") and ("truncated" not in f) and ("split" not in f) 170 | ] 171 | 172 | if not num_bytes: # if num_bytes is not specified, use all text data 173 | file_list = [str(data_dir / f) for f in all_files] 174 | byte_count = sum(os.path.getsize(data_dir / f) for f in all_files) 175 | print(f"Using all {len(file_list)} files in {data_dir}") 176 | else: 177 | random.shuffle(all_files) 178 | counter = 0 179 | tqdm_bar = tqdm(total=num_bytes, desc="Loading text data") 180 | while byte_count < num_bytes: 181 | fname = all_files[counter % len(all_files)] 182 | filesize = os.path.getsize(data_dir / fname) 183 | if byte_count + filesize <= num_bytes: 184 | file_list.append(str(data_dir / fname)) 185 | byte_count += filesize 186 | tqdm_bar.update(filesize) 187 | else: 188 | wanted_filesize = num_bytes - byte_count 189 | truncated_filepath, true_filesize = get_truncated_file( 190 | data_dir / fname, wanted_filesize 191 | ) 192 | file_list.append(truncated_filepath) 193 | byte_count += true_filesize 194 | tqdm_bar.update(true_filesize) 195 | counter += 1 196 | if not loop_around and counter >= len(all_files): 197 | break 198 | return file_list, byte_count 199 | --------------------------------------------------------------------------------