├── CODEOWNERS ├── configs ├── dataset_configs │ ├── example.json │ ├── pile.json │ └── openwebtext2_new_inputs.json ├── gpt2_small.json ├── gpt3_6-7B_256.json ├── gpt3_medium_256.json ├── gpt3_small_256.json ├── gpt3_2-7B_256.json ├── gpt3_PAR_small_256.json ├── gpt3_XL_256_Pile.json ├── gpt3_13B_256_Pile.json ├── gpt3_large_256.json └── gpt3_13B_256.json ├── requirements.txt ├── Dockerfile ├── export.py ├── CITATION.bib ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ └── pytest.yml ├── data ├── encoders.py ├── train_tokenizer.py └── create_tfrecords.py ├── encoders.py ├── LICENSE ├── docker-compose.yml ├── .gitignore ├── configs.py ├── models ├── activations.py ├── utils.py ├── gpt2 │ └── gpt2.py └── layers.py ├── tasks.py ├── test_models.py ├── optimizers.py ├── sample.py ├── run_experiment.py ├── utils.py ├── main.py ├── model_fns.py ├── inputs.py └── README.md /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * EleutherAI/pm-gptneo 2 | -------------------------------------------------------------------------------- /configs/dataset_configs/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_vocab": 32768, 3 | "path": "./tfrecords/openwebtext_*.tfrecords", 4 | "eval_path": "", 5 | "tokenizer_path": "./datasets/openwebtext/byte-level-bpe.tokenizer.json", 6 | "eos_id": 1, 7 | "padding_id": 0 8 | } 9 | -------------------------------------------------------------------------------- /configs/dataset_configs/pile.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_vocab": 50257, 3 | "path": "gs://neo-datasets/pile/pile_*.tfrecords", 4 | "eval_path": "gs://neo-datasets/pile_val.tfrecords", 5 | "tokenizer_is_pretrained": true, 6 | "tokenizer_path": "gpt2", 7 | "eos_id": 50256, 8 | "padding_id": 50257 9 | } 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | google-api-python-client 2 | jsonlines 3 | lm_dataformat 4 | mesh-tensorflow==0.1.18 5 | numpy 6 | oauth2client 7 | ortools 8 | pytest 9 | sacred 10 | tensorflow==2.5.1 11 | tensorflow-datasets==3.2.1 12 | tokenizers==0.9.4 13 | transformers==4.1.1 14 | tpunicorn 15 | absl-py 16 | ftfy 17 | sacred 18 | pymongo 19 | -------------------------------------------------------------------------------- /configs/dataset_configs/openwebtext2_new_inputs.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_vocab": 50257, 3 | "path": "gs://neo-datasets/openwebtext2_new_inputs/train/*.tfrecords", 4 | "eval_path": "gs://neo-datasets/openwebtext2_new_inputs/eval/*.tfrecords", 5 | "tokenizer_is_pretrained": true, 6 | "tokenizer_path": "gpt2", 7 | "eos_id": 50256, 8 | "padding_id": 50257 9 | } 10 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/deeplearning-platform-release/tf-cpu.1-15 2 | 3 | WORKDIR /neogpt 4 | 5 | # Make RUN commands use `bash --login`: 6 | SHELL ["/bin/bash", "--login", "-c"] 7 | ENV DEBIAN_FRONTEND=noninteractive 8 | RUN apt-get update -y && apt-get install tmux -y 9 | RUN conda install gcc_linux-64 gxx_linux-64 -y 10 | ADD requirements.txt . 11 | RUN pip install -r requirements.txt 12 | RUN apt-get install screen htop -y 13 | RUN python -m pip install tensorboard==1.15 cloud_tpu_profiler==1.15 14 | 15 | CMD tmux -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | 3 | def export_model(estimator, export_dir, params, 4 | checkpoint_path=None): 5 | 6 | 7 | def serving_input_receiver_fn(): 8 | t = tf.placeholder(dtype=tf.int64, 9 | shape=[1, params["n_ctx"]], 10 | name='input_example_tensor') 11 | return tf.estimator.export.ServingInputReceiver(t, t) 12 | 13 | return estimator.export_saved_model( 14 | export_dir, serving_input_receiver_fn, checkpoint_path=checkpoint_path) -------------------------------------------------------------------------------- /CITATION.bib: -------------------------------------------------------------------------------- 1 | @software{gpt-neo, 2 | author = {Black, Sid and 3 | Gao, Leo and 4 | Wang, Phil and 5 | Leahy, Connor and 6 | Biderman, Stella}, 7 | title = {{GPT-Neo: Large Scale Autoregressive Language 8 | Modeling with Mesh-Tensorflow}}, 9 | month = mar, 10 | year = 2021, 11 | publisher = {Zenodo}, 12 | version = {1.0}, 13 | doi = {10.5281/zenodo.5297715}, 14 | url = {https://doi.org/10.5281/zenodo.5297715} 15 | } 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Proposed solution** 24 | If you have an idea for how we can fix this problem, describe it here. 25 | 26 | **Screenshots** 27 | If applicable, add screenshots to help explain your problem. 28 | 29 | **Environment (please complete the following information):** 30 | - GPUs: 31 | - Configs: 32 | 33 | **Additional context** 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /data/encoders.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer 2 | from transformers import GPT2Tokenizer, GPT2TokenizerFast 3 | 4 | def fetch_encoder(params): 5 | no_dataset = params.get('no_dataset', False) 6 | if no_dataset: 7 | return None 8 | 9 | dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict 10 | path = dataset["tokenizer_path"] 11 | is_pretrained = dataset.get("tokenizer_is_pretrained", False) 12 | 13 | if is_pretrained: 14 | tok = GPT2TokenizerFast.from_pretrained(path) 15 | 16 | # Will add a padding token id of 50257 at run-time 17 | tok.add_special_tokens({'pad_token': '<|padding|>'}) 18 | return tok 19 | 20 | return Tokenizer.from_file(path) 21 | 22 | 23 | # GPT2Tokenizer and Tokenizer have different ways of fetching token ids 24 | def encode(encoder, text): 25 | result = encoder.encode(text) 26 | if isinstance(result, list): 27 | return result 28 | return result.ids 29 | -------------------------------------------------------------------------------- /encoders.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Tokenizer 2 | from transformers import GPT2Tokenizer, GPT2TokenizerFast 3 | 4 | def fetch_encoder(params): 5 | no_dataset = params.get('no_dataset', False) 6 | if no_dataset: 7 | return None 8 | 9 | dataset = next(iter(params['dataset_configs'].values())) # Get the first value from the dict 10 | path = dataset["tokenizer_path"] 11 | is_pretrained = dataset.get("tokenizer_is_pretrained", False) 12 | 13 | if is_pretrained: 14 | tok = GPT2TokenizerFast.from_pretrained(path) 15 | 16 | # Will add a padding token id of 50257 at run-time 17 | tok.add_special_tokens({'pad_token': '<|padding|>'}) 18 | return tok 19 | 20 | return Tokenizer.from_file(path) 21 | 22 | 23 | # GPT2Tokenizer and Tokenizer have different ways of fetching token ids 24 | def encode(encoder, text, gpt=True): 25 | result = encoder.encode(text) 26 | if isinstance(result, list): 27 | return result 28 | return result.ids 29 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6, 3.7] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Test with pytest 32 | run: | 33 | pytest 34 | -------------------------------------------------------------------------------- /configs/gpt2_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 6, 3 | "n_vocab": 50257, 4 | "embed_dropout": 0.1, 5 | "lr": 0.0006, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "opt_name": "adam", 12 | "weight_decay": 0, 13 | "train_batch_size": 512, 14 | "attn_dropout": 0.1, 15 | "train_steps": 1000000, 16 | "lr_decay_end": 300000, 17 | "eval_steps": 30, 18 | "predict_steps": 0, 19 | "res_dropout": 0.1, 20 | "eval_batch_size": 128, 21 | "predict_batch_size": 8, 22 | "iterations": 2500, 23 | "n_embd": 768, 24 | "datasets": ["openwebtext2_new_inputs"], 25 | "model_path": "gs://neo-models/GPT2_SMALL", 26 | "n_ctx": 1024, 27 | "n_layer": 12, 28 | "scale_by_depth": true, 29 | "scale_by_in": false, 30 | "attention_types" : [[["global"],12]], 31 | "activation_function": "gelu", 32 | "mesh_shape": "all:64", 33 | "layout": "batch:all", 34 | "recompute_grad": false, 35 | "gradient_clipping": 1.0 36 | } -------------------------------------------------------------------------------- /configs/gpt3_6-7B_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 32, 3 | "n_vocab": 50257, 4 | "embed_dropout": 0, 5 | "lr": 0.00012, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "opt_name": "adam", 12 | "weight_decay": 0.10, 13 | "train_batch_size": 1024, 14 | "attn_dropout": 0, 15 | "train_steps": 143075, 16 | "eval_steps": 0, 17 | "predict_steps": 1, 18 | "res_dropout": 0, 19 | "eval_batch_size": 128, 20 | "predict_batch_size": 1, 21 | "iterations": 500, 22 | "n_embd": 4096, 23 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 24 | "model_path": "gs://neo-models/GPT3_6-7B", 25 | "n_ctx": 2048, 26 | "n_layer": 32, 27 | "scale_by_depth": true, 28 | "scale_by_in": false, 29 | "attention_types" : [[["global"],32]], 30 | "mesh_shape": "x:128,y:2", 31 | "layout": "embd:y,batch:x", 32 | "activation_function": "gelu", 33 | "recompute_grad": true, 34 | "gradient_clipping": 1.0 35 | } 36 | 37 | -------------------------------------------------------------------------------- /configs/gpt3_medium_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 16, 3 | "n_vocab": 50304, 4 | "embed_dropout": 0, 5 | "lr": 0.0003, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "opt_name": "adam", 12 | "weight_decay": 0.10, 13 | "train_batch_size": 256, 14 | "attn_dropout": 0, 15 | "train_steps": 572300, 16 | "eval_steps": 0, 17 | "predict_steps": 1, 18 | "res_dropout": 0, 19 | "eval_batch_size": 64, 20 | "predict_batch_size": 1, 21 | "iterations": 2500, 22 | "n_embd": 1024, 23 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 24 | "model_path": "gs://neo-models/GPT3_MEDIUM", 25 | "n_ctx": 2048, 26 | "n_layer": 24, 27 | "scale_by_depth": true, 28 | "scale_by_in": false, 29 | "attention_types" : [[["global"],24]], 30 | "mesh_shape": "x:64,y:4", 31 | "layout": "batch:x,heads:y,vocab:y", 32 | "activation_function": "gelu", 33 | "recompute_grad": false, 34 | "gradient_clipping": 1.0 35 | } 36 | 37 | -------------------------------------------------------------------------------- /configs/gpt3_small_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 12, 3 | "n_vocab": 50304, 4 | "embed_dropout": 0, 5 | "lr": 0.0006, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "opt_name": "adam", 12 | "weight_decay": 0.10, 13 | "train_batch_size": 256, 14 | "attn_dropout": 0, 15 | "train_steps": 572300, 16 | "eval_steps": 0, 17 | "predict_steps": 1, 18 | "res_dropout": 0, 19 | "eval_batch_size": 64, 20 | "predict_batch_size": 1, 21 | "iterations": 2500, 22 | "n_embd": 768, 23 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 24 | "model_path": "gs://neo-models/GPT3_SMALL", 25 | "n_ctx": 2048, 26 | "n_layer": 12, 27 | "scale_by_depth": true, 28 | "scale_by_in": false, 29 | "attention_types": [[["global"],12]], 30 | "mesh_shape": "x:64,y:4", 31 | "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y", 32 | "activation_function": "gelu", 33 | "recompute_grad": false, 34 | "gradient_clipping": 1.0 35 | } 36 | 37 | -------------------------------------------------------------------------------- /configs/gpt3_2-7B_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 32, 3 | "n_vocab": 50257, 4 | "embed_dropout": 0, 5 | "lr": 0.00016, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "ada_epsilon1": 1e-30, 12 | "ada_epsilon2": 1e-3, 13 | "opt_name": "adam", 14 | "weight_decay": 0.10, 15 | "train_batch_size": 512, 16 | "attn_dropout": 0, 17 | "train_steps": 286150, 18 | "eval_steps": 0, 19 | "predict_steps": 1, 20 | "res_dropout": 0, 21 | "eval_batch_size": 128, 22 | "predict_batch_size": 1, 23 | "iterations": 500, 24 | "n_embd": 2560, 25 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 26 | "model_path": "gs://neo-models/GPT3_2-7B", 27 | "n_ctx": 2048, 28 | "n_layer": 32, 29 | "scale_by_depth": true, 30 | "scale_by_in": false, 31 | "attention_types" : [[["global"],32]], 32 | "mesh_shape": "x:128,y:2", 33 | "layout": "embd:y,batch:x", 34 | "activation_function": "gelu", 35 | "recompute_grad": true, 36 | "gradient_clipping": 1.0 37 | } 38 | 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/gpt3_PAR_small_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 12, 3 | "n_vocab": 50304, 4 | "embed_dropout": 0, 5 | "lr": 0.0006, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "opt_name": "adam", 12 | "weight_decay": 0.10, 13 | "train_batch_size": 256, 14 | "attn_dropout": 0, 15 | "train_steps": 572300, 16 | "eval_steps": 0, 17 | "predict_steps": 1, 18 | "res_dropout": 0, 19 | "eval_batch_size": 64, 20 | "predict_batch_size": 1, 21 | "iterations": 1000, 22 | "n_embd": 768, 23 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 24 | "model_path": "gs://neo-models/GPT3_PAR_SMALL", 25 | "n_ctx": 2048, 26 | "n_layer": 19, 27 | "scale_by_depth": true, 28 | "scale_by_in": false, 29 | "attention_types": [[["global", "none", "none"],5], [["none"], 4]], 30 | "mesh_shape": "x:64,y:4", 31 | "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y", 32 | "activation_function": "gelu", 33 | "recompute_grad": false, 34 | "gradient_clipping": 1.0 35 | } 36 | 37 | -------------------------------------------------------------------------------- /configs/gpt3_XL_256_Pile.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 32, 3 | "n_vocab": 50257, 4 | "embed_dropout": 0, 5 | "lr": 0.0002, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "opt_name": "adam", 12 | "weight_decay": 0.1, 13 | "train_batch_size": 512, 14 | "attn_dropout": 0, 15 | "train_steps": 286150, 16 | "eval_steps": 10, 17 | "predict_steps": 1, 18 | "res_dropout": 0, 19 | "eval_batch_size": 512, 20 | "predict_batch_size": 1, 21 | "iterations": 500, 22 | "n_embd": 2048, 23 | "datasets": [["pile", 25, "documents_random", 1.0]], 24 | "model_path": "gs://neo-models/GPT3_XL_Pile", 25 | "n_ctx": 2048, 26 | "n_layer": 24, 27 | "scale_by_depth": true, 28 | "scale_by_in": false, 29 | "attention_types" : [[["global"],24]], 30 | "mesh_shape": "x:128,y:2", 31 | "layout": "batch:x,memory_length:y,embd:y", 32 | "activation_function": "gelu", 33 | "recompute_grad": true, 34 | "gradient_clipping": 1.0, 35 | "tokens_per_mb_per_replica": 2048, 36 | "precision": "bfloat16" 37 | } 38 | -------------------------------------------------------------------------------- /configs/gpt3_13B_256_Pile.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "n_head": 40, 4 | "n_vocab": 50257, 5 | "embed_dropout": 0, 6 | "lr": 0.0001, 7 | "lr_decay": "cosine", 8 | "warmup_steps": 3000, 9 | "beta1": 0.9, 10 | "beta2": 0.95, 11 | "epsilon": 1e-8, 12 | "opt_name": "adam", 13 | "weight_decay": 0.1, 14 | "train_batch_size": 1024, 15 | "attn_dropout": 0, 16 | "train_steps": 286150, 17 | "eval_steps": 10, 18 | "predict_steps": 1, 19 | "res_dropout": 0, 20 | "eval_batch_size": 512, 21 | "predict_batch_size": 1, 22 | "iterations": 500, 23 | "n_embd": 5120, 24 | "datasets": [["pile", 25, "documents_random", 1.0]], 25 | "model_path": "gs://neo-models/GPT3_13B_Pile", 26 | "n_ctx": 2048, 27 | "n_layer": 40, 28 | "scale_by_depth": true, 29 | "scale_by_in": false, 30 | "attention_types" : [[["global"],40]], 31 | "mesh_shape": "x:16,y:16", 32 | "layout": "batch:x,memory_length:y,embd:y", 33 | "activation_function": "gelu", 34 | "recompute_grad": true, 35 | "gradient_clipping": 1.0, 36 | "tokens_per_mb_per_replica": 2048, 37 | "precision": "bfloat16" 38 | } 39 | -------------------------------------------------------------------------------- /configs/gpt3_large_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 16, 3 | "n_vocab": 50304, 4 | "embed_dropout": 0, 5 | "lr": 0.00025, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "ada_epsilon1": 1e-30, 12 | "ada_epsilon2": 1e-3, 13 | "opt_name": "adam", 14 | "weight_decay": 0.10, 15 | "train_batch_size": 256, 16 | "attn_dropout": 0, 17 | "train_steps": 572300, 18 | "eval_steps": 0, 19 | "predict_steps": 1, 20 | "res_dropout": 0, 21 | "eval_batch_size": 64, 22 | "predict_batch_size": 1, 23 | "iterations": 2500, 24 | "n_embd": 1536, 25 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 26 | "model_path": "gs://neo-models/GPT3_LARGE", 27 | "n_ctx": 2048, 28 | "n_layer": 24, 29 | "scale_by_depth": true, 30 | "scale_by_in": false, 31 | "attention_types" : [[["global"],24]], 32 | "mesh_shape": "x:64,y:4", 33 | "layout": "batch:x,vocab:y,heads:y", 34 | "activation_function": "gelu", 35 | "recompute_grad": true, 36 | "gradient_clipping": 1.0, 37 | "tokens_per_mb_per_replica": 2048 38 | } 39 | 40 | -------------------------------------------------------------------------------- /configs/gpt3_13B_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_head": 40, 3 | "n_vocab": 50257, 4 | "embed_dropout": 0, 5 | "lr": 0.0001, 6 | "lr_decay": "cosine", 7 | "warmup_steps": 3000, 8 | "beta1": 0.9, 9 | "beta2": 0.95, 10 | "epsilon": 1e-8, 11 | "ada_epsilon1": 1e-30, 12 | "ada_epsilon2": 1e-3, 13 | "opt_name": "adam", 14 | "weight_decay": 0.10, 15 | "train_batch_size": 1024, 16 | "attn_dropout": 0, 17 | "train_steps": 143075, 18 | "eval_steps": 0, 19 | "predict_steps": 1, 20 | "res_dropout": 0, 21 | "eval_batch_size": 128, 22 | "predict_batch_size": 1, 23 | "iterations": 500, 24 | "n_embd": 5120, 25 | "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], 26 | "model_path": "gs://neo-models/GPT3_13B", 27 | "n_ctx": 2048, 28 | "n_layer": 40, 29 | "scale_by_depth": true, 30 | "scale_by_in": false, 31 | "attention_types" : [[["global", "local"],20]], 32 | "mesh_shape": "x:16,y:16", 33 | "layout": "batch:x,embd:y,memory_length:y", 34 | "activation_function": "gelu", 35 | "recompute_grad": true, 36 | "gradient_clipping": 1.0, 37 | "tokens_per_mb_per_replica": 2048, 38 | "precision": "bfloat16" 39 | } 40 | 41 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | 4 | mongo: 5 | image: mongo 6 | ports: 7 | - 127.0.0.1:27017:27017 8 | environment: 9 | MONGO_INITDB_ROOT_USERNAME: user 10 | MONGO_INITDB_ROOT_PASSWORD: password 11 | MONGO_INITDB_DATABASE: db 12 | expose: 13 | - 27017 14 | networks: 15 | - omniboard 16 | volumes: 17 | - ./data:/data/db 18 | 19 | mongoClientTemp: 20 | image: mongo:latest 21 | container_name: mongoClientTemp 22 | links: 23 | - mongo:mongo 24 | command: mongo --host mongo -u user -p password --eval "db.getSiblingDB('db').createUser({user:'readonly', pwd:'password', roles:[{role:'read',db:'db'}]});" 25 | depends_on: 26 | - mongo 27 | networks: 28 | - omniboard 29 | 30 | omniboard_readonly: 31 | #image: vivekratnavel/omniboard:latest 32 | build: https://github.com/lucidrains/omniboard.git 33 | command: ["--mu", "mongodb://readonly:password@mongo:27017/db"] 34 | ports: 35 | - 0.0.0.0:8081:9000 36 | networks: 37 | - omniboard 38 | depends_on: 39 | - mongo 40 | 41 | omniboard: 42 | #image: vivekratnavel/omniboard:latest 43 | build: https://github.com/lucidrains/omniboard.git 44 | command: ["--mu", "mongodb://user:password@mongo:27017/db?authSource=admin"] 45 | expose: 46 | - 9000 47 | networks: 48 | - omniboard 49 | depends_on: 50 | - mongo 51 | 52 | nginx: 53 | image: dhswt/nginx-basic-auth:1.3 54 | environment: 55 | - HTPASSWD=isaac: #put passwd here 56 | - FORWARD_HOST=omniboard 57 | - FORWARD_PORT=9000 58 | networks: 59 | - omniboard 60 | depends_on: 61 | - omniboard 62 | ports: 63 | - 0.0.0.0:8080:80 64 | expose: 65 | - 8080 66 | networks: 67 | omniboard: 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # testing 2 | .test/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | logs/ 110 | *.log 111 | test_* 112 | test/ 113 | .vscode 114 | 115 | 116 | run_configs/ 117 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from collections import defaultdict 4 | 5 | DATASETS = {} 6 | 7 | for path in Path("configs/dataset_configs").glob("*.json"): 8 | dataset_id = path.stem 9 | DATASETS[dataset_id] = json.loads(path.read_text()) 10 | 11 | 12 | def fetch_model_params(model): 13 | model_path = model if model.endswith(".json") else f"configs/{model}.json" 14 | with open(model_path) as f: 15 | params = json.load(f) 16 | 17 | dataset_ids = [] 18 | for d in params.get("datasets"): 19 | if isinstance(d, list): 20 | dataset_ids.append(d[0]) 21 | else: 22 | dataset_ids.append(d) 23 | no_datasets = params.get("no_dataset", False) 24 | assert no_datasets or len(dataset_ids) > 0, "You must specify at least one dataset id in the model config" 25 | 26 | datasets = {} 27 | last_dataset = None 28 | for dataset_id in dataset_ids: 29 | assert dataset_id in DATASETS, f"Dataset '{dataset_id}' was not found under dataset_configs/ folder. Please follow the example.json in that folder." 30 | dataset = DATASETS[dataset_id] 31 | assert params["n_vocab"] >= dataset["n_vocab"], f"The embedding table size '{params['n_vocab']}' must be greater or equal to the vocab size used to encode the dataset '{dataset_id}' ({dataset['n_vocab']})" 32 | datasets[dataset_id] = dataset 33 | last_dataset = dataset 34 | 35 | if last_dataset is not None: 36 | params["padding_id"] = last_dataset.get("padding_id", 0) 37 | params["eos_id"] = last_dataset.get("eos_id", 1) 38 | 39 | params["dataset_configs"] = datasets 40 | 41 | # Set some other parameter defaults 42 | params["mlm_training"] = params.get("mlm_training") == True 43 | params["causal"] = not params["mlm_training"] 44 | 45 | # Set all other parameter values to default to None 46 | params = defaultdict(lambda: None, params) 47 | return params 48 | -------------------------------------------------------------------------------- /data/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import shutil 5 | from glob import glob 6 | from pathlib import Path 7 | 8 | from lm_dataformat import Reader 9 | from tokenizers import (Tokenizer, decoders, models, pre_tokenizers, 10 | processors, trainers) 11 | from tokenizers.normalizers import NFKC 12 | from tqdm import tqdm 13 | 14 | # parser 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--base_dir", type=str, help="Path to where your files are located. Files ending in .zst are treated as \ 18 | archives, all others as raw text.") 19 | parser.add_argument("--output_dir", type=str, default="tokenizers", help="Where to put the tokenizer") 20 | parser.add_argument("--file_type", type=str, choices=["xz", "txt"], default="xz", help="Extension of file to parse") 21 | parser.add_argument("--vocab_size", type=int, help="Size of vocabulary", required = True) 22 | args = parser.parse_args() 23 | 24 | # main script 25 | 26 | data_path = Path(args.base_dir) 27 | archives = glob(str(data_path / f"*.{args.file_type}")) 28 | 29 | out_path = Path(args.output_dir) 30 | 31 | if os.path.exists(out_path): 32 | shutil.rmtree(out_path) 33 | 34 | if not out_path.is_dir(): 35 | out_path.mkdir() 36 | 37 | for arch in tqdm(archives): 38 | name = os.path.basename(arch).split(".")[0] + ".txt" 39 | fp = out_path / name 40 | 41 | if args.file_type == 'xz': 42 | g = Reader(arch).stream_data() 43 | 44 | with open(fp, "w") as f: 45 | for s in g: 46 | f.write(s) 47 | f.write("\n\n") 48 | elif args.file_type == 'txt': 49 | shutil.copyfile(str(arch), str(fp)) 50 | 51 | data_files = glob(str(out_path / "*.txt")) 52 | data_files = random.sample(data_files, int(0.2 * len(data_files))) 53 | 54 | assert len(data_files) > 0, 'No data files found' 55 | 56 | # Initialize a tokenizer 57 | tokenizer = Tokenizer(models.BPE()) 58 | 59 | # Customize pre-tokenization and decoding 60 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True) 61 | tokenizer.decoder = decoders.ByteLevel() 62 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) 63 | tokenizer.normalizer = NFKC() 64 | 65 | # And then train 66 | trainer = trainers.BpeTrainer(vocab_size=args.vocab_size, min_frequency=2, special_tokens=["<|endoftext|>", "<|padding|>"]) 67 | tokenizer.train(trainer, data_files) 68 | 69 | # And Save it 70 | tokenizer_path = out_path / "byte-level-bpe.tokenizer.json" 71 | tokenizer.save(str(tokenizer_path), pretty=True) 72 | 73 | print(f'tokenizer saved at {str(tokenizer_path)}') -------------------------------------------------------------------------------- /models/activations.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import tensorflow.compat.v1 as tf 3 | import random 4 | 5 | BASE_FNS = {'gelu': mtf.gelu, 6 | 'relu': mtf.relu, 7 | 'sigmoid': mtf.sigmoid, 8 | 'tanh': mtf.tanh, 9 | 'selu': mtf.selu, 10 | 'elu': mtf.elu, 11 | 'abs': mtf.abs, 12 | 'sin': mtf.sin, 13 | 'cos': mtf.cos, 14 | 'sign': mtf.sign, 15 | 'silu': mtf.swish, 16 | 'softplus': mtf.softplus 17 | } 18 | 19 | 20 | def _arcsinh(x): 21 | return mtf.log(x + mtf.sqrt(1 + x ** 2)) 22 | 23 | 24 | def _var(x, init): 25 | return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], 26 | initializer=tf.constant_initializer(init), dtype=x.dtype) 27 | 28 | 29 | def _pos_var(x, val): 30 | return mtf.softplus(_var(x, 0)) + val 31 | 32 | 33 | def _rrelu(x): 34 | negative_scale = random.random() 35 | return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) 36 | 37 | 38 | def _elish(x): 39 | cond = mtf.cast(mtf.greater(x, 0), x.dtype) 40 | exp = mtf.exp(x) 41 | return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) 42 | 43 | 44 | CUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01), 45 | 'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20), 46 | 'id': lambda x: x, 47 | 'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49, 48 | 'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7, 49 | 'spike': lambda x: 1 / (1 + x ** 2), 50 | 'spike2': lambda x: mtf.exp(-x ** 2), 51 | 'tanhshrink': lambda x: x - tanh(x), 52 | 'softsign': lambda x: x / (mtf.abs(x) + 1), 53 | 'softmax': lambda x: mtf.softmax(x, x.shape[-1]), 54 | 'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]), 55 | 'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1, 56 | 'rrelu': _rrelu, 57 | 'elish': _elish, 58 | 'arcsinh': _arcsinh, 59 | 'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / ( 60 | _pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))), 61 | 'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)), 62 | 'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)), 63 | 'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0), 64 | 'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x), 65 | 'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)), 66 | 'cosid': lambda x: mtf.cos(x) - x, 67 | 'minsin': lambda x: mtf.minimum(x, mtf.sin(x)), 68 | 'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)), 69 | 'mish': lambda x: x * mtf.tanh(mtf.softplus(x)), 70 | 'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)), 71 | 'lisht': lambda x: x * mtf.tanh(x), 72 | 'seagull': lambda x: mtf.log(1 + x ** 2), 73 | 'snake': lambda x: x + mtf.sin(x) ** 2, 74 | 'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x), 75 | 'softplusmone': lambda x: mtf.softplus(x) - 1 76 | } 77 | 78 | 79 | def get_activation_fn(params): 80 | if "activation_fn" in params: 81 | activation_fn = params["activation_fn"] 82 | else: 83 | print("Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)") 84 | activation_fn = "gelu" 85 | 86 | if activation_fn in BASE_FNS: 87 | return BASE_FNS[activation_fn] 88 | 89 | if activation_fn in CUSTOM_FNS: 90 | return CUSTOM_FNS[activation_fn] 91 | 92 | raise ValueError('unknown activation function "activation_fn" in config') 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | import requests 4 | import numpy as np 5 | import ftfy 6 | from data.encoders import fetch_encoder, encode 7 | import tensorflow as tf 8 | import re 9 | from functools import partial 10 | 11 | lambada_src_uri = 'http://eaidata.bmk.sh/data/lambada_test.jsonl' 12 | normalization = 'NFKC' 13 | 14 | 15 | # Note: this task is called "lambada" but it really refers to OpenAI's version 16 | # of the task, which actually differs in some ways from the task described in 17 | # the original paper. So, strictly speaking, accuracy values from this task 18 | # should not be compared to accuracy values from the original lambada task. 19 | # For more information, see 20 | # https://github.com/openai/gpt-2/issues/131 21 | 22 | def lambada_create_tokens_data(params, path): 23 | with open(path, 'w') as f: 24 | req = requests.get(lambada_src_uri) 25 | req.raise_for_status() 26 | jsons = [json.loads(l) for l in req.iter_lines()] 27 | texts = [ftfy.fix_text(j['text'], normalization=normalization) for j in jsons] 28 | enc = fetch_encoder(params) 29 | arrays = [encode(enc, t) for t in texts] 30 | json.dump(arrays, f) 31 | return arrays 32 | 33 | 34 | def lambada_read_or_create_tokens_data(params, path): 35 | # if you tell me where the file should go, i will helpfully create it for you 36 | if not os.path.exists(path): 37 | return lambada_create_tokens_data(params, path) 38 | with open(path) as f: 39 | return json.load(f) 40 | 41 | 42 | def bin_pack(params, tokens_data): 43 | eos_token = params['eos_id'] 44 | n_ctx = params['n_ctx'] 45 | dummy_token = 1 46 | pad_batch_size = params['eval_batch_size'] 47 | bins = [] 48 | for a in tokens_data: 49 | if len(bins) == 0 or len(bins[-1]) + len(a) + 1 > n_ctx: 50 | bins.append([]) 51 | bins[-1] += a 52 | bins[-1].append(eos_token) 53 | while len(bins) % pad_batch_size != 0: 54 | bins.append([]) 55 | bins_array = np.full((len(bins), n_ctx), dummy_token, dtype=np.uint16) 56 | for i, b in enumerate(bins): 57 | bins_array[i, 0:len(b)] = b 58 | return bins_array 59 | 60 | 61 | def lambada_init(params): 62 | ds_configs = params['dataset_configs'] 63 | l = [ 64 | ds_configs[ds_id].get('lambada_tokens_path', "./lambada.json") 65 | for ds_id, _, _, _ in params['datasets'] 66 | ] 67 | assert len(l) > 0, 'lambada_tokens_path not found in the dataset config' 68 | lt_path = l[0] 69 | assert lt_path.endswith('.json'), 'lambada_tokens_path must have extension json' 70 | 71 | tokens_data = lambada_read_or_create_tokens_data(params, lt_path) 72 | bins_array = bin_pack(params, tokens_data) 73 | params['lambada_tokens_path'] = lt_path 74 | params['lambada_n_steps'] = len(bins_array) // params['eval_batch_size'] 75 | 76 | 77 | def lambada_get_task_info(params): 78 | return { 79 | 'n_steps': params['lambada_n_steps'], 80 | } 81 | 82 | 83 | # The LAMBADA evaluation code looks at the logits of each position just before an eos_token 84 | def lambada_input(params): 85 | eos_token = 50256 if params['n_vocab'] >= 50257 else 0 86 | n_ctx = params['n_ctx'] 87 | lt_path = params['lambada_tokens_path'] 88 | tokens_data = lambada_read_or_create_tokens_data(params, lt_path) 89 | bins_array = bin_pack(params, tokens_data) 90 | dataset = tf.data.Dataset.from_tensor_slices(bins_array) 91 | 92 | def _get_output(bin): 93 | bin = tf.cast(bin, dtype=tf.int32) 94 | indexes = tf.range(n_ctx) 95 | results = tf.gather(bin, (indexes + 1) % n_ctx) 96 | eos_next_positions = tf.math.equal(tf.gather(bin, (indexes + 2) % n_ctx), eos_token) 97 | output = tf.where(eos_next_positions, results, tf.constant(eos_token, shape=[n_ctx])) 98 | bin = tf.reshape(bin, [n_ctx]) 99 | bin = tf.cast(bin, dtype=tf.int32) 100 | output = tf.reshape(output, [n_ctx]) 101 | output = tf.cast(output, dtype=tf.int32) 102 | return bin, output 103 | 104 | dataset = dataset.map(_get_output,num_parallel_calls=tf.data.AUTOTUNE) 105 | dataset = dataset.batch(params['eval_batch_size'], drop_remainder=True) 106 | dataset = dataset.repeat() 107 | return dataset 108 | 109 | 110 | task_descriptors = { 111 | 'lambada': { 112 | 'init_fn': lambada_init, 113 | 'get_task_info_fn': lambada_get_task_info, 114 | 'input_fn': lambada_input, 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import mesh_tensorflow as mtf 3 | from functools import partial 4 | 5 | 6 | def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None, 7 | n_iter=50): 8 | x, = explicit_inputs 9 | y, = outputs 10 | dY, = output_grads 11 | 12 | gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) 13 | dX = dY * gppr 14 | 15 | q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim) 16 | dX = dX - q * gppr 17 | 18 | return dX, 19 | 20 | 21 | def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): 22 | assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' 23 | 24 | _gp = lambda x, alpha: x ** (alpha - 1) 25 | _gp_inv = lambda x, alpha: mtf.pow(x, (1 / (alpha - 1))) 26 | _p = lambda x, alpha: _gp_inv(mtf.relu(x), alpha) 27 | 28 | dim = x.shape[-1] if dim is None else dim 29 | d = dim.size 30 | 31 | x = x * (alpha - 1) 32 | 33 | max_val = mtf.reduce_max(x, reduced_dim=dim) 34 | 35 | tau_lo = max_val - _gp(1, alpha) 36 | tau_hi = max_val - _gp(1 / d, alpha) 37 | 38 | f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 39 | 40 | dm = tau_hi - tau_lo 41 | 42 | for _ in range(n_iter): 43 | dm = dm / 2 44 | tau_m = tau_lo + dm 45 | p_m = _p(x - tau_m, alpha) 46 | f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 47 | 48 | mask = mtf.greater_equal((f_m * f_lo), 0) 49 | tau_lo = mtf.where(mask, tau_m, tau_lo) 50 | 51 | p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) 52 | return p_m 53 | 54 | 55 | def entmax(x, alpha=1.3, dim=None, n_iter=50): 56 | kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter) 57 | 58 | return mtf.custom_gradient( 59 | partial(entmax_forward, **kwargs), 60 | partial(entmax_backward, **kwargs), 61 | [x] 62 | ) 63 | 64 | 65 | def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): 66 | if targets.dtype.is_integer: 67 | # hard targets 68 | if (set(targets.shape.dims) != set(logits.shape.dims).difference([vocab_dim])): 69 | raise ValueError( 70 | "softmax_cross_entropy_with_logits with hard targets " 71 | "dims in targets=%s should be dims in logits=%s other than " 72 | "vocab_dim=%s" % (targets, logits, vocab_dim)) 73 | targets = mtf.one_hot(targets, vocab_dim, dtype=logits.dtype) 74 | elif set(targets.shape.dims) != set(logits.shape.dims): 75 | raise ValueError( 76 | "softmax_cross_entropy_with_logits with soft targets " 77 | "dims in targets=%s should be dims in logits=%s" % (targets, logits)) 78 | 79 | if vocab_dim not in logits.shape.dims: 80 | raise ValueError("vocab_dim must be in logits.shape.dims") 81 | 82 | log_entmax = mtf.log(entmax(logits, dim=vocab_dim)) 83 | 84 | loss = mtf.negative( 85 | mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim)) 86 | 87 | return loss 88 | 89 | 90 | def sample_categorical(x, dim=None): 91 | dim = x.shape[-1] if dim is None else dim 92 | 93 | cdf = mtf.cumsum(x, dim) 94 | rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) 95 | mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) 96 | return mtf.argmax(mask, dim) 97 | 98 | 99 | def biasmask_attn_weights(mesh, nd, ns, variable_dtype): 100 | # The old mask_attn_weights applied directly to the QK; 101 | # this returns a bias that the attention code from mtf adds to the attention matrix. 102 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 103 | # n_src and n_dest are both the same, i.e equal to sequence length 104 | # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T 105 | # Information flows from k and v (memory_length) to q (sequence) 106 | i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size 107 | j = mtf.range(mesh, ns, tf.int32) 108 | i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) 109 | dtype = variable_dtype.activation_dtype 110 | return mtf.cast(mtf.less(i, j), dtype) * -1e10 111 | 112 | 113 | def parse_inputs(mtf_features, other_features): 114 | # Parse inputs and labels from the mtf_features / other_features input dicts 115 | # All dimensions are defined inside model_fn for efficiency 116 | x = mtf_features["inputs"] 117 | 118 | batch_dim = x.shape[0] 119 | sequence_dim = x.shape[1] 120 | embd_dim = other_features["embd_dim"] 121 | vocab_dim = other_features["vocab_dim"] 122 | embed_sequence_dim = other_features["embed_sequence_dim"] 123 | 124 | return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim 125 | -------------------------------------------------------------------------------- /test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import traceback 3 | import logging 4 | from collections import defaultdict 5 | from contextlib import contextmanager 6 | 7 | import tensorflow as tf 8 | tf.compat.v1.enable_eager_execution() 9 | import mesh_tensorflow as mtf 10 | from mesh_tensorflow import placement_mesh_impl 11 | 12 | from inputs import mlm_sample_text 13 | from models.gpt2 import gpt2 14 | from models.utils import biasmask_attn_weights, entmax, sample_categorical 15 | 16 | from sample import sample_autoregressive 17 | 18 | # helper functions 19 | 20 | @contextmanager 21 | def not_raises(exception): 22 | try: 23 | yield 24 | except exception: 25 | logging.error(traceback.format_exc()) 26 | raise pytest.fail("DID RAISE {0}".format(exception)) 27 | 28 | # fixtures 29 | 30 | params = defaultdict(lambda: None, { 31 | "n_head": 1, 32 | "n_ctx": 4, 33 | "n_embd": 2, 34 | "n_vocab": 256, 35 | "embed_dropout": 0., 36 | "n_layer": 2, 37 | "num_microbatches": 1, 38 | "train_batch_size": 1, 39 | "causal": True, 40 | "attention_types": ['global', 'local'], 41 | "res_dropout": 0.1, 42 | "rotary_emb": True, 43 | "activation_function": "gelu", 44 | "moe_layers": (1,), 45 | "num_mem_kv": 16, 46 | "no_weight_tie": True, 47 | "moe_params": { 48 | 'moe_dropout_rate': 0.0 49 | }, 50 | "mesh_shape": [], 51 | "layout": {}, 52 | "local_attention_radius": 128, 53 | "share_parameters": True, 54 | "rezero": True 55 | }) 56 | 57 | # tests 58 | 59 | def test_model(): 60 | graph = mtf.Graph() 61 | mesh = mtf.Mesh(graph, "my_mesh") 62 | 63 | seq_len = params["n_ctx"] 64 | 65 | batch_dim = mtf.Dimension("batch", 1) 66 | sequence_dim = mtf.Dimension("sequence", seq_len) 67 | 68 | features = { 69 | 'inputs': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32), 70 | 'labels': mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) 71 | } 72 | 73 | # create mask 74 | 75 | num_mem_kv = params.get('num_mem_kv', 0) 76 | length_dim = mtf.Dimension('sequence', seq_len) 77 | memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) 78 | embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) 79 | embd_dim = mtf.Dimension("embd", params["n_embd"]) 80 | vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) 81 | 82 | other_features = {} 83 | variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) 84 | 85 | other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) 86 | other_features["embd_dim"] = embd_dim 87 | other_features["vocab_dim"] = vocab_dim 88 | other_features["embed_sequence_dim"] = embed_sequence_dim 89 | other_features["memory_length_dim"] = memory_length_dim 90 | 91 | with not_raises(Exception): 92 | logits, _, _ = gpt2.model(features, other_features, params, mesh, variable_dtype=variable_dtype) 93 | 94 | mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) 95 | lowering = mtf.Lowering(graph, {mesh: mesh_impl}) 96 | logits = lowering.export_to_tf_tensor(logits) 97 | 98 | 99 | def test_sampling(): 100 | graph = mtf.Graph() 101 | mesh = mtf.Mesh(graph, "my_mesh") 102 | 103 | batch_dim = mtf.Dimension("batch", 1) 104 | sequence_dim = mtf.Dimension("sequence", 1) 105 | 106 | inputs = mtf.ones(mesh, mtf.Shape((batch_dim, sequence_dim)), tf.int32) 107 | inputs = mtf.pad(inputs, [0, 3], sequence_dim.name) 108 | 109 | # create mask 110 | 111 | seq_len = params["n_ctx"] 112 | num_mem_kv = params.get('num_mem_kv', 0) 113 | length_dim = mtf.Dimension('sequence', seq_len) 114 | memory_length_dim = mtf.Dimension('memory_length', seq_len + num_mem_kv) 115 | embed_sequence_dim = mtf.Dimension('embed_sequence', seq_len) 116 | embd_dim = mtf.Dimension("embd", params["n_embd"]) 117 | vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) 118 | 119 | other_features = {} 120 | 121 | other_features["attn_bias"] = biasmask_attn_weights(mesh, length_dim, memory_length_dim, mtf.VariableDType(tf.float32)) 122 | other_features["embd_dim"] = embd_dim 123 | other_features["vocab_dim"] = vocab_dim 124 | other_features["embed_sequence_dim"] = embed_sequence_dim 125 | other_features["memory_length_dim"] = memory_length_dim 126 | 127 | params["mode"] = "predict" 128 | 129 | with not_raises(Exception): 130 | samples = sample_autoregressive( 131 | inputs, other_features=other_features, params=params, variable_dtype=mtf.VariableDType(), 132 | remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=True) 133 | 134 | mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) 135 | lowering = mtf.Lowering(graph, {mesh: mesh_impl}) 136 | samples = lowering.export_to_tf_tensor(samples) 137 | 138 | # mlm 139 | 140 | mlm_params = defaultdict(lambda: None, { 141 | "n_head": 1, 142 | "n_ctx": 4, 143 | "n_embd": 1, 144 | "n_vocab": 256, 145 | "embed_dropout": 0., 146 | "n_layer": 2, 147 | "num_microbatches": 1, 148 | "train_batch_size": 1, 149 | "attention_types": ['global', 'local'], 150 | "res_dropout": 0.1, 151 | "mesh_shape": [], 152 | "layout": {}, 153 | "share_parameters": True, 154 | "mlm_training": True, 155 | "mlm_mask_id": 3, 156 | "mlm_cls_token_id": 4, 157 | "mlm_random_token_prob": 0.1 158 | }) 159 | 160 | def test_mlm_sample_text(): 161 | document = tf.random.normal((16,)) 162 | with not_raises(Exception): 163 | features, labels = mlm_sample_text(mlm_params, document, random_documents = True) 164 | assert features.shape == (mlm_params['n_ctx'],) 165 | 166 | # entmax 167 | 168 | def test_entmax(): 169 | graph = mtf.Graph() 170 | mesh = mtf.Mesh(graph, "my_mesh") 171 | length = mtf.Dimension("tensor_length", 8) 172 | tensor = mtf.range(mesh, length, tf.float32) 173 | output = entmax(tensor) 174 | grad = mtf.gradients([output], [tensor])[0] 175 | sample = sample_categorical(output, length) 176 | 177 | mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) 178 | lowering = mtf.Lowering(graph, {mesh: mesh_impl}) 179 | sample = lowering.export_to_tf_tensor(sample) 180 | grad = lowering.export_to_tf_tensor(grad) 181 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import re 6 | import mesh_tensorflow as mtf 7 | import tensorflow.compat.v1 as tf 8 | 9 | def clip_by_global_norm(grads, clip_norm): 10 | """Clip the grads by global norm.""" 11 | global_norm = mtf.sqrt(mtf.add_n([mtf.reduce_sum(mtf.square(t)) for t in grads if t is not None])) 12 | multiplier = clip_norm / mtf.maximum(global_norm, clip_norm) 13 | clipped_grads = [None if t is None else t * multiplier for t in grads] 14 | return clipped_grads, global_norm 15 | 16 | def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): 17 | """Creates and returns an optimizer training op.""" 18 | global_step = tf.train.get_or_create_global_step() 19 | 20 | learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype) 21 | clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype) 22 | 23 | if inp_var_grads is None: 24 | var_grads = mtf.gradients([loss], [v.outputs[0] for v in mesh.graph.trainable_variables]) 25 | else: 26 | var_grads = inp_var_grads 27 | 28 | # Cast to full precision 29 | var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads] 30 | 31 | # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps 32 | end_step = params.get("lr_decay_end", params["train_steps"]) 33 | 34 | if params["lr_decay"] == "linear": 35 | learning_rate = tf.train.polynomial_decay( 36 | learning_rate, 37 | global_step, 38 | end_step, 39 | end_learning_rate=params["lr"]*0.1, # Decrease to 10% of initial LR according to GPT-3 paper 40 | power=1.0, 41 | cycle=False) 42 | elif params["lr_decay"] == "cosine": 43 | learning_rate = tf.train.cosine_decay( 44 | learning_rate, 45 | global_step, 46 | end_step, 47 | alpha=0.1 # Alpha is min lr value as a fraction of init lr. 48 | ) 49 | 50 | if params["warmup_steps"] > 0: 51 | global_steps_int = tf.cast(global_step, tf.int32) 52 | warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32) 53 | 54 | dtype = variable_dtype.slice_dtype 55 | 56 | global_steps_float = tf.cast(global_steps_int, dtype) 57 | warmup_steps_float = tf.cast(warmup_steps_int, dtype) 58 | 59 | warmup_percent_done = global_steps_float / warmup_steps_float 60 | warmup_learning_rate = learning_rate * warmup_percent_done 61 | 62 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype) 63 | learning_rate = ((1.0 - is_warmup) * learning_rate + 64 | is_warmup * warmup_learning_rate) 65 | 66 | learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate") 67 | mtf.scalar_summary("lr", learning_rate) 68 | 69 | if params["opt_name"].lower() == "adam": 70 | optimizer = AdamWeightDecayOptimizer( 71 | learning_rate=learning_rate, 72 | weight_decay_rate=params["weight_decay"], 73 | beta_1=params["beta1"], 74 | beta_2=params["beta2"], 75 | epsilon=params["epsilon"], 76 | exclude_from_weight_decay=["norm", "bias"], 77 | variable_dtype=variable_dtype 78 | ) 79 | else: 80 | optimizer = mtf.optimize.AdafactorOptimizer( 81 | learning_rate=params["lr"], 82 | decay_rate=params["weight_decay"], 83 | beta1=params["beta1"], 84 | epsilon1=params["ada_epsilon1"], 85 | epsilon2=params["ada_epsilon2"] 86 | ) 87 | 88 | if params["gradient_clipping"] is not None: 89 | (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value) 90 | 91 | update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables) 92 | return learning_rate, update_ops, var_grads_fp 93 | 94 | 95 | class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): 96 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 97 | 98 | def __init__(self, 99 | learning_rate, 100 | weight_decay_rate=0.0, 101 | beta_1=0.9, 102 | beta_2=0.999, 103 | epsilon=1e-6, 104 | exclude_from_weight_decay=None, 105 | variable_dtype=None): 106 | """Constructs a AdamWeightDecayOptimizer.""" 107 | 108 | self.learning_rate = learning_rate 109 | self.weight_decay_rate = weight_decay_rate 110 | self.beta_1 = beta_1 111 | self.beta_2 = beta_2 112 | self.epsilon = epsilon 113 | self.exclude_from_weight_decay = exclude_from_weight_decay 114 | self.variable_dtype = variable_dtype 115 | 116 | def apply_grad(self, grad, var): 117 | """See base class.""" 118 | if grad is None: 119 | tf.logging.warning("Gradient is None for variable %s" % var.name) 120 | return [] 121 | 122 | grad = mtf.to_float(grad) 123 | 124 | assignments = [] 125 | 126 | m = mtf.get_variable( 127 | var.mesh, var.name + "/adam_m", var.shape, 128 | initializer=tf.zeros_initializer(), 129 | # master_dtype=self.variable_dtype.master_dtype, 130 | # slice_dtype=self.variable_dtype.slice_dtype, 131 | # activation_dtype=self.variable_dtype.activation_dtype, 132 | trainable=False) 133 | 134 | v = mtf.get_variable( 135 | var.mesh, var.name + "/adam_v", var.shape, 136 | initializer=tf.zeros_initializer(), 137 | # master_dtype=self.variable_dtype.master_dtype, 138 | # slice_dtype=self.variable_dtype.slice_dtype, 139 | # activation_dtype=self.variable_dtype.activation_dtype, 140 | trainable=False) 141 | 142 | # Standard Adam update. 143 | next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad 144 | next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) 145 | 146 | update = next_m / (mtf.sqrt(next_v) + self.epsilon) 147 | 148 | # Just adding the square of the weights to the loss function is *not* 149 | # the correct way of using L2 regularization/weight decay with Adam, 150 | # since that will interact with the m and v parameters in strange ways. 151 | # 152 | # Instead we want to decay the weights in a manner that doesn't interact 153 | # with the m/v parameters. This is equivalent to adding the square 154 | # of the weights to the loss with plain (non-momentum) SGD. 155 | if self._do_use_weight_decay(var.name): 156 | update += mtf.to_float(var.value) * self.weight_decay_rate 157 | 158 | update_with_lr = self.learning_rate * update 159 | 160 | var_update = mtf.assign_sub(var, update_with_lr) 161 | 162 | assignments.extend( 163 | [var_update, 164 | mtf.assign(m, next_m), 165 | mtf.assign(v, next_v)]) 166 | return assignments 167 | 168 | def _do_use_weight_decay(self, param_name): 169 | """Whether to use L2 weight decay for `param_name`.""" 170 | if not self.weight_decay_rate: 171 | return False 172 | if self.exclude_from_weight_decay: 173 | for r in self.exclude_from_weight_decay: 174 | if re.search(r, param_name) is not None: 175 | return False 176 | return True -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import tensorflow.compat.v1 as tf 3 | import mesh_tensorflow.transformer as mtf_transformer 4 | 5 | from models.utils import entmax, sample_categorical 6 | from models.gpt2 import gpt2 7 | 8 | def sample_autoregressive(partial_sequences, 9 | other_features, 10 | params, 11 | stop_at_token=50256, 12 | max_steps=None, 13 | temperature=0.9, 14 | variable_dtype=mtf.VariableDType(tf.float32), 15 | encoder_output=None, 16 | encoder_sequence_id=None, 17 | encoder_inputs=None, 18 | shared_params=None, 19 | has_partial_sequences=True, 20 | encoder_layer_outputs=None, 21 | never_end=False, 22 | remove_partial_sequences=False, 23 | sampling_keep_top_k=-1, 24 | sampling_use_entmax = False, 25 | bos_id=50256, 26 | ): 27 | """Sample randomly one token at a time. 28 | 29 | The partial_sequences represent partial sequences to be continued. The 30 | first tokens of each sequence are nonzero representing the given partial 31 | sequences and the last tokens of each sequence are zeros, representing what 32 | needs to be filled in. 33 | 34 | If there are no partial sequences (you want to sample from the beginning), 35 | then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and 36 | has_partial_sequences=False (so we can skip computation). 37 | 38 | Args: 39 | partial_sequences: an int32 Tensor with shape [, length_dim] 40 | stop_at_token: an optional integer eos id. Stop when we produce it. 41 | max_steps: an optional integer, the max number of steps to decode. 42 | temperature: an optional floating point value between 0.0 and 1.0 0.0 43 | means argmax, 1.0 means sample according to predicted distribution. 44 | variable_dtype: a mtf.VariableDType 45 | encoder_output: an optional Tensor 46 | encoder_sequence_id: an optional Tensor 47 | encoder_inputs: an optional Tensor 48 | shared_params: an optional dictionary 49 | has_partial_sequences: a boolean 50 | encoder_layer_outputs: optional - readonly list of tensor activations when 51 | decoding, one per each input layer + the embedding layer 52 | never_end: a boolean - if set, then avoid generating stop_at_token 53 | remove_partial_sequences: a boolean - whether to remove the partial 54 | sequences from the output 55 | sampling_keep_top_k: an integer - if not -1, only sample from the top k 56 | logits. 57 | bos_id: beginning of sequence id 58 | 59 | Returns: 60 | a Tensor with shape [, length_dim] 61 | """ 62 | 63 | inputs = partial_sequences # Partial sequences to fill in 64 | batch_dims = inputs.shape.dims[:-1] 65 | length_dim = inputs.shape.dims[-1] 66 | padding_id = params.get("padding_id", 0) 67 | slow_sampling = params.get("slow_sampling", False) 68 | 69 | 70 | initial_position = mtf.reduce_sum( 71 | mtf.to_int32(mtf.not_equal(inputs, padding_id)), reduced_dim=length_dim) # Gets position where zero padding starts 72 | 73 | length_range = mtf.range(inputs.mesh, length_dim, tf.int32) 74 | input_full_attention = True # for now hardcode this to true bc lazy 75 | if input_full_attention: 76 | # Vanilla autoregressive model - each position can see previous positions. 77 | # Think this feeds in to the loop fn and tells each position where it can attend to? 78 | read_priority = write_priority = length_range * mtf.to_int32( 79 | mtf.greater(length_range, initial_position)) 80 | else: 81 | read_priority = write_priority = length_range 82 | 83 | # Builds context to pass around internally 84 | # The 'first part' context records initial states of k / v / x 85 | 86 | if not slow_sampling: 87 | context_first_part = mtf_transformer.transformer.Context( 88 | model=None, 89 | mesh=inputs.mesh, 90 | batch_dims=batch_dims, 91 | length_dim=length_dim, 92 | variable_dtype=variable_dtype, 93 | mode="first_part", 94 | position=length_range, 95 | position_is_default=True, 96 | new_states=[], 97 | initial_position=initial_position, 98 | sequence_id=None, 99 | encoder_output=encoder_output, 100 | encoder_sequence_id=encoder_sequence_id, 101 | constant_states=[], 102 | shared_params=shared_params, 103 | encoder_layer_outputs=encoder_layer_outputs, 104 | write_priority=write_priority, 105 | read_priority=read_priority, 106 | inputs=inputs, 107 | encoder_inputs=encoder_inputs) 108 | 109 | with tf.variable_scope("gpt2"): 110 | logits, _, _ = gpt2.model({"inputs": inputs}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context_first_part) 111 | 112 | if not has_partial_sequences: 113 | initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] 114 | else: 115 | initial_states = context_first_part.new_states 116 | else: 117 | initial_states = [] 118 | 119 | if not has_partial_sequences: 120 | partial_sequences_eos_count = 0 121 | 122 | if stop_at_token is not None: 123 | partial_sequences_eos_count = mtf.reduce_sum( 124 | mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), 125 | reduced_dim=length_dim) 126 | 127 | def cond_fn(position, ids, *unused_states): 128 | """Should we run another loop iteration?""" 129 | past_end = mtf.greater_equal(position, length_dim.size) 130 | if max_steps: 131 | past_end = mtf.logical_or( 132 | past_end, mtf.greater_equal(position - initial_position, max_steps)) 133 | 134 | is_done = past_end 135 | if stop_at_token is not None: 136 | eos_count = mtf.reduce_sum( 137 | mtf.to_int32(mtf.equal(ids, stop_at_token)), 138 | reduced_dim=length_dim) 139 | has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) 140 | is_done = mtf.logical_or(is_done, has_additional_eos) 141 | all_done = mtf.reduce_all(is_done) 142 | return mtf.logical_not(all_done) 143 | 144 | def body_fn(position, ids, *states): 145 | """One step in the decode loop.""" 146 | nonlocal sampling_keep_top_k 147 | 148 | context = mtf_transformer.transformer.Context( 149 | model=None, 150 | mesh=inputs.mesh, 151 | batch_dims=batch_dims, 152 | length_dim=length_dim, 153 | variable_dtype=variable_dtype, 154 | mode="incremental", 155 | position=position, 156 | position_is_default=True, 157 | states=states, 158 | new_states=[], 159 | initial_position=position, 160 | sequence_id=None, 161 | encoder_output=encoder_output, 162 | encoder_sequence_id=encoder_sequence_id, 163 | shared_params=shared_params, 164 | encoder_layer_outputs=encoder_layer_outputs, 165 | write_priority=write_priority, 166 | read_priority=read_priority, 167 | inputs=ids, 168 | encoder_inputs=encoder_inputs) if not slow_sampling else None 169 | 170 | with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): 171 | logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context = context) 172 | 173 | if not sampling_use_entmax: 174 | # By default, do top_k sampling of 0.9 175 | if sampling_keep_top_k == -2: 176 | sampling_keep_top_k = int(logits.shape[-1].size * 0.1) 177 | 178 | if sampling_keep_top_k != -1: 179 | if sampling_keep_top_k <= 0: 180 | raise ValueError("sampling_keep_top_k must either be -1 or positive.") 181 | k_largest = mtf.nth_largest_element( 182 | logits, n=sampling_keep_top_k, 183 | reduced_dim=other_features["vocab_dim"]) 184 | logits = mtf.where(mtf.less_equal(logits, k_largest), 185 | mtf.ones_like(logits) * -1e6, logits) 186 | 187 | ids_this_step = mtf.sample_with_temperature( 188 | logits, other_features["vocab_dim"], temperature) 189 | else: 190 | ids_this_step = sample_categorical(entmax(logits)) 191 | 192 | if slow_sampling: 193 | ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) 194 | else: 195 | ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) 196 | 197 | one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) 198 | one_new_id = ids_this_step * one_hot 199 | new_ids = (1 - one_hot) * ids + one_new_id 200 | new_position = position + 1 201 | 202 | ret = [new_position, new_ids] 203 | if context is not None: 204 | ret += context.new_states 205 | return ret 206 | 207 | while_loop_inputs = [initial_position, inputs] + initial_states 208 | final_position, outputs = mtf.while_loop( 209 | cond_fn, body_fn, while_loop_inputs)[:2] 210 | del final_position 211 | if has_partial_sequences and remove_partial_sequences: 212 | # Remove partial sequences from outputs 213 | partial_length = mtf.reduce_sum( 214 | mtf.to_int32(mtf.not_equal(partial_sequences, padding_id)), 215 | reduced_dim=length_dim) 216 | outputs = mtf.dynamic_shift( 217 | outputs, -partial_length, length_dim, wrap=False) 218 | return outputs 219 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import sacred 3 | import argparse 4 | import time 5 | import math 6 | import subprocess 7 | import shutil 8 | import os 9 | import json 10 | import threading 11 | import requests 12 | import glob 13 | from configs import fetch_model_params 14 | import socket 15 | import subprocess 16 | import queue 17 | import sys 18 | import signal 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--tpu', type=str, required=True) # Name of TPU to train on, if any 23 | parser.add_argument('--model', type=str, required=True) # JSON file that contains model parameters 24 | parser.add_argument('--experiment_name', type=str, required=True) # name of experiment (will show up in omniboard) 25 | parser.add_argument('--steps_per_checkpoint', type=int, default=5000) 26 | parser.add_argument('--autostack', action="store_false") 27 | parser.add_argument('--auto_layout', action="store_true") 28 | parser.add_argument('--auto_layout_and_mesh_shape', action="store_true") 29 | parser.add_argument('--new', action='store_true') 30 | parser.add_argument('--test', action='store_true') 31 | parser.add_argument('--eval', action='store_true') 32 | parser.add_argument('--predict', action='store_true') 33 | parser.add_argument('--no_delete_tpu', action='store_true') 34 | parser.add_argument('--initial_heartbeat_timeout', type=int, default=7200) 35 | parser.add_argument('--heartbeat_timeout', type=int, default=1800) # kill and restart if nothing logged to tensorboard in this many seconds 36 | args = parser.parse_args() 37 | 38 | params = fetch_model_params(args.model) 39 | 40 | ex = sacred.Experiment(args.experiment_name) 41 | ex.observers.append(sacred.observers.QueuedMongoObserver(url='127.0.0.1:27017', db_name='db', username='user', password='password')) 42 | 43 | 44 | def get_open_port(lo=8000, hi=8100): 45 | for i in range(lo, hi): 46 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 47 | if s.connect_ex(('localhost', i)) != 0: 48 | return i 49 | 50 | 51 | def train_thread(args, tpu, id, q): 52 | print('starting training on', tpu) 53 | 54 | # pass binary flags through 55 | opts = '' 56 | for flag in ['auto_layout', 'auto_layout_and_mesh_shape', 'new', 'test', 'predict', 'eval', ]: 57 | if args.__getattribute__(flag): 58 | opts += ' --' + flag 59 | 60 | for flag in ['autostack', ]: 61 | if not args.__getattribute__(flag): 62 | opts += ' --' + flag 63 | 64 | cmd = "python3 main.py --tpu {tpu} --model run_configs/config_{id}.json --steps_per_checkpoint {steps_per_checkpoint} {opts} --sacred_id {run_id}".format(tpu=tpu, id=id, steps_per_checkpoint=args.steps_per_checkpoint, opts=opts, run_id=id) 65 | print('Running:', cmd) 66 | proc = subprocess.Popen(cmd, shell=True) 67 | 68 | # poll until it's exited 69 | while proc.poll() is None: 70 | time.sleep(60) 71 | try: 72 | nq, *nargs = q.get_nowait() 73 | if nq == 'kill': 74 | print('train thread recieved kill signal from logging thread') 75 | # first send SIGTERM 76 | proc.terminate() 77 | 78 | time.sleep(60) 79 | 80 | # if it still hasn't exited, we send SIGKILL 81 | if proc.poll() is None: 82 | print('SIGTERM not successful, sending SIGKILL') 83 | proc.kill() 84 | 85 | except queue.Empty: 86 | pass 87 | 88 | print('exited training!') 89 | if proc.returncode == 0: 90 | print('exited gracefully') 91 | os.kill(os.getpid(), signal.SIGINT) 92 | return 93 | 94 | if args.no_delete_tpu: 95 | print('recreate done, exiting train_thread - not killing tpu!') 96 | return 97 | print("Recreating {} in 60sec...".format(tpu)) 98 | time.sleep(60) 99 | os.system("pu recreate {} --yes --retry 3600 --retry-randomness 1.5".format(tpu)) 100 | print('recreate done, exiting train_thread') 101 | 102 | # clear out queue 103 | while True: 104 | try: 105 | q.get_nowait() 106 | print('dropped request in queue after pu recreate') 107 | except queue.Empty: 108 | break 109 | 110 | 111 | def get_json(uri, params=None, timeout=15): 112 | resp = requests.get(uri, params=params, timeout=timeout) 113 | resp.raise_for_status() 114 | return resp.json() 115 | 116 | 117 | def get_tag_sets(base_uri): 118 | j = get_json(f'{base_uri}/data/plugin/scalars/tags', {'experiment': ''}) 119 | assert isinstance(j, dict) 120 | return { 121 | run: j[run].keys() 122 | for run in j.keys() 123 | } 124 | 125 | 126 | def get_scalar_data(base_uri, run, tag): 127 | j = get_json(f'{base_uri}/data/plugin/scalars/scalars', {'experiment': '', 'run': run, 'tag': tag}) 128 | assert isinstance(j, list) 129 | return j 130 | 131 | 132 | def get_run_data(port): 133 | base_uri = f'http://localhost:{port}/' 134 | r = {} 135 | try: 136 | tag_sets = get_tag_sets(base_uri) 137 | runs = tag_sets.keys() 138 | if '.' in runs: 139 | if 'loss' in tag_sets['.']: 140 | r['loss'] = get_scalar_data(base_uri, '.', 'loss') 141 | if 'eval' in runs: 142 | if 'loss' in tag_sets['eval']: 143 | r['val_loss'] = get_scalar_data(base_uri, 'eval', 'loss') 144 | if 'eval_lambada' in runs: 145 | if 'lambada_acc' in tag_sets['eval_lambada']: 146 | r['lambada_acc'] = get_scalar_data(base_uri, 'eval_lambada', 'lambada_acc') 147 | if 'lambada_log_ppl' in tag_sets['eval_lambada']: 148 | r['lambada_ppl'] = [ 149 | [t, s, math.exp(lp)] 150 | for [t, s, lp] in get_scalar_data(base_uri, 'eval_lambada', 'lambada_log_ppl') 151 | ] 152 | except: 153 | import traceback 154 | traceback.print_exc() 155 | return r 156 | 157 | 158 | @ex.main 159 | def main(_run): 160 | print('Starting run', _run._id) 161 | print('experiment main invoked with argv:', " ".join(sys.argv)) 162 | print('WARNING: please remember to remove old metric log files from the model directory.') 163 | 164 | os.makedirs('run_configs', exist_ok=True) 165 | shutil.copy(args.model if args.model.endswith('.json') else 'configs/{}.json'.format(args.model), 'run_configs/config_{}.json'.format(_run._id)) 166 | 167 | tensorboard_port = get_open_port() 168 | print('Tensorboard at port:', tensorboard_port) 169 | print('Tensorboard url: ', 'http://eleutherai.bmk.sh:'+ str(tensorboard_port)) 170 | os.system("screen -S tensorboard_{} -d -m bash -c 'tensorboard --logdir {} --port {} --bind_all --reload_multifile=true || tensorboard --logdir {} --port {} --reload_multifile=true'".format(_run._id, params["model_path"], tensorboard_port,params["model_path"], tensorboard_port,)) 171 | atexit.register(goodbye, _run._id) 172 | 173 | curr_step = {} 174 | seen_predictions = set() 175 | 176 | heartbeat_timeout = args.initial_heartbeat_timeout * 2 177 | while True: 178 | last_tb_log_time = time.time() 179 | start_time = time.time() 180 | q = queue.Queue() 181 | trainthd = threading.Thread(target=train_thread, args=(args, args.tpu, _run._id, q)) 182 | trainthd.start() 183 | 184 | while trainthd.is_alive(): 185 | time.sleep(60) 186 | 187 | if start_time + args.initial_heartbeat_timeout < time.time(): 188 | # after initial args.initial_heartbeat_timeout grace period, now we want to set the timeout threshold much lower 189 | heartbeat_timeout = args.heartbeat_timeout 190 | 191 | print('Polling tensorboard for metrics...') 192 | data = get_run_data(tensorboard_port) 193 | for k in data.keys(): 194 | for ts, step, val in data[k]: 195 | if step <= curr_step.get(k, -1): 196 | continue 197 | _run.log_scalar(k, val, step) 198 | if k == 'loss': 199 | _run.log_scalar('tb_ts', ts, step) 200 | print('Logged to sacred: step={},loss={},tb_ts={}'.format(step, val, ts)) 201 | 202 | # found something new, so logging! 203 | last_tb_log_time = time.time() 204 | 205 | curr_step[k] = step 206 | 207 | for f in glob.glob('predictions_{}_*'.format(_run._id)): 208 | if f in seen_predictions: 209 | continue 210 | print('collecting prediction file', f) 211 | ex.add_artifact(f) 212 | 213 | seen_predictions.add(f) 214 | 215 | # collect eval metrics from jsonl 216 | if os.path.exists(f'eval_{_run._id}.jsonl'): 217 | with open(f'eval_{_run._id}.jsonl') as fh: 218 | for line in fh: 219 | ob = json.loads(line) 220 | val_step = ob['global_step'] 221 | val_task = ob['task'] 222 | for metr in ob.keys(): 223 | k = 'fs.' + val_task + '.' + metr 224 | if metr in ['task', 'global_step']: continue 225 | if val_step <= curr_step.get(k, -1): continue 226 | _run.log_scalar(k, ob[metr], val_step) 227 | curr_step[k] = val_step 228 | 229 | if time.time() - last_tb_log_time > heartbeat_timeout: 230 | # the run hasn't logged in a while, so we restart it 231 | q.put(('kill',)) 232 | 233 | # give training thread some time to do its thing and recreate tpu 234 | while trainthd.is_alive(): 235 | print('logging thread waiting for killing stalled run and for tpu recreate to finish') 236 | time.sleep(60) 237 | 238 | # reset heartbeat timeout to initial 239 | heartbeat_timeout = args.initial_heartbeat_timeout 240 | last_tb_log_time = time.time() 241 | 242 | 243 | if args.no_delete_tpu: 244 | break 245 | 246 | 247 | def goodbye(id): 248 | print("You are now leaving the Python sector.") 249 | print("Sie verlassen den pythonischen Sektor.") 250 | 251 | os.system("screen -S tensorboard_{} -X quit".format(id)) 252 | 253 | 254 | if __name__ == '__main__': 255 | for file in glob.glob("**/*", recursive=True): 256 | if file.split('.')[-1] in ['py']: 257 | print('Adding', file, 'to sacred') 258 | ex.add_source_file(file) 259 | 260 | ex.add_config({ 261 | 'tpu_name': args.tpu, 262 | **params 263 | }) 264 | 265 | ex.run() 266 | -------------------------------------------------------------------------------- /models/gpt2/gpt2.py: -------------------------------------------------------------------------------- 1 | """GPT-like model in Mesh-Tensorflow""" 2 | import tensorflow.compat.v1 as tf 3 | import mesh_tensorflow.transformer as mtf_transformer 4 | 5 | from models.utils import parse_inputs, entmax_cross_entropy_with_logits 6 | from models.layers import * 7 | 8 | 9 | # -------------------------------------------------------------------------------- 10 | # TRANSFORMER BLOCK: 11 | 12 | def block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, pos_emb, variable_dtype, context=None): 13 | use_mlp_glu = params["mlp_glu"] == True 14 | use_scale_norm = params["scalenorm"] == True 15 | use_moe = exists(params["moe_layers"]) and (layer_num in params["moe_layers"]) 16 | use_rezero = params["rezero"] == True 17 | macaron_attention = params["macaron"] == True 18 | 19 | def fn(x): 20 | with tf.variable_scope(scope): 21 | nx = x.shape[-1] # Grab last dimension from input 22 | 23 | if use_rezero: 24 | prenorm = identity 25 | elif use_scale_norm: 26 | prenorm = scale_norm 27 | else: 28 | prenorm = layer_norm 29 | 30 | pre_residual_fn = rezero if use_rezero else identity 31 | 32 | attention_type = params["attention_types"][layer_num] 33 | 34 | if macaron_attention: 35 | mult = 0.5 36 | mlp_fn = mlp_glu if use_mlp_glu else mlp 37 | intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) 38 | # Define intermediate layer of mlp - to split 39 | dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size) 40 | m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) 41 | 42 | x = x + (m * mult) 43 | else: 44 | mult = 1 45 | 46 | if attention_type != "none": 47 | res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params) 48 | a = attn(res_x, "attn", nx, attention_type=attention_type, 49 | params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim, 50 | variable_dtype=variable_dtype, context=context, pos_emb=pos_emb) 51 | else: 52 | a = x 53 | 54 | x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype) 55 | 56 | res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params) 57 | 58 | if use_moe: 59 | moe_params = mtf.transformer.moe.HParams() 60 | mtf.transformer.moe.set_default_moe_hparams(moe_params) 61 | moe_params.add_hparam("moe_min_expert_capacity", 1) 62 | moe_params.add_hparam("moe_use_experts_attention", False) 63 | 64 | # Override defaults 65 | for k, v in params["moe_params"].items(): 66 | moe_params.add_hparam(k, v) 67 | 68 | moe_train = params["mode"] == "train" 69 | 70 | m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params, 71 | train=moe_train, 72 | mesh_shape=params["mesh_shape"], 73 | layout=params["layout"], 74 | activation=params.get("moe_activation", 75 | "relu"), 76 | variable_dtype=variable_dtype, 77 | num_microbatches=params["num_microbatches"]) 78 | m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout") 79 | else: 80 | 81 | mlp_fn = mlp_glu if use_mlp_glu else mlp 82 | intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2) 83 | 84 | # Define intermediate layer of mlp - to split 85 | dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size) 86 | 87 | m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) 88 | aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype) 89 | 90 | x = x + pre_residual_fn((m * mult), "norm_rezero_2", variable_dtype) 91 | return x, aux_loss 92 | 93 | return fn 94 | 95 | 96 | # -------------------------------------------------------------------------------- 97 | # GPT2 MODEL: 98 | 99 | def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): 100 | """A GPT style model implemented in mesh tensorflow.""" 101 | 102 | x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features) 103 | 104 | if is_incremental_inference(context): 105 | # reshape inputs if in inference mode 106 | x = mtf.gather(x, context.position - 1, sequence_dim) 107 | x = mtf.reshape(x, [batch_dim]) 108 | 109 | use_axial_pos_emb = exists(params["axial_pos_emb"]) 110 | use_rotary_emb = exists(params["rotary_emb"]) 111 | 112 | # Text encoding 113 | wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), 114 | initializer=tf.random_normal_initializer(stddev=0.02), 115 | master_dtype=variable_dtype.master_dtype, 116 | slice_dtype=variable_dtype.slice_dtype, 117 | activation_dtype=variable_dtype.activation_dtype) 118 | 119 | with tf.variable_scope("token_embd"): 120 | # Text embedding 121 | h = mtf.gather(wte, x, vocab_dim) 122 | if params["embed_dropout"] > 0 and params["mode"] == "train": 123 | h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") 124 | 125 | # Position encoding 126 | 127 | if use_rotary_emb: 128 | wpe = None 129 | layer_pos_emb = rotary_positional_emb(mesh, sequence_dim, params, variable_dtype) 130 | elif use_axial_pos_emb: 131 | wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) 132 | layer_pos_emb = None 133 | else: 134 | # Use standard position encoding 135 | wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), 136 | initializer=tf.random_normal_initializer(stddev=0.01), 137 | master_dtype=variable_dtype.master_dtype, 138 | slice_dtype=variable_dtype.slice_dtype, 139 | activation_dtype=variable_dtype.activation_dtype) 140 | layer_pos_emb = None 141 | 142 | if exists(wpe): 143 | with tf.variable_scope("pos_embd"): 144 | # Positional embedding 145 | position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( 146 | context.position - 1) 147 | pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) 148 | if params["embed_dropout"] > 0 and params["mode"] == "train": 149 | pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") 150 | h += pos_emb 151 | 152 | aux_losses = 0 # instantiate auxiliary losses (for MOE models) 153 | 154 | for layer in range(params["n_layer"]): 155 | # attn blocks 156 | share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True 157 | block_scope = f"h{layer}" if not share_parameters else "" 158 | 159 | block_fn = block(params=params, scope=block_scope, layer_num=layer, 160 | bias=other_features["attn_bias"], 161 | sequence_dim=sequence_dim, 162 | memory_length_dim=other_features["memory_length_dim"], 163 | pos_emb = layer_pos_emb, 164 | variable_dtype=variable_dtype, 165 | context=context) 166 | 167 | # If true and in train mode, enable gradient checkpointing 168 | recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True 169 | h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h]) 170 | aux_losses += loss 171 | 172 | no_weight_tie_emb = params["no_weight_tie"] == True 173 | if no_weight_tie_emb: 174 | with tf.variable_scope("wte_final_linear"): 175 | logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) 176 | else: 177 | # Layer normalize & affine transform 178 | h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) 179 | seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1) 180 | with tf.variable_scope("wte_final_einsum"): 181 | # Equivalent to tf.matmul 182 | logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) 183 | 184 | if params["mode"] in ["train", "eval"]: 185 | labels = mtf_features["labels"] 186 | z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy 187 | 188 | # Go to full precision for the logits 189 | logits = mtf.cast(logits, tf.float32) 190 | 191 | use_entmax_loss = params.get("entmax_loss", False) 192 | loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits 193 | 194 | with tf.variable_scope("xentropy_final"): 195 | loss_batch = loss_fn(logits=logits, targets=labels, 196 | vocab_dim=logits.shape[-1], z_loss=z_loss) 197 | 198 | # For non-autoregressive models (masked language modeling training) 199 | # Make sure labels with padding tokens are not counted in the loss 200 | if not params["causal"]: 201 | padding_id = params.get("padding_id", 0) 202 | loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) 203 | 204 | with tf.variable_scope("reduce_mean_final"): 205 | loss = mtf.reduce_mean(loss_batch) 206 | 207 | loss += aux_losses # Add on auxiliary losses (currently only used for MoE) 208 | loss /= params["num_microbatches"] 209 | # Convert to train dtype 210 | loss = mtf.cast(loss, variable_dtype.slice_dtype) 211 | else: 212 | loss = None 213 | loss_batch = None 214 | 215 | # Cast back to checkpoint dtype 216 | logits = mtf.cast(logits, variable_dtype.master_dtype) 217 | return logits, loss, loss_batch 218 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from urllib.parse import urlparse 3 | from shutil import rmtree 4 | import logging 5 | import os 6 | from pathlib import Path 7 | import sys 8 | import tensorflow.compat.v1 as tf 9 | import tensorflow.compat.v2 as tf2 10 | import mesh_tensorflow as mtf 11 | import mesh_tensorflow.auto_mtf 12 | from data.encoders import fetch_encoder 13 | import re 14 | 15 | def setup_logging(args): 16 | Path("logs").mkdir(exist_ok=True) 17 | tf.logging.set_verbosity(logging.INFO) 18 | tf.get_logger().propagate = False # Remove double log on console 19 | name = os.path.splitext(os.path.basename(args.model))[0] 20 | handlers = [ 21 | logging.FileHandler(f"logs/{name}.log"), 22 | logging.StreamHandler(sys.stdout) 23 | ] 24 | logger = logging.getLogger("tensorflow") 25 | logger.handlers = handlers 26 | return logger 27 | 28 | 29 | def get_batch_size(params): 30 | return params[f"{params['mode']}_batch_size"] 31 | 32 | 33 | def add_mode_to_params(params, mode): 34 | if mode == tf.estimator.ModeKeys.PREDICT: 35 | params["mode"] = "predict" 36 | elif mode == tf.estimator.ModeKeys.EVAL: 37 | params["mode"] = "eval" 38 | elif mode == tf.estimator.ModeKeys.TRAIN: 39 | params["mode"] = "train" 40 | else: 41 | raise ValueError(f"Invalid mode {mode}") 42 | return params 43 | 44 | 45 | def simd_mesh_setup(params, mesh_shape, layout_rules): 46 | """Constructs SimdMesh function - instructions on how to evenly split tensors across all TPU cores""" 47 | 48 | num_hosts = params["context"].num_hosts 49 | host_placement_fn = params["context"].tpu_host_placement_function 50 | device_list = [host_placement_fn(host_id=i) for i in range(num_hosts)] 51 | tf.logging.info(f"device_list = {device_list}") 52 | 53 | # TODO: Better estimation of replica cache size? 54 | replica_cache_size = 300 * 1000000 # 300M per replica 55 | 56 | # Worker 0 caches all the TPU binaries 57 | worker0_mem = replica_cache_size * params["context"].num_replicas 58 | devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) 59 | var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memory_usage) 60 | mesh_devices = [""] * mesh_shape.size 61 | mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( 62 | mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment) 63 | 64 | return var_placer, mesh_impl 65 | 66 | 67 | def remove_batch_from_layout(layout): 68 | """ 69 | The tf-mesh layout splits across batch size, remove it. 70 | Useful for prediction steps, when you no longer want large batches. 71 | 72 | :param layout: string describing tf-mesh layout 73 | :return: layout minus batch dimension 74 | """ 75 | layout = layout.split(',') 76 | ret_layout = "" 77 | for i in layout: 78 | if "batch" in i: 79 | pass 80 | else: 81 | ret_layout += f"{i}," 82 | return ret_layout[:-1] 83 | 84 | 85 | def yes_or_no(question): 86 | while True: 87 | reply = str(input(question+' (y/n): ')).lower().strip() 88 | if reply[:1] == 'y': 89 | return True 90 | if reply[:1] == 'n': 91 | return False 92 | 93 | 94 | def remove_gs_or_filepath(path): 95 | parsed_url = urlparse(path) 96 | if parsed_url.scheme == "gs": 97 | os.system(f"gsutil rm -rf {path}") 98 | return 99 | rmtree(path) 100 | 101 | 102 | def save_config(params_dict, logdir): 103 | print(f"Saving config to {logdir}") 104 | text = "{\n\n" 105 | total_params = len(params_dict) 106 | for count, key in enumerate(params_dict): 107 | config_value = str(params_dict[key]) 108 | if re.search('[a-zA-Z]', config_value): 109 | if config_value.lower() != 'true': 110 | if config_value.lower() != 'false': 111 | if config_value[0] != '[': 112 | # TODO: Making a manual exception for parsing epsilon right now since it's the only number in 113 | # scientific notation. Should fix this. 114 | if key != "epsilon": 115 | config_value = f'"{config_value}"' 116 | if count == total_params - 1: 117 | text += f'"{str(key)}"' + ' : ' + config_value + '\n\n' 118 | else: 119 | text += f'"{str(key)}"' + ' : ' + config_value + ',\n\n' 120 | text += '\n\n}' 121 | sess = tf.InteractiveSession() 122 | summary_op = tf.summary.text("run_config", tf.convert_to_tensor(text)) 123 | summary_writer = tf.summary.FileWriter(f"{logdir}/config", sess.graph) 124 | text = sess.run(summary_op) 125 | summary_writer.add_summary(text, 0) 126 | summary_writer.flush() 127 | summary_writer.close() 128 | tf.reset_default_graph() 129 | print('Done!') 130 | 131 | 132 | def expand_attention_types_params(params_list): 133 | newlist = [] 134 | for item in params_list: 135 | for _ in range(item[1]): 136 | newlist.extend(item[0]) 137 | return newlist 138 | 139 | 140 | def get_n_trainable_vars(graph): 141 | """ 142 | Gets number of trainable vars in a MTF model. 143 | 144 | :param graph: Mesh-Tensorflow graph 145 | :return: None 146 | """ 147 | total_parameters = 0 148 | for variable in graph.trainable_variables: 149 | shape = variable.shape.dims 150 | variable_parameters = 1 151 | for dim in shape: 152 | variable_parameters *= dim.size 153 | total_parameters += variable_parameters 154 | print(f"\n\nN TRAINABLE VARS:\n{total_parameters:,}\n\n") 155 | 156 | 157 | def print_dim_names(graph): 158 | """ 159 | Print names of all Dimensions 160 | :param graph: Mesh-Tensorflow graph 161 | :return: None 162 | """ 163 | all_dim_names = [] 164 | for variable in graph.all_variables: 165 | names = variable.shape.dimension_names 166 | all_dim_names.append(names) 167 | 168 | # Print all dim names in graph & write to file 169 | all_dim_names = [item for sublist in all_dim_names for item in sublist] # Flatten all dims 170 | unique_dims = list(set(all_dim_names)) 171 | print("ALL DIM NAMES:") 172 | for dim_name in unique_dims: 173 | print(dim_name) 174 | print('\n') 175 | 176 | 177 | def get_graph_info(graph): 178 | """ 179 | Wrapper fn that calculates number of trainable vars in an MTF graph & prints all dim_names to file 180 | TODO: how to get un-trainable dim-names too, batch etc. 181 | 182 | :param graph: Mesh-Tensorflow graph 183 | :return: None 184 | """ 185 | get_n_trainable_vars(graph) 186 | print_dim_names(graph) 187 | 188 | 189 | def loss_denominator(targets, num_microbatches): 190 | """Denominator applied to losses. 191 | 192 | This is usually the size of the targets tensor (omitting ensemble 193 | dimensions). Alternatively, it is an override value passed to the 194 | class constructor. 195 | 196 | Args: 197 | targets: a mtf.Tensor 198 | num_microbatches: an integer - greater than one if the step has been 199 | serialized into multiple microbatches to save memory. 200 | Returns: 201 | a float 202 | """ 203 | ret = float(targets.shape.size) * num_microbatches 204 | return float(ret) 205 | 206 | def check_dataset(input_fn, params, global_step=None): 207 | tf.enable_eager_execution() 208 | if global_step is not None: 209 | dataset = input_fn(params, global_step=global_step) 210 | else: 211 | dataset = input_fn(params) 212 | dataset_iter = dataset.make_one_shot_iterator() 213 | tensor, _ = next(dataset_iter) 214 | enc = fetch_encoder(params) 215 | 216 | for p in tensor[:1]: 217 | txt = enc.decode(p) 218 | 219 | print('-' * 50) 220 | print(txt[:500], '\n\n...\n\n', txt[-500:]) 221 | print('-' * 50) 222 | exit() 223 | 224 | def auto_layout(graph, mesh_shape, logits, loss): 225 | layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) 226 | print(f"Auto-selected layout:\n{layout_rules}\nRe-initialize graph with selected layout") 227 | quit() 228 | 229 | def auto_layout_and_mesh_shape(graph, num_cores, logits, loss): 230 | layout_rules, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(graph, num_cores, 231 | [logits, loss], max_mesh_shape_dimensions=4) 232 | print(f"Num cores:\n{num_cores}\nAuto-selected layout:\n{layout_rules}\nAuto-selected mesh shape:\n{mesh_shape}" \ 233 | f"\nRe-initialize graph with selected layout & mesh shape") 234 | quit() 235 | 236 | def create_host_call(model_dir): 237 | """Construct a host_call writing scalar summaries. 238 | 239 | Borrowed from t2t. 240 | 241 | Args: 242 | model_dir: String containing path to train 243 | Returns: 244 | (fn, args) Pair to be called by TPUEstimator as the host_call. 245 | """ 246 | 247 | graph = tf.get_default_graph() 248 | # A list of (name, lowered tensor) tuples 249 | summaries = graph.get_collection(mtf.utils.SCALAR_SUMMARIES_COLLECTION_KEY) 250 | 251 | def maybe_cast(tensor): 252 | assert tensor.shape.is_compatible_with([]), tensor.name 253 | if tensor.dtype == tf.int64: 254 | return tf.to_int32(tensor) 255 | if tensor.dtype == tf.bfloat16: 256 | return tf.cast(tensor, tf.float32) 257 | return tensor 258 | 259 | reshaped_tensors = [tf.reshape(maybe_cast(t), [1]) for _, t in summaries] 260 | 261 | # When no supported summaries are found, don't create host_call. Otherwise, 262 | # TPU outfeed queue would enqueue global_step while host_call doesn't dequeue 263 | # it, eventually causing hang. 264 | if not reshaped_tensors: 265 | return None 266 | 267 | def host_call_fn(global_step, *args): 268 | """Training host call. Creates scalar summaries for training metrics.""" 269 | # This function is executed on the CPU and should not directly reference 270 | # any Tensors in the rest of the `model_fn`. To pass Tensors from the 271 | # model to the `model_fn`, provide as part of the `host_call`. 272 | global_step = tf.cast(global_step[0], tf.int64) 273 | with tf2.summary.create_file_writer(model_dir).as_default(): 274 | # We cannot directly use any tensor from summaries, because each 275 | # tensor here must be a concat of multiple tensors from all shards. 276 | # Therefore, we rely on the assumption that args wil have the same 277 | # length as summaries, and all tensors in args will have the same 278 | # order of self._tup_summaries. 279 | assert len(args) == len(summaries) 280 | for i, tensor in enumerate(args): 281 | name = summaries[i][0] 282 | tf2.summary.scalar(name, tf.reduce_mean(tensor), step=global_step) 283 | return tf.summary.all_v2_summary_ops() 284 | 285 | global_step_t = tf.reshape(tf.to_int32(tf.train.get_global_step()), [1]) 286 | return host_call_fn, [global_step_t] + reshaped_tensors 287 | 288 | 289 | def natural_sort(l): 290 | convert = lambda text: int(text) if text.isdigit() else text.lower() 291 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 292 | return sorted(l, key = alphanum_key) 293 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """GPT-like model in Mesh-Tensorflow""" 2 | 3 | from functools import partial 4 | import mesh_tensorflow as mtf 5 | import tensorflow.compat.v1 as tf 6 | from tensorflow.python.tpu import tpu_config, tpu_estimator 7 | from tensorflow_estimator.python.estimator import estimator as estimator_lib 8 | from utils import save_config, expand_attention_types_params, yes_or_no, remove_gs_or_filepath, setup_logging, \ 9 | check_dataset 10 | from inputs import sequential_input, pred_input, handle_pred_output, mlm_sample_text, generic_text 11 | from export import export_model 12 | from model_fns import model_fn 13 | from data.encoders import fetch_encoder 14 | from configs import fetch_model_params 15 | from tasks import task_descriptors 16 | import argparse 17 | import json 18 | import numpy 19 | 20 | 21 | def parse_args(): 22 | # Parse command line arguments 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on, if any.") 25 | parser.add_argument("--gpu_ids", nargs="+", type=str, default=["device:GPU:0"], 26 | help="If training on GPU, can specify your GPU names in a list - i.e 'device:GPU:0 device:GPU:1'") 27 | parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.") 28 | parser.add_argument("--steps_per_checkpoint", type=int, default=5000, help="Save a model checkpoint every X steps.") 29 | parser.add_argument("--auto_layout", action="store_true", help="If set, generates and prints the most memory " 30 | "efficient layout according to MTF auto layout.") 31 | parser.add_argument("--auto_layout_and_mesh_shape", action="store_true", 32 | help="If set, generates and prints the most memory efficient layout and mesh shape according to" 33 | " MTF auto layout.") 34 | parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and " 35 | "starts a new training run") 36 | parser.add_argument("--predict", action="store_true", help="If set, uses the model to predict rather than train.") 37 | parser.add_argument("--eval", action="store_true", help="If set, run model in evaluation mode.") 38 | parser.add_argument("--prompt", type=str, help="path to .txt file containing a prompt for prediction. If empty, " 39 | "defaults to unicorns.", 40 | default="") 41 | parser.add_argument("--check_dataset", action="store_true", 42 | help="If set, outputs sample from the dataset and quits.") 43 | parser.add_argument("--sacred_id", type=str, default="nosacred", help="Sacred run id.") 44 | parser.add_argument("--entmax_sampling", action="store_true", help="(experimental) use entmax sampling") 45 | parser.add_argument("--export", action="store_true", help="If set, will export the model.") 46 | args = parser.parse_args() 47 | assert args.model is not None, "Model must be set" 48 | return args 49 | 50 | 51 | def main(args): 52 | # Setup logging 53 | logger = setup_logging(args) 54 | 55 | # Read params of model 56 | params = fetch_model_params(args.model) 57 | 58 | # Fetch appropriate input functions 59 | input_fn = params.get("input_fn", "sequential_input") 60 | if input_fn == "sequential_input": 61 | input_fn = sequential_input 62 | elif input_fn == "generic_text": 63 | input_fn = generic_text 64 | pred_input_fn = pred_input 65 | handle_pred_output_fn = handle_pred_output 66 | 67 | # get current step 68 | current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params["model_path"])) 69 | logger.info(f"Current step {current_step}") 70 | 71 | if params["mlm_training"]: 72 | mlm_sample_text_fn = partial(mlm_sample_text, params) 73 | input_fn = partial(generic_text, sample_text_fn=mlm_sample_text_fn) 74 | if args.check_dataset: 75 | check_dataset(input_fn, params) 76 | 77 | 78 | # Fetch encoder per params 79 | encoder = fetch_encoder(params) 80 | 81 | pred_input_fn = partial(pred_input_fn, path_to_prompt=args.prompt, logger=logger, enc=encoder) 82 | 83 | # Sample from Dataset if check dataset flag is on 84 | if args.check_dataset: 85 | check_dataset(input_fn, params, global_step=current_step) 86 | 87 | # Confirm deletion of checkpoint files if --new flag is set 88 | if args.new: 89 | if yes_or_no(f"Are you sure you want to remove '{params['model_path']}' to start afresh?"): 90 | remove_gs_or_filepath(params["model_path"]) 91 | else: 92 | exit() 93 | 94 | # Save config to logdir for experiment management 95 | save_config(params, params["model_path"]) 96 | 97 | # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores 98 | mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) 99 | params["num_cores"] = mesh_shape.size 100 | params["auto_layout"] = args.auto_layout 101 | params["auto_layout_and_mesh_shape"] = args.auto_layout_and_mesh_shape 102 | params["use_tpu"] = True if not args.tpu is None else False 103 | params["gpu_ids"] = args.gpu_ids 104 | params["steps_per_checkpoint"] = args.steps_per_checkpoint 105 | # Expand attention types param 106 | params["attention_types"] = expand_attention_types_params(params["attention_types"]) 107 | assert len(params["attention_types"]) == params["n_layer"] # Assert that the length of expanded list = num layers 108 | params["predict_batch_size"] = params.get("predict_batch_size", 1) # Default to 1 109 | params["predict"] = args.predict 110 | params['model'] = params.get("model", "GPT") # Default model selection to GPT since it's the only option for now 111 | params["export"] = args.export 112 | # Set sampling parameters 113 | params["sampling_use_entmax"] = args.entmax_sampling 114 | 115 | # Sample quality of MoE models suffers when using the faster sampling method, so default to slow_sampling if 116 | # moe layers are present 117 | params["slow_sampling"] = True if params["moe_layers"] is not None else False 118 | 119 | logger.info(f"params = {params}") 120 | 121 | # Get eval tasks from params 122 | eval_tasks = params.get("eval_tasks", []) 123 | has_predict_or_eval_steps_or_eval_tasks = params["predict_steps"] > 0 or params["eval_steps"] > 0 or len( 124 | eval_tasks) > 0 125 | 126 | for t in eval_tasks: 127 | assert t in task_descriptors, f"Eval task '{t}' is not known" 128 | task_descriptors[t]["init_fn"](params) 129 | 130 | # Set up TPUs and Estimator 131 | if args.tpu == "colab": 132 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() if params["use_tpu"] else None 133 | else: 134 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) if params["use_tpu"] else None 135 | 136 | config = tpu_config.RunConfig( 137 | cluster=tpu_cluster_resolver, 138 | model_dir=params["model_path"], 139 | save_checkpoints_steps=None, # Disable the default saver 140 | save_checkpoints_secs=None, # Disable the default saver 141 | log_step_count_steps=params["iterations"], 142 | save_summary_steps=params["iterations"], 143 | tpu_config=tpu_config.TPUConfig( 144 | num_shards=mesh_shape.size, 145 | iterations_per_loop=params["iterations"], 146 | num_cores_per_replica=1, 147 | per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)) 148 | 149 | estimator = tpu_estimator.TPUEstimator( 150 | use_tpu=params["use_tpu"], 151 | model_fn=model_fn, 152 | config=config, 153 | train_batch_size=params["train_batch_size"], 154 | eval_batch_size=params["train_batch_size"], 155 | predict_batch_size=params["predict_batch_size"], 156 | params=params) 157 | 158 | def _make_task_estimator(task): 159 | task_params = params.copy() 160 | task_params["eval_task"] = task 161 | return tpu_estimator.TPUEstimator( 162 | use_tpu=params["use_tpu"], 163 | model_fn=model_fn, 164 | config=config, 165 | train_batch_size=params["train_batch_size"], 166 | eval_batch_size=params["eval_batch_size"], 167 | predict_batch_size=params["predict_batch_size"], 168 | params=task_params) 169 | 170 | eval_task_estimators = { 171 | task: _make_task_estimator(task) 172 | for task in eval_tasks 173 | } 174 | 175 | if args.export: 176 | export_model(estimator, "export", params) 177 | return 178 | 179 | if args.predict: 180 | # Predict 181 | predictions = estimator.predict(input_fn=pred_input_fn) 182 | logger.info("Predictions generated") 183 | enc = fetch_encoder(params) 184 | handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}") 185 | return 186 | 187 | def save_eval_results(task, eval_results): 188 | def as_python(x): 189 | if isinstance(x, numpy.generic): 190 | return x.item() 191 | return x 192 | eval_results = {k: as_python(v) for k, v in eval_results.items()} 193 | with open(f'eval_{args.sacred_id}.jsonl', 'a') as fh: 194 | json.dump({'task': task, 'current_step': current_step, **eval_results}, fh) 195 | fh.write('\n') 196 | 197 | def run_eval(): 198 | logger.info("Running evaluation...") 199 | eval_results = estimator.evaluate( 200 | input_fn=partial(input_fn, eval=True), 201 | steps=params["eval_steps"]) 202 | logger.info(f"Eval results: {eval_results}") 203 | save_eval_results('validation', eval_results) 204 | 205 | def run_eval_tasks(): 206 | for task in eval_tasks: 207 | logger.info(f"Starting evaluation task '{task}'") 208 | task_info = task_descriptors[task]["get_task_info_fn"](params) 209 | task_estimator = eval_task_estimators[task] 210 | task_input_fn = task_descriptors[task]["input_fn"] 211 | eval_results = task_estimator.evaluate( 212 | input_fn=task_input_fn, 213 | steps=task_info["n_steps"], 214 | name=task) 215 | logger.info(f"Eval task '{task}' results: {eval_results}") 216 | save_eval_results(task, eval_results) 217 | 218 | if args.eval: 219 | run_eval_tasks() 220 | if params["eval_steps"] > 0: 221 | run_eval() 222 | return 223 | 224 | 225 | elif has_predict_or_eval_steps_or_eval_tasks: 226 | # Eval and train - stop and predict and/or eval every checkpoint 227 | while current_step < params["train_steps"]: 228 | next_checkpoint = min(current_step + args.steps_per_checkpoint, 229 | params["train_steps"]) 230 | 231 | estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=next_checkpoint) 232 | current_step = next_checkpoint 233 | 234 | if params["predict_steps"] > 0: 235 | logger.info("Running prediction...") 236 | predictions = estimator.predict(input_fn=pred_input_fn) 237 | enc = fetch_encoder(params) 238 | handle_pred_output_fn(predictions, logger, enc, params, out_name=f"predictions_{args.sacred_id}_{current_step}") 239 | 240 | if params["eval_steps"] > 0: 241 | run_eval() 242 | 243 | if eval_tasks: 244 | run_eval_tasks() 245 | 246 | return 247 | else: 248 | # Else, just train 249 | while current_step < params["train_steps"]: 250 | # Else, don't stop and restart 251 | estimator.train(input_fn=partial(input_fn, global_step=current_step, eval=False), max_steps=params["train_steps"]) 252 | 253 | 254 | if __name__ == "__main__": 255 | tf.disable_v2_behavior() 256 | args = parse_args() 257 | main(args) 258 | -------------------------------------------------------------------------------- /data/create_tfrecords.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import ftfy 6 | import tensorflow as tf 7 | from lm_dataformat import Reader 8 | from tokenizers import Tokenizer 9 | from transformers import GPT2TokenizerFast 10 | from tqdm import tqdm 11 | import logging 12 | from multiprocessing import Pool, cpu_count 13 | from itertools import repeat 14 | import re 15 | 16 | logging.getLogger("transformers").setLevel(logging.ERROR) 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are " 20 | "treated as archives, all others as raw text.") 21 | parser.add_argument("--files_per", type=int, default=100000, help="Text files per tfrecord") 22 | parser.add_argument("--name", type=str, default="openwebtext", 23 | help="Name of output files will be name_i.tfrecords where i is the number of the file") 24 | parser.add_argument("--output_dir", type=str, default="./tfrecords", help="Where to put tfrecords") 25 | parser.add_argument("--encoder_path", type=str, 26 | help="Path to encoder files, or leave unspecified to use GPT2 tokenizer") 27 | parser.add_argument("--minimum_size", type=int, default=100, help="Minimum size a document has to be to be included") 28 | parser.add_argument("--ftfy", action="store_false", help="normalize with ftfy") 29 | parser.add_argument("--wikitext-detokenize", action="store_false", help="use wikitext detokenizer") 30 | parser.add_argument("--separator", nargs="+", type=int, default=[50256], 31 | help="separator to place between files in chunk mode") 32 | parser.add_argument("--chunk_size", type=int, default=2048, help="How big a chunk should be in chunk mode. " 33 | "Should equal your model's context size") 34 | parser.add_argument("--write_dataset_config", action="store_true", help="Write the dataset config file on completion") 35 | parser.add_argument("--processes", type=int, default=0, help="Number of processes to use. Defaults to cpu count.") 36 | 37 | args = parser.parse_args() 38 | if not args.output_dir.endswith("/"): 39 | args.output_dir = args.output_dir + "/" 40 | if not args.input_dir.endswith("/"): 41 | args.input_dir = args.input_dir + "/" 42 | assert len(args.separator) == 1 43 | 44 | 45 | def wikitext_detokenizer(string): 46 | # contractions 47 | string = string.replace("s '", "s'") 48 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 49 | # number separators 50 | string = string.replace(" @-@ ", "-") 51 | string = string.replace(" @,@ ", ",") 52 | string = string.replace(" @.@ ", ".") 53 | # punctuation 54 | string = string.replace(" : ", ": ") 55 | string = string.replace(" ; ", "; ") 56 | string = string.replace(" . ", ". ") 57 | string = string.replace(" ! ", "! ") 58 | string = string.replace(" ? ", "? ") 59 | string = string.replace(" , ", ", ") 60 | # double brackets 61 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 62 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 63 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 64 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 65 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 66 | # miscellaneous 67 | string = string.replace("= = = =", "====") 68 | string = string.replace("= = =", "===") 69 | string = string.replace("= =", "==") 70 | string = string.replace(" " + chr(176) + " ", chr(176)) 71 | string = string.replace(" \n", "\n") 72 | string = string.replace("\n ", "\n") 73 | string = string.replace(" N ", " 1 ") 74 | string = string.replace(" 's", "'s") 75 | 76 | return string 77 | 78 | 79 | def _int64_feature(value): 80 | """ 81 | Returns an int64_list from a bool / enum / int / uint. 82 | """ 83 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 84 | 85 | 86 | def write_to_file(writer, data): 87 | """ 88 | writes data to tfrecord file 89 | """ 90 | feature = { 91 | "text": _int64_feature(data) 92 | } 93 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) 94 | writer.write(tf_example.SerializeToString()) 95 | 96 | 97 | def get_tokenizer(args): 98 | if args.encoder_path is None: 99 | return GPT2TokenizerFast.from_pretrained('gpt2') 100 | else: 101 | return Tokenizer.from_file(args.encoder_path) 102 | 103 | 104 | def split_list(l, n): 105 | # splits list/string into n size chunks 106 | return [l[i:i + n] for i in range(0, len(l), n)] 107 | 108 | 109 | def archive_to_tokens(f, encoder, args, prefix=[]): 110 | # Generator that yields the contents of the files in an archive 111 | # if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data 112 | reader = Reader(f) 113 | for doc in reader.stream_data(threaded=False): 114 | if args.ftfy: # fix text with ftfy if specified 115 | doc = ftfy.fix_text(doc, normalization='NFKC') 116 | if args.wikitext_detokenize: 117 | doc = wikitext_detokenizer(doc) 118 | doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token 119 | yield split_list(prefix + doc, args.chunk_size) # split into n_ctx + 1 size chunks 120 | prefix = [] 121 | 122 | 123 | def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None): 124 | # writes a list of files to .tfrecords 125 | if files == None: 126 | return 127 | chunks = split_list(files, files_per) 128 | if not chunks: 129 | return 130 | 131 | if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per 132 | remainder = chunks.pop(-1) 133 | else: 134 | remainder = None # assuming files = remainder from an old chunk here 135 | files_per = len(chunks[-1]) 136 | 137 | for files in chunks: 138 | fp = f"{output_dir}/{out_name}_{start_no}" 139 | if process_no is not None: 140 | fp += f"_{process_no}" 141 | fp += f"_{files_per}" # add number of files in tfrecord to end of fp 142 | fp += ".tfrecords" 143 | with tf.io.TFRecordWriter(fp) as writer: 144 | for f in files: 145 | write_to_file(writer, f) 146 | start_no += 1 147 | return start_no, remainder 148 | 149 | 150 | def get_files(input_dir, filetypes=None): 151 | # gets all files of in input_dir 152 | if filetypes == None: 153 | filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"] 154 | files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes] 155 | # flatten list of list -> list and stringify Paths 156 | flattened_list = [str(item) for sublist in files for item in sublist] 157 | if not flattened_list: 158 | raise Exception(f"""did not find any files at this path {input_dir},\ 159 | please also ensure your files are in format {filetypes}""") 160 | return flattened_list 161 | 162 | 163 | def read_checkpoint(checkpoint_path, resume_from_checkpoint=True): 164 | # init checkpointing 165 | if resume_from_checkpoint and os.path.isfile(checkpoint_path): 166 | try: 167 | resume_files_processed, tfrecord_count = [int(i) for i in open(checkpoint_path, "r").read().split(", ")] 168 | print(f"\nResuming from tfrecord no. {tfrecord_count} / file no. {resume_files_processed}") 169 | return resume_files_processed, tfrecord_count 170 | except: 171 | pass 172 | return 0, 0 173 | 174 | 175 | def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False, 176 | resume_from_checkpoint=False, display_pbar=False): 177 | # iterates through files in input_dir, splitting into chunks and saving a tfrecords file every chunks. 178 | files, args, process_no = params 179 | enc = get_tokenizer(args) # get tokenizer 180 | 181 | # init metadata 182 | discarded_files = 0 183 | files_processed = 0 184 | pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ", 185 | disable=not display_pbar) 186 | checkpoint_path = f"{args.output_dir}/checkpoint.txt" 187 | resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint) 188 | 189 | data_to_prepend = [] 190 | tokenized_files_array = [] 191 | 192 | for f in files: 193 | for tokenized_files in archive_to_tokens(f, enc, args, prefix=data_to_prepend): 194 | files_processed += 1 195 | if files_processed < resume_files_processed: 196 | continue # resume from checkpoint 197 | 198 | # if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file 199 | data_to_prepend = [] 200 | n_tokens = len(tokenized_files[-1]) 201 | if n_tokens < args.chunk_size: 202 | data = tokenized_files.pop(-1) 203 | if n_tokens >= args.minimum_size: 204 | data_to_prepend = data 205 | else: 206 | discarded_files += 1 207 | 208 | # add tokenized files > chunk size to main array 209 | tokenized_files_array.extend(tokenized_files) 210 | 211 | if len(tokenized_files_array) >= args.files_per * write_every_n_files: # write every n files 212 | _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, 213 | output_dir=args.output_dir, out_name=args.name, 214 | start_no=tfrecord_count, process_no=process_no) 215 | pbar.update(_tfrecord_count - tfrecord_count) # update progress bar 216 | pbar.set_description( 217 | f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") 218 | tfrecord_count = _tfrecord_count 219 | tokenized_files_array = remainder if remainder is not None else [] # add remaining files to next chunk 220 | with open(checkpoint_path, "w") as checkpoint_file: 221 | checkpoint_file.write(f"{files_processed}, {tfrecord_count}") 222 | 223 | if len(tokenized_files_array) >= args.files_per: # also write at end 224 | _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, 225 | output_dir=args.output_dir, out_name=args.name, 226 | start_no=tfrecord_count, process_no=process_no) 227 | pbar.update(_tfrecord_count - tfrecord_count) 228 | pbar.set_description( 229 | f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") 230 | tfrecord_count = _tfrecord_count 231 | with open(checkpoint_path, "w") as checkpoint_file: 232 | checkpoint_file.write(f"{files_processed}, {tfrecord_count}") 233 | else: 234 | remainder = tokenized_files_array # add remaining to remainder 235 | 236 | if write_remainder: 237 | # write out the remaining files even if there's less than files_per 238 | write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name, 239 | start_no=tfrecord_count, write_remainder=True) 240 | 241 | successful_files = files_processed - discarded_files 242 | return {"discarded": discarded_files, "processed": files_processed, "successful": successful_files} 243 | 244 | 245 | def create_tfrecords_mp(files, args): 246 | files = split_list(files, len(files) // args.processes) 247 | with Pool(processes=args.processes) as pool: 248 | pbar = tqdm(pool.imap(create_tfrecords, zip(files, repeat(args), range(len(files))))) 249 | meta = {"discarded": 0, "processed": 0, "successful": 0} 250 | for results in pbar: 251 | pbar.update() 252 | for k, v in results.items(): 253 | meta[k] += v # update metadata 254 | return meta 255 | 256 | 257 | if __name__ == "__main__": 258 | os.makedirs(args.output_dir, exist_ok=True) # make output dir if it doesn't exist 259 | files = get_files(args.input_dir) 260 | args.chunk_size += 1 # we shift the data by 1 to the right for targets, so increment the chunk size here 261 | 262 | if args.processes == 0: 263 | args.processes = cpu_count() 264 | if args.processes > 1: 265 | results = create_tfrecords_mp(files, args) 266 | else: 267 | results = create_tfrecords((files, args, 0), display_pbar=True) 268 | print(results) 269 | -------------------------------------------------------------------------------- /model_fns.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import tensorflow.compat.v1 as tf 3 | from tensorflow.python.tpu import tpu_estimator 4 | import mesh_tensorflow.transformer as mtf_transformer 5 | from optimizers import get_optimizer 6 | from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params, 7 | get_batch_size, auto_layout, auto_layout_and_mesh_shape) 8 | from models.utils import biasmask_attn_weights 9 | from tensorflow.python.ops import resources 10 | from sample import sample_autoregressive 11 | from models.gpt2 import gpt2 12 | import math 13 | 14 | 15 | def model_fn(features, labels, mode, params): 16 | # Get global step 17 | global_step = tf.train.get_global_step() 18 | 19 | # Construct mtf graph + mesh from params 20 | graph = mtf.Graph() 21 | mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) 22 | layout_rules = mtf.convert_to_layout_rules(params["layout"]) 23 | 24 | # Mesh setup 25 | if params["use_tpu"]: 26 | var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) 27 | else: 28 | var_placer = None 29 | gpu_ids = params["gpu_ids"] 30 | mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( 31 | mesh_shape, layout_rules, gpu_ids) 32 | 33 | # Trainable variable precision 34 | # Store to checkpoints in master type, train in slice type, compute in activation type 35 | if params["precision"] == "bfloat16": 36 | variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, 37 | activation_dtype=tf.bfloat16) 38 | else: 39 | variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) 40 | 41 | # Build mtf mesh object 42 | mesh = mtf.Mesh(graph, "my_mesh", var_placer) 43 | 44 | # Build mtf_features & seq length dict for getting number of microbatches 45 | # We need to pack inputs into a dict to pass into serialize_training_step 46 | features_dict = {"inputs": features, "labels": labels} 47 | sequence_length_dict = {"inputs": params["n_ctx"], "labels": params["n_ctx"]} 48 | 49 | params = add_mode_to_params(params, mode) 50 | batch_size = get_batch_size(params) 51 | 52 | batch_dim = mtf.Dimension("batch", batch_size) 53 | batch_dims = [batch_dim] 54 | feature_length = sequence_length_dict["inputs"] 55 | length_dim = mtf.Dimension("sequence", feature_length) 56 | 57 | mtf_features = {} 58 | for key, x in features_dict.items(): 59 | if x is not None: 60 | feature_shape = mtf.Shape(batch_dims + [length_dim]) 61 | if type(features_dict[key]) == dict: 62 | features_dict[key] = features_dict[key]["feature"] 63 | x = tf.cast(features_dict[key], tf.int32) 64 | x = tf.reshape(x, feature_shape.to_integer_list) 65 | mtf_features[key] = mtf.import_fully_replicated( 66 | mesh, x, feature_shape, name=key) 67 | 68 | # Instantiate dict for dimensions, bias, etc that can be calculated here once then passed into model 69 | other_features = {} 70 | memory_length_dim = mtf.Dimension("memory_length", length_dim.size) 71 | 72 | attn_bias = biasmask_attn_weights(mesh, length_dim, memory_length_dim, variable_dtype) if params["causal"] else None 73 | 74 | # Add attn_bias into mtf_features 75 | other_features["attn_bias"] = attn_bias 76 | 77 | # Define other Dimensions that we'll need inside the model 78 | embd_dim = mtf.Dimension("embd", params["n_embd"]) 79 | vocab_dim = mtf.Dimension("vocab", params["n_vocab"]) 80 | # We need this because gathering when both the args have the same dimension in them breaks things 81 | # This dim is specifically for the weights 82 | # This prevents the "Einsum has lhs dimension without corresponding rhs or output dimension." error 83 | embed_sequence_dim = mtf.Dimension("embed_sequence", params["n_ctx"]) 84 | 85 | other_features["embd_dim"] = embd_dim 86 | other_features["vocab_dim"] = vocab_dim 87 | other_features["embed_sequence_dim"] = embed_sequence_dim 88 | other_features["memory_length_dim"] = memory_length_dim 89 | 90 | if mode == tf.estimator.ModeKeys.PREDICT: 91 | # Set up the model for prediction 92 | inputs = mtf_features["inputs"] 93 | if params["remove_partial_sequences"] is None: 94 | params["remove_partial_sequences"] = False 95 | 96 | export = params.get("export", False) 97 | 98 | if not export: 99 | mtf_samples = sample_autoregressive( 100 | inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, 101 | remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], 102 | sampling_use_entmax=params['sampling_use_entmax'], max_steps=params["predict_max_steps"]) 103 | 104 | else: 105 | with mtf.utils.outside_all_rewrites(): 106 | with tf.variable_scope('gpt2'): 107 | mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, 108 | variable_dtype=variable_dtype, context=None) 109 | 110 | mtf_samples = mtf.anonymize(mtf_samples) 111 | inputs = mtf.anonymize(inputs) 112 | lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) 113 | inputs = lowering.export_to_tf_tensor(inputs) 114 | outputs = lowering.export_to_tf_tensor(mtf_samples) 115 | predictions = { 116 | "inputs": inputs, 117 | "outputs": outputs} 118 | 119 | def scaffold_fn(): 120 | return tf.train.Scaffold( 121 | local_init_op=tf.group( 122 | tf.train.Scaffold.default_local_init_op(), 123 | lowering.copy_masters_to_slices(), 124 | name="mtf_local_init_op"), 125 | ready_op=tf.concat( 126 | [tf.report_uninitialized_variables(), 127 | resources.report_uninitialized_resources()], 128 | axis=0, 129 | name="mtf_ready_op")) 130 | 131 | return tpu_estimator.TPUEstimatorSpec( 132 | mode=tf.estimator.ModeKeys.PREDICT, 133 | predictions=predictions, 134 | scaffold_fn=scaffold_fn, 135 | prediction_hooks=[mtf.MtfRestoreHook(lowering)]) 136 | 137 | # We're not predicting, so we better be training or evaluating 138 | assert mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL] 139 | 140 | if mode == tf.estimator.ModeKeys.TRAIN: 141 | # Gets number of microbatches per batch for serialized training 142 | # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed 143 | num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim, 144 | sequence_length=sequence_length_dict, 145 | mesh_shape=mesh_shape, 146 | layout_rules=layout_rules, 147 | tokens_per_microbatch_per_replica= 148 | params["tokens_per_mb_per_replica"])) 149 | else: 150 | num_microbatches = 1 151 | 152 | params["num_microbatches"] = num_microbatches # Add num microbatches to params 153 | 154 | if num_microbatches > 1: 155 | 156 | # For serialize_training_step we need to modify the model to output results in a dict 157 | def serialized_fn(mtf_features): 158 | if params["model"] == "GPT": 159 | with tf.variable_scope('gpt2'): 160 | logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, 161 | variable_dtype=variable_dtype) 162 | return {"logits": logits, "loss": loss, "loss_batch": loss_batch} 163 | else: 164 | raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]") 165 | 166 | # Serialize the training step - Gradients are accumulated locally and reduced once. 167 | var_grads, output_dict = mtf.serialize_training_step(mtf_features, serialized_fn, batch_dim, num_microbatches) 168 | loss = output_dict["loss"] 169 | loss_batch = output_dict["loss_batch"] 170 | logits = output_dict["logits"] 171 | else: 172 | # If we're not splitting into microbatches, return logits & loss as is 173 | if params["model"] == "GPT": 174 | with mtf.utils.outside_all_rewrites(): 175 | with tf.variable_scope('gpt2'): 176 | logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, 177 | variable_dtype=variable_dtype, context=None) 178 | else: 179 | raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]") 180 | 181 | # Auto layout generation 182 | if params["auto_layout"]: 183 | auto_layout(graph, mesh_shape, logits, loss) 184 | if params["auto_layout_and_mesh_shape"]: 185 | auto_layout_and_mesh_shape(graph, params["num_cores"], logits, loss) 186 | 187 | if mode == tf.estimator.ModeKeys.TRAIN: 188 | # In TRAIN mode, get optimizer 189 | if params["num_microbatches"] > 1: 190 | # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn 191 | # So we pass them in here 192 | _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype, 193 | inp_var_grads=var_grads) 194 | else: 195 | # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank 196 | _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype) 197 | # Log summaries to tensorboard 198 | mtf.scalar_summary("loss", loss) 199 | # Log gradients if in params 200 | if params["log_grads"] not in [None, False]: 201 | for g in var_grads: 202 | grad_norm = mtf.sqrt(mtf.reduce_sum(mtf.square(g))) 203 | mtf.scalar_summary("grads/norm" + g.name[:-2], grad_norm) 204 | else: 205 | # For now, we can only export fully-replicated tensors. 206 | # This has to be done before lowering or they will not be included in the graph 207 | mean_logits = mtf.reduce_mean(logits, reduced_dim=vocab_dim) 208 | max_logits = mtf.argmax(logits, vocab_dim) 209 | del logits 210 | fully_replicated_mean_logits = mtf.anonymize(mean_logits) 211 | fully_replicated_max_logits = mtf.anonymize(max_logits) 212 | fully_replicated_loss_batch = mtf.anonymize(loss_batch) 213 | 214 | # Gets & prints info about no. trainable vars in the model & dimension names 215 | get_graph_info(graph) 216 | 217 | # 'lowers' mtf tensors into a tf graph - this enables us to export results as tf tensors 218 | lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=True) 219 | tf_loss = lowering.export_to_tf_tensor(loss) 220 | tf_loss = tf.cast(tf_loss, tf.float32) 221 | 222 | if mode == tf.estimator.ModeKeys.TRAIN: 223 | # Use our patched version until mtf updates theirs 224 | host_call = create_host_call(params['model_path']) 225 | mtf.utils.remove_summaries() 226 | 227 | # Creates train_op 228 | tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] 229 | tf_update_ops.append(tf.assign_add(global_step, 1)) # Need to manually increment global_step 230 | tf.logging.info(f"tf_update_ops: {tf_update_ops}") 231 | train_op = tf.group(tf_update_ops) 232 | else: 233 | tf_mean_logits = lowering.export_to_tf_tensor(fully_replicated_mean_logits) 234 | tf_max_logits = lowering.export_to_tf_tensor(fully_replicated_max_logits) 235 | tf_loss_batch = tf.to_float(lowering.export_to_tf_tensor(fully_replicated_loss_batch)) 236 | 237 | with mtf.utils.outside_all_rewrites(): 238 | # Copy master variables to slices. Must be called first. 239 | restore_hook = mtf.MtfRestoreHook(lowering) 240 | if mode == tf.estimator.ModeKeys.TRAIN: 241 | # Set up the checkpoint server and return the TPUEstimatorSpec 242 | saver = tf.train.Saver( 243 | tf.global_variables(), 244 | sharded=True, 245 | max_to_keep=10, 246 | keep_checkpoint_every_n_hours=2, 247 | defer_build=False, 248 | save_relative_paths=True) 249 | tf.add_to_collection(tf.GraphKeys.SAVERS, saver) 250 | saver_listener = mtf.MtfCheckpointSaverListener(lowering) 251 | saver_hook = tf.train.CheckpointSaverHook( 252 | params["model_path"], 253 | save_steps=params["steps_per_checkpoint"], 254 | saver=saver, 255 | listeners=[saver_listener]) 256 | 257 | return tpu_estimator.TPUEstimatorSpec( 258 | tf.estimator.ModeKeys.TRAIN, 259 | loss=tf_loss, 260 | host_call=host_call, 261 | train_op=train_op, 262 | training_hooks=[restore_hook, saver_hook]) 263 | 264 | elif mode == tf.estimator.ModeKeys.EVAL: 265 | # Evaluation metrics 266 | def _perplexity(loss): 267 | perplexity = tf.exp(loss) 268 | return tf.metrics.mean(perplexity) 269 | 270 | def _bits_per_byte(loss): 271 | bpb = loss * (0.29335 / math.log(2)) 272 | return tf.metrics.mean(bpb) 273 | 274 | def _metric_fn(tf_mean_logits, tf_loss_batch): 275 | mean_logits = tf.metrics.mean(tf_mean_logits) 276 | loss = tf.reduce_mean(tf_loss_batch) 277 | perp = _perplexity(loss) 278 | bpb = _bits_per_byte(loss) 279 | return {"mean_logits": mean_logits, "perplexity": perp, "bits per byte": bpb} 280 | 281 | def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): 282 | eos_token = params["eos_id"] 283 | answer_positions = tf.where(tf.math.not_equal(labels, eos_token)) 284 | 285 | correct_answers = tf.gather_nd(tf.math.equal(tf_max_logits, labels), answer_positions) 286 | accuracy = tf.metrics.mean(tf.cast(correct_answers, tf.float32)) 287 | 288 | # I guess tf_loss_batch has z_loss and maybe other stuff added to it 289 | # so maybe this should be calculated separately in the future 290 | answer_loss = tf.gather_nd(tf_loss_batch, answer_positions) 291 | log_perplexity = tf.metrics.mean(answer_loss) 292 | 293 | return {"lambada_acc": accuracy, "lambada_log_ppl": log_perplexity} 294 | 295 | eval_task = params["eval_task"] 296 | if eval_task == "lambada": 297 | eval_metrics = (_lambada_metric_fn, [labels, tf_max_logits, tf_loss_batch]) 298 | else: 299 | eval_metrics = (_metric_fn, [tf_mean_logits, tf_loss_batch]) 300 | 301 | return tpu_estimator.TPUEstimatorSpec( 302 | tf.estimator.ModeKeys.EVAL, 303 | evaluation_hooks=[restore_hook], 304 | loss=tf_loss, 305 | eval_metrics=eval_metrics) 306 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import tensorflow.compat.v1 as tf 3 | import math 4 | import mesh_tensorflow.transformer as mtf_transformer 5 | 6 | from models.activations import get_activation_fn 7 | 8 | 9 | # -------------------------------------------------------------------------------- 10 | # LAYERS: 11 | 12 | sentinel = object() 13 | 14 | 15 | def exists(x): 16 | return x is not None 17 | 18 | 19 | def identity(x, *args, **kwargs): 20 | return x 21 | 22 | 23 | def is_incremental_inference(context): 24 | return exists(context) and context.mode == "incremental" 25 | 26 | 27 | def norm(x, axis, epsilon=1e-8): 28 | x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u") 29 | s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s") 30 | return x * mtf.rsqrt(s + epsilon) 31 | 32 | 33 | def rezero(x, scope, dtype): 34 | with tf.variable_scope(scope): 35 | g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype) 36 | return x * g 37 | 38 | 39 | def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): 40 | if axis is sentinel: 41 | axis = x.shape[-1] 42 | 43 | with tf.variable_scope(scope): 44 | g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1), 45 | master_dtype=variable_dtype.master_dtype, 46 | slice_dtype=variable_dtype.slice_dtype, 47 | activation_dtype=variable_dtype.activation_dtype) 48 | 49 | x = norm(x, axis, epsilon) 50 | x = x * g 51 | return x 52 | 53 | 54 | def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): 55 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 56 | if axis is sentinel: 57 | axis = x.shape[-1] 58 | 59 | with tf.variable_scope(scope): 60 | n_state = x.shape[-1] 61 | 62 | g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1), 63 | master_dtype=variable_dtype.master_dtype, 64 | slice_dtype=variable_dtype.slice_dtype, 65 | activation_dtype=variable_dtype.activation_dtype) 66 | b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0), 67 | master_dtype=variable_dtype.master_dtype, 68 | slice_dtype=variable_dtype.slice_dtype, 69 | activation_dtype=variable_dtype.activation_dtype) 70 | 71 | x = norm(x, axis, epsilon) 72 | x = x * g + b 73 | return x 74 | 75 | 76 | def linear_attention(q, k, v): 77 | batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) 78 | q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") 79 | k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") 80 | 81 | dim_in = k.shape[-1] 82 | 83 | q = mtf.softmax(q, dim_in) 84 | k = mtf.softmax(k, seq_dim) 85 | 86 | context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out]) 87 | attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) 88 | return attn 89 | 90 | 91 | def causal_linear_attention(q, k, v, eps = 1e-6): 92 | batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) 93 | q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") 94 | k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") 95 | 96 | dim_in = k.shape[-1] 97 | 98 | q = mtf.softmax(q, dim_in) 99 | k = mtf.exp(k) 100 | 101 | cumulative_k = mtf.cumsum(k, seq_dim) + eps 102 | D_inv = 1. / mtf.einsum([q, cumulative_k], output_shape=[batch_dim, seq_dim, head_dim]) 103 | 104 | context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out]) 105 | cumulative_context = mtf.cumsum(context, seq_dim) 106 | 107 | attn = mtf.einsum([q, cumulative_context, D_inv], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) 108 | return attn 109 | 110 | 111 | def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False): 112 | # nf = number of features 113 | if params["scale_by_depth"] and scale: 114 | # Scale by sqrt(num_layers), only happens at the final projection before a res block output 115 | w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"])) 116 | if params["scale_by_in"]: # Scale by sqrt(num_input_features) 117 | w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) # Dimension is a namedtuple of (name, size) 118 | # Not in the variable_scope because mtf already has a variable_scope in it 119 | with tf.variable_scope("conv1d_main"): 120 | c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True, 121 | kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev), 122 | variable_dtype=variable_dtype, 123 | ) 124 | return c 125 | 126 | 127 | def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh): 128 | """memory / key values from all attention paper""" 129 | 130 | dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv) 131 | emb_dim = k.shape[-1] 132 | mem_std = 1 / math.sqrt(emb_dim.size) 133 | 134 | mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), 135 | initializer=tf.random_normal_initializer(stddev=mem_std), 136 | master_dtype=variable_dtype.master_dtype, 137 | slice_dtype=variable_dtype.slice_dtype, 138 | activation_dtype=variable_dtype.activation_dtype, 139 | ) 140 | mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), 141 | initializer=tf.random_normal_initializer(stddev=mem_std), 142 | master_dtype=variable_dtype.master_dtype, 143 | slice_dtype=variable_dtype.slice_dtype, 144 | activation_dtype=variable_dtype.activation_dtype) 145 | 146 | mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]), 147 | (mem_k, mem_v)) 148 | mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"), 149 | (mem_k, mem_v)) 150 | 151 | k = mtf.concat([mem_k, k], "sequence") 152 | v = mtf.concat([mem_v, v], "sequence") 153 | return k, v 154 | 155 | 156 | def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None, pos_emb=None): 157 | # x :: [batch, seq, n_embd] 158 | x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh 159 | 160 | # n_state is the same as config["n_embd"], which is also the same as dim_embd. 161 | assert n_state.size % params["n_head"] == 0 162 | 163 | dim_heads = mtf.Dimension("heads", params["n_head"]) 164 | 165 | num_mem_kv = params.get("num_mem_kv", 0) 166 | use_num_mem_kv = num_mem_kv > 0 167 | 168 | with tf.variable_scope(scope): 169 | # Compute attention inputs 170 | dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) 171 | mtfparams = mtf.transformer.attention.attention_params_simple( 172 | x.mesh, 173 | io_dim=dim_embd, 174 | kv_dim=dim_kv, 175 | heads_dim=dim_heads, 176 | variable_dtype=variable_dtype 177 | ) 178 | q = mtfparams.compute_q(x) 179 | k = mtfparams.compute_k(x) 180 | v = mtfparams.compute_v(x) 181 | 182 | if is_incremental_inference(context): 183 | one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) 184 | inv_one_hot = 1.0 - one_hot 185 | old_k, old_v = context.get_states(2) 186 | k = old_k * inv_one_hot + k * one_hot 187 | v = old_v * inv_one_hot + v * one_hot 188 | 189 | if exists(context): 190 | context.record_new_states([k, v]) 191 | 192 | if exists(pos_emb): 193 | cos, sin = pos_emb 194 | k = apply_rotary_emb(k, cos, sin) 195 | 196 | if is_incremental_inference(context): 197 | seq_dim = cos.shape.get_dim_by_name('sequence') 198 | cos = mtf.gather(cos, context.position - 1, seq_dim) 199 | sin = mtf.gather(sin, context.position - 1, seq_dim) 200 | 201 | q = apply_rotary_emb(q, cos, sin) 202 | 203 | with tf.variable_scope("attention"): 204 | if attention_type == "local": 205 | # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. 206 | radius = params.get("local_attention_radius", 256) 207 | 208 | if is_incremental_inference(context): 209 | q *= one_hot 210 | 211 | a = mtf_transformer.attention.local_attention_1d( 212 | q, k, v, 213 | length_dim=k.shape[1], 214 | key_dim=dim_kv, 215 | value_dim=dim_kv, 216 | radius=radius, 217 | length_dim_num_splits=1, 218 | fully_autoregressive=params["causal"], 219 | attention_kwargs={}, 220 | ) 221 | 222 | if is_incremental_inference(context): 223 | a = mtf.gather(a, context.position - 1, dim_seq) 224 | 225 | elif attention_type == "global": 226 | 227 | # TODO: pass in fake context 228 | # Broadcast mask bias across batch and heads 229 | if exists(bias): 230 | if not is_incremental_inference(context): 231 | broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) 232 | else: 233 | # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position 234 | bias = mtf.gather(bias, context.position - 1, dim_seq) 235 | broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) 236 | 237 | # memory key / values, from all-attention paper 238 | if use_num_mem_kv: 239 | k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) 240 | 241 | k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) 242 | v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) 243 | 244 | attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 245 | 246 | a = mtf_transformer.attention.attention( 247 | q, k, v, 248 | memory_length_dim=memory_length_dim, 249 | key_dim=dim_kv, 250 | value_dim=dim_kv, 251 | bias=broadcasted_bias, 252 | dropout_rate=attn_dropout_rate 253 | ) 254 | 255 | elif attention_type == "linear": 256 | linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention 257 | a = linear_attn_fn(q, k, v) 258 | 259 | else: 260 | raise NotImplementedError("Unknown attention type {}!".format(attention_type)) 261 | 262 | with tf.variable_scope("compute_output"): 263 | a = mtfparams.compute_output(a, x_shape) 264 | 265 | with tf.variable_scope("compute_output_bias"): 266 | b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), 267 | master_dtype=variable_dtype.master_dtype, 268 | slice_dtype=variable_dtype.slice_dtype, 269 | activation_dtype=variable_dtype.activation_dtype) 270 | a += b 271 | 272 | if params["mode"] == "train" and params["res_dropout"] > 0: 273 | a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") 274 | return a 275 | 276 | 277 | def mlp(x, scope, n_state, *, variable_dtype, params): 278 | activation_fn = get_activation_fn(params) 279 | with tf.variable_scope(scope): 280 | nx = x.shape[-1] 281 | h = activation_fn(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params)) 282 | h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) 283 | if params["mode"] == "train" and params["res_dropout"] > 0: 284 | h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") 285 | return h2 286 | 287 | 288 | def mlp_glu(x, scope, n_state, *, variable_dtype, params): 289 | activation_fn = get_activation_fn(params) 290 | with tf.variable_scope(scope): 291 | nx = x.shape[-1] 292 | h = linear(x, "c_fc", n_state, params=params) 293 | 294 | h, gate = mtf.split(h, h.shape[-1], 2) 295 | h *= activation_fn(gate) 296 | 297 | h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) 298 | if params["mode"] == "train" and params["res_dropout"] > 0: 299 | h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") 300 | return h2 301 | 302 | 303 | def axial_positional_emb(embd_dim, mesh, params, variable_dtype): 304 | # Use axial position encoding 305 | axial_dim_1, axial_dim_2 = params["axial_pos_emb"] 306 | 307 | axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2) 308 | dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))] 309 | 310 | axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), 311 | initializer=tf.random_normal_initializer(stddev=0.01), 312 | master_dtype=variable_dtype.master_dtype, 313 | slice_dtype=variable_dtype.slice_dtype, 314 | activation_dtype=variable_dtype.activation_dtype) 315 | 316 | axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), 317 | initializer=tf.random_normal_initializer(stddev=0.01), 318 | master_dtype=variable_dtype.master_dtype, 319 | slice_dtype=variable_dtype.slice_dtype, 320 | activation_dtype=variable_dtype.activation_dtype) 321 | 322 | axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), 323 | (axial_wpe_1, axial_wpe_2)) 324 | wpe = (axial_wpe_1 + axial_wpe_2) / 2 325 | 326 | wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) 327 | 328 | return wpe 329 | 330 | def rotary_positional_emb(mesh, sequence_dim, params, variable_dtype): 331 | dtype = variable_dtype.master_dtype 332 | dim_head = params["n_embd"] // params["n_head"] 333 | 334 | dim_head = mtf.Dimension("features_per_head", dim_head) 335 | half_dim_head = mtf.Dimension("half_features_per_head", dim_head.size // 2) 336 | 337 | dim_range = mtf.range(mesh, half_dim_head, dtype) * 2 / dim_head.size 338 | half_freqs = 1. / mtf.pow(mtf.constant(mesh, 10000, dtype = dtype), dim_range) 339 | 340 | seq = mtf.range(mesh, sequence_dim, dtype) 341 | half_freqs = mtf.einsum([half_freqs, seq], [sequence_dim, half_dim_head]) 342 | 343 | freqs = mtf.concat((half_freqs, half_freqs), half_dim_head.name) 344 | freqs = mtf.rename_dimension(freqs, half_dim_head.name, dim_head.name) 345 | return mtf.cos(freqs), mtf.sin(freqs) 346 | 347 | def rotate_half(x): 348 | dim_head_name = "features_per_head" 349 | dim_head = x.shape.get_dim_by_name(dim_head_name) 350 | half_dim_head_size = dim_head.size // 2 351 | x1 = mtf.slice(x, 0, half_dim_head_size, dim_head_name) 352 | x2 = mtf.slice(x, half_dim_head_size, half_dim_head_size, dim_head_name) 353 | return mtf.concat((-x2, x1), dim_head.name) 354 | 355 | def apply_rotary_emb(x, cos, sin): 356 | rotated_x = rotate_half(x) 357 | return x * cos + rotated_x * sin 358 | -------------------------------------------------------------------------------- /inputs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.compat.v1 as tf 3 | from functools import partial 4 | from data.encoders import encode 5 | import random 6 | import re 7 | import logging 8 | from itertools import cycle 9 | from utils import natural_sort 10 | 11 | 12 | ### IN USE ### 13 | 14 | def _get_number_of_documents(filename): 15 | # extracts number of files from a filename formatted "_.tfrecords." 16 | # if no pattern is matched, returns None 17 | match = re.search("_(\d{1,}).tfrecords$", filename) 18 | return int(match.group(1)) if match is not None else match 19 | 20 | 21 | def _get_number_of_documents_by_iteration(filename): 22 | # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename 23 | # this could be very slow. 24 | logging.warning( 25 | "inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length") 26 | count = 0 27 | for item in tf.io.tf_record_iterator(filename): 28 | count += 1 29 | return count 30 | 31 | 32 | def _get_skip_index(all_files, n_batches): 33 | prev_cumsum = 0 34 | cumsum = 0 35 | global_n_documents = None 36 | for count, f in cycle(enumerate(all_files)): 37 | prev_cumsum = cumsum 38 | if _get_number_of_documents(f) is not None: 39 | cumsum += _get_number_of_documents(f) 40 | elif global_n_documents is None: 41 | global_n_documents = _get_number_of_documents_by_iteration(f) 42 | cumsum += global_n_documents 43 | else: 44 | cumsum += global_n_documents 45 | if cumsum == n_batches: 46 | remainder = 0 47 | skip_idx = count + 1 48 | elif cumsum > n_batches: 49 | remainder = n_batches - prev_cumsum 50 | skip_idx = count 51 | break 52 | return skip_idx, remainder 53 | 54 | 55 | def _parse_function(example_proto): 56 | features = { 57 | "text": tf.VarLenFeature(tf.int64) 58 | } 59 | parsed_features = tf.parse_single_example(example_proto, features) 60 | return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0]) 61 | 62 | 63 | def autoregressive_sample_text(params, x): 64 | vals1 = x[:params["n_ctx"]] 65 | vals2 = x[1:params["n_ctx"] + 1] 66 | 67 | vals1 = tf.reshape(vals1, [params["n_ctx"]]) 68 | vals2 = tf.reshape(vals2, [params["n_ctx"]]) 69 | vals1 = tf.cast(vals1, dtype=tf.int32) 70 | vals2 = tf.cast(vals2, dtype=tf.int32) 71 | return vals1, vals2 72 | 73 | 74 | def sequential_input(params, global_step=None, eval=False): 75 | """ 76 | Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either: 77 | 78 | - has the number of documents for each tfrecord file encoded in the title in the format 79 | _.tfrecords. 80 | 81 | OR 82 | 83 | - has a fixed number of documents per tfrecord file. 84 | 85 | If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read. 86 | If this isn't the case, it may result in errors, or some samples being missed. 87 | 88 | This means we can calculate the number of samples we've seen so far using the global step, 89 | and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient. 90 | 91 | If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model 92 | performance, as it results in less repeated data. 93 | """ 94 | if not eval: 95 | assert global_step is not None 96 | logging.warning( 97 | "Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.") 98 | batch_size = params['eval_batch_size' if eval else 'train_batch_size'] 99 | 100 | filenames = [] 101 | for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params 102 | path_key = 'path' if not eval else 'eval_path' 103 | path = dataset_config[path_key] 104 | filenames.extend( 105 | tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs 106 | 107 | filenames = natural_sort(filenames) 108 | shuffle_filenames = params.get("shuffle_input_filenames", True) 109 | if shuffle_filenames: 110 | seed = params.get('seed', 1) # shuffle deterministically 111 | random.seed(seed) 112 | random.shuffle(filenames) 113 | 114 | dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity 115 | 116 | if not eval: 117 | # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files 118 | skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[ 119 | "train_batch_size"]) # TODO: fix for > 1 epoch 120 | dataset = dataset.skip(skip_idx) # skip to skip idx 121 | 122 | # read tfrecord examples and skip remainder 123 | dataset = dataset.apply(tf.data.TFRecordDataset) 124 | dataset = dataset.skip(remainder) 125 | else: 126 | # shuffle filenames if in eval mode 127 | dataset = dataset.shuffle(len(filenames)) 128 | dataset = dataset.apply(tf.data.TFRecordDataset) 129 | 130 | # parse the tokenized data from the tfrecord files and shuffle 131 | dataset = dataset.map(_parse_function, num_parallel_calls=1) 132 | dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1) 133 | 134 | # batch data and repeat to infinity 135 | dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2) 136 | return dataset.repeat() 137 | 138 | 139 | def pred_input(params, logger, enc=None, 140 | path_to_prompt=""): 141 | unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ 142 | "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ 143 | "researchers was the fact that the unicorns spoke perfect English." 144 | 145 | text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read() 146 | tokens = encode(enc, text) 147 | 148 | if len(tokens) > params["n_ctx"]: 149 | logger.info("The length of your input prompt is longer than the model's context length - truncating input.") 150 | tokens = tokens[len(tokens) - params["n_ctx"]:] 151 | if len(tokens) < params["n_ctx"]: 152 | tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"]) 153 | t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]]) 154 | dataset = tf.data.Dataset.from_tensors(t) 155 | 156 | def _dummy_labels(x): 157 | return x, x 158 | 159 | dataset = dataset.map(_dummy_labels) 160 | return dataset 161 | 162 | 163 | def handle_pred_output(predictions, logger, enc, params, out_name="test"): 164 | with tf.gfile.Open(f"{out_name}.txt", "w") as f: 165 | for i, p in enumerate(predictions): 166 | p = p["outputs"] 167 | 168 | # remove eos + padding ids from output 169 | idx = np.argmax(p == params['eos_id']) 170 | if idx > 0: 171 | p = p[:idx] 172 | idx = np.argmax(p == params['padding_id']) 173 | if idx > 0: 174 | p = p[:idx] 175 | 176 | text = enc.decode(p) 177 | f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n") 178 | f.write(text) 179 | f.write("\n" + "=" * 80 + "\n") 180 | 181 | logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n") 182 | logger.info(text) 183 | logger.info("\n" + "=" * 80 + "\n") 184 | 185 | 186 | ### DEPRECATED ### 187 | 188 | def generic_text(params, eval=False, sample_text_fn=None, **kwargs): 189 | logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.") 190 | i = 0 if not eval else 1 191 | 192 | weights = [] 193 | datasets = [] 194 | 195 | for dataset in params["datasets"]: 196 | dataset_id, stitch, datatype, weight = dataset 197 | 198 | assert dataset_id in params[ 199 | 'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration' 200 | dataset_config = params['dataset_configs'][dataset_id] 201 | 202 | path_key = 'path' if not eval else 'eval_path' 203 | path = dataset_config[path_key] 204 | 205 | datasets.append(text_dataset( 206 | tf.io.gfile.glob(path), 207 | params, 208 | stitch=stitch, 209 | datatype=datatype, 210 | batch=False, 211 | sample_text_fn=sample_text_fn 212 | )) 213 | 214 | weights.append(weight) 215 | 216 | batch_size = params['eval_batch_size' if eval else 'train_batch_size'] 217 | 218 | seed = params.get('seed', None) 219 | dataset = tf.data.experimental.sample_from_datasets(datasets, weights=weights, seed=seed) 220 | dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2) 221 | return dataset 222 | 223 | 224 | def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None): 225 | seed = params.get('seed', None) 226 | deterministic = seed is not None 227 | num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE 228 | 229 | dataset = tf.data.Dataset.from_tensor_slices(files) 230 | 231 | if deterministic: 232 | dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4) 233 | else: 234 | dataset = dataset.apply( 235 | tf.data.experimental.parallel_interleave(tf.data.TFRecordDataset, cycle_length=4, sloppy=False)) 236 | 237 | if "documents" in datatype: 238 | def _parse_function(example_proto): 239 | features = { 240 | # "hash": tf.VarLenFeature(tf.string), 241 | "text": tf.VarLenFeature(tf.int64) 242 | } 243 | parsed_features = tf.parse_single_example(example_proto, features) 244 | return parsed_features["text"], parsed_features["text"].dense_shape[0] 245 | else: 246 | def _parse_function(example_proto): 247 | features = { 248 | "text": tf.VarLenFeature(tf.int64) 249 | } 250 | parsed_features = tf.parse_single_example(example_proto, features) 251 | return parsed_features["text"] # Assuming the text is not sparse 252 | 253 | dataset = dataset.map(_parse_function, num_parallel_calls=1) 254 | 255 | # Subsample method 256 | if "documents" in datatype: 257 | # Since samples can be less than the correct length, and TPUs don't like variable lengths, this function stitches together enough samples 258 | # to have a text at least 1024 tokens long. For this to work the stitch parameter must be correctly tuned so that 259 | # stitch * min(characters_in_text) >= amount 260 | def _stitch_text(x, y): 261 | x = tf.sparse.to_dense(x) 262 | 263 | def _get_x(i): 264 | return tf.gather(x[i], tf.range(y[i])) 265 | 266 | out = _get_x(0) 267 | eos_id = params['eos_id'] 268 | 269 | for i in range(1, stitch): 270 | out = tf.concat([out, [eos_id], _get_x(i)], axis=0) # text1<|endoftext|>text2 271 | 272 | return out 273 | 274 | # Hack-y way to stitch together multiple texts 275 | 276 | dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text, 277 | num_parallel_calls=num_parallel_calls) 278 | 279 | # Sample 1024(+1) tokens from the stitched together text 280 | is_random_documents = datatype == "documents_random" 281 | if sample_text_fn is not None: 282 | _sample_text = partial(sample_text_fn, random_documents=is_random_documents) 283 | else: 284 | _sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text 285 | _sample_text = partial(_sample_text, params) 286 | 287 | dataset = dataset.map(_sample_text, num_parallel_calls=num_parallel_calls) 288 | 289 | if batch: 290 | dataset = dataset.batch(params["train_batch_size"], drop_remainder=True).prefetch(params["iterations"] * 2) 291 | 292 | dataset = dataset.repeat() 293 | 294 | return dataset 295 | 296 | 297 | def autoregressive_sample_text_random_documents(params, x): 298 | seed = params.get('seed', None) 299 | s = tf.size(x) 300 | r = tf.random.uniform([], maxval=s - (params["n_ctx"] + 1), dtype=tf.dtypes.int32, seed=seed) 301 | r1 = tf.range(r, r + params["n_ctx"]) 302 | r2 = tf.range(r + 1, (r + 1) + params["n_ctx"]) 303 | r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy 304 | r2 = tf.reshape(r2, [params[ 305 | "n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input 306 | vals1 = tf.gather(x, r1) 307 | vals2 = tf.gather(x, r2) 308 | 309 | vals1 = tf.reshape(vals1, [params["n_ctx"]]) 310 | vals2 = tf.reshape(vals2, [params["n_ctx"]]) 311 | vals1 = tf.cast(vals1, dtype=tf.int32) 312 | vals2 = tf.cast(vals2, dtype=tf.int32) 313 | return vals1, vals2 314 | 315 | 316 | def mlm_sample_text(params, x, random_documents=False): 317 | seed = params.get('seed', None) 318 | ctx_len = params["n_ctx"] 319 | assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token' 320 | 321 | mask_id = params['mlm_mask_id'] 322 | cls_token_id = params.get('mlm_cls_token_id', None) 323 | num_tokens = params.get('n_vocab', None) 324 | 325 | mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', [])) 326 | mask_ignore_ids.add(cls_token_id) 327 | 328 | mask_prob = params.get('mlm_mask_prob', 0.15) 329 | same_token_prob = params.get('mlm_same_token_prob', 0.10) 330 | random_token_prob = params.get('mlm_random_token_prob', 0.) 331 | 332 | seq_len = ctx_len if cls_token_id is None else (ctx_len - 1) 333 | 334 | if random_documents: 335 | s = tf.size(x) 336 | r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed) 337 | r1 = tf.range(r, r + seq_len) 338 | r1 = tf.reshape(r1, [seq_len]) 339 | features = tf.gather(x, r1) 340 | else: 341 | features = x[:seq_len] 342 | 343 | # add cls token id if specified by `mlm_cls_token_id` 344 | if cls_token_id is not None: 345 | features = tf.pad(features, [[1, 0]], constant_values=cls_token_id) 346 | 347 | features = tf.cast(features, dtype=tf.int32) 348 | shape = features.shape 349 | 350 | # determine which tokens are mask-able 351 | can_mask = tf.not_equal(features, 0) 352 | for ignore_id in mask_ignore_ids: 353 | can_mask &= tf.not_equal(features, ignore_id) 354 | 355 | # generate boolean mask for masking ids 356 | mask_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob) 357 | mask_mask &= can_mask 358 | 359 | # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same 360 | replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), 361 | 1 - same_token_prob) 362 | 363 | # randomly replace some tokens with random tokens before masking 364 | if random_token_prob > 0: 365 | random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), 366 | random_token_prob) 367 | random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed) 368 | 369 | # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids` 370 | random_can_mask = tf.not_equal(random_tokens, 0) 371 | for ignore_id in mask_ignore_ids: 372 | random_can_mask &= tf.not_equal(random_tokens, ignore_id) 373 | 374 | features = tf.where(random_token_mask & random_can_mask, random_tokens, features) 375 | 376 | # mask the tokens 377 | mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id 378 | masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features) 379 | 380 | # labels will be set to 0 for all non-masked tokens 381 | labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features) 382 | 383 | masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels)) 384 | return masked_features, labels 385 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT Neo 2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5297715.svg)](https://doi.org/10.5281/zenodo.5297715) [![arXiv](https://img.shields.io/badge/arXiv-2101.00027-f9f107.svg)](https://arxiv.org/abs/2101.00027) 4 | 5 | **As of August, 2021 code is no longer maintained. It is preserved here in archival form for people who wish to continue to use it.* 6 | 7 | 🎉 1T or bust my dudes 🎉 8 | 9 | An implementation of model & data parallel [GPT3](https://arxiv.org/abs/2005.14165)-like models using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library. 10 | 11 | **If you're just here to play with our pre-trained models, we strongly recommend you try out the [HuggingFace Transformer integration](https://huggingface.co/EleutherAI).** 12 | 13 | Training and inference is officially supported on TPU and should work on GPU as well. This repository will be (mostly) archived as we move focus to our GPU-specific repo, [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/). 14 | 15 | In addition to the functionality offered by GPT-3, we also offer the following: 16 | * [Local attention](https://arxiv.org/abs/2004.05150) 17 | * [Linear attention](https://arxiv.org/abs/1812.01243) 18 | * [Mixture of Experts](https://arxiv.org/abs/1701.06538) 19 | * [Axial Positional embedding](https://arxiv.org/abs/1912.12180) 20 | 21 | NB, while neo can *technically* run a training step at 200B+ parameters, it is very inefficient at those scales. This, as well as the fact that many GPUs became available to us, among other things, prompted us to move development over to [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/). 22 | 23 | # Pretrained Models 24 | 25 | **Update 21/03/2021:** 26 | 27 | We're proud to release two pretrained GPT-Neo models trained on The Pile, the weights and configs can be freely downloaded from [the-eye.eu](https://the-eye.eu/public/AI/gptneo-release/). 28 | 29 | 1.3B: https://mystic.the-eye.eu/public/AI/gptneo-release/GPT3_XL/ 30 | 31 | 2.7B: https://mystic.the-eye.eu/public/AI/gptneo-release/GPT3_2-7B/ 32 | 33 | For more information on how to get these set up, see the colab notebook, or read through the rest of the readme. 34 | 35 | ## Model Evaluations 36 | 37 | #### Linguistic Reasoning 38 | 39 | | Model and Size | Pile BPB | Pile PPL | Wikitext PPL | Lambada PPL | Lambada Acc | Winogrande | Hellaswag | 40 | |------------------|------------|-----------|--------------|-------------|-------------|------------|------------| 41 | | **GPT-Neo 125M** | ----- | ----- | **32.285** | **30.266** | **37.36%** | **50.43%** | **28.67%** | 42 | | GPT-3 125M | ----- | ----- | ----- | 18.6 | 42.7% | 52.0% | 33.7% | 43 | | **GPT-Neo 350M** | ----- | ----- | **22.5657** | **13.876** | **47.27%** | **51.14%** | **32.16%** | 44 | | GPT-3 350M | ----- | ----- | ----- | 9.09 | 54.3% | 52.1% | 43.6% | 45 | | GPT-3 Ada | 0.9631 | ----- | ----- | 9.954 | 51.60% | 52.90% | 35.93% | 46 | | **GPT-Neo 1.3B** | **0.7527** | **6.159** | **13.10** | **7.498** | **57.23%** | **55.01%** | **38.66%** | 47 | | GPT-3 1.3B | ----- | ----- | ----- | 5.44 | 63.6% | 58.7% | 54.7% | 48 | | GPT-2 1.5B | 1.0468 | ----- | 17.48 | 10.634 | 51.21% | 59.40% | 40.03% | 49 | | **GPT-Neo 2.7B** | **0.7165** | **5.646** | **11.39** | **5.626** | **62.22%** | **56.50%** | **42.73%** | 50 | | GPT-3 2.7B | ----- | ----- | ----- | 4.60 | 67.1% | 62.3% | 62.8% | 51 | 52 | 53 | #### Physical and Scientific Reasoning 54 | 55 | | Model and Size | MathQA | PubMedQA | Piqa | 56 | |------------------|------------|------------|------------| 57 | | **GPT-Neo 125M** | **22.78%** | **55.10%** | **63.06%** | 58 | | GPT-3 125M | ----- | ----- | 64.6% | 59 | | **GPT-Neo 350M** | **23.45%** | **53.80%** | **65.07%** | 60 | | GPT-3 350M | ----- | ----- | 70.2% | 61 | | GPT-3 Ada | 24.29% | 52.80% | 68.88% | 62 | | **GPT-Neo 1.3B** | **24.05%** | **54.40%** | **71.11%** | 63 | | GPT-3 1.3B | ----- | ----- | 75.1% | 64 | | GPT-2 1.5B | 23.64% | 58.33% | 70.78% | 65 | | **GPT-Neo 2.7B** | **24.72%** | **57.54%** | **72.14%** | 66 | | GPT-3 2.7B | ----- | ----- | 75.6% | 67 | 68 | 69 | **Note:** All evaluations were done using our [evaluation harness](https://github.com/EleutherAI/lm-evaluation-harness). Some results for GPT-2 and GPT-3 are inconsistent with the values reported in the respective papers. We are currently looking into why, and would greatly appreciate feedback and further testing of our eval harness. 70 | 71 | # Setup 72 | 73 | ```bash 74 | git clone https://github.com/EleutherAI/GPTNeo 75 | cd GPTNeo 76 | pip3 install -r requirements.txt 77 | ``` 78 | # Training Setup 79 | 80 | ## TPUs: 81 | 82 | Sign up for [Google Cloud Platform](https://cloud.google.com/), and create a [storage bucket](https://cloud.google.com/storage). 83 | 84 | Create your VM through a google shell (`https://ssh.cloud.google.com/`) with `ctpu up --vm-only` so that it can connect to your Google bucket and TPUs and install the requirements with pip (see above). 85 | 86 | Google colab provides tpu-v8s for free, which should be enough to finetune our models up to GPT3XL (1.5B parameter) sizes. 87 | Click [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EleutherAI/GPTNeo/blob/master/GPTNeo_example_notebook.ipynb) to run through our example colab notebook. 88 | 89 | For more detailed instructions, run through our [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below. 90 | 91 | ## GPUs: 92 | 93 | You can also choose to train GPTNeo locally on your GPUs. To do so, you can omit the Google cloud setup steps above, and git clone the repo locally. Run through the [Training Guide](https://github.com/EleutherAI/GPTNeo#training-guide) below, then when running main.py, you simply have to omit the `tpu` flag, and pass in GPU ids instead. 94 | 95 | Note: Some users have reported having difficulty getting MTF to recognize their GPUs. See [here](https://github.com/EleutherAI/gpt-neo/issues/150) for details and instructions on how to fix it. 96 | 97 | # Generating Text 98 | 99 | Once you have a trained model, or you've downloaded one of our pre-trained models, generating text is as simple as running the main.py script with the `--predict` flag on. You can pass a path to your prompt txt file with the `--prompt` flag, like so: 100 | 101 | ```bash 102 | python3 main.py --predict --prompt --tpu --model 103 | ``` 104 | 105 | or, if using GPUs: 106 | 107 | ```bash 108 | python3 main.py --predict --prompt --gpu_ids --model 109 | ``` 110 | 111 | # Training Guide 112 | 113 | ## 1. Create your Tokenizer (OPTIONAL) 114 | 115 | We recommend you use [Huggingface's pretrained GPT2 tokenizer](https://huggingface.co/transformers/model_doc/gpt2.html#transformers.GPT2Tokenizer) with our repo (instructions provided below), but if you want to train a model with a different vocabulary size, we provide facilities to train your own tokenizer like so: 116 | 117 | ```bash 118 | python data/train_tokenizer.py \ 119 | --base_dir ./path/to/your/txt/files \ 120 | --output_dir ./output/path \ 121 | --file_type txt \ 122 | --vocab_size 50257 123 | 124 | # if it succeeded, you should see the message 125 | # 'tokenizer saved at ./output/path/byte-level-bpe.tokenizer.json' 126 | ``` 127 | 128 | ## 2. Tokenizing your Dataset 129 | 130 | If you just want to test training, you can skip this step and download some dummy data like so: 131 | 132 | ``` 133 | wget https://storage.googleapis.com/connors-datasets/bundestag/bundestag_0.tfrecords 134 | ``` 135 | 136 | Then copy the data to your bucket, or if using GPUs, a local directory: 137 | 138 | ``` 139 | gsutil cp bundestag_0.tfrecords gs:/// 140 | ``` 141 | 142 | If using your own data to train, you can use the `data/create_tfrecords.py` script to encode your text data into tfrecords. 143 | 144 | Your data must either be in the form of lots of normal .txt files (one document per file), or in any format supported by [lm_dataformat](https://github.com/leogao2/lm_dataformat). 145 | 146 | You can run the script without parameters to see help for all options. 147 | 148 | In **document mode** Each example in the tfrecords is one (variably sized) document. This is to be used with the `documents_fixed` and `documents_random` sampling modes (For more details see the parameters reference section). 149 | Document mode is the default mode. 150 | 151 | The below command will tokenize all files in acceptable formats in *base_dir* using gpt2 tokenizer and save them to *output_dir* 152 | ``` 153 | python3 create_tfrecords.py --mode documents --input_dir --name --output_dir --use_gpt2_tokenizer --minimum_size 154 | ``` 155 | 156 | - `input_dir`: Defines the folder where your data is located. The script will encode all files present in this folder. 157 | - `name`: Name of output files will be `name_i.tfrecords` where i is the number of the file. 158 | - `output_dir`: Where to save the tfrecords to 159 | - `use_gpt2_tokenizer`: Whether to use the pretrained HuggingFace GPT2 tokenizer, in which case the separator will be set to [50256]. 160 | - `encoder_path`: if not using the pretrained gpt2 tokenizer, use this flag to provide a path to your generated tokenizer json. 161 | - `separator`: Written in list format, the separator token(s) to insert between documents (e.g. "[0]"). Will depend on your encoder. 162 | - `minimum_size`: The minimum size (in tokens) a document must have, otherwise it is discarded. This is what will later determine your `stitch` parameter: `stitch * minimum_size` must always be greater or equal `n_ctx` (For more details see the parameters reference section). 163 | 164 | ## 4. Using a Dataset in a Model 165 | 166 | To use a dataset in a model, you must first register that dataset under `./configs/dataset_configs` folder. First choose a filename with a `.json` extension. That filename will serve as the dataset identification. The config should be filled out the following manner. 167 | 168 | If you have a dataset encoded using the pretrained gpt2 tokenizer, you can specify that like so: 169 | 170 | ```json 171 | { 172 | "n_vocab": 50257, 173 | "path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords", 174 | "eval_path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords", 175 | "tokenizer_is_pretrained": true, 176 | "tokenizer_path": "gpt2" 177 | } 178 | ``` 179 | 180 | or if you've trained a custom tokenizer, like so: 181 | 182 | ```json 183 | { 184 | "n_vocab": 32768, 185 | "path": "./path/to/your/*.tfrecords", 186 | "eval_path": "./path/to/your/eval/*.tfrecords", 187 | "tokenizer_path": "./path/to/your/byte-level-bpe.tokenizer.json" 188 | } 189 | ``` 190 | 191 | Finally, in your model config, add the filename that you created above to the `datasets` array. 192 | 193 | The `` will be the filename, excluding the `.json`, that you created above 194 | 195 | ``` 196 | "datasets": [[, , , ]] # datasets key defines at run time how each dataset is processed for training 197 | ``` 198 | 199 | ## 5. Choose a model configuration 200 | 201 | Once you have your datasets set up, find a suitable config in `/configs`. 202 | 203 | Here we use a GPT3-XL sized model as an example, but there are many more in `./configs`, all of which have short summaries in the Available Configs section. 204 | 205 | All you need to do is edit the dataset id as described above, and edit `model_path` (where logs and checkpoints will be saved) to point to a cloud bucket you have write access to (or local path, if using GPUs). 206 | 207 | ```json 208 | { 209 | "n_head": 32, 210 | "n_vocab": 50257, 211 | "embed_dropout": 0.1, 212 | "lr": 0.0002, 213 | "lr_decay": "cosine", 214 | "warmup_steps": 3000, 215 | "beta1": 0.9, 216 | "beta2": 0.95, 217 | "epsilon": 1e-8, 218 | "opt_name": "adam", 219 | "weight_decay": 0.1, 220 | "train_batch_size": 512, 221 | "attn_dropout": 0.1, 222 | "train_steps": 286150, 223 | "eval_steps": 0, 224 | "predict_steps": 1, 225 | "res_dropout": 0.1, 226 | "eval_batch_size": 128, 227 | "predict_batch_size": 1, 228 | "iterations": 2500, 229 | "n_embd": 2048, 230 | "datasets": [["your_dataset_name", 25, "documents_random", 1.0]], 231 | "model_path": "gs://neo-models/GPT3_XL", 232 | "n_ctx": 2048, 233 | "n_layer": 24, 234 | "scale_by_depth": true, 235 | "scale_by_in": false, 236 | "attention_types" : [[["global"],24]], 237 | "mesh_shape": "x:128,y:2", 238 | "layout": "batch:x,memory_length:y,embd:y", 239 | "activation_function": "gelu", 240 | "recompute_grad": true, 241 | "gradient_clipping": 1.0, 242 | "tokens_per_mb_per_replica": 2048 243 | } 244 | ``` 245 | 246 | 247 | ## 6. Run Training 248 | 249 | ``` 250 | python3 main.py --model --steps_per_checkpoint --tpu 251 | ``` 252 | 253 | - `tpu`: Name of the TPU to use. 254 | - `steps_per_checkpoint`: The frequency in steps at which to save checkpoints. 255 | - `--auto_layout` and `--auto_layout_and_mesh_shape` (Optional): Disable training and instead auto generate a memory efficient `layout` (and `mesh_shape`) 256 | - `gpu_ids`: if training using GPUs, omit the `tpu` flag and pass in the ids of your gpus. In the example below, we train on 3 GPUs, specifying their device ids delimited by spaces: 257 | 258 | ``` 259 | python3 main.py --model --steps_per_checkpoint --gpu_ids 260 | ``` 261 | 262 | # Available Configs 263 | 264 | We have several model sizes available, but some of our configs require large TPUs and will need tweaking to run on smaller machines, or GPUs. Below is a short guide to each model in the configs directory: 265 | 266 | TODO 267 | 268 | # Extra Features: 269 | 270 | ## Training (with Sacred) 271 | 272 | [Sacred](https://github.com/IDSIA/sacred) helps track experiments and is much nicer to work with than tensorboard. 273 | 274 | To setup: 275 | 276 | 1. Install Docker and Docker-compose 277 | 278 | 2. Run `docker-compose up` 279 | 280 | To use: 281 | 282 | 1. Ensure model_dir doesn't have any metric logs in it (it trips up the metric stuff for tensorboard, which assumes that it's a continuation of the existing run). You can use `gsutil rm -r ...` to delete model dir 283 | 284 | 2. Run `python3 run_experiment.py --tpu sometpuhere --model someconfig.json` Options are the same as `main.py`. 285 | 286 | 3. You can go to http://server_ip_goes_here:8081/ to see the Omniboard overview. If you prefer to see a tensorboard, the script also spins one up and automatically assigns it a port. The script should print out the tensorboard port near the top of the log. 287 | 288 | ## Peeking at a Dataset 289 | 290 | If you are ever confused by the dataset of a particular config file, you can easily check the minimum and maximum token ids with a single command. This is useful for making sure that the vocabulary size of the model is at least as large as the maximum token id. Tensorflow will not error if you try to gather on a matrix with out of bounds indices, so you need to make sure your vocabulary size is sufficiently large. 291 | 292 | ```bash 293 | python main --model {config_name} --check_dataset 294 | ``` 295 | 296 | ## Masked Language Modeling 297 | 298 | In addition to being able to train large GPT's, this repository also allows you to easily do masked language modeling (BERT, RoBERTa). In order to do so, you must follow two additional steps. 299 | 300 | 1. When tokenizing your dataset, you must reserve a special id for the `[mask]` token. 301 | 302 | 2. In the configs, you will have to define two additional fields 303 | 304 | ```python 305 | "mlm_training": true, # must be set to true 306 | "mlm_mask_id": # the mask id that you reserved from above 307 | ``` 308 | 309 | That's all you need to train a model with the MLM objective, good for any type of data that you have encoded properly. If you would like to tweak the other related hyperparameters, please continue reading. 310 | 311 | ```python 312 | "mlm_cls_token_id": , # auto append specified CLS token id on the left 313 | "mlm_mask_prob": 0.15, # the probability of masking a token, defaults to 15% 314 | "mlm_same_token_prob": 0.10, # probability of keeping the token the same, defaults to 10% 315 | "mlm_random_token_prob": 0.10, # probability of tokens that are replaced with random tokens, 10% was recommended by the BERT paper 316 | "mlm_mask_ignore_ids": [, ] # ignore masking other special tokens, if any 317 | ``` 318 | 319 | ## Parameter Reference 320 | 321 | Pick a valid config from `/configs` and tweak the parameters as needed: 322 | 323 | - `n_heads`: The number of attention heads. 324 | - `n_embd`: Size of the hidden layers, must be divisible by `n_heads`. 325 | - `n_vocab`: Vocabulary size. 326 | - `embed_dropout`, `res_dropout`, `attn_dropout`: Dropout probability for word embedding/residuals/attention 327 | - `lr`: Learning rate 328 | - `warmup_steps`: Number of steps before full learning rate is reached (linear ramp from `0` to `lr`). 329 | - `lr_decay`: `cosine` or `linear`. 330 | - `opt_name`: `adam` or `adafactor`. 331 | - `beta1`, `beta2` and `epsilon`: `adam` optimizer params. 332 | - `beta1`, `ada_epsilon1` and `ada_epsilon2`: `adafactor` optimizer params. 333 | - `weight_decay`: Weight decay parameter, if not present no weight decay is used (the weight decay fix for Adam is used) (default: 0.01) (optional). 334 | - `train_batch_size`: Batch size during training. 335 | - `train_steps`: Number of training steps (batches), set to roughly ~1 epoch for now (total number of tokens in your dataset / number of tokens per batch (= `train_batch_size` / `n_ctx`)). 336 | - `eval_steps`: Number of steps to run for each evaluation. Set to `0` for no eval. i.e After every checkpoint, the model is tested for `eval_steps` 337 | - `iterations`: Number of steps queued to the TPU, must be smaller than `steps_per_checkpoint`. (default: 500) 338 | - `datasets`: List of tfrecords datasets to use. Each dataset is a list with the following parameters: `[train glob , eval glob, stitch, sampling_mode, weight]`. So for example for a single dataset (note the double list): `[["bundestag_*.tfrecords", "", 10, "random_sample", 1.0]]` 339 | + `dataset_id`: The name of a dataset configuration file in `./configs/dataset_configs` 340 | + `stitch`: If `sampling_mode` `random_sample` is used, the input pipeline samples this amount of texts into one to sample from. You must select stitch so that `stitch * minimum_document_length >= n_ctx` 341 | + `sampling_mode`: `chunks` (tfrecords are preprocessed into the correct length and are read sequentially) or `documents_random` (`stitch` amount of documents are concatenated and then a `n_ctx` chunk is randomly subsampled) 342 | + `weights`: How much relative weight this dataset should have compared to others 343 | - `model`: Which model to train. Currently only `GPT` is supported, and it defaults to this if not present. 344 | - `model_path`: Google storage bucket location (or local path, if using GPUs) to save model checkpoints and logs. 345 | - `n_ctx`: Size of context window. Default is 2048 346 | - `n_layer`: Number of layers (blocks) in the model. 347 | - `scale_by_depth`: If true, the weight initialization of layers are scaled by their depth as in the GPT2 paper. 348 | - `scale_by_in`: If true, the weight initialization of layers are scaled by their number of inputs as in the GPT2 paper. 349 | - `mesh_shape`: A Mesh is an n-dimensional array of processors with named dimensions used for parallelism in the mesh-tensorflow library. Each Tensor is split evenly across mesh dimensions according to the layout (see below). The 'mesh_shape' is the shape of this array, and must be equal to the number of processors. e.g., for a v3-128 TPU "mesh_shape": “x:16,y:8”. 350 | - `layout`: A Tensor is laid out on its mesh with one slice on each processor. A Tensor "layout", is an injective partial map specifying which dimensions of the tensor are (evenly) split across which dimensions of the mesh. No dimension of a tensor may be split across two dimensions of its mesh and no two dimensions of a tensor may be split across the same dimension of its mesh. The user defines a global set of layout rules in the form of (tensor-dimension-name, mesh-dimension-name) pairs. A dimension of a tensor is split across a dimension of its mesh if there is a matching rule, e.g. (for the above example mesh_shape: "layout":"batch:x,heads:y" 351 | - `activation_function`: `selu` (self normalizing) or `gelu` (used by OA), activation function used in feed-forward passes. (default: gelu) 352 | - `attention_types`: the type of attention for each layer in a list of the following format [[["attention_type"], n_layers]]. e.g. for a 12 layer net [[["global"], 12]] or [[["local"], 10], [["global"], 2]]. 353 | + Choose from: `linear`, `global`, `local` or `none`. We have found a 50/50 mix of `global` and `linear` to work well. `none` allows you to create feed-forward only layers for more efficient [PAR Transformer](https://arxiv.org/abs/2009.04534) models. 354 | - `precision`: `float32` or `bfloat16`. 355 | - `tokens_per_mb_per_replica`: If not None, will split the batch up into smaller microbatches containing `tokens_per_mb_per_replica` tokens to avoid OOMs. Gradients are accumulated locally and reduced once. IMPORTANT: mb refers to *minibatch* not megabyte here. 356 | 357 | **Mixture of Experts** 358 | 359 | - `moe_layers`: A list of layer numbers to append a [mixture of experts](https://arxiv.org/abs/1701.06538) layer onto. E.G: `[2,4,6,8,10,12]`. 360 | We have experimentally found a moe layer for every two self-attention layers to work well. 361 | - `moe_params`: a dictionary of additional kwargs to pass in to the moe layer. E.G 362 | `{"moe_dropout_rate": 0.0 }` 363 | 364 | **Experimental features** 365 | 366 | - `axial_pos_emb_`: If true, uses [axial positional embedding](https://arxiv.org/abs/1912.12180. 367 | - `mlp_glu`: If true, uses a gated linear unit variant of feed forward layers. 368 | - `scalenorm`: If true, uses scalenorm instead of layernorm. 369 | - `rezero`: If true, uses [rezero](https://www.groundai.com/project/rezero-is-all-you-need-fast-convergence-at-large-depth/1) instead of layernorm. 370 | - `num_mem_kv`: adds memory / key values from the [all-attention paper](https://arxiv.org/pdf/1907.01470.pdf). Param is an int with the number of desired mem/key values. 371 | - `macaron`: if true - uses a [macaron transformer](https://arxiv.org/pdf/1906.02762.pdf) for each layer block. 372 | 373 | ## TODO: 374 | 375 | - [x] finalize documentation 376 | - [ ] update configs 377 | 378 | ## Citing GPT-Neo 379 | 380 | If you have found GPT-Neo helpful in your work, you can cite this repository as 381 | 382 | ``` 383 | @software{gpt-neo, 384 | author = {Black, Sid and 385 | Gao, Leo and 386 | Wang, Phil and 387 | Leahy, Connor and 388 | Biderman, Stella}, 389 | title = {{GPT-Neo: Large Scale Autoregressive Language 390 | Modeling with Mesh-Tensorflow}}, 391 | month = mar, 392 | year = 2021, 393 | note = {{If you use this software, please cite it using 394 | these metadata.}}, 395 | publisher = {Zenodo}, 396 | version = {1.0}, 397 | doi = {10.5281/zenodo.5297715}, 398 | url = {https://doi.org/10.5281/zenodo.5297715} 399 | } 400 | 401 | ``` 402 | The version number should be replaced with the version number you are using, and the year corresponds to the project's open-source release. 403 | 404 | If you are specifically interested in citing the GPT-Neo models trained on [the Pile](https://arxiv.org/abs/2101.00027), we would appreciate also citing 405 | ``` 406 | @article{gao2020pile, 407 | title={The Pile: An 800GB Dataset of Diverse Text for Language Modeling}, 408 | author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and others}, 409 | journal={arXiv preprint arXiv:2101.00027}, 410 | year={2020} 411 | } 412 | ``` 413 | --------------------------------------------------------------------------------