├── .gitignore ├── README.md ├── env.yaml ├── planets-sim ├── draw_sample.py └── rebound_sim.py ├── tokenize-dataset.py ├── tokenizer.json ├── xval ├── __init__.py ├── analyze.py ├── make_tokenizer.py ├── numformer.py └── preprocess.py └── xval_demo.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | xval/data/* 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # PyInstaller 11 | # Usually these files are written by a python script from a template 12 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 13 | *.manifest 14 | *.spec 15 | 16 | # Installer logs 17 | pip-log.txt 18 | pip-delete-this-directory.txt 19 | 20 | # Unit test / coverage reports 21 | htmlcov/ 22 | .tox/ 23 | .nox/ 24 | .coverage 25 | .coverage.* 26 | .cache 27 | nosetests.xml 28 | coverage.xml 29 | *.cover 30 | *.py,cover 31 | .hypothesis/ 32 | .pytest_cache/ 33 | cover/ 34 | 35 | # Translations 36 | *.mo 37 | *.pot 38 | 39 | # Django stuff: 40 | *.log 41 | local_settings.py 42 | db.sqlite3 43 | db.sqlite3-journal 44 | 45 | # Flask stuff: 46 | instance/ 47 | .webassets-cache 48 | 49 | # Scrapy stuff: 50 | .scrapy 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | 55 | # PyBuilder 56 | .pybuilder/ 57 | target/ 58 | 59 | # Jupyter Notebook 60 | .ipynb_checkpoints 61 | 62 | # IPython 63 | profile_default/ 64 | ipython_config.py 65 | 66 | # pyenv 67 | # For a library or package, you might want to ignore these files since the code is 68 | # intended to run in multiple environments; otherwise, check them in: 69 | # .python-version 70 | 71 | # pipenv 72 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 73 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 74 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 75 | # install all needed dependencies. 76 | #Pipfile.lock 77 | 78 | # poetry 79 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 80 | # This is especially recommended for binary packages to ensure reproducibility, and is more 81 | # commonly ignored for libraries. 82 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 83 | #poetry.lock 84 | 85 | # pdm 86 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 87 | #pdm.lock 88 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 89 | # in version control. 90 | # https://pdm.fming.dev/#use-with-ide 91 | .pdm.toml 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # pytype static type analyzer 131 | .pytype/ 132 | 133 | # Cython debug symbols 134 | cython_debug/ 135 | 136 | # PyCharm 137 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 138 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 139 | # and can be added to the global gitignore or merged into this file. For a more nuclear 140 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 141 | #.idea/ 142 | scratch/ 143 | settings.json 144 | chkpt.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # xVal 2 | 3 | 4 | 5 | Repository for code used in the xVal paper. 6 | 7 | ### Instructions 8 | 9 | This repository holds the code for both xVal number encoding preprocessing as well as a transformer based model designed for its use. Both of these are described with an example in the xval_demo.ipynb notebook. 10 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: 2 | xval 3 | 4 | channels: 5 | - conda-forge 6 | - pytorch 7 | - nvidia 8 | - huggingface 9 | 10 | dependencies: 11 | - python=3.9 12 | - pytorch::pytorch-cuda=11.8 13 | - pytorch::pytorch=2.0 14 | - conda-forge::numpy 15 | - huggingface::transformers 16 | - huggingface::datasets 17 | - conda-forge::tiktoken 18 | - conda-forge::wandb 19 | - conda-forge::tqdm 20 | - conda-forge::matplotlib 21 | - conda-forge::jupyter 22 | - conda-forge::ipywidgets 23 | - conda-forge::pip -------------------------------------------------------------------------------- /planets-sim/draw_sample.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import sys, os 3 | 4 | from rebound_sim import construct_example 5 | import numpy as np 6 | import uuid 7 | 8 | path = "." 9 | os.makedirs(path, exist_ok=True) 10 | 11 | 12 | def reformat_data(data): 13 | data["description"].pop("seed") 14 | stepsize = data["description"].pop("stepsize") 15 | keys_list = data["description"].keys() 16 | 17 | rescale = {"m": 1e5, "a": 1, "e": 20} 18 | 19 | desc_dict = { 20 | "planet" + str(i): {key: val * rescale[key] for key, val in zip(keys_list, el)} 21 | for i, el in enumerate(zip(*data["description"].values())) 22 | } 23 | 24 | desc_dict["stepsize"] = stepsize 25 | 26 | data = [[[l2["x"], l2["y"]] for l2 in l1] for l1 in data["data"]] 27 | 28 | return str({"description": desc_dict, "data": data}).replace(" ", "") 29 | 30 | 31 | if __name__ == "__main__": 32 | import sys 33 | 34 | num_draws = int(sys.argv[1]) 35 | 36 | filename = str(uuid.uuid4()) 37 | 38 | samples = [] 39 | 40 | num_prints = 10 41 | print_every = num_draws // num_prints 42 | 43 | print("\nStarting sample generation...\n") 44 | drawn_samples = 0 45 | while drawn_samples < num_draws: 46 | sample = reformat_data(construct_example()) 47 | max_value = np.abs(np.array(eval(sample)["data"])).max() 48 | if max_value < 10: 49 | drawn_samples += 1 50 | samples.append(reformat_data(construct_example())) 51 | else: 52 | print("Sample rejected, max value: {}".format(max_value)) 53 | 54 | if drawn_samples % print_every == 0: 55 | print("Sample {}/{}".format(drawn_samples, num_draws)) 56 | 57 | print("\nWriting to file: {}{}\n".format(path, filename)) 58 | 59 | with open( 60 | r"{}{}".format(path, filename), 61 | "w", 62 | ) as f: 63 | for el in samples: 64 | # write each item on a new line 65 | f.write("{}\n".format(el)) 66 | 67 | print("Done!") 68 | 69 | # %% 70 | -------------------------------------------------------------------------------- /planets-sim/rebound_sim.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # Here, we generate some data using rebound. 3 | import rebound 4 | import numpy as np 5 | 6 | 7 | # %% 8 | def make_sim(rstate): 9 | sim = rebound.Simulation() 10 | sim.add(m=1.0) 11 | nplanet = rstate.randint(2, 6) 12 | 13 | m = rstate.uniform(1e-5, 5e-5, nplanet) 14 | a_final = rstate.uniform(1.5, 3.0) 15 | a = np.linspace(1.0, a_final, nplanet) 16 | e = rstate.uniform(0.0, 0.1, nplanet) 17 | theta = rstate.uniform(0.0, 2 * np.pi, nplanet) 18 | 19 | # with probability 0.3, set all theta to 0 20 | if rstate.uniform() < 0.3: 21 | theta = np.zeros(nplanet) 22 | 23 | # random permutation of the planet order 24 | perm = rstate.permutation(nplanet) 25 | m, a, e, theta = m[perm], a[perm], e[perm], theta[perm] 26 | 27 | for i in range(nplanet): 28 | sim.add(m=m[i], a=a[i], e=e[i], theta=theta[i]) 29 | 30 | return sim, {"m": list(m), "a": list(a), "e": list(e)} 31 | 32 | 33 | # %% 34 | def integrate(sim, t): 35 | outputs = [] 36 | for ti in t: 37 | # Counter-intuitively, sim.integrate 38 | # integrates from 0 to its input, not 39 | # from its current time to its input. 40 | sim.integrate(ti) 41 | cur_out = [] 42 | for i, particle in enumerate(sim.particles): 43 | if i == 0: 44 | continue 45 | 46 | cur_out.append( 47 | dict( 48 | a=particle.a, 49 | e=particle.e, 50 | # Omega=particle.Omega, # longitude of ascending node 51 | omega=particle.omega, # argument of periapsis 52 | f=particle.f, # true anomaly 53 | # (Just for plotting) 54 | x=particle.x, 55 | y=particle.y, 56 | ) 57 | ) 58 | outputs.append(cur_out) 59 | return outputs 60 | 61 | 62 | # %% 63 | def construct_example(seed=None): 64 | rstate = np.random.RandomState(seed) 65 | sim, meta = make_sim(rstate) 66 | revs = rstate.randint(15, 25) 67 | steps = rstate.randint(30, 60) 68 | t = np.linspace(0, revs, steps) 69 | data = integrate(sim, t) 70 | return { 71 | "data": data, 72 | "description": {"seed": seed, **meta, "stepsize": t[1] - t[0]}, 73 | } 74 | 75 | 76 | # %% 77 | -------------------------------------------------------------------------------- /tokenize-dataset.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | from xval.preprocess import extract_all_keys, tokenize_fnc, convert_num_string 4 | from xval.tokenizer import make_tokenizer 5 | from datasets import DatasetDict 6 | from transformers import PreTrainedTokenizerFast 7 | import yaml 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--datapath", type=str, help="Path to data") 11 | parser.add_argument("--savepath", type=str, help="Save path for tokenizer and tokenized dataset.") 12 | parser.add_argument("--encoding", type=str, choices=['xval', 'fp15', 'p10', 'p1000', 'b1999'], default='xval', help="Choose your encoding scheme. (xVal is ours.)") 13 | args = parser.parse_args() 14 | 15 | data_path = args.datapath 16 | save_path = args.savepath 17 | 18 | print(f"Using encoding: {args.encoding}") 19 | 20 | if save_path is None: 21 | save_path = data_path # use same path by default 22 | 23 | files = [f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f)) and "test" in f or "train" in f or "val" in f and "json" not in f and "token" not in f and "yaml" not in f] 24 | 25 | print(f"\nLoading dataset from {data_path} with splits: {files}") 26 | ds = DatasetDict.from_text({file: data_path + "/" + file for file in files}) 27 | 28 | print("\nExtracting keys from the train set...") 29 | ds_keys = ds["train"].map( 30 | lambda x: {"keys": extract_all_keys(x["text"])}, 31 | num_proc=30, 32 | remove_columns=["text"], 33 | load_from_cache_file=False, 34 | ) 35 | sample_keys = list(set([item for sublist in ds_keys["keys"] for item in sublist])) 36 | print(f"\nExtracted keys from dataset: {sample_keys}\n") 37 | 38 | tokenizer_path = save_path + "tokenizer_{}.json".format(args.encoding) 39 | os.makedirs(save_path, exist_ok=True) 40 | make_tokenizer( 41 | encoding=args.encoding, 42 | save_file=tokenizer_path, 43 | efficient_json=True, 44 | sample_keys=sample_keys 45 | ) 46 | 47 | tokenizer = PreTrainedTokenizerFast( 48 | tokenizer_file=tokenizer_path, 49 | bos_token="[END]", 50 | eos_token="[END]", 51 | mask_token="[MASK]", 52 | pad_token="[PAD]", 53 | ) 54 | 55 | print("Tokenizer saved to: ", tokenizer_path) 56 | print("Vocab size: ", len(tokenizer.vocab)) 57 | 58 | print("\nStarting tokenization...") 59 | if args.encoding == 'xval': 60 | tokenize_lambda = lambda x: tokenize_fnc(x, tokenizer) 61 | batched = False 62 | batch_size = None 63 | 64 | else: 65 | def tokenize_lambda(samples): 66 | out = [] 67 | for sample in samples["text"]: 68 | out.append(tokenizer.encode(convert_num_string(sample))) 69 | return {"input_ids": out} 70 | batched = True 71 | batch_size = 100 72 | 73 | tokenized_ds = ds.map( 74 | tokenize_lambda, 75 | batched=batched, 76 | num_proc=30, 77 | batch_size=batch_size, 78 | remove_columns=["text"], 79 | load_from_cache_file=False, 80 | ) 81 | 82 | if args.encoding == "xval": 83 | max_len = max([max(tokenized_ds[key]["len"]) for key in tokenized_ds.keys()]) 84 | tokenized_ds = DatasetDict( 85 | {key: val.remove_columns(["len"]) for key, val in tokenized_ds.items()} 86 | ) 87 | else: 88 | lens = tokenized_ds["train"].map( 89 | lambda x: {"len": len(x["input_ids"])}, 90 | num_proc=30, 91 | remove_columns="input_ids", 92 | load_from_cache_file=False, 93 | )["len"] 94 | max_len = max(lens) 95 | 96 | print(f"Longest sequence length: {max_len}") 97 | 98 | print("\nTokenization finished. Saving...") 99 | full_save_path = save_path + "/tokenized_ds_"+str(args.encoding) 100 | tokenized_ds.save_to_disk(full_save_path) 101 | print("Tokenized dataset saved to: ", full_save_path) 102 | 103 | config = { 104 | "vocab_size": len(tokenizer.vocab), 105 | "block_size": max_len, 106 | "tokenizer": tokenizer_path, 107 | "dataset": full_save_path, 108 | "dataset_type": "hf_fullsample_numbers_mlm", 109 | "mask_token": "[MASK]", 110 | } 111 | 112 | config_path = save_path + f"/config_{args.encoding}.yaml" 113 | with open(config_path, "w") as file: 114 | yaml.dump(config, file) 115 | 116 | print("\nConfiguration saved to: ", config_path) 117 | -------------------------------------------------------------------------------- /tokenizer.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.0", 3 | "truncation": null, 4 | "padding": null, 5 | "added_tokens": [ 6 | { 7 | "id": 0, 8 | "content": "[END]", 9 | "single_word": false, 10 | "lstrip": false, 11 | "rstrip": false, 12 | "normalized": false, 13 | "special": true 14 | }, 15 | { 16 | "id": 1, 17 | "content": "[MASK]", 18 | "single_word": false, 19 | "lstrip": false, 20 | "rstrip": false, 21 | "normalized": false, 22 | "special": true 23 | }, 24 | { 25 | "id": 2, 26 | "content": "[PAD]", 27 | "single_word": false, 28 | "lstrip": false, 29 | "rstrip": false, 30 | "normalized": false, 31 | "special": true 32 | }, 33 | { 34 | "id": 3, 35 | "content": "[NUM]", 36 | "single_word": false, 37 | "lstrip": false, 38 | "rstrip": false, 39 | "normalized": false, 40 | "special": true 41 | }, 42 | { 43 | "id": 4, 44 | "content": "{", 45 | "single_word": false, 46 | "lstrip": false, 47 | "rstrip": false, 48 | "normalized": true, 49 | "special": false 50 | }, 51 | { 52 | "id": 5, 53 | "content": "}", 54 | "single_word": false, 55 | "lstrip": false, 56 | "rstrip": false, 57 | "normalized": true, 58 | "special": false 59 | }, 60 | { 61 | "id": 6, 62 | "content": "[", 63 | "single_word": false, 64 | "lstrip": false, 65 | "rstrip": false, 66 | "normalized": true, 67 | "special": false 68 | }, 69 | { 70 | "id": 7, 71 | "content": "]", 72 | "single_word": false, 73 | "lstrip": false, 74 | "rstrip": false, 75 | "normalized": true, 76 | "special": false 77 | }, 78 | { 79 | "id": 8, 80 | "content": ",", 81 | "single_word": false, 82 | "lstrip": false, 83 | "rstrip": false, 84 | "normalized": true, 85 | "special": false 86 | }, 87 | { 88 | "id": 9, 89 | "content": "]]]", 90 | "single_word": false, 91 | "lstrip": false, 92 | "rstrip": false, 93 | "normalized": true, 94 | "special": false 95 | }, 96 | { 97 | "id": 10, 98 | "content": "[[[", 99 | "single_word": false, 100 | "lstrip": false, 101 | "rstrip": false, 102 | "normalized": true, 103 | "special": false 104 | }, 105 | { 106 | "id": 11, 107 | "content": "]]", 108 | "single_word": false, 109 | "lstrip": false, 110 | "rstrip": false, 111 | "normalized": true, 112 | "special": false 113 | }, 114 | { 115 | "id": 12, 116 | "content": "[[", 117 | "single_word": false, 118 | "lstrip": false, 119 | "rstrip": false, 120 | "normalized": true, 121 | "special": false 122 | }, 123 | { 124 | "id": 13, 125 | "content": "]]],[[[", 126 | "single_word": false, 127 | "lstrip": false, 128 | "rstrip": false, 129 | "normalized": true, 130 | "special": false 131 | }, 132 | { 133 | "id": 14, 134 | "content": "]],[[", 135 | "single_word": false, 136 | "lstrip": false, 137 | "rstrip": false, 138 | "normalized": true, 139 | "special": false 140 | }, 141 | { 142 | "id": 15, 143 | "content": "],[", 144 | "single_word": false, 145 | "lstrip": false, 146 | "rstrip": false, 147 | "normalized": true, 148 | "special": false 149 | }, 150 | { 151 | "id": 16, 152 | "content": "'e':", 153 | "single_word": false, 154 | "lstrip": false, 155 | "rstrip": false, 156 | "normalized": true, 157 | "special": false 158 | }, 159 | { 160 | "id": 17, 161 | "content": "'planet4':", 162 | "single_word": false, 163 | "lstrip": false, 164 | "rstrip": false, 165 | "normalized": true, 166 | "special": false 167 | }, 168 | { 169 | "id": 18, 170 | "content": "'description':", 171 | "single_word": false, 172 | "lstrip": false, 173 | "rstrip": false, 174 | "normalized": true, 175 | "special": false 176 | }, 177 | { 178 | "id": 19, 179 | "content": "'data':", 180 | "single_word": false, 181 | "lstrip": false, 182 | "rstrip": false, 183 | "normalized": true, 184 | "special": false 185 | }, 186 | { 187 | "id": 20, 188 | "content": "'a':", 189 | "single_word": false, 190 | "lstrip": false, 191 | "rstrip": false, 192 | "normalized": true, 193 | "special": false 194 | }, 195 | { 196 | "id": 21, 197 | "content": "'m':", 198 | "single_word": false, 199 | "lstrip": false, 200 | "rstrip": false, 201 | "normalized": true, 202 | "special": false 203 | }, 204 | { 205 | "id": 22, 206 | "content": "'planet3':", 207 | "single_word": false, 208 | "lstrip": false, 209 | "rstrip": false, 210 | "normalized": true, 211 | "special": false 212 | }, 213 | { 214 | "id": 23, 215 | "content": "'stepsize':", 216 | "single_word": false, 217 | "lstrip": false, 218 | "rstrip": false, 219 | "normalized": true, 220 | "special": false 221 | }, 222 | { 223 | "id": 24, 224 | "content": "'planet1':", 225 | "single_word": false, 226 | "lstrip": false, 227 | "rstrip": false, 228 | "normalized": true, 229 | "special": false 230 | }, 231 | { 232 | "id": 25, 233 | "content": "'planet2':", 234 | "single_word": false, 235 | "lstrip": false, 236 | "rstrip": false, 237 | "normalized": true, 238 | "special": false 239 | }, 240 | { 241 | "id": 26, 242 | "content": "'planet0':", 243 | "single_word": false, 244 | "lstrip": false, 245 | "rstrip": false, 246 | "normalized": true, 247 | "special": false 248 | } 249 | ], 250 | "normalizer": null, 251 | "pre_tokenizer": { 252 | "type": "ByteLevel", 253 | "add_prefix_space": false, 254 | "trim_offsets": true, 255 | "use_regex": true 256 | }, 257 | "post_processor": { 258 | "type": "ByteLevel", 259 | "add_prefix_space": true, 260 | "trim_offsets": false, 261 | "use_regex": true 262 | }, 263 | "decoder": { 264 | "type": "ByteLevel", 265 | "add_prefix_space": true, 266 | "trim_offsets": true, 267 | "use_regex": true 268 | }, 269 | "model": { 270 | "type": "BPE", 271 | "dropout": null, 272 | "unk_token": null, 273 | "continuing_subword_prefix": null, 274 | "end_of_word_suffix": null, 275 | "fuse_unk": false, 276 | "byte_fallback": false, 277 | "vocab": {}, 278 | "merges": [] 279 | } 280 | } -------------------------------------------------------------------------------- /xval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PolymathicAI/xVal/653a9424280e107817ac2d75079ca38b529b3c52/xval/__init__.py -------------------------------------------------------------------------------- /xval/analyze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import pandas as pd 4 | import numpy as np 5 | 6 | 7 | def token_structure(sample, tokenizer, start=0, tokens=30): 8 | import textwrap 9 | 10 | def colored_text(text, color_code): 11 | return f"\033[{color_code}m{text}\033[0m" 12 | 13 | if start < 0: 14 | start = len(sample["input_ids"]) + start 15 | 16 | seq = sample["input_ids"][start : start + tokens] 17 | 18 | print(tokenizer.decode(seq)) 19 | 20 | text_parts = [] 21 | for i, el in enumerate(seq): 22 | index_part = f"{start+i}" # Red color for index part 23 | element_part = colored_text(f"[{el}]", "32") # Green color for element part 24 | decoded_part = colored_text( 25 | f"{tokenizer.decode(el)}", "34" 26 | ) # Blue color for decoded text part 27 | to_append = f"{index_part},{element_part},{decoded_part}" 28 | if "numbers" in sample.keys(): 29 | to_append += colored_text(f",{sample['numbers'][start+i]:.2g}", "31") 30 | text_parts.append(to_append) 31 | 32 | text = " ".join(text_parts) 33 | 34 | print() 35 | print("\n".join(textwrap.wrap(text, 240, break_long_words=False))) 36 | 37 | 38 | def mask_numbers(sample, tokenizer, n_list): 39 | import copy 40 | 41 | mask_token = tokenizer.encode("[MASK]")[0] 42 | masked_sample = copy.deepcopy(sample) 43 | len_ = len(masked_sample["input_ids"]) 44 | masked_sample["masked_numbers"] = copy.deepcopy(sample["numbers"])[:len_] 45 | masked_sample["numbers"] = masked_sample["numbers"][:len_] 46 | masked_sample["labels"] = sample["input_ids"] 47 | for n in n_list: 48 | masked_sample["input_ids"][n] = mask_token 49 | masked_sample["masked_numbers"][n] = 1.0 50 | # Next two lines are for calculating the correct mlm loss 51 | # tells the model to only look at the masked token for calculating x-entropy 52 | masked_sample["labels"] = list(0 * np.array(masked_sample["labels"]) - 100) 53 | masked_sample["labels"][n] = sample["input_ids"][n] 54 | masked_sample["ans"] = masked_sample["numbers"][n] 55 | masked_sample["text"] = tokenizer.decode(sample["input_ids"]) 56 | masked_sample["masked_text"] = tokenizer.decode(masked_sample["input_ids"]) 57 | return masked_sample 58 | 59 | 60 | def mask_nth_number(sample, tokenizer, n): 61 | import copy 62 | 63 | mask_token = tokenizer.encode("[MASK]")[0] 64 | masked_sample = copy.deepcopy(sample) 65 | masked_sample["input_ids"][n] = mask_token 66 | len_ = len(masked_sample["input_ids"]) 67 | masked_sample["masked_numbers"] = copy.deepcopy(sample["numbers"])[:len_] 68 | masked_sample["numbers"] = masked_sample["numbers"][:len_] 69 | masked_sample["labels"] = sample["input_ids"] 70 | masked_sample["masked_numbers"][n] = 1.0 71 | # Next two lines are for calculating the correct mlm loss 72 | # tells the model to only look at the masked token for calculating x-entropy 73 | masked_sample["labels"] = list(0 * np.array(masked_sample["labels"]) - 100) 74 | masked_sample["labels"][n] = sample["input_ids"][n] 75 | masked_sample["text"] = tokenizer.decode(sample["input_ids"]) 76 | masked_sample["masked_text"] = tokenizer.decode(masked_sample["input_ids"]) 77 | masked_sample["ans"] = masked_sample["numbers"][n] 78 | return masked_sample 79 | 80 | 81 | ### Each number 82 | def predict(model, masked_sample, device="cuda"): 83 | model.eval() 84 | model.to(device) 85 | input = { 86 | "x": torch.tensor(masked_sample["input_ids"]).view(1, -1).to(device), 87 | # "y": torch.tensor(masked_sample["labels"]).view(1, -1).to(device), 88 | "x_num": torch.tensor(masked_sample["masked_numbers"]).view(1, -1).to(device), 89 | # "y_num": torch.tensor(masked_sample["masked_numbers"]).view(1, -1).to(device), 90 | } 91 | out = model(**input) 92 | return out 93 | 94 | 95 | ### Each row 96 | def predict_numbers(model, sample, tokenizer, n_list, device, all_at_once=False): 97 | num_pred_list = [] 98 | num_true_list = [] 99 | if all_at_once: 100 | masked_sample = mask_numbers(sample, tokenizer, n_list) 101 | out = predict(model, masked_sample, device) 102 | for n in n_list: 103 | num_pred_list.append(out[1][0][n].item()) 104 | num_true_list.append(masked_sample["numbers"][n]) 105 | else: 106 | for n in n_list: 107 | masked_sample = mask_nth_number(sample, tokenizer, n) 108 | out = predict(model, masked_sample, device) 109 | num_pred_list.append(out[1][0][n].item()) 110 | num_true_list.append(masked_sample["numbers"][n]) 111 | 112 | return { 113 | "num_pred_list": num_pred_list, 114 | "num_true_list": num_true_list, 115 | } 116 | 117 | 118 | ### Run on whole dataset 119 | def slow_eval_numbers( 120 | model, dataset, tokenizer, n_list, device, num_samples=None, all_at_once=False 121 | ): 122 | model.eval() 123 | model.to(device) 124 | 125 | if num_samples is None: 126 | num_samples = len(dataset) 127 | 128 | with torch.no_grad(): 129 | out = [] 130 | for i in tqdm(range(num_samples)): 131 | sample = dataset[i] 132 | out.append( 133 | predict_numbers(model, sample, tokenizer, n_list, device, all_at_once) 134 | ) 135 | 136 | pd_out = pd.DataFrame(out) 137 | return pd_out 138 | -------------------------------------------------------------------------------- /xval/make_tokenizer.py: -------------------------------------------------------------------------------- 1 | from tokenizers import ( 2 | decoders, 3 | models, 4 | processors, 5 | Tokenizer, 6 | pre_tokenizers, 7 | ) 8 | import numpy as np 9 | 10 | 11 | def make_tokenizer( 12 | vocab_words=[], 13 | encoding="xval", 14 | save_file="./tokenizer.json", 15 | special_tokens=["[END]", "[MASK]", "[PAD]"], 16 | efficient_json=True, 17 | sample_keys=None, 18 | ): 19 | if encoding == "xval": 20 | special_tokens += ["[NUM]"] 21 | full_vocab = {} 22 | 23 | else: 24 | vocab = ["{", "}", "[", "]", ",", "-", "+", "#"] 25 | full_vocab = {el: i for i, el in enumerate(vocab)} 26 | 27 | if encoding == "fp15": 28 | vocab_words += ( 29 | [ 30 | "]]", 31 | "[[", 32 | "]]]", 33 | "[[[", 34 | "],[", 35 | "]],[[", 36 | "]]],[[[", 37 | "'data':", 38 | "'description':", 39 | "+0.00e+0", 40 | ] 41 | + [ 42 | f"{s}{n:.2f}e+{i}" 43 | for n in np.arange(1, 10, 0.01) 44 | for i in range(0, 8) 45 | for s in ["+", "-"] 46 | ] 47 | + [ 48 | f"{s}{n:.2f}e-{i}" 49 | for n in np.arange(1, 10, 0.01) 50 | for i in range(1, 8) 51 | for s in ["+", "-"] 52 | ] 53 | ) 54 | 55 | elif encoding == "p10": 56 | vocab_words += ( 57 | [ 58 | "]]", 59 | "[[", 60 | "]]]", 61 | "[[[", 62 | "],[", 63 | "]],[[", 64 | "]]],[[[", 65 | "'data':", 66 | "'description':", 67 | "+0.00e+0", 68 | "+", 69 | "-", 70 | ] 71 | + [str(d) for d in range(10)] 72 | + [f"e+{i}" for i in range(0, 8)] 73 | + [f"e-{i}" for i in range(1, 8)] 74 | ) 75 | 76 | elif encoding == "p1000": 77 | vocab_words = ( 78 | [ 79 | "]]", 80 | "[[", 81 | "]]]", 82 | "[[[", 83 | "],[", 84 | "]],[[", 85 | "]]],[[[", 86 | "'data':", 87 | "'description':", 88 | "0.00", 89 | ] 90 | + [f"{n:.2f}" for n in np.arange(1, 10, 0.01)] 91 | + [f"e+{i}" for i in range(0, 8)] 92 | + [f"e-{i}" for i in range(1, 8)] 93 | ) 94 | 95 | elif encoding == "b1999": 96 | vocab_words = ( 97 | [ 98 | "]]", 99 | "[[", 100 | "]]]", 101 | "[[[", 102 | "],[", 103 | "]],[[", 104 | "]]],[[[", 105 | "'data':", 106 | "'description':", 107 | "+0.00", 108 | ] 109 | + [f"{s}{n:.2f}" for n in np.arange(1, 10, 0.01) for s in ["+", "-"]] 110 | + [f"e+{i}" for i in range(0, 8)] 111 | + [f"e-{i}" for i in range(1, 8)] 112 | ) 113 | 114 | if efficient_json: 115 | efficient_vocab = [ 116 | "{", 117 | "}", 118 | "[", 119 | "]", 120 | ",", 121 | "]]]", 122 | "[[[", 123 | "]]", 124 | "[[", 125 | "]]],[[[", 126 | "]],[[", 127 | "],[", 128 | ] 129 | vocab_words += efficient_vocab 130 | 131 | if sample_keys is not None: 132 | vocab_words += [f"'{el}':" for el in sample_keys] 133 | 134 | tokenizer = Tokenizer(models.BPE(vocab=full_vocab, merges=[])) 135 | tokenizer.add_special_tokens(special_tokens) 136 | tokenizer.add_tokens(vocab_words) 137 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) 138 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=False) 139 | tokenizer.decoder = decoders.ByteLevel() 140 | tokenizer.save(save_file) 141 | -------------------------------------------------------------------------------- /xval/numformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils.rnn import pad_sequence 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from datasets import DatasetDict 7 | from transformers import PreTrainedTokenizerFast 8 | import torch.optim as optim 9 | 10 | 11 | class Numformer(nn.Module): 12 | def __init__( 13 | self, 14 | vocab_size, 15 | d_model=768, 16 | nhead=6, 17 | num_layers=6, 18 | dim_feedforward=3072, 19 | dropout=0.1, 20 | activation=nn.GELU(), 21 | layer_norm_eps=1e-05, 22 | batch_first=True, 23 | norm_first=True, 24 | transformer_bias=False, 25 | numhead_bias=True, 26 | context_length=1024, 27 | is_causal=False, 28 | ): 29 | super().__init__() 30 | encoder = nn.TransformerEncoderLayer( 31 | d_model=d_model, 32 | nhead=nhead, 33 | dim_feedforward=dim_feedforward, 34 | dropout=dropout, 35 | activation=activation, 36 | layer_norm_eps=layer_norm_eps, 37 | batch_first=batch_first, 38 | norm_first=norm_first, 39 | # bias=transformer_bias, 40 | ) 41 | self.encoder_stack = nn.TransformerEncoder( 42 | encoder_layer=encoder, num_layers=num_layers, enable_nested_tensor=False 43 | ) 44 | self.token_embed = nn.Embedding(vocab_size, d_model) 45 | self.position_embed = nn.Embedding(context_length, d_model) 46 | self.lm_head = nn.Sequential( 47 | nn.Linear(d_model, dim_feedforward, bias=transformer_bias), 48 | nn.GELU(), 49 | nn.Linear(dim_feedforward, vocab_size, bias=transformer_bias), 50 | ) 51 | self.num_head = nn.Sequential( 52 | nn.Linear(d_model, dim_feedforward, bias=numhead_bias), 53 | nn.GELU(), 54 | nn.Linear(dim_feedforward, 1, bias=numhead_bias), 55 | ) 56 | self.is_causal = is_causal 57 | 58 | def forward(self, x, x_num): 59 | x = self.token_embed(x) * x_num.unsqueeze(-1) 60 | x = x + self.position_embed.weight[: x.shape[1]].unsqueeze(0) 61 | x = self.encoder_stack(x, is_causal=self.is_causal) 62 | logit_preds = self.lm_head(x) 63 | num_preds = self.num_head(x) 64 | return logit_preds, num_preds 65 | 66 | 67 | ### Define collator and data loaders 68 | def define_masked_num_collator(pad_token_id, mask_token_id, mlm_probability): 69 | def masked_num_collator(batch): 70 | x = [torch.tensor(sample["input_ids"]) for sample in batch] 71 | x_num = [torch.tensor(sample["numbers"]) for sample in batch] 72 | x = pad_sequence(x, batch_first=True, padding_value=pad_token_id) 73 | x_num = pad_sequence(x_num, batch_first=True, padding_value=1) 74 | probability_matrix = torch.full(x.shape, mlm_probability) 75 | mask = torch.bernoulli(probability_matrix).bool() 76 | y = x.clone() 77 | y_num = x_num.clone() 78 | y[~mask] = -100 79 | x[mask] = mask_token_id 80 | x_num[mask] = 1 81 | return {"x": x, "x_num": x_num, "y": y, "y_num": y_num, "mask": mask} 82 | 83 | return masked_num_collator 84 | -------------------------------------------------------------------------------- /xval/preprocess.py: -------------------------------------------------------------------------------- 1 | # Defining the regular expression replacement to replace numbers with ¬s 2 | # and add the numbers to a list 3 | import re 4 | import numpy as np 5 | 6 | # Defining the regular expression replacement to replace numbers with ¬s 7 | # and add the numbers to a list 8 | 9 | def replace(text, numbers, num_token="[NUM]"): 10 | text = text.replace(num_token, "¬").replace("¬¬", "¬, ¬").replace("¬¬", "¬, ¬") 11 | for number in numbers: 12 | text = text.replace("¬", str(number), 1) 13 | return text 14 | 15 | def compress_matrix(text): 16 | text = ( 17 | text.replace("¬, ¬", "¬¬") 18 | .replace("¬, ¬", "¬¬") 19 | .replace("¬,¬", "¬¬") 20 | .replace("¬,¬", "¬¬") 21 | ) 22 | return text 23 | 24 | def extract(text, num_token="[NUM]"): 25 | import re 26 | 27 | # this regular expression is intended to match numerical values in various forms 28 | # like integers, floating-point numbers, or scientific notation, while avoiding 29 | # matching numbers that are part of strings. 30 | pattern = r"(?