├── .DS_Store ├── src ├── sage_tokenizer │ ├── __init__.py │ ├── Word2VecParams.py │ ├── paths.py │ ├── HFEncoding.py │ ├── embeddings.py │ ├── model.py │ ├── SaGeVocabBuilder.py │ └── utils.py └── main.py ├── run.sh ├── sage_v1 ├── Python-Modules │ ├── Parameters.py │ ├── Logger.py │ ├── Corpus.py │ ├── Embeddings.py │ ├── SG_BPE.py │ └── Utils.py ├── LICENSE ├── README.md └── Main.py ├── pyproject.toml ├── LICENSE ├── .gitignore └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MeLeLBGU/SaGe/HEAD/.DS_Store -------------------------------------------------------------------------------- /src/sage_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .SaGeVocabBuilder import SaGeVocabBuilder 2 | from .model import SaGeTokenizer 3 | from .paths import setSageFolder 4 | -------------------------------------------------------------------------------- /src/sage_tokenizer/Word2VecParams.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | 3 | class Word2VecParams: 4 | def __init__(self, D, N, ALPHA, window_size, min_count, sg): 5 | self.D = D 6 | self.N = N 7 | self.ALPHA = ALPHA 8 | self.window_size = window_size 9 | self.min_count = min_count 10 | self.sg = sg 11 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python main.py exp_name \ 2 | --corpus_filepath data/wiki_lines.txt \ 3 | --initial_vocabulary_filepath data/initial_vocab_hex.vocab \ 4 | --vocabulary_schedule 262144 229376 196608 163840 131072 98304 65536 57344 49152 40960 32768 16384 \ 5 | --embeddings_schedule 262144 131072 65536 49152 40960 32768 \ 6 | --partial_corpus_filepath data/wiki_lines_partial.txt \ 7 | --partial_corpus_line_number 500 \ 8 | --max_len 17 \ 9 | --workers 4 \ 10 | --random_seed 692653 11 | -------------------------------------------------------------------------------- /src/sage_tokenizer/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | PATH_SAGE = Path(os.getcwd()) 6 | def setSageFolder(path: Path): 7 | global PATH_SAGE 8 | PATH_SAGE = path 9 | 10 | def getDataFolder() -> Path: 11 | path = PATH_SAGE / "data" 12 | path.mkdir(exist_ok=True) 13 | return path 14 | 15 | def getResultsFolder() -> Path: 16 | path = PATH_SAGE / "results" 17 | path.mkdir(exist_ok=True) 18 | return path 19 | 20 | def getLogsFolder() -> Path: 21 | path = PATH_SAGE / "logs" 22 | path.mkdir(exist_ok=True) 23 | return path 24 | -------------------------------------------------------------------------------- /sage_v1/Python-Modules/Parameters.py: -------------------------------------------------------------------------------- 1 | ################################################################################################# 2 | ################# Training parameters (Embeddings.py) ########################################### 3 | ################################################################################################# 4 | 5 | # D = Embedding vector length 6 | D = 50 7 | 8 | # N = Number of negative samples 9 | N = 15 10 | 11 | # ALPHA = Initial learning rate 12 | ALPHA = 0.025 13 | 14 | ################################################################################################# 15 | -------------------------------------------------------------------------------- /sage_v1/Python-Modules/Logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig( 4 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 5 | datefmt="%m/%d/%Y %H:%M:%S", 6 | level=logging.INFO, 7 | ) 8 | 9 | class Logger: 10 | def __init__(self, name): 11 | self.logger = logging.getLogger(name) 12 | 13 | def info(self, message, *args, **kwargs): 14 | self.logger.info(message, *args, **kwargs) 15 | 16 | def warning(self, message, *args, **kwargs): 17 | self.logger.warn(message, *args, **kwargs) 18 | 19 | def error(self, message, *args, **kwargs): 20 | self.logger.error(message, *args, **kwargs) 21 | 22 | def log_separator(self): 23 | self.logger.info("---------------------") 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "sage-tokenizer" 7 | version = "2.0.0" 8 | dependencies = [ 9 | "gensim == 4.3.2", 10 | "scipy == 1.12.0" 11 | ] 12 | authors = [ 13 | { name="Shaked Yehezkel"}, 14 | { name="Kensho Technologies"}, 15 | { name="Bar Gazit"}, 16 | { name="Yuval Pinter"} 17 | ] 18 | description = "SaGe subword tokenizer - version 2.0" 19 | readme = "README.md" 20 | requires-python = ">=3.8" 21 | classifiers = [ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.urls] 28 | Homepage = "https://github.com/MeLeLBGU/SaGe" 29 | Issues = "https://github.com/MeLeLBGU/SaGe/issues" -------------------------------------------------------------------------------- /sage_v1/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MeLeL lab, Ben-Gurion University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MeLeL lab, Ben-Gurion University and Kensho Technologies, LLC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sage_v1/Python-Modules/Corpus.py: -------------------------------------------------------------------------------- 1 | import Utils 2 | import random 3 | 4 | class Corpus: 5 | def __init__(self, corpus_filepath, partial_corpus_filepath, partial_corpus_lines_number, log): 6 | self._log = log 7 | 8 | with open(corpus_filepath) as full_corpus: 9 | self._full_corpus_data = full_corpus.readlines() 10 | self._log.info("original corpus num of lines: {}".format(len(self._full_corpus_data))) 11 | random.shuffle(self._full_corpus_data) 12 | self._partial_corpus_data = self._full_corpus_data[:partial_corpus_lines_number] 13 | self._log.info("num of lines in corpus: {}".format(len(self._partial_corpus_data))) 14 | 15 | with open(partial_corpus_filepath, "w+") as partial_corpus_file: 16 | partial_corpus_file.writelines(self._partial_corpus_data) 17 | 18 | def get_full_corpus(self): 19 | return self._full_corpus_data 20 | 21 | def get_partial_corpus(self): 22 | return self._partial_corpus_data 23 | 24 | def get_corpus(self, partial=True): 25 | if partial: 26 | cor = self.get_partial_corpus() 27 | else: 28 | cor = self.get_full_corpus() 29 | 30 | self._log.info("corpus num of lines: {}".format(len(cor))) 31 | return cor 32 | 33 | def compute_window(self, token_index, tokens_in_line): 34 | return Utils.compute_window(token_index, tokens_in_line) 35 | 36 | -------------------------------------------------------------------------------- /src/sage_tokenizer/HFEncoding.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | 3 | # map all bytes to valid utf-8 characters 4 | # in the same way that the huggingface tokenizers byte level pre-tokenizer does 5 | class HFEncoding: 6 | 7 | # translated from rust code found here: 8 | # https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs 9 | @staticmethod 10 | def bytes_char(): 11 | 12 | bs = [] 13 | bs.extend(range(ord('!'), ord('~') + 1)) 14 | bs.extend(range(0xA1, 0xAC + 1)) 15 | bs.extend(range(0xAE, 0xFF + 1)) 16 | cs = [b for b in bs] 17 | 18 | n = 0 19 | for b in range(256): 20 | if b not in bs: 21 | bs.append(b) 22 | cs.append(2 ** 8 + n) 23 | n += 1 24 | 25 | return {bytes([f]): chr(t) for f, t in zip(bs, cs)} 26 | 27 | def __init__(self): 28 | # map any byte to the corresponding character 29 | self.byte_map = HFEncoding.bytes_char() 30 | # the inverse character to byte mapping 31 | self.inv_byte_map = {v: k for k, v in self.byte_map.items()} 32 | 33 | # convert an encoded string of our mapped characters back to the original bytes 34 | def to_bytes(self, s: str) -> bytes: 35 | return b"".join([self.inv_byte_map[c] for c in s]) 36 | 37 | # convert a byte string into an encoded string of valid characters 38 | def to_encoded(self, byte_str: bytes) -> str: 39 | return "".join([self.byte_map[bytes([c])] for c in byte_str]) 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # pycharm project settings 121 | .idea 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | data 135 | logs 136 | results -------------------------------------------------------------------------------- /sage_v1/Python-Modules/Embeddings.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import Parameters 3 | import gensim.models 4 | 5 | ############################################################################ 6 | 7 | class CorpusIteratorForGensim: 8 | def __init__(self, corpus, model): 9 | self._corpus = corpus 10 | self._model = model 11 | 12 | def __iter__(self): 13 | corpus_lines = self._corpus.get_corpus() 14 | 15 | for line in corpus_lines: 16 | # yield tokenized line, when representing tokens as strings (not ints) 17 | tokens_in_line = [self._model.id_to_piece(x) for x in self._model.encode(line, out_type=int)] 18 | yield tokens_in_line 19 | 20 | ############################################################################ 21 | 22 | class EmbeddingsTrainer: 23 | def __init__(self, model, corpus, window_size, log): 24 | self._model = model 25 | self._corpus = corpus 26 | self._window_size = window_size 27 | self._log = log 28 | 29 | ###################################### 30 | 31 | def construct_word2vec_embeddings(self, word2vec_model): 32 | vocab_size = self._model.vocab_size() 33 | embeddings = np.zeros(shape=(vocab_size, Parameters.D)) 34 | 35 | for i in range(vocab_size): 36 | ith_token = self._model.id_to_piece(i) 37 | if ith_token in word2vec_model.wv.key_to_index.keys(): 38 | embeddings[i] = word2vec_model.wv[ith_token] 39 | else: 40 | self._log.warning("No wv for token {}. Assigning random vector...".format(ith_token)) 41 | embeddings[i] = np.random.uniform(low=-0.5 / Parameters.D, high=0.5 / Parameters.D, size=(1, Parameters.D)) 42 | 43 | return embeddings 44 | 45 | ###################################### 46 | 47 | def train_embeddings(self): 48 | self._log.info("Training embeddings.") 49 | sentences = CorpusIteratorForGensim(self._corpus, self._model) 50 | self._log.info("Built CorpusIteratorForGensim") 51 | word2vec_model = gensim.models.Word2Vec(sentences=sentences, \ 52 | vector_size=Parameters.D, \ 53 | window=self._window_size, 54 | min_count=0, 55 | sg=1, 56 | negative=Parameters.N) 57 | self._log.info("gensim.models.Word2Vec Finished") 58 | 59 | # now extract the embeddings in the format we expect to 60 | embeddings = self.construct_word2vec_embeddings(word2vec_model) 61 | self._log.info("Finished construct_word2vec_embeddings") 62 | 63 | # in gensim there is no way to get context embeddings - so we do target=context=word vectors. 64 | return embeddings, embeddings 65 | 66 | ############################################################################ 67 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | 3 | import argparse 4 | 5 | from sage_tokenizer.SaGeVocabBuilder import SaGeVocabBuilder 6 | 7 | 8 | def load_args(): 9 | parser = argparse.ArgumentParser(description="Optimized implementation of SaGe method") 10 | parser.add_argument("experiment_name", 11 | help="name of experiment, will save results under that name.") 12 | parser.add_argument("--corpus_filepath", required=True, 13 | help="filepath for full corpus (e.g. wiki corpus)") 14 | parser.add_argument("--initial_vocabulary_filepath", required=True, 15 | help="initial vocabulary, hex formatted, one vocab word per line") 16 | parser.add_argument("--vocabulary_schedule", nargs="+", type=int, required=True, 17 | help="what vocabulary sizes are we aiming for. Tokenization won't be done on the last value") 18 | parser.add_argument("--embeddings_schedule", nargs="+", type=int, required=True, 19 | help="from vocabulary_schedule, in which steps we should re-run embeddings") 20 | parser.add_argument("--partial_corpus_filepath", default="", 21 | help="where to create / load partial corpus file. " 22 | "Default is empty string for creating partial corpus under 'data' folder") 23 | parser.add_argument("--partial_corpus_line_number", type=int, default=1000, 24 | help="number of lines for partial corpus - in thousands. Default is 1000") 25 | parser.add_argument("--max_len", type=int, default=16, 26 | help="max length of tokens in bytes. Default is 16") 27 | parser.add_argument("--workers", type=int, default=1, 28 | help="number of worker threads to use. Default is 1") 29 | parser.add_argument("--random_seed", type=int, default=692653, 30 | help="random seed value. Default is 692653") 31 | 32 | # word2vec params 33 | parser.add_argument("--word2vec_D", type=int, default=50, 34 | help="word2vec embedding vector length. Default is 50") 35 | parser.add_argument("--word2vec_N", type=int, default=15, 36 | help="word2vec number of negative samples. Default is 15") 37 | parser.add_argument("--word2vec_ALPHA", type=float, default=0.025, 38 | help="word2vec Initial learning rate. Default is 0.025") 39 | parser.add_argument("--word2vec_window_size", type=int, default=5, 40 | help="word2vec context window size. Default is 5") 41 | parser.add_argument("--word2vec_min_count", type=int, default=1, 42 | help="word2vec minimum count of word. Default is 1, i.e. must be used at least once") 43 | parser.add_argument("--word2vec_sg", type=int, default=1, 44 | help="word2vec skip-gram if 1; otherwise CBOW. Default is 1") 45 | 46 | return vars(parser.parse_args()) 47 | 48 | 49 | if __name__ == '__main__': 50 | args = load_args() 51 | vocab_builder = SaGeVocabBuilder( 52 | args['vocabulary_schedule'], 53 | args['embeddings_schedule'], 54 | args['max_len'], 55 | args['workers'], 56 | args['random_seed'], 57 | args['word2vec_D'], 58 | args['word2vec_N'], 59 | args['word2vec_ALPHA'], 60 | args['word2vec_window_size'], 61 | args['word2vec_min_count'], 62 | args['word2vec_sg'] 63 | ) 64 | 65 | vocab_builder.build_vocab( 66 | args['experiment_name'], 67 | args['corpus_filepath'], 68 | args['initial_vocabulary_filepath'], 69 | args['partial_corpus_filepath'], 70 | args['partial_corpus_line_number'], 71 | ) 72 | -------------------------------------------------------------------------------- /sage_v1/README.md: -------------------------------------------------------------------------------- 1 | # SaGe 2 | Code for the SaGe subword tokenizer ([EACL 2023](https://aclanthology.org/2023.eacl-main.45/)). Downstream applications of the tokenizer, i.e. pre-training an LLM model and evaluating on benchmarks, are independent of the tokenizer code - in the paper we used [academic budget BERT](https://github.com/IntelLabs/academic-budget-bert). 3 | 4 | 5 | ## Requirements 6 | 1. `sentencepiece` (version 0.1.95) 7 | 2. `gensim` 8 | 3. a prepared corpus - see `Dataset` section. 9 | 10 | This code does not require GPU\TPU. 11 | 12 | ## Dataset 13 | In the paper, we used wikipedia latest dumps - `wget https://dumps.wikimedia.org/wiki/latest/wiki-latest-pages-articles.xml.bz2`, where `` should be the language code (`en`, `tr`, etc.). 14 | 15 | You can use any other dataset instead, the format for the script is one file with lines of raw text. 16 | 17 | ## Notes 18 | This script can be re-executed "from checkpoint" - 19 | The vocabulary creation script saves several files ("checkpoints") to be able to later continue - for example it saves the seed, the embeddings and sentencepiece models, and even a list of tokens sorted according to the Skipgram objective. 20 | 21 | ## Execution 22 | Execute `Main.py` from its working directory. 23 | The command line parameters are: 24 | 1. `experiment_name`: positional first parameter. A unique name for the experiment, results will be saved under that name (in the `results` directory). 25 | 26 | Required arguments: 27 | ``` 28 | --final_vocab_size: expected final vocabulary size. 29 | --initial_vocab_size: initial vocabulary size, from which ablation start. 30 | --tokens_to_prune_in_iteration: number of tokens to prune in each iteration ($k$ in paper). 31 | --tokens_to_consider_in_iteration: number of tokens to consider in each iteration ($M$ in paper). 32 | --iterations_until_reranking: number of iterations until reranking ($m$ in paper). 33 | --corpus_filepath: filepath for the full corpus (e.g. wiki corpus). 34 | --partial_corpus_filepath: where to create a partial corpus file in case of thousands_of_corpus_lines argument supplied. 35 | ``` 36 | 37 | Default override arguments: 38 | ``` 39 | --is_continue: is this execution continuing a former execiution of the same experiment: [Y/N]. default="N". 40 | --thousands_of_corpus_lines: number of corpus lines - in thousands. default=200. 41 | --max_lines_per_token: max number of lines to consider for each token in the objective calculation (not mentioned in the paper, affects a small portion of the vocabulary responsible for many unnecessary calculations). default=1000. 42 | --window_size: window size for the Skipgram objective calculation, as well as for the embeddings calculation. default=5. 43 | ``` 44 | 45 | Example: 46 | ``` 47 | python Main.py \ 48 | exp_name \ 49 | --final_vocab_size 16000 \ 50 | --initial_vocab_size 20000 \ 51 | --tokens_to_prune_in_iteration 50 \ 52 | --tokens_to_consider_in_iteration 2000 \ 53 | --iterations_until_reranking 15 \ 54 | --corpus_filepath data/wiki_lines.txt \ 55 | --partial_corpus_filepath data/wiki_lines_partial.txt \ 56 | --thousands_of_corpus_lines 2000 57 | ``` 58 | 59 | To re-execute from a checkpoint, just execute the same command with `--is_continue Y`. The script will then search for those files under the `results/exp_name` directory. 60 | 61 | ## Citation 62 | ``` 63 | @inproceedings{yehezkel-pinter-2023-incorporating, 64 | title = "Incorporating Context into Subword Vocabularies", 65 | author = "Yehezkel, Shaked and 66 | Pinter, Yuval", 67 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics", 68 | month = may, 69 | year = "2023", 70 | address = "Dubrovnik, Croatia", 71 | publisher = "Association for Computational Linguistics", 72 | url = "https://aclanthology.org/2023.eacl-main.45", 73 | pages = "623--635", 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /src/sage_tokenizer/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | 3 | import time 4 | import logging 5 | 6 | import gensim.models 7 | import numpy as np 8 | 9 | from typing import List 10 | from pathlib import Path 11 | 12 | # class CorpusIterator: 13 | # def __init__(self, model, corpus_filepath): 14 | # self._model = model 15 | # self.corpus_filepath = corpus_filepath 16 | 17 | # def __iter__(self): 18 | # with open(self.corpus_filepath) as f: 19 | # corpus_data = f.readlines() 20 | # for line in corpus_data: 21 | # # convert bytes to tokens in encoded string form, for gensim 22 | # yield self._model.tokenize_to_encoded_str(bytes(line, 'utf-8')) 23 | from .Word2VecParams import Word2VecParams 24 | from .model import SaGeTokenizer 25 | from .paths import getDataFolder 26 | 27 | 28 | def get_embeddings(vocab_size: int, embeddings_folder: Path, partial_corpus: List[str], sage_model: SaGeTokenizer, workers_number: int, word2vec_params: Word2VecParams) -> np.ndarray: 29 | logging.info(f"training Embeddings at vocab size {vocab_size}") 30 | # is there an embedding of this size 31 | embeddings_filepath = Path(embeddings_folder) / f"embeddings_{vocab_size}.npy" 32 | if embeddings_filepath.exists(): 33 | logging.info(f"Found trained embeddings. Loading it from {embeddings_filepath.as_posix()}") 34 | # context and target embeddings are the same so just keep one copy around 35 | embeddings = np.load(embeddings_filepath.as_posix()) 36 | else: 37 | logging.info(f"Start training embeddings with Word2Vec...") 38 | start_time = time.time() 39 | embeddings = train_embeddings(sage_model, partial_corpus, workers_number, word2vec_params) 40 | logging.info(f"Embeddings time: {time.time() - start_time}") 41 | logging.info(f"Save embeddings to {embeddings_filepath.as_posix()}") 42 | np.save(embeddings_filepath.as_posix(), embeddings, allow_pickle=True) 43 | return embeddings 44 | 45 | 46 | def train_embeddings(sage_model: SaGeTokenizer, partial_corpus: List[str], workers: int, word2vec_params: Word2VecParams) -> np.ndarray: 47 | # sentences = CorpusIterator(model, corpus_filepath) 48 | 49 | # also save in a version of this with a sentence per line, whitespace per token 50 | # which gensim's word2vec wants to process things in parallel 51 | # see https://github.com/RaRe-Technologies/gensim/releases/tag/3.6.0 52 | # and https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/Any2Vec_Filebased.ipynb 53 | gensim_corpus_filepath = getDataFolder() / f"gensim_{sage_model.vocab_size()}.txt" 54 | 55 | # if already exists, use otherwise create the file 56 | if gensim_corpus_filepath.exists(): 57 | logging.info(f"Gensim format data file already exists: {gensim_corpus_filepath.as_posix()}") 58 | else: 59 | gensim_start = time.time() 60 | logging.info(f"starting tokenization of {len(partial_corpus)} lines for gensim") 61 | with open(gensim_corpus_filepath, "w", encoding="utf-8") as gensim_file: 62 | for i, line in enumerate(partial_corpus): 63 | if i % 1_000_000 == 0: 64 | logging.info(f"tokenizing line {i}, time: {(time.time() - gensim_start):.2f}") 65 | gensim_file.write(" ".join(sage_model.tokenize_to_encoded_str(bytes(line, 'utf-8'))) + "\n") 66 | logging.info(f"Gensim format data written: {gensim_corpus_filepath.as_posix()}, time: {(time.time()-gensim_start):.2f}") 67 | 68 | word2vec_model = gensim.models.Word2Vec(corpus_file=gensim_corpus_filepath.as_posix(), 69 | vector_size=word2vec_params.D, 70 | negative=word2vec_params.N, 71 | alpha=word2vec_params.ALPHA, 72 | window=word2vec_params.window_size, 73 | min_count=word2vec_params.min_count, 74 | sg=word2vec_params.sg, 75 | workers=workers) 76 | 77 | embeddings = np.zeros(shape=(sage_model.vocab_size(), word2vec_params.D)) 78 | 79 | for idx, token in sage_model.inv_str_vocab.items(): 80 | if token in word2vec_model.wv.key_to_index.keys(): 81 | embeddings[idx] = word2vec_model.wv[token] 82 | else: 83 | # some may not have made the min_count value, so will be missing 84 | # Embeddings not found for this token. Assign a random vector 85 | # doing this the same way as the old SaGe code 86 | embeddings[idx] = np.random.uniform(low=-0.5/word2vec_params.D, high=0.5/word2vec_params.D, size=(1, word2vec_params.D)) 87 | 88 | # just return one copy that we'll use for both context and target embeddings 89 | return embeddings 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SaGe 2.0 2 | Version 2.0 for the SaGe subword tokenizer ([EACL 2023](https://aclanthology.org/2023.eacl-main.45/)), excelling in [morphological segmentation](https://aclanthology.org/2024.acl-short.73/). Downstream applications of the tokenizer, i.e. pre-training an LLM model and evaluating on benchmarks, are independent of the tokenizer code - in the paper we used [academic budget BERT](https://github.com/IntelLabs/academic-budget-bert). 3 | 4 | Pre-trained SaGe-based models are available in [this](https://github.com/kensho-technologies/timtc_vocabs_models) repository. 5 | The large versions (2.4B params) produced the best results over BPE, UnigramLM, and PathPiece---see Table 14 in the Appendix [here](https://aclanthology.org/2024.emnlp-main.40/). 6 | 7 | SaGe 2.0 implements a faster, parallelizable version of the vocabulary learning algorithm. 8 | 9 | ```python 10 | from sage_tokenizer.SaGeVocabBuilder import SaGeVocabBuilder 11 | vocab_builder = SaGeVocabBuilder(full_vocab_schedule=[262144, 229376, 196608, 163840, 131072, 98304, 65536, 57344, 49152, 40960, 32768, 16384], 12 | embeddings_schedule=[262144, 131072, 65536, 49152, 40960, 32768], 13 | workers_number=4) 14 | 15 | vocab_builder.build_vocab(experiment_name='experiment_name', 16 | corpus_filepath='data/wiki_lines.txt', 17 | vocabulary_filepath='data/initial_vocab_hex.vocab') 18 | ``` 19 | The `.vocab` file can then be loaded as-is into most tokenization toolkits, such as Huggingface's `tokenizers`. 20 | 21 | SaGe tokenizer can be installed from PyPI: 22 | ``` 23 | pip install sage-tokenizer 24 | ``` 25 | 26 | ## Requirements 27 | 1. `gensim==4.3.2` 28 | 2. `scipy==1.12.0` 29 | 3. a prepared corpus - see `Dataset` section. 30 | 4. an initial vocabulary - see `Dataset` section. 31 | 32 | This code does not require GPU\TPU. 33 | 34 | ## Dataset 35 | In the paper, we used wikipedia latest dumps - `wget https://dumps.wikimedia.org/wiki/latest/wiki-latest-pages-articles.xml.bz2`, where `` should be the language code (`en`, `tr`, etc.). 36 | We used them to create the corpus. From that corpus we use BPE tokenizer to create our initial vocabulary. 37 | 38 | You can use any other dataset instead. 39 | The expected format for the corpus is one file with lines of raw text. 40 | The expected format for the initial vocabulary is one vocab word per line, hex formatted. 41 | 42 | ## Notes 43 | This script can be re-executed "from checkpoint" - 44 | The vocabulary creation script saves several files ("checkpoints") to be able to later continue - for example it saves the partial corpus used, seed, the embeddings, and even a list of tokens sorted according to the Skipgram objective. 45 | 46 | ## Arguments 47 | 48 | Required arguments: 49 | 50 | - `experiment_name`: Positional first parameter - a unique name for the experiment. Results will be saved under that name (in the `results` directory). 51 | - `corpus_filepath`: filepath for the full corpus (e.g. wiki corpus). Format is lines of raw text. A random subset from this file is used to create the partial corpus, which serves as the actual corpus for training. 52 | - `initial_vocabulary_filepath`: initial vocabulary, hex formatted, one vocab word per line. 53 | - `vocabulary_schedule`: what vocabulary sizes are we aiming for. **Note:** Tokenization won't be done for the last vocab size. 54 | - `embeddings_schedule`: from vocabulary_schedule, in which steps we should re-run embeddings (similar to *l* in paper). 55 | 56 | Default override arguments: 57 | 58 | - `partial_corpus_filepath`: where to create / load partial corpus file. Default is `''` for creating partial corpus under 'data' folder. The partial corpus is a random subset of the full corpus and serves as the actual corpus used for training. 59 | - `partial_corpus_line_number`: number of lines for partial corpus - in thousands. Default is `1000`. 60 | - `max_len`: max length of tokens in bytes. Default is `16`. 61 | - `workers`: number of worker threads to use. Default is `1`. 62 | - `random_seed`: random seed value. Default is `692653`. 63 | 64 | - **word2vec arguments:** 65 | - `word2vec_D`: word2vec embedding vector length. Default is `50` 66 | - `word2vec_N`: word2vec number of negative samples. Default is `15` 67 | - `word2vec_ALPHA`: word2vec Initial learning rate. Default is `0.025` 68 | - `word2vec_window_size`: word2vec context window size. Default is `5` 69 | - `word2vec_min_count`: word2vec minimum count of word. Default is `1`, i.e. must be used at least once 70 | - `word2vec_sg`: word2vec skip-gram if 1; otherwise CBOW. Default is `1` 71 | 72 | For execution via command line, run `main.py` from its working directory. 73 | 74 | Example: 75 | ``` 76 | python main.py \ 77 | exp_name \ 78 | --corpus_filepath data/wiki_lines.txt \ 79 | --initial_vocabulary_filepath data/initial_vocab_hex.vocab \ 80 | --vocabulary_schedule 262144 229376 196608 163840 131072 98304 65536 57344 49152 40960 32768 16384 \ 81 | --embeddings_schedule 262144 131072 65536 49152 40960 32768 \ 82 | --partial_corpus_filepath data/wiki_lines_partial.txt \ 83 | --partial_corpus_line_number 500 \ 84 | --max_len 17 \ 85 | --workers 4 \ 86 | --random_seed 692653 87 | ``` 88 | 89 | ### API differences from **SaGe 1.0**: 90 | - **Argument `final_vocab_size` is obsolete**, and replaced by the last value of `vocabulary_schedule`. 91 | - In **SaGe 1.0** the algorithm started with running *BPE* to create an initial vocab in a desired size (*V×n* in paper). 92 | Now, the algorithm accepts an already created vocabulary file as input, **making argument `initial_vocab_size` obsolete**. 93 | - In **SaGe 1.0**, after the initial vocabulary created, we iteratively ablated constant number of tokens (*k* in paper), 94 | until the final vocab size (*V*) was reached. 95 | Now, the user can directly choose the intermediate vocabulary sizes, thereby defining the ablation schedule (i.e., the difference between each adjacent vocab sizes), **making the ablation size dynamic and `tokens_to_prune_in_iteration` obsolete**. 96 | - Due to performance improvement, reranking happens every iteration and all tokens are considered for ablation, **making arguments `iterations_until_reranking` (*m* in paper) and `tokens_to_consider_in_iteration` (*M* in paper) obsolete**. 97 | - In order to re-execute from a checkpoint, just execute the same command. By default, the script searches for already existing files under `results/exp_name` directory. **Argument `is_continue` is obsolete**. 98 | - Argument `thousands_of_corpus_lines` name changed to `partial_corpus_line_number`. 99 | - Argument `max_lines_per_token` is obsolete, all lines will be considered for each token in the objective calculation. 100 | 101 | ## Citation 102 | 103 | Version 2.0 was mostly developed by Kensho Technologies, LLC, and ported by Bar Gazit. Citation TBA. 104 | 105 | Version 1.0 was developed by Shaked Yehezkel and Yuval Pinter, please use this citation: 106 | ``` 107 | @inproceedings{yehezkel-pinter-2023-incorporating, 108 | title = "Incorporating Context into Subword Vocabularies", 109 | author = "Yehezkel, Shaked and 110 | Pinter, Yuval", 111 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics", 112 | month = may, 113 | year = "2023", 114 | address = "Dubrovnik, Croatia", 115 | publisher = "Association for Computational Linguistics", 116 | url = "https://aclanthology.org/2023.eacl-main.45", 117 | pages = "623--635", 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /src/sage_tokenizer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | 3 | from typing import List, Dict, Union, Tuple 4 | 5 | import numpy as np 6 | 7 | from .HFEncoding import HFEncoding 8 | 9 | 10 | Tokenizable = Union[str,bytes] 11 | 12 | 13 | class SaGeTokenizer: 14 | 15 | def __init__(self, initial_vocabulary, max_len: int=16): 16 | self.hfe = HFEncoding() 17 | self.byte_vocab: Dict[bytes, int] = None 18 | self.inv_byte_vocab: Dict[int, bytes] = None 19 | self.str_vocab: Dict[str, int] = None 20 | self.inv_str_vocab: Dict[int, str] = None 21 | 22 | self.set_vocabulary(initial_vocabulary) 23 | self.max_len = max_len 24 | 25 | # given an order list of bytes for vocabulary 26 | # create our internal structures 27 | # overwriting any previous values 28 | def set_vocabulary(self, new_vocab: List[bytes]): 29 | self.byte_vocab = self.set_bytes_vocab(new_vocab) 30 | 31 | # make sure we always have all single bytes in vocabulary 32 | verify_all_single_byte_exist_in_vocab(self.byte_vocab) 33 | 34 | # inverted map of inv_byte_vocab (int index : bytes) 35 | self.inv_byte_vocab = {v: k for (k, v) in self.byte_vocab.items()} 36 | 37 | # encoded str : int index 38 | # convert bytes to our encoded form for keys 39 | self.str_vocab = {self.hfe.to_encoded(k): v for (k, v) in self.byte_vocab.items()} 40 | # int index : encoded str 41 | self.inv_str_vocab = {v: k for (k, v) in self.str_vocab.items()} 42 | 43 | @staticmethod 44 | def set_bytes_vocab(new_vocab: List[bytes]) -> Dict[bytes, int]: 45 | # bytes : int index 46 | byte_vocab = {} 47 | for idx, token in enumerate(new_vocab): 48 | # token should have been converted to bytes 49 | assert type(token) == bytes 50 | byte_vocab[token] = idx 51 | return byte_vocab 52 | 53 | def id_to_bytes(self, token_id: int) -> bytes: 54 | return self.inv_byte_vocab[token_id] 55 | 56 | def id_to_encoded(self, token_id: int) -> str: 57 | return self.inv_str_vocab[token_id] 58 | 59 | def get_vocabulary(self) -> Dict[bytes, int]: 60 | return self.byte_vocab 61 | 62 | def vocab_size(self) -> int: 63 | return len(self.byte_vocab) 64 | 65 | def print_tokens(self, ids: List[int]) -> List[bytes]: 66 | """ 67 | Human readable for debugging 68 | """ 69 | return [self.inv_byte_vocab[i] for i in ids] 70 | 71 | def add_all_byte_ids(self, vocab: Dict[int,float], score: float=1e400): 72 | """ 73 | For each single byte, look up its id, then assign the given score to that id in the given dictionary. 74 | """ 75 | for i in range(256): 76 | # what is the corresponding token id 77 | tid = self.byte_vocab[bytes([i])] 78 | # add that with a "good" score 79 | vocab[tid] = score 80 | 81 | def tokenize(self, sent: Tokenizable, tokens_only: bool=False) -> Union[List[int], List[Tuple[int, int, int]]]: 82 | """ 83 | Split the gives sentence into tokens and convert them to IDs. 84 | """ 85 | if isinstance(sent, str): 86 | sent = bytes(sent, encoding='utf-8') 87 | data = [] 88 | i = 0 89 | while i < len(sent): # Iterate through the sentence input 90 | for j in range(self.max_len, 0, -1): # Find the longest possible token 91 | tok = sent[i:i + j] 92 | if tok in self.byte_vocab: 93 | if tokens_only: 94 | # Only add token_id to results 95 | data.append(self.byte_vocab[tok]) 96 | else: 97 | # Add (token_id, token_start_idx, token_width) to results 98 | data.append((self.byte_vocab[tok], i, len(tok))) 99 | i += j # advance to next token 100 | break # the for loop 101 | return data 102 | 103 | def tokenize_to_encoded_str(self, sent: Tokenizable) -> List[str]: 104 | """ 105 | Return the tokenization as tokens in encoded str form 106 | """ 107 | return [self.inv_str_vocab[token_id] for token_id in self.tokenize(sent, tokens_only=True)] 108 | 109 | def tokenize_to_bytes(self, sent: Tokenizable) -> List[bytes]: 110 | """ 111 | Same as tokenize_to_encoded_strbut return byte form 112 | """ 113 | return [self.inv_byte_vocab[token_id] for token_id in self.tokenize(sent, tokens_only=True)] 114 | 115 | @staticmethod 116 | def do_triples(combined, pad: int, padleft, padright, cur_id, sign, triples): 117 | """ 118 | Add the appropriate (t,v,v') triples to our dictionary where t, v, and v' are all int indices. 119 | """ 120 | # where the right padding starts 121 | right_ind = len(combined) - padright 122 | 123 | # iterate over the targets 124 | # note that the padding elements now have different contexts in 125 | # center section, so need to let them be targets too 126 | for t, target in enumerate(combined): 127 | # the contexts, need pad here not padleft or padright, 128 | # since some context may be within the combined 129 | for c in range(t - pad, t + pad + 1): 130 | # context is in range and distinct from target 131 | # ignore the case where both c and t are in padding since that cancels 132 | if c >= 0 and c != t and c < len(combined) and \ 133 | ((c >= padleft and c < right_ind) or (t >= padleft and t < right_ind)): 134 | trip = (cur_id, target, combined[c]) 135 | # add sign to the triples 136 | triples[trip] = triples.get(trip, 0) + sign 137 | 138 | def fast_sage(self, sent: Tokenizable, triples, ablated_sizes, pad: int=2, verbose: bool=False) -> int: 139 | """ 140 | Tokenize the sentence `sent`, add to the counts in the triples dict tracking the (cur_id,t,c) for the ablated 141 | token cur_id, with target token t and context token c. Also updates the statistics in ablated_sizes. 142 | Returns the total_tokens from tokenizing `sent`. 143 | """ 144 | n = len(sent) 145 | 146 | # returns triples of (ids, start_index, width) 147 | values = self.tokenize(sent) 148 | ids, start_indices, widths = zip(*values) 149 | # if you use np.array here, remember to fix concatenation below 150 | 151 | # note, these are arrays over the tokens so len(values) < n 152 | total_tokens = len(values) 153 | 154 | # tuples to list 155 | ids = list(ids) 156 | start_indices = list(start_indices) 157 | widths = list(widths) 158 | 159 | max_len = 0 160 | 161 | # have a constant time lookup on whether we're at a token on the base tokenization 162 | # if >= 0, is the index of the token in ids or widths 163 | on_base = np.zeros(n, dtype=int) - 1 164 | for j, si in enumerate(start_indices): 165 | on_base[si] = j 166 | # now we can just produce our ablated tokenizations 167 | # quite efficiently 168 | for loc, (cur_id, start_index, width) in enumerate(values): 169 | # skip single bytes 170 | if width > 1: 171 | 172 | ablated_tokenization = [] 173 | 174 | # find the next token with width-1 or less 175 | # starting at start_index 176 | i = start_index 177 | for j in range(width - 1, 0, -1): 178 | tok = sent[i:i + j] 179 | if tok in self.byte_vocab: 180 | ablated_tokenization.append(self.byte_vocab[tok]) # keep the ids 181 | i += j # advance to next token 182 | break # the for loop 183 | 184 | # now extend as normal until we get back on the old path 185 | while i < n: 186 | for j in range(min(self.max_len, n - i), 0, -1): 187 | tok = sent[i:i + j] 188 | if tok in self.byte_vocab: 189 | ablated_tokenization.append(self.byte_vocab[tok]) 190 | i += j # advance to next token 191 | break # the for loop 192 | 193 | # we never got back on the path, so set beyond to n 194 | if i >= n: 195 | beyond = n 196 | break 197 | 198 | # we get to a spot on the current longest path 199 | # we're back to the old tokenization, set beyond accordingly 200 | if on_base[i] != -1: 201 | beyond = on_base[i] 202 | break 203 | 204 | if verbose: 205 | print(self.print_tokens(ablated_tokenization)) 206 | 207 | # track how many tokens were required for the ablation 208 | lat = len(ablated_tokenization) 209 | ablated_sizes[lat] = ablated_sizes.get(lat, 0) + 1 210 | max_len = max(max_len, lat) 211 | 212 | # note: on_base[i] is one beyond the last difference 213 | base_tok = ids[loc:beyond] 214 | if verbose: 215 | print(self.print_tokens(base_tok)) 216 | 217 | # can we do any padding on left or right 218 | padleft = min(pad, loc) 219 | padright = min(pad, len(values) - beyond) 220 | left_pad = ids[loc - padleft:loc] 221 | # print(print_tokens(left_pad)) 222 | # note: beyond is one beyond the last difference 223 | right_pad = ids[beyond:beyond + padright] 224 | # print(print_tokens(right_pad)) 225 | 226 | # combine with the padding, and work out the context triples 227 | combined_ab = left_pad + ablated_tokenization + right_pad 228 | self.do_triples(combined_ab, pad, padleft, padright, cur_id, 1, triples) 229 | 230 | # and same for the base tokenization 231 | combined_base = left_pad + base_tok + right_pad 232 | self.do_triples(combined_base, pad, padleft, padright, cur_id, -1, triples) 233 | 234 | if verbose: 235 | print("base:", self.print_tokens(left_pad), self.print_tokens(base_tok), 236 | self.print_tokens(right_pad)) 237 | print("ab: ", self.print_tokens(left_pad), self.print_tokens(ablated_tokenization), 238 | self.print_tokens(right_pad)) 239 | print("comb base:", self.print_tokens(combined_base)) 240 | print("comb ab:", self.print_tokens(combined_ab)) 241 | print() 242 | 243 | # log some of these 244 | if max_len > 200: 245 | # remember to convert from bytes 246 | print("long max_len:", max_len, '"' + sent.decode('utf-8') + '"') 247 | 248 | return total_tokens 249 | 250 | 251 | def verify_all_single_byte_exist_in_vocab(vocab): 252 | for i in range(256): 253 | b = bytes([i]) 254 | if b not in vocab: 255 | raise Exception(f"missing byte {b}") 256 | -------------------------------------------------------------------------------- /sage_v1/Python-Modules/SG_BPE.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import random 4 | import sentencepiece as spm 5 | import math 6 | import numpy as np 7 | import Utils 8 | import multiprocessing as mp 9 | import os 10 | 11 | class SG_BPE_Models: 12 | def __init__(self, experiment_name, is_continue_execution, final_vocab_size, initial_vocab_size, partial_corpus_filepath): 13 | # SG-BPE model - this we are gonna update 14 | sg_bpe_model_prefix = "results/{}/sg_bpe".format(experiment_name) 15 | sg_bpe_model = sg_bpe_model_prefix + ".model" 16 | 17 | if not is_continue_execution: 18 | spm.SentencePieceTrainer.train(input=partial_corpus_filepath, model_prefix=sg_bpe_model_prefix, vocab_size=initial_vocab_size, model_type="bpe") 19 | self._sg_bpe_model = spm.SentencePieceProcessor(model_file=sg_bpe_model) 20 | 21 | # Vanilla BPE model - for comparisons 22 | vanilla_bpe_model_prefix = "results/{}/bpe_vanilla".format(experiment_name) 23 | vanilla_bpe_model = vanilla_bpe_model_prefix + ".model" 24 | 25 | if not is_continue_execution: 26 | spm.SentencePieceTrainer.train(input=partial_corpus_filepath, model_prefix=vanilla_bpe_model_prefix, vocab_size=final_vocab_size, model_type="bpe") 27 | self._bpe_vanilla_model = spm.SentencePieceProcessor(model_file=vanilla_bpe_model) 28 | 29 | def get_sg_bpe_model(self): 30 | return self._sg_bpe_model 31 | 32 | def get_bpe_vanilla_model(self): 33 | return self._bpe_vanilla_model 34 | 35 | 36 | class Model: 37 | def __init__(self, experiment_name, log, model, model_name, \ 38 | target_embeddings, context_embeddings, corpus_lines, max_lines_per_token, window_size, is_continue_execution=False, vocab_filepath=False): 39 | 40 | self._experiment_name = experiment_name 41 | self._model = model 42 | self._model_name = model_name 43 | self._target_embeddings = target_embeddings 44 | self._context_embeddings = context_embeddings 45 | self._log = log 46 | self._current_vocab = None 47 | self._should_compute_vocab = True 48 | self._corpus_lines = corpus_lines 49 | self._max_lines_per_token = max_lines_per_token 50 | self._window_size = window_size 51 | 52 | if is_continue_execution: 53 | with open(vocab_filepath + ".bin", "rb") as vocab_file: 54 | current_bpe_vocab = pickle.load(vocab_file) 55 | self.set_vocab(current_bpe_vocab) 56 | 57 | def initialize_encoded_form_for_corpus_lines(self): 58 | model_encoded_coprus_lines_token_ids = [] 59 | model_encoded_coprus_lines_token_pieces = [] 60 | for line in self._corpus_lines: 61 | tokens_in_line_ints = self._model.encode(line, out_type=int) 62 | tokens_in_line_pieces = [self._model.id_to_piece(x) for x in tokens_in_line_ints] 63 | 64 | model_encoded_coprus_lines_token_ids.append(tokens_in_line_ints) 65 | model_encoded_coprus_lines_token_pieces.append(tokens_in_line_pieces) 66 | 67 | self._model_encoded_corpus_lines_token_ids = model_encoded_coprus_lines_token_ids 68 | self._model_encoded_corpus_lines_token_pieces = model_encoded_coprus_lines_token_pieces 69 | 70 | def initialize_token_to_line_indices_dictionary(self, current_vocab, corpus_lines, experiment_name, is_continue_execution): 71 | token_to_line_indices_dict_filepath = "results/{}/token_to_line_indices_dict.bin".format(experiment_name) 72 | if is_continue_execution and os.path.exists(token_to_line_indices_dict_filepath): 73 | with open(token_to_line_indices_dict_filepath, "rb") as token_to_line_indices_dict_file: 74 | self._token_to_line_indices_dict = pickle.load(token_to_line_indices_dict_file) 75 | else: 76 | self._token_to_line_indices_dict = Utils.token_to_line_indices_dictionary(current_vocab, corpus_lines) 77 | with open(token_to_line_indices_dict_filepath, "wb") as token_to_line_indices_dict_file: 78 | pickle.dump(self._token_to_line_indices_dict, token_to_line_indices_dict_file) 79 | 80 | def update_encoded_form_for_corpus_lines(self, tokens_pruned): 81 | for token in tokens_pruned: 82 | line_indices_with_token = self._token_to_line_indices_dict[token] 83 | for index in line_indices_with_token: 84 | current_line = self._corpus_lines[index] 85 | self._model_encoded_corpus_lines_token_ids[index] = self._model.encode(current_line, out_type=int) 86 | self._model_encoded_corpus_lines_token_pieces[index] = [self._model.id_to_piece(x) for x in self._model_encoded_corpus_lines_token_ids[index]] 87 | 88 | # We use the "get_current_vocab" method because of sentencepiece tricky way to remove tokens from the vocabulary. 89 | def get_current_vocab(self): 90 | if not self._should_compute_vocab: 91 | return self._current_vocab 92 | 93 | model_vocab = [] 94 | for i in range(self._model.vocab_size()): 95 | model_vocab.append(self._model.id_to_piece(i)) 96 | 97 | return model_vocab 98 | 99 | def log_experiments_model_results(self, training_filepath, model_name_override=None): 100 | if model_name_override: 101 | model_name = model_name_override 102 | else: 103 | model_name = self._model_name 104 | 105 | ## Log model vocabulary 106 | model_vocab = self.get_current_vocab() 107 | model_vocab_filepath = ("./results/{}/" + model_name + "_vocab.txt").format(self._experiment_name) 108 | with open(model_vocab_filepath, "w+") as model_vocab_results_file: 109 | model_vocab_results_file.write(json.dumps(model_vocab, indent=4)) 110 | 111 | ## Log encoding of input file with model vocabulary 112 | with open(training_filepath) as input_file: 113 | input_file_data = input_file.read() 114 | 115 | encoded_data = [self._model.id_to_piece(x) for x in self._model.encode(input_file_data)] 116 | model_encoding_filepath = ("./results/{}/" + model_name + "_encoding.txt").format(self._experiment_name) 117 | with open(model_encoding_filepath, "w+") as encoding_results_file: 118 | encoding_results_file.write(' '.join(encoded_data)) 119 | 120 | def sg_for_window(self, target_token, window): 121 | current_p = 0 122 | for w in window: 123 | # calculate current value to add 124 | dot_product = np.dot(self._target_embeddings[target_token], self._context_embeddings[w]) 125 | try: 126 | current_p += math.log(Utils.sigmoid(dot_product)) 127 | except: 128 | pass 129 | 130 | return (-1) * current_p 131 | 132 | def token_context_sg_log_prob(self, token_int, i, tokens_in_line_ints): 133 | # get context window tokens 134 | window, _, _ = Utils.compute_window(i, tokens_in_line_ints, self._window_size) 135 | return self.sg_for_window(token_int, window) 136 | 137 | def total_sg_log_prob(self, training_filepath): 138 | # getting lines from corpus 139 | with open(training_filepath, "r") as training_file: 140 | corpus_lines = training_file.readlines() 141 | 142 | p = 0 143 | for line in corpus_lines: 144 | current_p = 0 145 | tokens_in_line_ints = self._model.encode(line, out_type=int) 146 | for i, token_int in enumerate(tokens_in_line_ints): 147 | current_p += self.token_context_sg_log_prob(token_int, i, tokens_in_line_ints) 148 | 149 | p += current_p 150 | 151 | return p 152 | 153 | # Compute total_sg_log_prob without each token 154 | # -- This is executed once in iteration -- 155 | # Multi Processing version of this method ########### 156 | def get_sg_log_prob_without_tokens_mp(self, current_total_sg, training_filepath, nat_list=None, dict_of_top_tokens=None): 157 | token_and_sg_log_prob_without_it = {} 158 | current_vocab = self.get_current_vocab() 159 | 160 | if dict_of_top_tokens: 161 | current_vocab = dict_of_top_tokens 162 | 163 | print("\nCurrent vocab len - {}".format(len(current_vocab))) 164 | 165 | if nat_list: 166 | current_nat_list = Utils.get_not_ablateable_tokens_list(current_vocab) 167 | current_vocab = [t for t in current_vocab if t not in current_nat_list] 168 | 169 | process_pool = mp.Pool(mp.cpu_count()) 170 | params = [(self._model, token, current_total_sg, \ 171 | current_vocab, training_filepath, \ 172 | self._target_embeddings, self._context_embeddings, self._log, self._corpus_lines, self._window_size) for token in current_vocab] 173 | res = process_pool.starmap(Utils.sg_wo_token_mp, params) 174 | 175 | for r in res: 176 | token = r[0] 177 | current_sg_log_prob = r[1] 178 | token_and_sg_log_prob_without_it[token] = current_sg_log_prob 179 | 180 | return token_and_sg_log_prob_without_it 181 | 182 | # More optimized form of "get_sg_log_prob_without_tokens_mp" 183 | def get_sg_log_prob_without_tokens_mp2(self, current_total_sg, nat_list=None, dict_of_top_tokens=None): 184 | token_and_sg_log_prob_without_it = {} 185 | 186 | current_vocab = self.get_current_vocab() 187 | if dict_of_top_tokens: 188 | current_vocab = dict_of_top_tokens 189 | 190 | if nat_list: 191 | current_nat_list = Utils.get_not_ablateable_tokens_list(current_vocab) 192 | current_vocab = [t for t in current_vocab if t not in current_nat_list] 193 | 194 | process_pool = mp.Pool(mp.cpu_count()) 195 | lines_per_token = {} 196 | 197 | params = [] 198 | for token in current_vocab: 199 | lines_to_consider = self._token_to_line_indices_dict[token] 200 | if len(lines_to_consider) > self._max_lines_per_token: 201 | lines_to_consider = random.sample(lines_to_consider, self._max_lines_per_token) 202 | 203 | for line_index in lines_to_consider: 204 | params.append((self._model, line_index, \ 205 | self._corpus_lines, self._model_encoded_corpus_lines_token_ids, self._model_encoded_corpus_lines_token_pieces, \ 206 | token, current_vocab, \ 207 | self._target_embeddings, self._context_embeddings, self._log, self._window_size)) 208 | 209 | lines_per_token[token] = len(lines_to_consider) 210 | 211 | res = process_pool.starmap(Utils.get_diff_sg_wo_token_for_line, params) 212 | 213 | for r in res: 214 | token = r[0] 215 | sg_wo_diff = r[1] 216 | 217 | if token not in token_and_sg_log_prob_without_it.keys(): 218 | token_and_sg_log_prob_without_it[token] = current_total_sg 219 | 220 | token_and_sg_log_prob_without_it[token] += (float(sg_wo_diff) / lines_per_token[token]) 221 | 222 | return token_and_sg_log_prob_without_it 223 | 224 | def get_model(self): 225 | return self._model 226 | 227 | def set_vocab(self, new_vocabulary): 228 | # Note: this indeed changing the model vocab, but not model.vocab_size(). 229 | # if you want to test this - remove some token from the vocabulary, and compare encodings before and after removing. 230 | self._model.set_vocabulary(new_vocabulary) 231 | self._current_vocab = new_vocabulary 232 | self._should_compute_vocab = False 233 | 234 | -------------------------------------------------------------------------------- /src/sage_tokenizer/SaGeVocabBuilder.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | from typing import List, Union, Optional 3 | from pathlib import Path 4 | 5 | import logging 6 | 7 | from .Word2VecParams import Word2VecParams 8 | from .embeddings import get_embeddings 9 | from .model import SaGeTokenizer 10 | from .utils import init_logger, set_random_seed, load_vocab, get_output_folder, load_corpus, run_sage_parallel, \ 11 | save_sorted_losses, save_stats, write_vocab 12 | 13 | 14 | class SaGeVocabBuilder: 15 | 16 | def __init__(self, full_vocab_schedule: List[int], embeddings_schedule: List[int], 17 | max_len: int=16, workers_number: int=1, random_seed: int=692653, 18 | word2vec_d: int=50, word2vec_n: int=15, word2vec_alpha: float=0.025, word2vec_window_size: int=5, word2vec_min_count: int=1, word2vec_sg: bool=True): 19 | self.full_vocab_schedule = full_vocab_schedule 20 | self.embeddings_schedule = embeddings_schedule 21 | self.max_len = max_len 22 | self.workers_number = workers_number 23 | self.random_seed = random_seed 24 | self.word2vec_params = Word2VecParams( 25 | D=word2vec_d, 26 | N=word2vec_n, 27 | ALPHA=word2vec_alpha, 28 | window_size=word2vec_window_size, 29 | min_count=word2vec_min_count, 30 | sg=int(word2vec_sg) # 1 uses skip-gram, 0 uses CBoW. 31 | ) 32 | 33 | def build_vocab(self, experiment_name: str, 34 | corpus_filepath: Union[str,Path], vocabulary_filepath: Union[str,Path], 35 | partial_corpus_filepath: Optional[Union[str,Path]]=None, partial_corpus_line_number: int=1000): 36 | corpus_filepath = Path(corpus_filepath) 37 | vocabulary_filepath = Path(vocabulary_filepath) 38 | partial_corpus_filepath = Path(partial_corpus_filepath) if isinstance(partial_corpus_filepath, str) else None 39 | 40 | init_logger(experiment_name) 41 | logging.info(f"Start experiment {experiment_name}") 42 | logging.info(f"Process will use up to {self.workers_number} worker threads.") 43 | 44 | logging.info("Getting output directories") 45 | embeddings_folder, stats_folder, vocab_folder = get_output_folder(experiment_name) 46 | logging.info("Setting random seed") 47 | set_random_seed(experiment_name, self.random_seed) 48 | logging.info(f"Loading initial vocabulary from file {vocabulary_filepath.as_posix()}") 49 | byte_vocab = load_vocab(vocabulary_filepath) 50 | logging.info(f"Finished loading initial vocabulary. Vocabulary size: {len(byte_vocab)}") 51 | 52 | actual_max_len = max([len(v) for v in byte_vocab]) 53 | if self.max_len != actual_max_len: 54 | logging.warning(f"max_len parameter value {self.max_len} doesn't match actual max {actual_max_len}") 55 | 56 | logging.info("Initializing tokenizer") 57 | sage_model = SaGeTokenizer(byte_vocab, self.max_len) 58 | 59 | logging.info(f"Loading Corpus from {corpus_filepath.as_posix()}") 60 | partial_corpus = load_corpus(corpus_filepath, partial_corpus_filepath, partial_corpus_line_number) 61 | logging.info("Starting the training loop") 62 | vocab_schedule = self.full_vocab_schedule 63 | 64 | if not len(vocab_schedule) >= 2: 65 | raise Exception("Vocabulary schedule must contain more than 2 vocabulary sizes!") 66 | 67 | vocab_schedule.sort(reverse=True) # largest first 68 | logging.info(f"initial vocab_schedule is {vocab_schedule[0]} vs actual size {sage_model.vocab_size()}") 69 | 70 | embedding_sizes = set(self.embeddings_schedule) 71 | 72 | # initialize embeddings for first iteration 73 | embeddings = get_embeddings(vocab_schedule[0], embeddings_folder, partial_corpus, sage_model, 74 | self.workers_number, self.word2vec_params) 75 | 76 | # skipping the initial vocab size here 77 | i = 0 78 | # stop one before the end, since we do i+1, 79 | # so we'll make the vocab of that final size, but won't do the tokenization on it 80 | while i < len(vocab_schedule) - 1: 81 | current_step_vocab_size = vocab_schedule[i] # this will be the label used for files 82 | target_vocab_size = vocab_schedule[i + 1] 83 | actual_vocab_size = sage_model.vocab_size() 84 | logging.info(f"\nRound {i} - Start: " 85 | f"\n\tCurrent step vocabulary size: {current_step_vocab_size}, " 86 | f"\n\tTarget vocabulary size: {target_vocab_size}, " 87 | f"\n\tActual vocabulary size: {actual_vocab_size}") 88 | 89 | if vocab_schedule[i] in embedding_sizes: 90 | embeddings = get_embeddings(current_step_vocab_size, embeddings_folder, partial_corpus, sage_model, 91 | self.workers_number, self.word2vec_params) 92 | 93 | if actual_vocab_size <= target_vocab_size: 94 | logging.info(f"Actual vocab is already smaller than target. continue to next iteration ") 95 | i += 1 96 | continue 97 | 98 | # call sage in parallel 99 | logging.info(f"Sage started.") 100 | total_tokens, total_triples, token_to_losses, ablated_sizes = run_sage_parallel(embeddings, 101 | partial_corpus, 102 | sage_model, 103 | self.workers_number) 104 | logging.info(f"Sage finished. total tokens: {total_tokens}, total triplets: {total_triples}") 105 | 106 | # token_to_losses won't include any single byte tokens, but we want to keep those around 107 | # so lets just add them with large scores, so they stay around 108 | vocab_size_before_single_byte_tokens_addition = len(token_to_losses) 109 | sage_model.add_all_byte_ids(token_to_losses, score=1e6) 110 | logging.info(f"Adding single bytes to vocab. Size before: {vocab_size_before_single_byte_tokens_addition}, " 111 | f"size after: {len(token_to_losses)}") 112 | 113 | # if a token doesn't appear in token_to_losses then it didn't participate in the tokenization 114 | current_active_vocab_size = len(token_to_losses) 115 | current_inactive_vocab_size = actual_vocab_size - len(token_to_losses) 116 | logging.info(f"Actual vocab size: {actual_vocab_size}, " 117 | f"Target vocab size: {target_vocab_size}, " 118 | f"Active Vocab Size: {current_active_vocab_size}, " 119 | f"Inactive Vocab Size: {current_inactive_vocab_size}") 120 | 121 | # how many of the losses are negative 122 | neg_loss = len([loss for loss in token_to_losses.values() if loss < 0.0]) 123 | zero_loss = len([loss for loss in token_to_losses.values() if loss == 0.0]) 124 | pos_loss = len([loss for loss in token_to_losses.values() if loss > 0.0]) 125 | logging.info(f"Negative losses: {neg_loss}, zero losses: {zero_loss}, positive losses: {pos_loss}") 126 | 127 | # in case the active vocab we found is actually smaller than the target vocab, 128 | # change the target to the next one, until it's smaller than the vocab we found, 129 | # so the ablation part will actually do something 130 | while current_active_vocab_size <= target_vocab_size: 131 | logging.info(f"Active vocab size is {current_active_vocab_size} - " 132 | f"smaller than target {target_vocab_size}. Moving to next target_vocab_size" 133 | f"\n\n(Round number increased to {i + 1})\n") 134 | i += 1 135 | target_vocab_size = vocab_schedule[i + 1] 136 | logging.info(f"New target_vocab_size: {target_vocab_size}") 137 | 138 | num_tokens_to_prune = current_active_vocab_size - target_vocab_size 139 | logging.info(f"Num tokens to prune {num_tokens_to_prune}") 140 | 141 | ###################### 142 | # do the ablation 143 | ###################### 144 | # we want to drop the smallest (negative) values 145 | # these are the ones with the largest decrease in likelihood from dropping the ablated token 146 | sorted_losses = list(sorted([(loss, tid) for (tid, loss) in token_to_losses.items()])) 147 | save_sorted_losses(sage_model, sorted_losses, target_vocab_size, vocab_folder) 148 | 149 | stats = { 150 | "current_step_vocab_size": current_step_vocab_size, "total_tokens": total_tokens, 151 | "total_triples": total_triples, "current_active_vocab_size": current_active_vocab_size, 152 | "current_inactive_vocab_size": current_inactive_vocab_size, "neg_loss": neg_loss, 153 | "zero_loss": zero_loss, "pos_loss": pos_loss, "target_vocab_size": target_vocab_size, 154 | "num_tokens_to_prune": num_tokens_to_prune, "ablated_sizes": ablated_sizes, 155 | } 156 | save_stats(stats, stats_folder, target_vocab_size) 157 | 158 | # these are the tokens to be removed 159 | tokens_to_prune = {sage_model.id_to_bytes(tid) for (loss, tid) in sorted_losses[:num_tokens_to_prune]} 160 | # double check there are no single bytes tokens to prune here 161 | single_byte_tokens_to_prune = [token for token in tokens_to_prune if len(token) == 1] 162 | assert len(single_byte_tokens_to_prune) == 0 163 | 164 | # our active vocabulary *after* pruning 165 | # is active if it has an entry in token_to_losses 166 | active_vocab = {tok: tid for tok, tid in sage_model.get_vocabulary().items() 167 | if tid in token_to_losses and tok not in tokens_to_prune} 168 | 169 | # our overall vocabulary after pruning 170 | target_vocab = {tok: tid for tok, tid in sage_model.get_vocabulary().items() 171 | if tok not in tokens_to_prune} 172 | 173 | # the deleted items 174 | deleted_vocab = {tok: tid for tok, tid in sage_model.get_vocabulary().items() 175 | if tok in tokens_to_prune} 176 | 177 | vocab_save_name = vocab_folder / f"sage_vocab_{target_vocab_size}.vocab" 178 | logging.info(f"Saving intermediate vocab of size {len(target_vocab)} to {vocab_save_name.as_posix()}") 179 | write_vocab(target_vocab, vocab_save_name) 180 | 181 | active_save_name = vocab_folder / f"active_vocab_{target_vocab_size}.vocab" 182 | logging.info(f"Saving active vocab of size {len(active_vocab)} to {active_save_name.as_posix()}") 183 | write_vocab(active_vocab, active_save_name) 184 | 185 | # save the deleted ones too for analysis, with the original size 186 | deleted_save_name = vocab_folder / f"deleted_vocab_{target_vocab_size}.vocab" 187 | logging.info(f"Saving deleted vocab of size {len(deleted_vocab)} to {deleted_save_name.as_posix()}") 188 | write_vocab(deleted_vocab, deleted_save_name) 189 | 190 | # now update the internal state of sage_model to use the new smaller vocab 191 | # pass in list of bytes keys, which keep insertion order 192 | sage_model.set_vocabulary(list(target_vocab.keys())) 193 | 194 | logging.info(f"\nRound {i} - End: " 195 | f"\n\tCurrent step vocabulary size: {current_step_vocab_size}, " 196 | f"\n\tTarget vocabulary size: {target_vocab_size}, " 197 | f"\n\tActual vocabulary size:{len(active_vocab)}") 198 | 199 | # advance to next smaller size 200 | i += 1 201 | -------------------------------------------------------------------------------- /sage_v1/Python-Modules/Utils.py: -------------------------------------------------------------------------------- 1 | 2 | from scipy.special import expit 3 | import numpy as np 4 | import math 5 | 6 | SPECIAL_TOKENS = ["", "", ""] 7 | WORD_PREFIX_CHAR = "\u2581" 8 | 9 | def sigmoid(num): 10 | s = expit(num) 11 | if s == 1.0: 12 | raise BaseException("sigmoid overflow") 13 | 14 | return expit(num) 15 | 16 | def compute_window(token_index, tokens_in_line, window_size): 17 | # get context window tokens 18 | context_start = max(token_index-window_size, 0) 19 | context_end = min(token_index+window_size+1, len(tokens_in_line)) 20 | window = tokens_in_line[context_start:token_index] + tokens_in_line[token_index+1:context_end] 21 | 22 | return window, context_start, context_end 23 | 24 | # Compose list of tokens we never ablate 25 | def get_not_ablateable_tokens_list(current_vocab): 26 | NAT_list = [] 27 | for t in current_vocab: 28 | if len(t) == 1: 29 | NAT_list.append(t) 30 | continue 31 | 32 | if t.startswith(WORD_PREFIX_CHAR) and len(t) == 2: 33 | NAT_list.append(t) 34 | continue 35 | 36 | if t in SPECIAL_TOKENS: 37 | NAT_list.append(t) 38 | continue 39 | 40 | continue 41 | 42 | return NAT_list 43 | 44 | # calculating 'offset' = number of tokens added because of our new encoding 45 | def calculate_token_offset(i, last_times_accumulate_offset, tokens_in_line, updated_tokens_in_line, log): 46 | # original encoding next_word_index 47 | original_current_index = i+1 48 | if original_current_index == len(tokens_in_line): 49 | original_current_index = i 50 | else: 51 | while not tokens_in_line[original_current_index].startswith("\u2581"): 52 | original_current_index += 1 53 | if original_current_index == len(tokens_in_line): 54 | break 55 | 56 | original_next_word_index = original_current_index 57 | 58 | # and the updated next_word_index 59 | current_index = last_times_accumulate_offset + i+1 60 | if current_index == len(updated_tokens_in_line): 61 | current_index = last_times_accumulate_offset + i 62 | else: 63 | if (current_index >= len(updated_tokens_in_line)): 64 | log.info("Fail! current index = {}, length of \"updated_tokens_in_line\" = {}".format(current_index, len(updated_tokens_in_line))) 65 | log.info("token #{}, updated tokens:{}".format(i, updated_tokens_in_line)) 66 | current_index -= 1 67 | else: 68 | while not updated_tokens_in_line[current_index].startswith("\u2581"): 69 | current_index += 1 70 | if current_index == len(updated_tokens_in_line): 71 | break 72 | 73 | updated_next_word_index = current_index 74 | 75 | # the offset is the difference between them 76 | offset = updated_next_word_index - original_next_word_index - last_times_accumulate_offset 77 | 78 | return offset 79 | 80 | def token_to_line_indices_dictionary(current_vocab, corpus_lines): 81 | # first lets hold the instance of line with "/u2581" chars instead of spaces 82 | corpus_lines_encoded = [] 83 | for line in corpus_lines: 84 | enc_line = WORD_PREFIX_CHAR + line 85 | enc_line = enc_line.replace(" ", WORD_PREFIX_CHAR) 86 | corpus_lines_encoded.append(enc_line) 87 | 88 | # and test whether tokens contained in the lines 89 | token_to_line_dict = {} 90 | for token in current_vocab: 91 | lines_with_token_indices = [] 92 | for i, enc_line in enumerate(corpus_lines_encoded): 93 | if token not in enc_line: 94 | continue 95 | lines_with_token_indices.append(i) 96 | token_to_line_dict[token] = lines_with_token_indices 97 | 98 | return token_to_line_dict 99 | 100 | #################################################################### 101 | ## Functions for multiprocessing - cannot be class methods ######### 102 | #################################################################### 103 | 104 | def sg_for_window_mp(target_token, window, target_embeddings, context_embeddings, log): 105 | current_p = 0 106 | for w in window: 107 | # calculate current value to add 108 | dot_product = np.dot(target_embeddings[target_token], context_embeddings[w]) 109 | try: 110 | current_p += math.log(sigmoid(dot_product)) 111 | except: 112 | pass 113 | 114 | return (-1) * current_p 115 | 116 | 117 | def substract_windows_from_sg_mp(i, tokens_in_line_ints, current_sg, target_embeddings, context_embeddings, window_size, log): 118 | # we should substract windows of all tokens in our original window 119 | sg_wo = current_sg 120 | _, context_start, context_end = compute_window(i, tokens_in_line_ints, window_size) 121 | 122 | for index in range(context_start, context_end): 123 | window, _, _ = compute_window(index, tokens_in_line_ints, window_size) 124 | sg_wo -= sg_for_window_mp(tokens_in_line_ints[index], window, target_embeddings, context_embeddings, log) 125 | 126 | return sg_wo 127 | 128 | def add_windows_to_sg_mp(model, updated_i, offset, updated_context_start, \ 129 | updated_context_end, updated_tokens_in_line, current_sg, \ 130 | target_embeddings, context_embeddings, window_size, log): 131 | sg_wo = current_sg 132 | 133 | # assume it now looks like: old[context_start:i] + [ablated_token_encoding] + old[i+1:context_end] 134 | # add the windows before i (ablated token) 135 | for index in range(updated_context_start, updated_i): 136 | try: 137 | window, _, _ = compute_window(index, updated_tokens_in_line, window_size) 138 | sg_wo += sg_for_window_mp(updated_tokens_in_line[index], window, target_embeddings, context_embeddings, log) 139 | except: 140 | print("index {}".format(index)) 141 | print("tokens {}".format(updated_tokens_in_line)) 142 | raise 143 | 144 | # add the ablated token encoding windows 145 | for index in range(updated_i, updated_i+offset): 146 | try: 147 | window, _, _ = compute_window(index, updated_tokens_in_line, window_size) 148 | sg_wo += sg_for_window_mp(updated_tokens_in_line[index], window, target_embeddings, context_embeddings, log) 149 | except: 150 | print("index {}".format(index)) 151 | print("tokens {}".format(updated_tokens_in_line)) 152 | raise 153 | 154 | # add the windows after the ablated token 155 | for index in range(updated_i+offset, updated_context_end): 156 | try: 157 | window, _, _ = compute_window(index, updated_tokens_in_line, window_size) 158 | sg_wo += sg_for_window_mp(updated_tokens_in_line[index], window, target_embeddings, context_embeddings, log) 159 | except: 160 | print("index {}, len {}".format(index, len(updated_tokens_in_line))) 161 | updated_tokens_in_line_str = [model.id_to_piece(x) for x in updated_tokens_in_line] 162 | print("tokens {}".format(updated_tokens_in_line_str)) 163 | raise 164 | 165 | return sg_wo 166 | 167 | def update_sg_per_instance_of_token_mp(model, token_to_ablate, i, \ 168 | line, tokens_in_line_ints, tokens_in_line_pieces, \ 169 | last_times_accumulate_offset, \ 170 | current_total_sg, current_vocab, \ 171 | target_embeddings, context_embeddings, log, window_size): 172 | 173 | # window of token - before ablation 174 | sg_wo = current_total_sg 175 | 176 | # we should substract windows of all tokens in our original window 177 | original_window, original_context_start, original_context_end = compute_window(i, tokens_in_line_pieces, window_size) 178 | 179 | sg_wo = substract_windows_from_sg_mp(i, tokens_in_line_ints, sg_wo, target_embeddings, context_embeddings, window_size, log) 180 | 181 | ## remove token from current vocab 182 | vocab_without_token = current_vocab.copy() 183 | vocab_without_token.remove(token_to_ablate) 184 | model.set_vocabulary(vocab_without_token) 185 | 186 | # encode with new vocab 187 | updated_tokens_in_line_ints = model.encode(line, out_type=int) 188 | updated_tokens_in_line = [model.id_to_piece(u) for u in updated_tokens_in_line_ints] 189 | 190 | # calculating 'offset' = number of tokens added because of our new encoding 191 | offset = calculate_token_offset(i, last_times_accumulate_offset, tokens_in_line_pieces, updated_tokens_in_line, log) 192 | 193 | ################################################################################# 194 | # make sure the window (before the ablated token) stays the same ################ 195 | ################################################################################# 196 | _, updated_context_start, updated_context_end = compute_window(i+last_times_accumulate_offset, updated_tokens_in_line_ints, window_size) 197 | updated_context_end = min(updated_context_end+offset, len(updated_tokens_in_line)) # real context_en 198 | updated_i = i + last_times_accumulate_offset 199 | updated_window = updated_tokens_in_line[updated_context_start:updated_i] + updated_tokens_in_line[updated_i+offset+1:updated_context_end] 200 | 201 | first_part_of_window_length = min(window_size, updated_i) 202 | 203 | if updated_window[:first_part_of_window_length] != original_window[:first_part_of_window_length]: 204 | return current_total_sg, offset 205 | 206 | ################################################################################# 207 | ################################################################################# 208 | 209 | # and add the updated windows sg 210 | sg_wo = add_windows_to_sg_mp(model, updated_i, offset, updated_context_start, \ 211 | updated_context_end, updated_tokens_in_line_ints, \ 212 | sg_wo, target_embeddings, context_embeddings, window_size, log) 213 | 214 | ## revert model to current vocab 215 | model.set_vocabulary(current_vocab) 216 | 217 | return sg_wo, offset 218 | 219 | ## We assume the line contains the token_to_ablate (check it before shooting new process to execute this method) 220 | def get_diff_sg_wo_token_for_line(model, line_index, \ 221 | corpus_lines, model_encoded_corpus_lines_token_ids, model_encoded_corpus_lines_token_pieces, \ 222 | token_to_ablate, current_vocab, target_embeddings, context_embeddings, log, window_size): 223 | sg_wo_diff = 0 224 | 225 | tokens_in_line_ints = model_encoded_corpus_lines_token_ids[line_index] 226 | tokens_in_line_pieces = model_encoded_corpus_lines_token_pieces[line_index] 227 | line = corpus_lines[line_index] 228 | 229 | # For start, assume there are no overlaps in the windows of 2 instances 230 | indices = [i for i, x in enumerate(tokens_in_line_pieces) if x == token_to_ablate] 231 | 232 | last_time_offset = 0 233 | last_times_accumulate_offset = 0 234 | for i in indices: 235 | sg_wo_diff, last_time_offset = update_sg_per_instance_of_token_mp(model, token_to_ablate, i, line, tokens_in_line_ints, tokens_in_line_pieces, last_times_accumulate_offset, sg_wo_diff, current_vocab, target_embeddings, context_embeddings, log, window_size) 236 | last_times_accumulate_offset += last_time_offset 237 | 238 | return token_to_ablate, sg_wo_diff 239 | 240 | def sg_wo_token_mp(model, token_to_ablate, current_total_sg, current_vocab, training_filepath, target_embeddings, context_embeddings, log, corpus_lines, window_size): 241 | sg_wo = current_total_sg 242 | 243 | for line in corpus_lines: 244 | tokens_in_line_pieces = [model.id_to_piece(x) for x in model.encode(line, out_type=int)] 245 | if token_to_ablate not in tokens_in_line_pieces: 246 | continue 247 | 248 | # For start, assume there are no overlaps in the windows of 2 instances 249 | indices = [i for i, x in enumerate(tokens_in_line_pieces) if x == token_to_ablate] 250 | 251 | last_time_offset = 0 252 | last_times_accumulate_offset = 0 253 | for i in indices: 254 | tokens_in_line_ints = [model.piece_to_id(t) for t in tokens_in_line_pieces] 255 | sg_wo, last_time_offset = update_sg_per_instance_of_token_mp(model, token_to_ablate, i, line, tokens_in_line_ints, tokens_in_line_pieces, last_times_accumulate_offset, sg_wo, current_vocab, target_embeddings, context_embeddings, log, window_size) 256 | last_times_accumulate_offset += last_time_offset 257 | 258 | return token_to_ablate, sg_wo 259 | -------------------------------------------------------------------------------- /src/sage_tokenizer/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © 2023 Kensho Technologies, LLC 2 | from typing import Iterable, Tuple, List, Optional, Dict 3 | 4 | import json 5 | import logging 6 | import multiprocessing as mp 7 | import random 8 | import time 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | from scipy.special import expit 13 | 14 | from .model import SaGeTokenizer 15 | 16 | # only log code outside of multiprocessing 17 | # logger = logging.getLogger(__name__) 18 | from .paths import getDataFolder, getLogsFolder, getResultsFolder 19 | 20 | 21 | def write_vocab(vocab: Dict[bytes,int], filename: Path): 22 | """ 23 | Dump the byte vocab to a file, encoded as hex characters inside this function. 24 | Saved in same order by index, so should preserve order. 25 | No special tokens are added. 26 | """ 27 | # write these in increasing index order 28 | # so same as any previous order 29 | byindex = sorted([(idx, token) for token, idx in vocab.items()]) 30 | 31 | with open(filename, 'w', encoding="utf-8") as f: 32 | for _, token in byindex: 33 | f.write(token.hex() + '\n') 34 | 35 | 36 | def save_sorted_losses(sage_model: SaGeTokenizer, sorted_losses, target_vocab_size: int, vocab_folder: Path): 37 | vocab_folder = Path(vocab_folder) 38 | 39 | sorted_losses_filepath = vocab_folder / f"sorted_losses_before_{target_vocab_size}.txt" 40 | worst_500_filepath = vocab_folder / f"worst_500_{target_vocab_size}.txt" 41 | best_500_filepath = vocab_folder / f"best_500_{target_vocab_size}.txt" 42 | 43 | logging.info(f"Saving sorted losses to {sorted_losses_filepath.as_posix()}") 44 | write_sorted_losses_into_file(sorted_losses, sorted_losses_filepath, sage_model) 45 | write_sorted_losses_into_file(sorted_losses[:500], worst_500_filepath, sage_model) 46 | write_sorted_losses_into_file(sorted_losses[-500:], best_500_filepath, sage_model) 47 | 48 | 49 | def write_sorted_losses_into_file(sl: Iterable[Tuple[float,int]], filename: Path, sage_model: SaGeTokenizer): 50 | with open(filename, 'w', encoding="utf-8") as f: 51 | for loss, tid in sl: 52 | f.write(sage_model.id_to_encoded(tid) + "\t" + str(loss) + "\n") 53 | 54 | 55 | def load_vocab(vocab_filepath: Path) -> List[bytes]: 56 | """ 57 | Read our hex formatted vocab file, return a list of bytes objects. 58 | Input file has one vocab word per line, each hex encoded. 59 | """ 60 | vocab_filepath = Path(vocab_filepath) 61 | if not vocab_filepath.exists(): 62 | raise FileNotFoundError(f'Missing vocab file: {vocab_filepath.as_posix()}') 63 | 64 | with open(vocab_filepath, "r") as vocab_file: 65 | # fromhex ignores whitespace from \n at end 66 | initial_vocab = [bytes.fromhex(token) for token in vocab_file.readlines()] 67 | 68 | return initial_vocab 69 | 70 | 71 | def load_corpus(corpus_filepath: Path, partial_corpus_filepath: Optional[Path], partial_corpus_line_number: int) -> List[str]: 72 | corpus_filepath = Path(corpus_filepath) 73 | partial_corpus_filepath = Path(partial_corpus_filepath) if isinstance(partial_corpus_filepath, str) else partial_corpus_filepath 74 | 75 | if partial_corpus_filepath and partial_corpus_filepath.exists(): # Corpus already exists, directly loading 76 | logging.info(f"Found pre-existing partial corpus. Loading from {partial_corpus_filepath.as_posix()}...") 77 | read_start = time.time() 78 | with open(partial_corpus_filepath, "r") as corpus_f: 79 | partial_corpus = corpus_f.readlines() 80 | logging.info(f"Size of Corpus: {len(partial_corpus)}, time: {(time.time() - read_start):.2f}") 81 | else: 82 | read_start = time.time() 83 | with open(corpus_filepath, "r") as full_corpus_f: 84 | corpus = full_corpus_f.readlines() 85 | logging.info(f"Loading from Original Corpus. Number of lines: {len(corpus)}") 86 | 87 | random.shuffle(corpus) 88 | logging.info(f"Original Corpus read and shuffled. Time: {(time.time() - read_start):.2f}") 89 | 90 | # may be same as original depending on partial_corpus_line_number 91 | write_start_time = time.time() 92 | partial_corpus = corpus[:partial_corpus_line_number * 1000] 93 | 94 | if partial_corpus_filepath is None: 95 | partial_corpus_filepath = getDataFolder() / f"{corpus_filepath.stem}_{len(partial_corpus)}.txt" 96 | 97 | with open(partial_corpus_filepath, "w+") as partial_corpus_f: 98 | partial_corpus_f.writelines(partial_corpus) 99 | logging.info(f"Partial corpus saved at {partial_corpus_filepath.as_posix()}. " 100 | f"Number of lines: {len(partial_corpus)}, " 101 | f"time: {(time.time() - write_start_time):.2f}") 102 | 103 | return partial_corpus 104 | 105 | 106 | def divide_data_by_num(data: List[str], num_procs: int) -> Iterable[List[str]]: 107 | """ 108 | Split the data given the number of chunks we expect 109 | Returns a generator 110 | """ 111 | size_per_chunk = len(data) // num_procs 112 | for i in range(0, len(data), size_per_chunk + 1): 113 | yield data[i: i + size_per_chunk + 1] 114 | 115 | 116 | def divide_data_by_size(data, size): 117 | """ 118 | Split the data given the size of chunks we expect 119 | Returns a generator 120 | """ 121 | for i in range(0, len(data), size): 122 | yield data[i: i + size] 123 | 124 | 125 | def compute_losses(losses, all_triples, embeddings): 126 | """ 127 | function for computing losses given triple counts and embeddings 128 | losses : accumulate losses per ablated token, excluding the single byte ones, side effect this 129 | all_triples : triple values to aggregate into losses 130 | embeddings : embedding for each token 131 | """ 132 | target_ids, context_ids, count = zip(*[(target_id, context_id, count) for (_, target_id, context_id), count in all_triples.items()]) 133 | target_embeddings = np.array([embeddings[target_id] for target_id in target_ids]) 134 | context_embeddings = np.array([embeddings[context_id] for context_id in context_ids]) 135 | count = np.array(count) 136 | triples_loss = count * np.log(expit(np.einsum('ij,ij->i', target_embeddings, context_embeddings))) 137 | for idx, ((ablated_token_id, target_id, context_id), count) in enumerate(all_triples.items()): 138 | losses[ablated_token_id] = losses.get(ablated_token_id, 0.0) + triples_loss[idx] 139 | 140 | 141 | def run_sage_parallel(embeddings: np.ndarray, partial_corpus: List[str], sage_model: SaGeTokenizer, workers_number: int): 142 | logging.info(f"Splitting data into {workers_number} chunks.") 143 | data_chunk_gen = divide_data_by_num(partial_corpus, workers_number) 144 | 145 | # these get aggregated over each chunk 146 | sage_losses = {} # is token_id : loss 147 | overall_total_tokens = 0 148 | overall_total_triples = 0 149 | ablated_sizes = {} 150 | start_time = time.time() 151 | logging.info(f"Start spawning processes...") 152 | with mp.Pool(processes=workers_number) as pool: 153 | tasks = {} 154 | 155 | for tid, data_chunk in enumerate(data_chunk_gen): 156 | res = pool.apply_async(sage_per_chunk, args=(tid, sage_model, data_chunk, embeddings)) 157 | tasks[res] = tid 158 | 159 | while tasks: 160 | results_ready_list = [] 161 | for res, tid in tasks.items(): 162 | if res.ready(): 163 | results_ready_list.append((res, tid)) 164 | 165 | # process finished task results 166 | for res, tid in results_ready_list: 167 | losses, total_tokens, total_triples, ab_sizes = res.get() 168 | 169 | # just add these to totals/maxes 170 | overall_total_tokens += total_tokens 171 | overall_total_triples += total_triples 172 | 173 | # add to the overall tallys 174 | for k, v in losses.items(): 175 | sage_losses[k] = sage_losses.get(k, 0) + v 176 | 177 | # how many tokens needed to be examined 178 | for k, v in ab_sizes.items(): 179 | ablated_sizes[k] = ablated_sizes.get(k, 0) + v 180 | 181 | # all done with this, 182 | # can delete from tasks without messing up iteration over list 183 | del tasks[res] 184 | 185 | logging.info(f"task {tid} finished after {(time.time() - start_time):.2f} seconds. " 186 | f"Tokens:{total_tokens}, triples:{total_triples}, active:{len(sage_losses)}") 187 | 188 | logging.info(f"Sleeping 1 second. Number of still running tasks: {len(tasks)}") 189 | time.sleep(1.0) 190 | return overall_total_tokens, overall_total_triples, sage_losses, ablated_sizes 191 | 192 | 193 | def sage_per_chunk(tid, model: SaGeTokenizer, data, embeddings, chunk_size: int=10000): 194 | """ 195 | function that runs sage on each chunk of data (in parallelization) 196 | note: this is called from multiprocessing, so use print rather than logging 197 | """ 198 | print(f"Starting chunk {tid}, with {len(data)} lines of data") 199 | 200 | start_time = time.time() 201 | 202 | # accumulate over all the data 203 | losses = {} 204 | 205 | # these accumulate over each size 206 | triples = {} 207 | ablated_sizes = {} 208 | total_tokens = 0 209 | total_triples = 0 210 | total_fs_time = 0.0 211 | total_cl_time = 0.0 212 | 213 | fs_start = time.time() 214 | for row, d in enumerate(data): 215 | 216 | total_tokens += model.fast_sage(bytes(d, 'utf-8'), triples, ablated_sizes) 217 | 218 | # if filled up chunk, then compute the losses 219 | # to free up memory 220 | if (row > 0) and (row % chunk_size == 0): 221 | # take the total time here over all calls 222 | fs_time = time.time() - fs_start 223 | total_fs_time += fs_time 224 | # reinitialize fs_start 225 | fs_start = time.time() 226 | 227 | cl_start = time.time() 228 | compute_losses(losses, triples, embeddings) 229 | cl_time = time.time() - cl_start 230 | total_cl_time += cl_time 231 | 232 | print(f"fast_sage {tid}, row {row} of {len(data)}, " 233 | f"fs_time: {fs_time:.2f}, cl_time: {cl_time:.2f}, " 234 | f"triples: {len(triples)}, tokens: {total_tokens}") 235 | 236 | # total these up 237 | total_triples += len(triples) 238 | 239 | # zero out the triples from this chunksize lines 240 | triples = {} 241 | 242 | # compute for final partial chunk 243 | if triples: 244 | compute_losses(losses, triples, embeddings) 245 | total_triples += len(triples) 246 | 247 | # the triples can get quite large, so to avoid merging these 248 | # dict values, let's compute the losses in parallel too 249 | print(f"final fast_sage {tid}, row {row} of {len(data)}, " 250 | f"fs_time: {total_fs_time:.2f}, cl_time: {total_cl_time:.2f}, time: {(time.time() - start_time):.2f}, " 251 | f"triples: {len(triples)}, tokens: {total_tokens}") 252 | 253 | # Extra negative sign for equation (1) in SaGe paper 254 | # track number in cache too 255 | losses = {k: -v for k, v in losses.items()} 256 | 257 | return losses, total_tokens, total_triples, ablated_sizes 258 | 259 | 260 | def init_logger(experiment_name: str): 261 | timestamp_str = time.strftime("%Y%m%d_%H%M%S") 262 | log_filename = getLogsFolder() / f"{experiment_name}_{timestamp_str}.log" 263 | logging.basicConfig(filename=log_filename.as_posix(), 264 | format="%(asctime)s - %(message)s", 265 | datefmt="%m/%d/%Y %H:%M:%S", 266 | level=logging.INFO) 267 | 268 | print(f"Logs will be stored in {log_filename.as_posix()}") 269 | 270 | 271 | def get_output_folder(experiment_name: str) -> Tuple[Path, Path, Path]: 272 | results_path = getResultsFolder() / experiment_name 273 | 274 | vocab_folder = results_path / "sage_vocabs" 275 | vocab_folder.mkdir(exist_ok=True) 276 | 277 | stats_folder = results_path / "stats" 278 | stats_folder.mkdir(exist_ok=True) 279 | 280 | embeddings_folder = results_path / "embeddings" 281 | embeddings_folder.mkdir(exist_ok=True) 282 | 283 | return embeddings_folder, stats_folder, vocab_folder 284 | 285 | 286 | def set_random_seed(experiment_name: str, random_seed: int): 287 | # Log seed 288 | seed_filepath = getResultsFolder() / experiment_name / "seed.txt" 289 | with open(seed_filepath, "w+") as f: 290 | f.write(str(random_seed)) 291 | 292 | # Set seed 293 | random.seed(random_seed) 294 | np.random.seed(random_seed) 295 | 296 | 297 | def save_stats(stats: dict, stats_folder: Path, target_vocab_size: int): 298 | stats_folder = Path(stats_folder) 299 | 300 | stats_filename = stats_folder / f"stats_{target_vocab_size}.json" 301 | logging.info(f"Saving stats to {stats_filename.as_posix()}") 302 | with open(stats_filename, "w") as f: 303 | json.dump(stats, f, indent=2) # pretty print a bit 304 | f.write("\n") 305 | -------------------------------------------------------------------------------- /sage_v1/Main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import multiprocessing as mp 4 | import sys 5 | import json 6 | import random 7 | import pickle 8 | import argparse 9 | import numpy as np 10 | from tqdm import tqdm 11 | import sentencepiece as spm 12 | from os.path import exists 13 | import os 14 | 15 | # Import From Python-Modules 16 | sys.path.insert(0, "Python-Modules") 17 | import Utils 18 | import Corpus 19 | import Logger 20 | import SG_BPE 21 | import Embeddings 22 | 23 | ######### Fix Random Seed ########################################################### 24 | 25 | # Fill with your chosen seed 26 | CHOSEN_SEED = 0 27 | 28 | def set_random_seed(chosen_seed): 29 | # setting seed 30 | random.seed(chosen_seed) 31 | spm.SetRandomGeneratorSeed(chosen_seed) 32 | np.random.seed(chosen_seed) 33 | 34 | def fix_random_seed(experiment_name, is_continue_execution, chosen_seed, log): 35 | seed_filepath = "results/{}/seed.txt".format(experiment_name) 36 | if is_continue_execution and exists(seed_filepath): 37 | with open(seed_filepath, "r") as seed_file: 38 | chosen_seed = int(seed_file.read()) 39 | log.info("chosen seed is {}".format(chosen_seed)) 40 | 41 | else: 42 | # saving for continue execution of that experiment, if this is first time execution 43 | with open(seed_filepath, "w+") as seed_file: 44 | seed_file.write(str(chosen_seed)) 45 | 46 | set_random_seed(chosen_seed) 47 | 48 | ######### Log Parameters ##################################################################### 49 | 50 | def log_parameters(log, final_vocab_size, initial_vocab_size, \ 51 | partial_corpus_lines_number, tokens_to_prune_in_iteration, tokens_to_consider_in_iteration, \ 52 | iterations_until_reranking, max_lines_per_token, corpus_filepath, partial_corpus_filepath, window_size): 53 | 54 | log.info("-----------------------------------------") 55 | log.info("Starting Experiment.") 56 | log.info("initial_vocab_size: {}, final_vocab_size: {}".format(initial_vocab_size, final_vocab_size)) 57 | log.info("partial_corpus_lines_number: {}\n".format(partial_corpus_lines_number)) 58 | log.info("tokens_to_prune_in_iteration {}".format(tokens_to_prune_in_iteration)) 59 | log.info("tokens_to_consider_in_iteration: {}".format(tokens_to_consider_in_iteration)) 60 | log.info("iterations_until_reranking: {}".format(iterations_until_reranking)) 61 | log.info("max_lines_per_token: {}".format(max_lines_per_token)) 62 | log.info("corpus_filepath: {}".format(corpus_filepath)) 63 | log.info("partial_corpus_filepath: {}".format(partial_corpus_filepath)) 64 | log.info("window_size: {}".format(window_size)) 65 | log.info("Number of processes in pool: {}".format(mp.cpu_count())) 66 | log.info("Not re-calculating embeddings") 67 | log.info("-----------------------------------------") 68 | 69 | ######### Main ##################################################################### 70 | 71 | def main(experiment_name, is_continue_execution, final_vocab_size, \ 72 | initial_vocab_size, partial_corpus_lines_number, tokens_to_prune_in_iteration, tokens_to_consider_in_iteration, \ 73 | iterations_until_reranking, max_lines_per_token, corpus_filepath, partial_corpus_filepath, window_size): 74 | 75 | # Initialize Statistics Logger 76 | print("Initializing statistics logger") 77 | log = Logger.Logger("statistics") 78 | 79 | if not exists("results"): 80 | log.info("Creating 'results' dir") 81 | os.mkdir("results") 82 | 83 | experiment_results_directory = "results/{}".format(experiment_name) 84 | if not exists(experiment_results_directory): 85 | log.info("Creating {} dir".format(experiment_results_directory)) 86 | os.mkdir(experiment_results_directory) 87 | 88 | log.info("Fixing random seed") 89 | fix_random_seed(experiment_name, is_continue_execution, CHOSEN_SEED, log) 90 | 91 | # Preparing Data 92 | log.info("Preparing corpus") 93 | corpus = Corpus.Corpus(corpus_filepath, partial_corpus_filepath, partial_corpus_lines_number, log) 94 | partial_corpus = corpus.get_corpus() 95 | 96 | # Log parameters if not continuation execution 97 | if not is_continue_execution: 98 | log_parameters(log, final_vocab_size, initial_vocab_size, partial_corpus_lines_number, \ 99 | tokens_to_prune_in_iteration, tokens_to_consider_in_iteration, iterations_until_reranking, \ 100 | max_lines_per_token, corpus_filepath, partial_corpus_filepath, window_size) 101 | 102 | # BPE using SentencePiece Python Module 103 | log.info("Preparing SG and BPE Models") 104 | models = SG_BPE.SG_BPE_Models(experiment_name, is_continue_execution, final_vocab_size, initial_vocab_size, partial_corpus_filepath) 105 | sg_bpe_model = models.get_sg_bpe_model() 106 | bpe_vanilla_model = models.get_bpe_vanilla_model() 107 | 108 | # Computing Embedding Matrix 109 | log.info("Computing embeddings") 110 | embeddings_filepath = "results/{}/embeddings.bin".format(experiment_name) 111 | if not is_continue_execution: 112 | wp_trainer = Embeddings.EmbeddingsTrainer(sg_bpe_model, corpus, window_size, log) 113 | target_embeddings, context_embeddings = wp_trainer.train_embeddings() 114 | with open(embeddings_filepath, "wb") as embeddings_file: 115 | pickle.dump(target_embeddings, embeddings_file) 116 | pickle.dump(context_embeddings, embeddings_file) 117 | log.info("dumped embeddings to {}".format(embeddings_filepath)) 118 | else: 119 | with open(embeddings_filepath, "rb") as embeddings_file: 120 | target_embeddings = pickle.load(embeddings_file) 121 | context_embeddings = pickle.load(embeddings_file) 122 | 123 | # Creating model objects 124 | # These will hold needed logic for training loop, logging results and gathering information about models too 125 | log.info("Creating model objects") 126 | vocab_filepath = "results/{}/current_vocab".format(experiment_name) 127 | sg_bpe_model_object = SG_BPE.Model(experiment_name, log, sg_bpe_model, "sg_bpe", \ 128 | target_embeddings, context_embeddings, partial_corpus, max_lines_per_token, window_size, is_continue_execution, vocab_filepath) 129 | bpe_vanilla_model_object = SG_BPE.Model(experiment_name, log, bpe_vanilla_model, "bpe_vanilla", \ 130 | target_embeddings, context_embeddings, partial_corpus, max_lines_per_token, window_size) 131 | 132 | # We use the "get_current_vocab" method because of sentencepiece tricky way to remove tokens from the vocabulary. 133 | current_vocab = sg_bpe_model_object.get_current_vocab() 134 | with open(vocab_filepath + ".bin", "wb") as vocab_file: 135 | pickle.dump(current_vocab, vocab_file) 136 | 137 | ########################################################################################################################### 138 | # Prune Tokens - SkipGram Log Probability 139 | # For now we compute SG objective without negative samples - this causes noise, and due to the fact we look at the actual result 140 | # (and don't just refer it as objective) we don't want that noise. 141 | # We say we have better vocabulary when we minimize this objective. 142 | # Thus, in each iteration we will find the tokens that without them the objective is minimal. 143 | ########################################################################################################################### 144 | 145 | log.info("Preparing sorted Tokens-SG-objective list") 146 | sorted_tokens_sg_filepath = "results/{}/sorted_tokens_sg.bin".format(experiment_name) 147 | if is_continue_execution and exists(sorted_tokens_sg_filepath): 148 | with open(sorted_tokens_sg_filepath, "rb") as sorted_tokens_sg_file: 149 | sorted_tokens_sg = pickle.load(sorted_tokens_sg_file) 150 | else: 151 | # Log starting-point vocabulary: 152 | sg_bpe_model_object.log_experiments_model_results(partial_corpus_filepath) 153 | bpe_vanilla_model_object.log_experiments_model_results(partial_corpus_filepath) 154 | 155 | # And Log starting-point SG objective 156 | total_skipgram_ns_probability = sg_bpe_model_object.total_sg_log_prob(partial_corpus_filepath) 157 | log.info("Initial SG-BPE total_log_sg_prob: {}".format(total_skipgram_ns_probability)) 158 | 159 | bpe_vanilla_total_skipgram_ns_probability = bpe_vanilla_model_object.total_sg_log_prob(partial_corpus_filepath) 160 | log.info("Initial Vanilla-BPE total_log_sg_prob: {}".format(bpe_vanilla_total_skipgram_ns_probability)) 161 | 162 | log.log_separator() 163 | 164 | # Log original sg log prob without each token 165 | original_token_and_sg_log_prob_without_it = sg_bpe_model_object.get_sg_log_prob_without_tokens_mp(total_skipgram_ns_probability, partial_corpus_filepath, True) 166 | sorted_tokens_sg = sorted(original_token_and_sg_log_prob_without_it.items(), key=lambda item: item[1]) 167 | sg_log_prob_without_tokens_mp_filepath = "./results/{}/sg_log_probs_without_tokens_mp.txt".format(experiment_name) 168 | with open(sg_log_prob_without_tokens_mp_filepath, "w+") as logprob_file: 169 | logprob_file.write(json.dumps(sorted_tokens_sg, indent=4)) 170 | 171 | with open(sorted_tokens_sg_filepath, "wb") as sorted_tokens_sg_file: 172 | pickle.dump(sorted_tokens_sg, sorted_tokens_sg_file) 173 | 174 | # And start the "Training-Loop" 175 | log.info("Starting the 'Training-Loop'") 176 | 177 | current_vocab = sg_bpe_model_object.get_current_vocab() 178 | current_total_sg = sg_bpe_model_object.total_sg_log_prob(partial_corpus_filepath) 179 | 180 | ## prepare from ahead #### 181 | ### now we want dict from token to index of lines - and re-encode when token is ablated 182 | sg_bpe_model_object.initialize_encoded_form_for_corpus_lines() 183 | sg_bpe_model_object.initialize_token_to_line_indices_dictionary(current_vocab, partial_corpus, experiment_name, is_continue_execution) 184 | 185 | # loop and ablate 186 | iteration = 0 187 | current_dict_of_top_tokens = {} 188 | 189 | progress_bar = tqdm(total = (len(current_vocab) - final_vocab_size) / tokens_to_prune_in_iteration) 190 | while len(current_vocab) > final_vocab_size: 191 | log.info("iteration #{}, num of vocab: {}, length of top: {}".format(iteration, len(current_vocab), len(current_dict_of_top_tokens))) 192 | 193 | if (iteration % iterations_until_reranking) == 0: 194 | token_and_sg_log_prob_without_it = sg_bpe_model_object.get_sg_log_prob_without_tokens_mp2(current_total_sg, True) 195 | sorted_tokens_sg = sorted(token_and_sg_log_prob_without_it.items(), key=lambda item: item[1]) 196 | current_dict_of_top_tokens = [t[0] for t in sorted_tokens_sg[:tokens_to_consider_in_iteration]] 197 | else: 198 | token_and_sg_log_prob_without_it = sg_bpe_model_object.get_sg_log_prob_without_tokens_mp2(current_total_sg, True, current_dict_of_top_tokens) 199 | sorted_tokens_sg = sorted(token_and_sg_log_prob_without_it.items(), key=lambda item: item[1]) 200 | 201 | # finding the tokens to prune in that iteration (those that without them the value is minimal) 202 | tokens_to_prune = [t[0] for t in sorted_tokens_sg[:tokens_to_prune_in_iteration]] 203 | log.info("pruning: {}".format(tokens_to_prune)) 204 | 205 | for t in tokens_to_prune: 206 | try: 207 | current_vocab.remove(t) 208 | current_dict_of_top_tokens.remove(t) 209 | except: 210 | log.info("could not remove {}".format(t[0])) 211 | raise 212 | 213 | # update our model vocab 214 | with open(vocab_filepath + ".bin", "wb") as vocab_file: 215 | pickle.dump(current_vocab, vocab_file) 216 | 217 | sg_bpe_model_object.set_vocab(current_vocab) 218 | sg_bpe_model_object.update_encoded_form_for_corpus_lines(tokens_to_prune) 219 | 220 | # upgrade our current total sg 221 | current_total_sg = sg_bpe_model_object.total_sg_log_prob(partial_corpus_filepath) 222 | 223 | # log current sg objective 224 | log.info("current SG-BPE total_log_sg_prob: {}".format(current_total_sg)) 225 | log.info("current size is {}".format(len(current_vocab))) 226 | log.info("Time elapsed (iteration #{}) - {} (minutes)".format(iteration, float(progress_bar.format_dict["elapsed"])/60)) 227 | log.info("Time elapsed (iteration #{}) - {} (minutes)".format(iteration, float(progress_bar.format_dict["elapsed"])/60)) 228 | 229 | # prepare to next iteration 230 | log.log_separator() 231 | progress_bar.update(1) 232 | iteration += 1 233 | 234 | progress_bar.close() 235 | 236 | # Logging experiment results: 237 | sg_bpe_model_object.log_experiments_model_results(partial_corpus_filepath, "sg_bpe") 238 | 239 | final_total_skipgram_ns_probability = sg_bpe_model_object.total_sg_log_prob(partial_corpus_filepath) 240 | log.info("Final SG-BPE total_log_sg_prob: {}".format(final_total_skipgram_ns_probability)) 241 | 242 | final_bpe_vanilla_total_skipgram_ns_probability = bpe_vanilla_model_object.total_sg_log_prob(partial_corpus_filepath) 243 | log.info("Vanilla-BPE total_log_sg_prob: {}".format(final_bpe_vanilla_total_skipgram_ns_probability)) 244 | 245 | # Log difference between vanilla and updated bpe models 246 | # log tokens only in sg-bpe vocab 247 | current_vocab = sg_bpe_model_object.get_current_vocab() 248 | log.info("vocab len = {}".format(len(current_vocab))) 249 | bpe_vanilla_vocab = bpe_vanilla_model_object.get_current_vocab() 250 | log.info("bpe vanilla vocab len = {}".format(len(bpe_vanilla_vocab))) 251 | 252 | only_in_sg_bpe = [t for t in current_vocab if t not in bpe_vanilla_vocab] 253 | sg_bpe_only_filepath = "./results/{}/sg_bpe_only.txt".format(experiment_name) 254 | with open(sg_bpe_only_filepath, "w+") as sg_bpe_only: 255 | sg_bpe_only.write(json.dumps(only_in_sg_bpe, indent=4)) 256 | 257 | # log tokens only in bpe-vanilla vocab 258 | only_in_bpe_vanilla = [t for t in bpe_vanilla_vocab if t not in current_vocab] 259 | bpe_vanilla_only_filepath = "./results/{}/bpe_vanilla_only.txt".format(experiment_name) 260 | with open(bpe_vanilla_only_filepath, "w+") as bpe_vanilla_only: 261 | bpe_vanilla_only.write(json.dumps(only_in_bpe_vanilla, indent=4)) 262 | 263 | # Average token length 264 | average_token_length = sum(map(len, current_vocab)) / len(current_vocab) 265 | log.info("SG-BPE average token length: {}".format(average_token_length)) 266 | 267 | # Average token length of vanilla-bpe 268 | vanilla_average_token_length = sum(map(len, bpe_vanilla_vocab)) / len(bpe_vanilla_vocab) 269 | log.info("Vanilla-BPE average token length: {}".format(vanilla_average_token_length)) 270 | 271 | # Length of encoded file data 272 | with open(partial_corpus_filepath) as input_file: 273 | input_file_data = input_file.read() 274 | 275 | encoded_data = [sg_bpe_model.id_to_piece(x) for x in sg_bpe_model.encode(input_file_data)] 276 | log.info("SG-BPE encoding length: {} (file length: {})".format(len(encoded_data), len(input_file_data))) 277 | 278 | # Length of encoded file data of vanilla-bpe 279 | vanilla_encoded_data = [bpe_vanilla_model.id_to_piece(x) for x in bpe_vanilla_model.encode(input_file_data)] 280 | log.info("Vanilla-BPE encoding length: {} (file length: {})".format(len(vanilla_encoded_data), len(input_file_data))) 281 | 282 | def prepare_parameters(): 283 | parser = argparse.ArgumentParser(description="Calculating SaGe vocabulary.") 284 | parser.add_argument("experiment_name", help="name of experiment, will save results under that name.") # Required param (no "--" prefix) 285 | parser.add_argument("--final_vocab_size", required=True, help="final vocabulary size") 286 | parser.add_argument("--initial_vocab_size", required=True, help="initial vocabulary size") 287 | parser.add_argument("--tokens_to_prune_in_iteration", required=True, help="number of tokens to prune in each iteration") 288 | parser.add_argument("--tokens_to_consider_in_iteration", required=True, help="number of tokens to consider in each iteration") 289 | parser.add_argument("--iterations_until_reranking", required=True, help="number of iterations until reranking") 290 | parser.add_argument("--is_continue", default="N", help="is this execution continues former execiution of that experiment: [Y/N]") 291 | parser.add_argument("--corpus_filepath", required=True, help="filepath for full corpus (like wiki corpus)") 292 | parser.add_argument("--thousands_of_corpus_lines", default=200, help="number of corpus lines - in thousands") 293 | parser.add_argument("--partial_corpus_filepath", required=True, help="where to create partial corpus file - with number of lines requested") 294 | parser.add_argument("--max_lines_per_token", default=1000, help="max number of lines to consider in objective calculation, per-token") 295 | parser.add_argument("--window_size", default=5, help="window size for SG objective calculation, and also for embeddings calculation") 296 | return vars(parser.parse_args()) 297 | 298 | if __name__ == "__main__": 299 | args = prepare_parameters() 300 | 301 | experiment_name = args["experiment_name"] 302 | print("Starting experiment {}".format(experiment_name)) 303 | 304 | is_continue_execution = True if args["is_continue"] == "Y" else False 305 | print("is_continue={}".format(is_continue_execution)) 306 | 307 | partial_corpus_lines_number = int(args["thousands_of_corpus_lines"]) * 1000 308 | final_vocab_size = int(args["final_vocab_size"]) 309 | initial_vocab_size = int(args["initial_vocab_size"]) 310 | tokens_to_prune_in_iteration = int(args["tokens_to_prune_in_iteration"]) 311 | tokens_to_consider_in_iteration = int(args["tokens_to_consider_in_iteration"]) 312 | iterations_until_reranking = int(args["iterations_until_reranking"]) 313 | max_lines_per_token = int(args["max_lines_per_token"]) 314 | corpus_filepath = args["corpus_filepath"] 315 | partial_corpus_filepath = args["partial_corpus_filepath"] 316 | window_size = int(args["window_size"]) 317 | 318 | main(experiment_name, is_continue_execution, final_vocab_size, \ 319 | initial_vocab_size, partial_corpus_lines_number, tokens_to_prune_in_iteration, tokens_to_consider_in_iteration, \ 320 | iterations_until_reranking, max_lines_per_token, corpus_filepath, partial_corpus_filepath, window_size) 321 | --------------------------------------------------------------------------------