├── requirements.txt ├── bert.py ├── classify_demo.py ├── pretrain_demo.py ├── README.md ├── .gitignore ├── utils.py └── tokenization.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.1 # used for the actual model code/weights 2 | six==1.16.0 # used by tokenizer 3 | requests==2.27.1 # used to download gpt-2 files from openai 4 | tqdm==4.64.0 # progress bar to keep your sanity 5 | fire==0.5.0 # easy CLI creation 6 | 7 | # used to load the gpt-2 weights from the open-ai tf checkpoint 8 | # M1 Macbooks require tensorflow-macos 9 | tensorflow==2.11.0; sys_platform != 'darwin' or platform_machine != 'arm64' 10 | tensorflow-macos==2.11.0; sys_platform == 'darwin' and platform_machine == 'arm64' 11 | -------------------------------------------------------------------------------- /bert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def gelu(x): 4 | return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) 5 | 6 | def softmax(x): 7 | exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) 8 | return exp_x / np.sum(exp_x, axis=-1, keepdims=True) 9 | 10 | def layer_norm(x, g, b, eps=1e-3): 11 | mean = np.mean(x, axis=-1, keepdims=True) 12 | variance = np.var(x, axis=-1, keepdims=True) 13 | return g * (x - mean) / np.sqrt(variance + eps) + b 14 | 15 | def linear(x, w, b): 16 | return x @ w + b 17 | 18 | def ffn(x, c_fc, c_proj): 19 | return linear(gelu(linear(x, **c_fc)), **c_proj) 20 | 21 | def attention(q, k, v): 22 | return softmax(q @ k.T / np.sqrt(q.shape[-1])) @ v 23 | 24 | def mha(x, c_attn, c_proj, n_head): 25 | x = linear(x, **c_attn) 26 | qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1))) 27 | out_heads = [attention(q, k, v) for q, k, v in zip(*qkv_heads)] 28 | x = linear(np.hstack(out_heads), **c_proj) 29 | return x 30 | 31 | def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): 32 | x = layer_norm(x + mha(x, **attn, n_head=n_head), **ln_1) 33 | x = layer_norm(x + ffn(x, **mlp), **ln_2) 34 | return x 35 | 36 | def bert(input_ids, segment_ids, wte, wpe, wse, ln_e, blocks, pooler, n_head): 37 | x = wte[input_ids] + wpe[range(len(input_ids))] + wse[segment_ids] 38 | x += layer_norm(x, **ln_e) 39 | 40 | for block in blocks: 41 | x = transformer_block(x, **block, n_head=n_head) 42 | 43 | seq_output = x 44 | pooled_output = np.tanh(linear(x[0], **pooler)) 45 | return seq_output, pooled_output 46 | -------------------------------------------------------------------------------- /classify_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from datasets import load_dataset 3 | from sklearn.linear_model import LogisticRegression 4 | from sklearn.metrics import classification_report 5 | from sklearn.model_selection import train_test_split 6 | from tqdm import tqdm 7 | 8 | from bert import bert 9 | from utils import load_tokenizer_hparams_and_params, mask_tokens, tokenize 10 | 11 | 12 | def main( 13 | dataset_name: str = "imdb", 14 | N: int = 1000, 15 | test_ratio: float = 0.2, 16 | model_name: str = "bert-base-uncased", 17 | models_dir: str = "models", 18 | seed: int = 123, 19 | ): 20 | np.random.seed(seed) 21 | 22 | # load tokenizer, hparams, and params 23 | tokenizer, hparams, params = load_tokenizer_hparams_and_params(model_name, models_dir) 24 | n_head = hparams["num_attention_heads"] 25 | max_len = hparams["max_position_embeddings"] 26 | 27 | # load dataset 28 | dataset = load_dataset(dataset_name, split="train").shuffle() 29 | 30 | # extract bert features 31 | X, y = [], [] 32 | for text, label in tqdm(zip(dataset[:N]["text"], dataset[:N]["label"]), total=N): 33 | _, input_ids, segment_ids = tokenize(tokenizer, text) 34 | input_ids, segment_ids = input_ids[:max_len], segment_ids[:max_len] 35 | _, pooled_output = bert(input_ids, segment_ids, **params["bert"], n_head=n_head) 36 | X.append(pooled_output) 37 | y.append(label) 38 | 39 | # train test split 40 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio, stratify=y) 41 | 42 | # train classifier 43 | classifier = LogisticRegression() 44 | classifier.fit(X_train, y_train) 45 | 46 | # predictions 47 | preds = classifier.predict(X_test) 48 | print(classification_report(y_test, preds)) 49 | 50 | 51 | if __name__ == "__main__": 52 | import fire 53 | 54 | fire.Fire(main) 55 | -------------------------------------------------------------------------------- /pretrain_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from bert import bert, layer_norm, linear 4 | from utils import load_tokenizer_hparams_and_params, mask_tokens, tokenize 5 | 6 | 7 | def mask_tokens(tokenizer, input_ids, mask_prob): 8 | CLS_ID, SEP_ID, MASK_ID = tokenizer.vocab["[CLS]"], tokenizer.vocab["[SEP]"], tokenizer.vocab["[MASK]"] 9 | 10 | masked_indices = np.random.choice(len(input_ids), int(len(input_ids) * mask_prob), replace=False) 11 | masked_indices = list(filter(lambda i: i not in {CLS_ID, SEP_ID}, masked_indices)) # dont mask cls or sep 12 | 13 | masked_input_ids = input_ids[:] 14 | for i in masked_indices: 15 | masked_input_ids[i] = MASK_ID 16 | 17 | return masked_input_ids, masked_indices 18 | 19 | 20 | def nsp_head(pooled_output, fc): 21 | return linear(pooled_output, **fc) 22 | 23 | 24 | def mlm_head(seq_output, wte, fc, ln, bias): 25 | x = linear(seq_output, **fc) 26 | x = layer_norm(x, **ln) 27 | return linear(x, wte.T, bias) 28 | 29 | 30 | def main( 31 | text_a: str, 32 | text_b: str = None, 33 | model_name: str = "bert-base-uncased", 34 | models_dir: str = "models", 35 | mask_prob: float = 0.15, 36 | seed: int = 123, 37 | verbose: bool = False, 38 | ): 39 | np.random.seed(seed) 40 | 41 | tokenizer, hparams, params = load_tokenizer_hparams_and_params(model_name, models_dir) 42 | 43 | tokens, input_ids, segment_ids = tokenize(tokenizer, text_a, text_b) 44 | masked_input_ids, masked_indices = mask_tokens(tokenizer, input_ids, mask_prob) 45 | 46 | assert len(input_ids) <= hparams["max_position_embeddings"] 47 | seq_output, pooled_output = bert( 48 | masked_input_ids, 49 | segment_ids, 50 | **params["bert"], 51 | n_head=hparams["num_attention_heads"], 52 | ) 53 | 54 | if mask_prob > 0: 55 | mlm_logits = mlm_head(seq_output, params["bert"]["wte"], **params["mlm"]) 56 | correct = [input_ids[i] == np.argmax(mlm_logits[i]) for i in masked_indices] 57 | 58 | if verbose: 59 | preds = np.argmax(mlm_logits, axis=-1) 60 | print(f"input = {tokenizer.convert_ids_to_tokens(masked_input_ids)}\n") 61 | for i in sorted(masked_indices): 62 | print(f"actual: {tokens[i]}\npred: {tokenizer.inv_vocab[preds[i]]}\n") 63 | 64 | print(f"mlm_accuracy = {sum(correct)}/{len(correct)} = {sum(correct)/len(correct)} ") 65 | if text_b: 66 | nsp_logits = nsp_head(pooled_output, **params["nsp"]) 67 | print(f"is_next_sentence = {np.argmax(nsp_logits) == 0}") 68 | 69 | 70 | if __name__ == "__main__": 71 | import fire 72 | 73 | fire.Fire(main) 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # picoBERT 2 | Like [picoGPT](https://github.com/jaymody/picoGPT), but for [BERT](https://arxiv.org/pdf/1810.04805.pdf). 3 | 4 | #### Dependencies 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | Tested on `Python 3.9.10`. 9 | 10 | #### Usage 11 | * `bert.py` contains the actual BERT model code. 12 | * `utils.py` includes utility code to download, load, and tokenize stuff. 13 | * `tokenization.py` includes BERT WordPiece tokenizer code. 14 | * `pretrain_demo.py` code to demo BERT doing pre-training tasks (MLM and NSP). 15 | * `classify_demo.py` code to demo training an SKLearn classifier using the BERT output embeddings as input. This is not the same as actually fine-tuning the BERT model. 16 | 17 | To demo BERT on pre-training tasks: 18 | 19 | ```bash 20 | python pretrain_demo.py \ 21 | --text_a "The apple doesn't fall far from the tree." \ 22 | --text_b "Instead, it falls on Newton's head." \ 23 | --model_name "bert-base-uncased" \ 24 | --mask_prob 0.20 25 | ``` 26 | 27 | Which outputs: 28 | 29 | ```text 30 | mlm_accuracy = 0.75 31 | is_next_sentence = True 32 | ``` 33 | 34 | If we add the `--verbose` flag, we can also see where the model went wrong with masked language modeling: 35 | 36 | ```text 37 | input = ['[CLS]', 'the', 'apple', 'doesn', "'", '[MASK]', 'fall', 'far', 'from', 'the', 'tree', '.', '[SEP]', 'instead', ',', 'it', 'falls', 'on', '[MASK]', "'", '[MASK]', '[MASK]', '.', '[SEP]'] 38 | 39 | actual: t 40 | pred: t 41 | 42 | actual: newton 43 | pred: one 44 | 45 | actual: s 46 | pred: s 47 | 48 | actual: head 49 | pred: head 50 | ``` 51 | 52 | Instead of predicting the word "newton", it predicted the word "one", which still gives a valid sentence "Instead, it falls on one's head.". 53 | 54 | For a demo of training an SKLearn classifier for the [IMDB dataset](https://huggingface.co/datasets/imdb), using BERT output embeddings as input to the classifier: 55 | ```bash 56 | python classify_demo.py 57 | dataset_name "imdb" \ 58 | N 1000 \ 59 | test_ratio 0.2 \ 60 | model_name "bert-base-uncased" \ 61 | models_dir "models" 62 | ``` 63 | 64 | Which outputs (note, it takes a while to run the BERT model and extract all the embeddings): 65 | 66 | ```text 67 | precision recall f1-score support 68 | 69 | 0 0.78 0.85 0.81 104 70 | 1 0.82 0.74 0.78 96 71 | 72 | accuracy 0.80 200 73 | macro avg 0.80 0.79 0.79 200 74 | weighted avg 0.80 0.80 0.79 200 75 | ``` 76 | 77 | Not bad, 80% accuracy using only 800 training examples and a simple SKLearn model. Of course, fine-tuning the entire model over all the training examples would yield much better results. 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Project ### 2 | /models/ 3 | 4 | 5 | ### Python ### 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 111 | __pypackages__/ 112 | 113 | # Celery stuff 114 | celerybeat-schedule 115 | celerybeat.pid 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .env 122 | .venv 123 | env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | # pytype static type analyzer 148 | .pytype/ 149 | 150 | # Cython debug symbols 151 | cython_debug/ 152 | 153 | # PyCharm 154 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 155 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 156 | # and can be added to the global gitignore or merged into this file. For a more nuclear 157 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 158 | #.idea/ 159 | 160 | 161 | ### macOS ### 162 | 163 | # General 164 | .DS_Store 165 | .AppleDouble 166 | .LSOverride 167 | 168 | # Icon must end with two \r 169 | Icon 170 | 171 | 172 | # Thumbnails 173 | ._* 174 | 175 | # Files that might appear in the root of a volume 176 | .DocumentRevisions-V100 177 | .fseventsd 178 | .Spotlight-V100 179 | .TemporaryItems 180 | .Trashes 181 | .VolumeIcon.icns 182 | .com.apple.timemachine.donotpresent 183 | 184 | # Directories potentially created on remote AFP share 185 | .AppleDB 186 | .AppleDesktop 187 | Network Trash Folder 188 | Temporary Items 189 | .apdisk 190 | 191 | 192 | ### Linux ### 193 | 194 | *~ 195 | 196 | # temporary files which can be created if a process still has a handle open of a deleted file 197 | .fuse_hidden* 198 | 199 | # KDE directory preferences 200 | .directory 201 | 202 | # Linux trash folder which might appear on any partition or disk 203 | .Trash-* 204 | 205 | # .nfs files are created when an open file is removed but is still being accessed 206 | .nfs* 207 | 208 | 209 | ### Windows ### 210 | 211 | # Windows thumbnail cache files 212 | Thumbs.db 213 | Thumbs.db:encryptable 214 | ehthumbs.db 215 | ehthumbs_vista.db 216 | 217 | # Dump file 218 | *.stackdump 219 | 220 | # Folder config file 221 | [Dd]esktop.ini 222 | 223 | # Recycle Bin used on file shares 224 | $RECYCLE.BIN/ 225 | 226 | # Windows Installer files 227 | *.cab 228 | *.msi 229 | *.msix 230 | *.msm 231 | *.msp 232 | 233 | # Windows shortcuts 234 | *.lnk 235 | 236 | 237 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import zipfile 5 | 6 | import numpy as np 7 | import requests 8 | import tensorflow as tf 9 | from tqdm import tqdm 10 | 11 | from tokenization import FullTokenizer 12 | 13 | model_name_to_url = { 14 | "bert-tiny-uncased": "https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip", 15 | "bert-mini-uncased": "https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-4_H-256_A-4.zip", 16 | "bert-small-uncased": "https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-4_H-512_A-8.zip", 17 | "bert-medium-uncased": "https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-8_H-512_A-8.zip", 18 | "bert-base-uncased": "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip", 19 | "bert-base-cased": "https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip", 20 | "bert-large-uncased": "https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip", 21 | "bert-large-cased": "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip", 22 | "bert-base-multilingual-cased": "https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip", 23 | } 24 | 25 | params_key_map = { 26 | # block 27 | "attention/output/LayerNorm/gamma": "ln_1/g", 28 | "attention/output/LayerNorm/beta": "ln_1/b", 29 | "attention/output/dense/kernel": "attn/c_proj/w", 30 | "attention/output/dense/bias": "attn/c_proj/b", 31 | "attention/self/query/kernel": "attn/q/w", 32 | "attention/self/query/bias": "attn/q/b", 33 | "attention/self/key/kernel": "attn/k/w", 34 | "attention/self/key/bias": "attn/k/b", 35 | "attention/self/value/kernel": "attn/v/w", 36 | "attention/self/value/bias": "attn/v/b", 37 | "intermediate/dense/kernel": "mlp/c_fc/w", 38 | "intermediate/dense/bias": "mlp/c_fc/b", 39 | "output/dense/kernel": "mlp/c_proj/w", 40 | "output/dense/bias": "mlp/c_proj/b", 41 | "output/LayerNorm/gamma": "ln_2/g", 42 | "output/LayerNorm/beta": "ln_2/b", 43 | # top level 44 | "bert/embeddings/LayerNorm/gamma": "bert/ln_e/g", 45 | "bert/embeddings/LayerNorm/beta": "bert/ln_e/b", 46 | "bert/embeddings/position_embeddings": "bert/wpe", 47 | "bert/embeddings/word_embeddings": "bert/wte", 48 | "bert/embeddings/token_type_embeddings": "bert/wse", 49 | "bert/pooler/dense/kernel": "bert/pooler/w", 50 | "bert/pooler/dense/bias": "bert/pooler/b", 51 | "cls/predictions/output_bias": "mlm/bias", 52 | "cls/predictions/transform/dense/kernel": "mlm/fc/w", 53 | "cls/predictions/transform/dense/bias": "mlm/fc/b", 54 | "cls/predictions/transform/LayerNorm/gamma": "mlm/ln/g", 55 | "cls/predictions/transform/LayerNorm/beta": "mlm/ln/b", 56 | "cls/seq_relationship/output_weights": "nsp/fc/w", 57 | "cls/seq_relationship/output_bias": "nsp/fc/b", 58 | } 59 | 60 | 61 | def download_bert_files(model_name, model_dir): 62 | url = model_name_to_url[model_name] 63 | zip_fpath = os.path.join(model_dir, os.path.basename(url)) 64 | 65 | r = requests.get(url, stream=True) 66 | r.raise_for_status() 67 | file_size = int(r.headers["content-length"]) 68 | chunk_size = 1000 69 | with open(zip_fpath, "wb") as f: 70 | with tqdm(ncols=100, desc=f"downloading zip files.", total=file_size, unit_scale=True) as pbar: 71 | for chunk in r.iter_content(chunk_size=chunk_size): 72 | f.write(chunk) 73 | pbar.update(chunk_size) 74 | 75 | with zipfile.ZipFile(zip_fpath) as f: 76 | # we do this hack instead of simply f.extractall(model_dir) since the older 77 | # zipfiles released by google are nested inside a folder 78 | for finfo in f.infolist(): 79 | if finfo.filename[-1] == "/": 80 | continue 81 | finfo.filename = os.path.basename(finfo.filename) 82 | f.extract(finfo, model_dir) 83 | 84 | 85 | def load_bert_params_from_tf_ckpt(tf_ckpt_path, hparams): 86 | def set_in_nested_dict(d, keys, val): 87 | if not keys: 88 | return val 89 | if keys[0] not in d: 90 | d[keys[0]] = {} 91 | d[keys[0]] = set_in_nested_dict(d[keys[0]], keys[1:], val) 92 | return d 93 | 94 | params = {"bert": {"blocks": [{} for _ in range(hparams["num_hidden_layers"])]}} 95 | for name, _ in tf.train.list_variables(tf_ckpt_path): 96 | array = np.squeeze(tf.train.load_variable(tf_ckpt_path, name)) 97 | if name.startswith("bert/encoder/layer_"): 98 | m = re.match(r"bert/encoder/layer_([0-9]+)/(.*)", name) 99 | n = int(m[1]) 100 | sub_name = params_key_map[m[2]] 101 | set_in_nested_dict(params["bert"]["blocks"][n], sub_name.split("/"), array) 102 | else: 103 | name = params_key_map[name] 104 | set_in_nested_dict(params, name.split("/"), array) 105 | 106 | # combine the q, k, v weights and biases 107 | for i, block in enumerate(params["bert"]["blocks"]): 108 | attn = block["attn"] 109 | q, k, v = attn.pop("q"), attn.pop("k"), attn.pop("v") 110 | params["bert"]["blocks"][i]["attn"]["c_attn"] = {} 111 | params["bert"]["blocks"][i]["attn"]["c_attn"]["w"] = np.concatenate([q["w"], k["w"], v["w"]], axis=-1) 112 | params["bert"]["blocks"][i]["attn"]["c_attn"]["b"] = np.concatenate([q["b"], k["b"], v["b"]], axis=-1) 113 | 114 | # we need to transpose the following parameter 115 | params["nsp"]["fc"]["w"] = params["nsp"]["fc"]["w"].T 116 | 117 | return params 118 | 119 | 120 | def load_tokenizer_hparams_and_params(model_name, models_dir): 121 | assert model_name in model_name_to_url 122 | 123 | is_uncased = "uncased" in model_name 124 | 125 | model_dir = os.path.join(models_dir, model_name) 126 | hparams_path = os.path.join(model_dir, "bert_config.json") 127 | vocab_path = os.path.join(model_dir, "vocab.txt") 128 | tf_ckpt_path = os.path.join(model_dir, "bert_model.ckpt") 129 | 130 | if not os.path.isfile(hparams_path): # download files if necessary 131 | os.makedirs(model_dir, exist_ok=True) 132 | download_bert_files(model_name, model_dir) 133 | 134 | tokenizer = FullTokenizer(vocab_path, do_lower_case=is_uncased) 135 | hparams = json.load(open(hparams_path)) 136 | params = load_bert_params_from_tf_ckpt(tf_ckpt_path, hparams) 137 | 138 | return tokenizer, hparams, params 139 | 140 | 141 | def tokenize(tokenizer, text_a, text_b=None): 142 | tokens_a = ["[CLS]"] + tokenizer.tokenize(text_a) + ["[SEP]"] 143 | tokens_b = (tokenizer.tokenize(text_b) + ["[SEP]"]) if text_b else [] 144 | 145 | tokens = tokens_a + tokens_b 146 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 147 | segment_ids = [0] * len(tokens_a) + [1] * len(tokens_b) 148 | 149 | return tokens, input_ids, segment_ids 150 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | """BERT WordPiece Tokenizer. 2 | 3 | Copied from: https://github.com/google-research/bert/blob/master/tokenization.py 4 | """ 5 | # coding=utf-8 6 | # Copyright 2018 The Google AI Language Team Authors. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from __future__ import absolute_import, division, print_function 21 | 22 | import collections 23 | import re 24 | import unicodedata 25 | 26 | import six 27 | 28 | 29 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 30 | """Checks whether the casing config is consistent with the checkpoint name.""" 31 | 32 | # The casing has to be passed in by the user and there is no explicit check 33 | # as to whether it matches the checkpoint. The casing information probably 34 | # should have been stored in the bert_config.json file, but it's not, so 35 | # we have to heuristically detect it to validate. 36 | 37 | if not init_checkpoint: 38 | return 39 | 40 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 41 | if m is None: 42 | return 43 | 44 | model_name = m.group(1) 45 | 46 | lower_models = [ 47 | "uncased_L-24_H-1024_A-16", 48 | "uncased_L-12_H-768_A-12", 49 | "multilingual_L-12_H-768_A-12", 50 | "chinese_L-12_H-768_A-12", 51 | ] 52 | 53 | cased_models = [ 54 | "cased_L-12_H-768_A-12", 55 | "cased_L-24_H-1024_A-16", 56 | "multi_cased_L-12_H-768_A-12", 57 | ] 58 | 59 | is_bad_config = False 60 | if model_name in lower_models and not do_lower_case: 61 | is_bad_config = True 62 | actual_flag = "False" 63 | case_name = "lowercased" 64 | opposite_flag = "True" 65 | 66 | if model_name in cased_models and do_lower_case: 67 | is_bad_config = True 68 | actual_flag = "True" 69 | case_name = "cased" 70 | opposite_flag = "False" 71 | 72 | if is_bad_config: 73 | raise ValueError( 74 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 75 | "However, `%s` seems to be a %s model, so you " 76 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 77 | "how the model was pre-training. If this error is wrong, please " 78 | "just comment out this check." 79 | % (actual_flag, init_checkpoint, model_name, case_name, opposite_flag) 80 | ) 81 | 82 | 83 | def convert_to_unicode(text): 84 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 85 | if six.PY3: 86 | if isinstance(text, str): 87 | return text 88 | elif isinstance(text, bytes): 89 | return text.decode("utf-8", "ignore") 90 | else: 91 | raise ValueError("Unsupported string type: %s" % (type(text))) 92 | elif six.PY2: 93 | if isinstance(text, str): 94 | return text.decode("utf-8", "ignore") 95 | elif isinstance(text, unicode): 96 | return text 97 | else: 98 | raise ValueError("Unsupported string type: %s" % (type(text))) 99 | else: 100 | raise ValueError("Not running on Python2 or Python 3?") 101 | 102 | 103 | def printable_text(text): 104 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 105 | 106 | # These functions want `str` for both Python2 and Python3, but in one case 107 | # it's a Unicode string and in the other it's a byte string. 108 | if six.PY3: 109 | if isinstance(text, str): 110 | return text 111 | elif isinstance(text, bytes): 112 | return text.decode("utf-8", "ignore") 113 | else: 114 | raise ValueError("Unsupported string type: %s" % (type(text))) 115 | elif six.PY2: 116 | if isinstance(text, str): 117 | return text 118 | elif isinstance(text, unicode): 119 | return text.encode("utf-8") 120 | else: 121 | raise ValueError("Unsupported string type: %s" % (type(text))) 122 | else: 123 | raise ValueError("Not running on Python2 or Python 3?") 124 | 125 | 126 | def load_vocab(vocab_file): 127 | """Loads a vocabulary file into a dictionary.""" 128 | vocab = collections.OrderedDict() 129 | index = 0 130 | with open(vocab_file, "r") as reader: 131 | while True: 132 | token = convert_to_unicode(reader.readline()) 133 | if not token: 134 | break 135 | token = token.strip() 136 | vocab[token] = index 137 | index += 1 138 | return vocab 139 | 140 | 141 | def convert_by_vocab(vocab, items): 142 | """Converts a sequence of [tokens|ids] using the vocab.""" 143 | output = [] 144 | for item in items: 145 | output.append(vocab[item]) 146 | return output 147 | 148 | 149 | def convert_tokens_to_ids(vocab, tokens): 150 | return convert_by_vocab(vocab, tokens) 151 | 152 | 153 | def convert_ids_to_tokens(inv_vocab, ids): 154 | return convert_by_vocab(inv_vocab, ids) 155 | 156 | 157 | def whitespace_tokenize(text): 158 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 159 | text = text.strip() 160 | if not text: 161 | return [] 162 | tokens = text.split() 163 | return tokens 164 | 165 | 166 | class FullTokenizer(object): 167 | """Runs end-to-end tokenziation.""" 168 | 169 | def __init__(self, vocab_file, do_lower_case=True): 170 | self.vocab = load_vocab(vocab_file) 171 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 172 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 173 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 174 | 175 | def tokenize(self, text): 176 | split_tokens = [] 177 | for token in self.basic_tokenizer.tokenize(text): 178 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 179 | split_tokens.append(sub_token) 180 | 181 | return split_tokens 182 | 183 | def convert_tokens_to_ids(self, tokens): 184 | return convert_by_vocab(self.vocab, tokens) 185 | 186 | def convert_ids_to_tokens(self, ids): 187 | return convert_by_vocab(self.inv_vocab, ids) 188 | 189 | 190 | class BasicTokenizer(object): 191 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 192 | 193 | def __init__(self, do_lower_case=True): 194 | """Constructs a BasicTokenizer. 195 | Args: 196 | do_lower_case: Whether to lower case the input. 197 | """ 198 | self.do_lower_case = do_lower_case 199 | 200 | def tokenize(self, text): 201 | """Tokenizes a piece of text.""" 202 | text = convert_to_unicode(text) 203 | text = self._clean_text(text) 204 | 205 | # This was added on November 1st, 2018 for the multilingual and Chinese 206 | # models. This is also applied to the English models now, but it doesn't 207 | # matter since the English models were not trained on any Chinese data 208 | # and generally don't have any Chinese data in them (there are Chinese 209 | # characters in the vocabulary because Wikipedia does have some Chinese 210 | # words in the English Wikipedia.). 211 | text = self._tokenize_chinese_chars(text) 212 | 213 | orig_tokens = whitespace_tokenize(text) 214 | split_tokens = [] 215 | for token in orig_tokens: 216 | if self.do_lower_case: 217 | token = token.lower() 218 | token = self._run_strip_accents(token) 219 | split_tokens.extend(self._run_split_on_punc(token)) 220 | 221 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 222 | return output_tokens 223 | 224 | def _run_strip_accents(self, text): 225 | """Strips accents from a piece of text.""" 226 | text = unicodedata.normalize("NFD", text) 227 | output = [] 228 | for char in text: 229 | cat = unicodedata.category(char) 230 | if cat == "Mn": 231 | continue 232 | output.append(char) 233 | return "".join(output) 234 | 235 | def _run_split_on_punc(self, text): 236 | """Splits punctuation on a piece of text.""" 237 | chars = list(text) 238 | i = 0 239 | start_new_word = True 240 | output = [] 241 | while i < len(chars): 242 | char = chars[i] 243 | if _is_punctuation(char): 244 | output.append([char]) 245 | start_new_word = True 246 | else: 247 | if start_new_word: 248 | output.append([]) 249 | start_new_word = False 250 | output[-1].append(char) 251 | i += 1 252 | 253 | return ["".join(x) for x in output] 254 | 255 | def _tokenize_chinese_chars(self, text): 256 | """Adds whitespace around any CJK character.""" 257 | output = [] 258 | for char in text: 259 | cp = ord(char) 260 | if self._is_chinese_char(cp): 261 | output.append(" ") 262 | output.append(char) 263 | output.append(" ") 264 | else: 265 | output.append(char) 266 | return "".join(output) 267 | 268 | def _is_chinese_char(self, cp): 269 | """Checks whether CP is the codepoint of a CJK character.""" 270 | # This defines a "chinese character" as anything in the CJK Unicode block: 271 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 272 | # 273 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 274 | # despite its name. The modern Korean Hangul alphabet is a different block, 275 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 276 | # space-separated words, so they are not treated specially and handled 277 | # like the all of the other languages. 278 | if ( 279 | (cp >= 0x4E00 and cp <= 0x9FFF) 280 | or (cp >= 0x3400 and cp <= 0x4DBF) # 281 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 282 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 283 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 284 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 285 | or (cp >= 0xF900 and cp <= 0xFAFF) 286 | or (cp >= 0x2F800 and cp <= 0x2FA1F) # 287 | ): # 288 | return True 289 | 290 | return False 291 | 292 | def _clean_text(self, text): 293 | """Performs invalid character removal and whitespace cleanup on text.""" 294 | output = [] 295 | for char in text: 296 | cp = ord(char) 297 | if cp == 0 or cp == 0xFFFD or _is_control(char): 298 | continue 299 | if _is_whitespace(char): 300 | output.append(" ") 301 | else: 302 | output.append(char) 303 | return "".join(output) 304 | 305 | 306 | class WordpieceTokenizer(object): 307 | """Runs WordPiece tokenziation.""" 308 | 309 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 310 | self.vocab = vocab 311 | self.unk_token = unk_token 312 | self.max_input_chars_per_word = max_input_chars_per_word 313 | 314 | def tokenize(self, text): 315 | """Tokenizes a piece of text into its word pieces. 316 | This uses a greedy longest-match-first algorithm to perform tokenization 317 | using the given vocabulary. 318 | For example: 319 | input = "unaffable" 320 | output = ["un", "##aff", "##able"] 321 | Args: 322 | text: A single token or whitespace separated tokens. This should have 323 | already been passed through `BasicTokenizer. 324 | Returns: 325 | A list of wordpiece tokens. 326 | """ 327 | 328 | text = convert_to_unicode(text) 329 | 330 | output_tokens = [] 331 | for token in whitespace_tokenize(text): 332 | chars = list(token) 333 | if len(chars) > self.max_input_chars_per_word: 334 | output_tokens.append(self.unk_token) 335 | continue 336 | 337 | is_bad = False 338 | start = 0 339 | sub_tokens = [] 340 | while start < len(chars): 341 | end = len(chars) 342 | cur_substr = None 343 | while start < end: 344 | substr = "".join(chars[start:end]) 345 | if start > 0: 346 | substr = "##" + substr 347 | if substr in self.vocab: 348 | cur_substr = substr 349 | break 350 | end -= 1 351 | if cur_substr is None: 352 | is_bad = True 353 | break 354 | sub_tokens.append(cur_substr) 355 | start = end 356 | 357 | if is_bad: 358 | output_tokens.append(self.unk_token) 359 | else: 360 | output_tokens.extend(sub_tokens) 361 | return output_tokens 362 | 363 | 364 | def _is_whitespace(char): 365 | """Checks whether `chars` is a whitespace character.""" 366 | # \t, \n, and \r are technically contorl characters but we treat them 367 | # as whitespace since they are generally considered as such. 368 | if char == " " or char == "\t" or char == "\n" or char == "\r": 369 | return True 370 | cat = unicodedata.category(char) 371 | if cat == "Zs": 372 | return True 373 | return False 374 | 375 | 376 | def _is_control(char): 377 | """Checks whether `chars` is a control character.""" 378 | # These are technically control characters but we count them as whitespace 379 | # characters. 380 | if char == "\t" or char == "\n" or char == "\r": 381 | return False 382 | cat = unicodedata.category(char) 383 | if cat in ("Cc", "Cf"): 384 | return True 385 | return False 386 | 387 | 388 | def _is_punctuation(char): 389 | """Checks whether `chars` is a punctuation character.""" 390 | cp = ord(char) 391 | # We treat all non-letter/number ASCII as punctuation. 392 | # Characters such as "^", "$", and "`" are not in the Unicode 393 | # Punctuation class but we treat them as punctuation anyways, for 394 | # consistency. 395 | if ( 396 | (cp >= 33 and cp <= 47) 397 | or (cp >= 58 and cp <= 64) 398 | or (cp >= 91 and cp <= 96) 399 | or (cp >= 123 and cp <= 126) 400 | ): 401 | return True 402 | cat = unicodedata.category(char) 403 | if cat.startswith("P"): 404 | return True 405 | return False 406 | --------------------------------------------------------------------------------