├── .bettercodehub.yml ├── .deepsource.toml ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── 32big_mixer.json ├── 32ctx_mixer.json └── 32mixer_group.json ├── main.py ├── pytest.ini ├── requirements.txt ├── scripts ├── Dockerfile ├── build_and_push.sh ├── chunk_video_json.py ├── compile_local_text2tfrecord.sh ├── compile_train_tokenizer.sh ├── install_packages.sh ├── local_text2tfrecord.pyx ├── requirements.txt ├── run_experiments.py ├── run_local_text2tfrecord.py ├── run_manager.py ├── run_train_tokenizer.py ├── split_video_json.py ├── text2tfrecord.py ├── train_tokenizer.pyx └── video2tfrecord.py ├── src ├── dataclass.py ├── inputs.py ├── interface.py ├── main.py ├── model │ ├── __init__.py │ ├── activation.py │ ├── backend.py │ ├── basic.py │ ├── convolution.py │ ├── embedding.py │ ├── frontend.py │ ├── momentumnet.py │ ├── normalization.py │ ├── revnet.py │ └── spatial.py ├── mtf_wrapper.py ├── optimizer │ ├── __init__.py │ ├── backend.py │ ├── context.py │ ├── gradients.py │ ├── learning_rate.py │ └── optimizers.py ├── rest_api.py ├── run │ ├── dataloader_placement.py │ ├── inference.py │ ├── run.py │ ├── train.py │ └── utils_run.py ├── tf_wrapper.py ├── utils_core.py └── utils_mtf.py └── tests ├── backend.py ├── basic_linear_square_test.py ├── basic_pointwise_test.py └── variable_test.py /.bettercodehub.yml: -------------------------------------------------------------------------------- 1 | component_depth: 2 2 | languages: 3 | - python 4 | - script 5 | -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["*/test/**"] 4 | 5 | [[analyzers]] 6 | name = "python" 7 | enabled = true 8 | 9 | [analyzers.meta] 10 | runtime_version = "3.x.x" 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # testing 2 | .test/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | logs/ 110 | *.log 111 | test_* 112 | test/ 113 | .vscode 114 | 115 | #Video 116 | *.mkv 117 | *.mp4 118 | *.ogg 119 | *.webm 120 | *.avi 121 | *flv 122 | *.vtt 123 | *.tfrecord 124 | 125 | run_configs/ 126 | /data/channel_video_id_list.json 127 | 128 | # intellij 129 | .idea 130 | 131 | creds.json 132 | creds.* 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020-2021 Yannic Kilcher (yk), Lucas Nestler (clashluke), Shawn Presser (shawwn), Jan (xmaster96) 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OBST 2 | 3 | Copyright (c) 2020-2021 Yannic Kilcher (yk), Lucas Nestler (clashluke), Shawn Presser (shawwn), Jan (xmaster96) 4 | 5 | ## Quickstart 6 | 7 | First, create your VM through [google cloud shell](https://ssh.cloud.google.com/) with `ctpu up --vm-only`. This way it has all the necessary permissions to connect to your Buckets and TPUs.\ 8 | Next, install the requirements with pip on your VM using `git clone https://github.com/tensorfork/obst && cd obst && python3 -m pip install -r requirements.txt`.\ 9 | Finally, start a TPU to kick off a training run using `python3 main.py --model configs/big_ctx.json --tpu ${YOUR_TPU_NAME}`. 10 | 11 | ## Acknowledgements 12 | 13 | * [Mesh Tensorflow](https://github.com/tensorflow/mesh/) as machine learning library 14 | * Intial code forked from [Eleuther AI's GPT-Neo](https://github.com/EleutherAI/gpt-neo) 15 | 16 | We also want to explicitly thank 17 | * [tensorfork](https://www.tensorfork.com/) and [TRC](https://sites.research.google/trc/) for providing us with the required compute (TPUs) 18 | * [Ben Wang (kindiana)](https://github.com/kingoflolz) and [Shawn Presser](https://twitter.com/theshawwn) for their invaluable knowledge about TensorFlow, TPU, and language models 19 | * [Gwern Branwen](https://www.gwern.net/index), [Tri Songz](https://github.com/trisongz/) and [Aleph Alpha](https://aleph-alpha.de/) for financing our storage and servers 20 | -------------------------------------------------------------------------------- /configs/32big_mixer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embedding_stddev": 0.004, 3 | "calc_accuracy": true, 4 | "scale_by_depth": true, 5 | "block_config": [ 6 | { 7 | "layer": [ 8 | "norm-shift-scale-features-group", 9 | "bottleneck_group_linear-in:relu-mid:relu-mid:norm-mid:shift-mid:scale-mid:features" 10 | ] 11 | }, 12 | { 13 | "layer": [ 14 | "norm-shift-scale-features-group", 15 | "attention-biased_attention_map-absolute-input_as_value-shared", 16 | "norm-shift-scale-features-group", 17 | "activation-gelu", 18 | "attention-biased_attention_map-absolute-input_as_value-shared" 19 | ] 20 | } 21 | ], 22 | "group_linear_factor": 2, 23 | "intermediate_feed_forward_multiplier_multiplier": 0.5, 24 | "depth": 32, 25 | "use_initial_position_embedding": false, 26 | "sequence_length": 512, 27 | "features_per_head": 512, 28 | "heads": 8, 29 | "use_random_dataloader": false, 30 | "shuffle_buffer": 1048576, 31 | "buffer_size": 64, 32 | "train_batch_size": 1024, 33 | "interleaved_datasets": 64, 34 | "data_seed": 134567, 35 | "dataset_configs": [ 36 | { 37 | "path": "gs://ggpt4/the-char-pile/*", 38 | "type": "text", 39 | "weight": 1 40 | } 41 | ], 42 | "vocab_size": 256, 43 | "model_mode": "gpt", 44 | "use_language": true, 45 | "adaptive_gradient_clipping": false, 46 | "gradient_clip": 1, 47 | "learning_rate": 0.01, 48 | "opt_beta1": 0.9, 49 | "memory_reduction_strategy": "revnet", 50 | "opt_beta2": 0.99, 51 | "optimizer": "adaptive_clip:0.003-sm3-momentum:0.9:1:1-learning_rate", 52 | "weight_decay": 0.0001, 53 | "weight_centralisation": false, 54 | "weight_standardisation": false, 55 | "macro_batching": 1, 56 | "macro_batch_loss_smoothing": true, 57 | "model_path": "gs://ggpt4/runs/aa/activation/features=seq=512-batch=1024-mixer-group_bottleneck", 58 | "steps_per_checkpoint": 256, 59 | "use_checkpointing": false, 60 | "calculation_dtype": "bfloat16", 61 | "storage_dtype": "bfloat16", 62 | "optimizer_slice_dtype": "bfloat16", 63 | "slice_dtype": "float32", 64 | "sampling_temperature": 0.75, 65 | "use_autoregressive_sampling": true, 66 | "initial_autoregressive_position": 64, 67 | "learning_rate_config": {"linear_warmup": {"final_step": 4096}} 68 | } -------------------------------------------------------------------------------- /configs/32ctx_mixer.json: -------------------------------------------------------------------------------- 1 | { 2 | "embedding_stddev": 0.004, 3 | "calc_accuracy": true, 4 | "scale_by_depth": true, 5 | "block_config": [ 6 | { 7 | "layer": [ 8 | "norm-shift-scale-features-group", 9 | "bottleneck_group_linear-in:relu-mid:relu-mid:norm-mid:shift-mid:scale-mid:features" 10 | ] 11 | }, 12 | { 13 | "layer": [ 14 | "norm-shift-scale-features-group", 15 | "attention-biased_attention_map-absolute-input_as_value-shared", 16 | "norm-shift-scale-features-group", 17 | "activation-gelu", 18 | "attention-biased_attention_map-absolute-input_as_value-shared" 19 | ] 20 | } 21 | ], 22 | "group_linear_factor": 2, 23 | "intermediate_feed_forward_multiplier_multiplier": 0.5, 24 | "depth": 32, 25 | "use_initial_position_embedding": false, 26 | "sequence_length": 2048, 27 | "features_per_head": 256, 28 | "heads": 8, 29 | "use_random_dataloader": false, 30 | "shuffle_buffer": 1048576, 31 | "buffer_size": 64, 32 | "train_batch_size": 256, 33 | "interleaved_datasets": 64, 34 | "data_seed": 134567, 35 | "dataset_configs": [ 36 | { 37 | "path": "gs://ggpt4/the-char-pile/*", 38 | "type": "text", 39 | "weight": 1 40 | } 41 | ], 42 | "vocab_size": 256, 43 | "model_mode": "gpt", 44 | "use_language": true, 45 | "adaptive_gradient_clipping": false, 46 | "gradient_clip": 1, 47 | "learning_rate": 0.01, 48 | "opt_beta1": 0.9, 49 | "memory_reduction_strategy": "revnet", 50 | "opt_beta2": 0.99, 51 | "optimizer": "adaptive_clip:0.003-sm3-momentum:0.9:1:1-learning_rate", 52 | "weight_decay": 0.0001, 53 | "weight_centralisation": false, 54 | "weight_standardisation": false, 55 | "macro_batching": 1, 56 | "macro_batch_loss_smoothing": true, 57 | "model_path": "gs://ggpt4/runs/aa/activation/features=256-seq=2048-batch=256-mixer-group_bottleneck", 58 | "steps_per_checkpoint": 256, 59 | "use_checkpointing": false, 60 | "calculation_dtype": "bfloat16", 61 | "storage_dtype": "bfloat16", 62 | "optimizer_slice_dtype": "bfloat16", 63 | "slice_dtype": "float32", 64 | "sampling_temperature": 0.75, 65 | "use_autoregressive_sampling": true, 66 | "initial_autoregressive_position": 64, 67 | "learning_rate_config": {"linear_warmup": {"final_step": 4096}} 68 | } -------------------------------------------------------------------------------- /configs/32mixer_group.json: -------------------------------------------------------------------------------- 1 | { 2 | "embedding_stddev": 0.004, 3 | "calc_accuracy": true, 4 | "scale_by_depth": true, 5 | "block_config": [ 6 | { 7 | "layer": [ 8 | "norm-shift-scale-features-group", 9 | "bottleneck_group_linear-in:relu-mid:relu-mid:norm-mid:shift-mid:scale-mid:features" 10 | ] 11 | }, 12 | { 13 | "layer": [ 14 | "norm-shift-scale-features-group", 15 | "attention-biased_attention_map-absolute-input_as_value-shared", 16 | "norm-shift-scale-features-group", 17 | "activation-gelu", 18 | "attention-biased_attention_map-absolute-input_as_value-shared" 19 | ] 20 | } 21 | ], 22 | "group_linear_factor": 2, 23 | "intermediate_feed_forward_multiplier_multiplier": 0.5, 24 | "depth": 32, 25 | "use_initial_position_embedding": false, 26 | "sequence_length": 256, 27 | "features_per_head": 256, 28 | "heads": 8, 29 | "use_random_dataloader": false, 30 | "shuffle_buffer": 1048576, 31 | "buffer_size": 64, 32 | "train_batch_size": 4096, 33 | "interleaved_datasets": 64, 34 | "data_seed": 134567, 35 | "dataset_configs": [ 36 | { 37 | "path": "gs://ggpt4/the-char-pile/*", 38 | "type": "text", 39 | "weight": 1 40 | } 41 | ], 42 | "vocab_size": 256, 43 | "model_mode": "gpt", 44 | "use_language": true, 45 | "adaptive_gradient_clipping": false, 46 | "gradient_clip": 1, 47 | "learning_rate": 0.01, 48 | "opt_beta1": 0.9, 49 | "memory_reduction_strategy": "revnet", 50 | "opt_beta2": 0.99, 51 | "optimizer": "adaptive_clip:0.003-sm3-momentum:0.9:1:1-learning_rate", 52 | "weight_decay": 0.0001, 53 | "weight_centralisation": false, 54 | "weight_standardisation": false, 55 | "macro_batching": 1, 56 | "macro_batch_loss_smoothing": true, 57 | "model_path": "gs://ggpt4/runs/aa/activation/features=seq=256-batch=4096-mixer-group_bottleneck", 58 | "steps_per_checkpoint": 256, 59 | "use_checkpointing": false, 60 | "calculation_dtype": "bfloat16", 61 | "storage_dtype": "bfloat16", 62 | "optimizer_slice_dtype": "bfloat16", 63 | "slice_dtype": "float32", 64 | "sampling_temperature": 0.75, 65 | "use_autoregressive_sampling": true, 66 | "initial_autoregressive_position": 64, 67 | "learning_rate_config": {"linear_warmup": {"final_step": 4096}} 68 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | "Main" script that parses arguments and starts functions that actually build the model graph and start 3 | training if so desired. 4 | """ 5 | 6 | import argparse 7 | 8 | import tensorflow as tf 9 | 10 | from src import main 11 | 12 | if __name__ == "__main__": 13 | tf.compat.v1.disable_v2_behavior() 14 | 15 | modes = ','.join(main.RUN_MODE_FNS.keys()) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--tpu", type=str, help="Name of TPU to train on") 19 | parser.add_argument("--model", type=str, default=None, help="JSON file that contains model parameters.") 20 | parser.add_argument("--workers", type=int, default=1, help="Number of workers in WebAPI.") 21 | parser.add_argument("--run_mode", type=str, default="train", help=modes) 22 | parser.add_argument("--debug_grad", help="Log the gradients to tensorbord.") 23 | 24 | args = parser.parse_args() 25 | 26 | if args.run_mode not in main.RUN_MODE_FNS: 27 | raise ValueError(f"'{args.run_mode}' is not a supported argument for" 28 | f" --run_mode, please use one of {modes}.") 29 | 30 | main.main(args) 31 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = *_test.py 3 | python_functions = *_test -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | mesh-tensorflow 3 | tensorflow==2.8.2 4 | tokenizers 5 | transformers 6 | 7 | tpunicorn 8 | google-auth 9 | google-api-python-client 10 | cython 11 | pysimdjson 12 | 13 | absl-py 14 | ftfy 15 | jsonlines 16 | lm_dataformat 17 | ortools 18 | pytest 19 | sacred 20 | attrs 21 | 22 | opencv-python 23 | Pillow 24 | git+https://github.com/ytdl-org/youtube-dl.git 25 | google-cloud-storage 26 | oauth2client 27 | utils 28 | scipy 29 | 30 | fastapi 31 | uvicorn 32 | click 33 | -------------------------------------------------------------------------------- /scripts/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:latest 2 | COPY requirements.txt install_packages.sh ./ 3 | RUN chmod +x install_packages.sh && ./install_packages.sh && rm install_packages.sh 4 | RUN mkdir buffer datasets 5 | COPY video2tfrecord.py ./ 6 | -------------------------------------------------------------------------------- /scripts/build_and_push.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | docker build -t ykilcher/jannet . 7 | docker push ykilcher/jannet 8 | -------------------------------------------------------------------------------- /scripts/chunk_video_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json 4 | import os 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument('load_path', type=str, 12 | help='The path to a json file containing video information, or a path to a folder containing ' 13 | 'json files with video information.') 14 | parser.add_argument('min_duration', type=int, help='The Minimum duration a chunk is suppose to contain.') 15 | parser.add_argument('-prefix', type=str, default='', help='A save file prerfix.') 16 | 17 | args = parser.parse_args() 18 | 19 | load_path = args.load_path 20 | min_duration = args.min_duration 21 | prefix = args.prefix 22 | 23 | if os.path.isdir(load_path): 24 | load_path = [os.path.join(load_path, p) for p in os.listdir(load_path)] 25 | else: 26 | load_path = [load_path] 27 | 28 | ids = [] 29 | duration = [] 30 | 31 | for l in load_path: 32 | json_load = json.load(open(l)) 33 | 34 | ids = ids + json_load['id'] 35 | duration = duration + json_load['duration'] 36 | 37 | chunks_ids = [] 38 | chunks_duration = [] 39 | 40 | _chunk_ids = [] 41 | _chunk_duration = [] 42 | _chunk_duration_sum = 0 43 | 44 | videos = list(zip(ids, duration)) 45 | random.shuffle(videos) 46 | random.shuffle(videos) 47 | 48 | for i, d in videos: 49 | 50 | _chunk_ids.append(i) 51 | _chunk_duration.append(d) 52 | _chunk_duration_sum = _chunk_duration_sum + d 53 | 54 | if _chunk_duration_sum >= min_duration: 55 | 56 | chunks_ids.append(_chunk_ids) 57 | chunks_duration.append(_chunk_duration) 58 | 59 | _chunk_ids = [] 60 | _chunk_duration = [] 61 | _chunk_duration_sum = 0 62 | 63 | chunks_ids.append(_chunk_ids) 64 | chunks_duration.append(_chunk_duration) 65 | 66 | ids = chunks_ids 67 | duration = chunks_duration 68 | 69 | chunk_video_count = 0 70 | chunk_video_duration = 0 71 | 72 | for i in range(len(ids)): 73 | buffer_video_count = len(ids[i]) 74 | buffer_video_duration = sum(duration[i]) 75 | 76 | print('chunk:', i, 'videos:', buffer_video_count, 'duration:', buffer_video_duration) 77 | 78 | chunk_video_count += buffer_video_count 79 | chunk_video_duration += buffer_video_duration 80 | 81 | print('') 82 | print('total num of videos:', chunk_video_count, 'total video duration:', chunk_video_duration) 83 | 84 | path = f"{prefix}work_chunks.json" 85 | dump = {'id': ids, 'duration': duration} 86 | 87 | json.dump(dump, open(path, 'w')) -------------------------------------------------------------------------------- /scripts/compile_local_text2tfrecord.sh: -------------------------------------------------------------------------------- 1 | python_flags=`python3-config --cflags --ldflags --includes --libs` 2 | python_flags=`echo "${python_flags//-g}"` 3 | python_include=-I`python3 -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` 4 | 5 | optimization_options="-fsingle-precision-constant -fcx-fortran-rules -flto -Ofast -ffast-math -ffinite-math-only -fno-trapping-math -frounding-math -freciprocal-math -fassociative-math -fno-signaling-nans -fstdarg-opt" 6 | code_generation_options="-fwrapv -fPIC -fdelete-dead-exceptions" 7 | preprocessor_options="-pthread" 8 | machine_options="-march=native -mtune=native -msse2 -msse4.2 -mavx -msse4.1 -msse -msse3 -mstackrealign -mmmx -maes -mpclmul -mclflushopt -mfsgsbase -mrdrnd -mf16c -mpopcnt -mfxsr -mxsave -mxsaveopt -msahf -mcx16 -mmovbe -mshstk -mcrc32 -mmwaitx -mrecip -minline-all-stringops" 9 | linker_options="-s -shared" 10 | c_dialect_options="-fopenmp -fopenacc -fsigned-char" 11 | 12 | gcc_options="$python_flags $python_include $optimization_options $code_generation_options $preprocessor_options $machine_options $linker_options $c_dialect_options" 13 | 14 | gcc_options=`echo $gcc_options | tr '\n' ' ' | tr '\r' ' ' | tr '\t' ' ' | tr ' ' ' '` 15 | 16 | 17 | echo "Global GCC Flags:" 18 | echo "$gcc_options" 19 | 20 | 21 | function compile { 22 | file=${1} 23 | echo "Cythonizing.." 24 | python3 -m cython "$file.pyx" -3 -Wextra -D 25 | flags="$file.c $gcc_options -o $file.so" 26 | echo "Executing gcc.." 27 | time ((gcc-11 $flags) || (gcc-10 $flags) || (gcc-9 $flags) || (gcc-8 $flags) || (gcc $flags)) 28 | echo "Testing compilation.." 29 | python3 -c "import $file" 30 | echo 31 | } 32 | 33 | 34 | compile local_text2tfrecord 35 | -------------------------------------------------------------------------------- /scripts/compile_train_tokenizer.sh: -------------------------------------------------------------------------------- 1 | python_flags=`python3-config --cflags --ldflags --includes --libs` 2 | python_flags=`echo "${python_flags//-g}"` 3 | python_include=-I`python3 -c 'import numpy, sys; sys.stdout.write(numpy.get_include()); sys.stdout.flush()'` 4 | 5 | optimization_options="-fsingle-precision-constant -fcx-fortran-rules -flto -Ofast -ffast-math -ffinite-math-only -fno-trapping-math -frounding-math -freciprocal-math -fassociative-math -fno-signaling-nans -fstdarg-opt" 6 | code_generation_options="-fwrapv -fPIC -fdelete-dead-exceptions" 7 | preprocessor_options="-pthread" 8 | machine_options="-march=native -mtune=native -msse2 -msse4.2 -shared -mavx -msse4.1 -msse -msse3 -mstackrealign -mmmx -maes -mpclmul -mclflushopt -mfsgsbase -mrdrnd -mf16c -mpopcnt -mfxsr -mxsave -mxsaveopt -msahf -mcx16 -mmovbe -mshstk -mcrc32 -mmwaitx -mrecip -minline-all-stringops" 9 | linker_options="-s -shared" 10 | c_dialect_options="-fopenmp -fopenacc -fsigned-char" 11 | 12 | gcc_options="$python_flags $python_include $optimization_options $code_generation_options $preprocessor_options $machine_options $linker_options $c_dialect_options" 13 | 14 | gcc_options=`echo $gcc_options | tr '\n' ' ' | tr '\r' ' ' | tr '\t' ' ' | tr ' ' ' '` 15 | 16 | 17 | echo "Global GCC Flags:" 18 | echo "$gcc_options" 19 | 20 | 21 | function compile { 22 | file=${1} 23 | echo "Cythonizing.." 24 | python3 -m cython "$file.pyx" -3 -Wextra -D 25 | flags="$file.c $gcc_options -o $file.so" 26 | echo "Executing gcc.." 27 | time ((gcc-11 $flags) || (gcc-10 $flags) || (gcc-9 $flags) || (gcc-8 $flags) || (gcc $flags)) 28 | echo "Testing compilation.." 29 | python3 -c "import $file" 30 | echo 31 | } 32 | 33 | 34 | compile train_tokenizer 35 | -------------------------------------------------------------------------------- /scripts/install_packages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Bash "strict mode", to help catch problems and bugs in the shell 4 | # script. Every bash script you write should include this. See 5 | # http://redsymbol.net/articles/unofficial-bash-strict-mode/ for 6 | # details. 7 | set -euvo pipefail 8 | 9 | # Tell apt-get we're never going to be able to give manual 10 | # feedback: 11 | export DEBIAN_FRONTEND=noninteractive 12 | 13 | # Update the package listing, so we know what package exist: 14 | apt-get update 15 | apt-get install -y python3 python3-pip ffmpeg libgl-dev git 16 | 17 | python3 -m pip install -U pip 18 | pip3 install -r requirements.txt 19 | 20 | # Delete cached files we don't need anymore: 21 | apt-get clean 22 | rm -rf /var/lib/apt/lists/* 23 | -------------------------------------------------------------------------------- /scripts/local_text2tfrecord.pyx: -------------------------------------------------------------------------------- 1 | """ 2 | tokenization to bpe or character embeddings of text datasets 3 | make sure to first run train_tokenizer to get the preprocessed files locally 4 | """ 5 | 6 | #!python 7 | # cython: boundscheck=False 8 | # cython: initializedcheck=False 9 | # cython: nonecheck=False 10 | # cython: wraparound=False 11 | # cython: cdivision=True 12 | # cython: profile=False 13 | # cython: linetrace=False 14 | # cython: language_level=3 15 | 16 | 17 | import argparse 18 | import io 19 | import os 20 | import shutil 21 | import time 22 | import random 23 | import multiprocessing 24 | import urllib3 25 | 26 | import jsonlines 27 | import requests 28 | import simdjson 29 | import tensorflow as tf 30 | import zstandard 31 | from google.cloud import storage 32 | from transformers import GPT2TokenizerFast 33 | 34 | 35 | DEF NAME = "gpt2-bpe" 36 | DEF INT64 = 1 37 | DEF BUCKET_NAME = "obst-euw4a-aa" 38 | DEF OUTPUT_DIR = "the-fixed-gpt2-bpe-pile/" 39 | DEF PROCS = 12 40 | DEF SERVICE_ACCOUNT_JSON_PATH = "a.json" 41 | DEF BUFFER_SIZE = 2 ** 24 42 | DEF PRINTERVALL = 16 43 | 44 | 45 | cdef void create_tfrecords(unsigned short pid): 46 | cdef unicode prefix = f"{'int64' if INT64 else 'bytes'}_{NAME}_" 47 | 48 | bucket = storage.Client.from_service_account_json(SERVICE_ACCOUNT_JSON_PATH).get_bucket(BUCKET_NAME) 49 | encode = (GPT2TokenizerFast.from_pretrained('gpt2') if INT64 else str).encode 50 | 51 | cdef unsigned short splits = 30 52 | cdef unsigned short i = 0 53 | cdef unicode txt = "" 54 | cdef unicode filename = "" 55 | cdef unsigned long long processed_chars = 0 56 | cdef unsigned long tfrecord_count = 0 57 | 58 | cdef unsigned long last_write = time.time() 59 | cdef unsigned long start_time = time.time() 60 | 61 | for i in range(pid, splits, PROCS): 62 | with open(f'{i}.txt', 'r', BUFFER_SIZE * 2) as f: 63 | while True: 64 | txt = f.read(BUFFER_SIZE) 65 | if not txt: 66 | break 67 | processed_chars += BUFFER_SIZE 68 | joined = encode(txt) 69 | 70 | filename = f"{prefix}{tfrecord_count:_>6d}_{processed_chars}_{len(joined)}.tfrecord" 71 | 72 | with tf.io.TFRecordWriter(filename) as writer: 73 | if INT64: 74 | feature = {"text": tf.train.Feature(int64_list=tf.train.Int64List(value=joined))} 75 | else: 76 | feature = {"text": tf.train.Feature(bytes_list=tf.train.BytesList(value=[joined]))} 77 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) 78 | writer.write(tf_example.SerializeToString()) 79 | 80 | while True: 81 | try: 82 | bucket.blob(f'{OUTPUT_DIR}{filename}').upload_from_filename(filename) 83 | break 84 | except urllib3.exceptions.TimeoutError: 85 | pass 86 | tfrecord_count += 1 87 | if tfrecord_count % PRINTERVALL == 0: 88 | print(f"[{pid:{len(str(PROCS))}d}/{PROCS}] Processed: {processed_chars} - Total: {time.time()-start_time:.0f}s - Since last write: {time.time()-last_write:.0f}s") 89 | last_write = time.time() 90 | 91 | 92 | cpdef main(): 93 | processes = [multiprocessing.Process(target=create_tfrecords, args=(pid,)) for pid in range(PROCS)] 94 | for p in processes: 95 | p.start() 96 | for p in processes: 97 | p.join() 98 | -------------------------------------------------------------------------------- /scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow 2 | google-api-python-client 3 | oauth2client 4 | opencv-python 5 | Pillow 6 | git+https://github.com/ytdl-org/youtube-dl.git 7 | google-cloud-storage 8 | transformers 9 | -------------------------------------------------------------------------------- /scripts/run_experiments.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | import hashlib 4 | import json 5 | import math 6 | import time 7 | import os 8 | 9 | import numpy as np 10 | 11 | def str2bool(v: str) -> bool: 12 | if isinstance(v, bool): 13 | return v 14 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 15 | return True 16 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 17 | return False 18 | raise argparse.ArgumentTypeError('Boolean value expected.') 19 | 20 | 21 | if __name__ == '__main__': 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('base_config', type=str, 26 | help="The path to the .json config with will be use as the bases for this run.") 27 | parser.add_argument('tpu_start_id', type=int, help="The tpu ID at with the TPU IDs will start.") 28 | parser.add_argument('--run_config', default='', type=str, help="The path to the .json config that continents " 29 | "the Hyperparameters to be run. The config must contain a" 30 | " dict in with each entry is a list, the list continents " 31 | "the different Hyperparameters for variable") 32 | parser.add_argument('--run_name_prefix', type=str, default='gs://text-datasets/video-transformer/') 33 | parser.add_argument('--number_of_repetitions', type=int, default=1, help="The number of times the same " 34 | "parameters will get tested.") 35 | parser.add_argument('--repetition_start_idx', type=int, default=0) 36 | parser.add_argument('--use_preemptible', type=str, default='true') 37 | parser.add_argument('--tpu_type', type=str, default='v3-8') 38 | parser.add_argument('--zone', type=str, default='europe-west4-a') 39 | parser.add_argument('--network', type=str, default='tpu-euw4a') 40 | parser.add_argument('--start_up_sleep', type=int, default=0) 41 | parser.add_argument('--project', type=str, default='mlops-engine') 42 | parser.add_argument('--use_manager', type=str, default='False') 43 | 44 | args = parser.parse_args() 45 | 46 | tpu_type = args.tpu_type 47 | tpu_type_str = '"' + tpu_type + '"' 48 | 49 | with open(args.base_config) as f: 50 | base_config = json.load(f) 51 | 52 | if args.run_config != "": 53 | with open(args.run_config) as f: 54 | run_config = json.load(f) 55 | else: 56 | run_config = {} 57 | 58 | if not os.path.exists("../buffer_configs/"): 59 | os.makedirs("../buffer_configs/") 60 | 61 | tpu_id = args.tpu_start_id 62 | run_config_key = list(run_config.keys()) 63 | 64 | _key = [np.arange(len(run_config[key])) for key in run_config_key] 65 | key_pos = np.meshgrid(*_key, sparse=False) 66 | key_pos = np.stack(key_pos, axis=-1) 67 | _shape = key_pos.shape 68 | key_pos = np.reshape(key_pos, newshape=(np.prod(_shape[:-1]), _shape[-1])) 69 | 70 | for pos in key_pos: 71 | 72 | copy_base_config = base_config.copy() 73 | 74 | for idx, key in enumerate(run_config_key): 75 | copy_base_config[key] = run_config[key][pos[idx]] 76 | 77 | for repetition_idx in range(args.repetition_start_idx, args.number_of_repetitions): 78 | tpu_name = f"tpu-{tpu_type}-{args.network}-{tpu_id}" 79 | 80 | cors = int(str(tpu_type).split('-')[-1]) 81 | if cors == 8: 82 | tpu_range = f"10.48.{tpu_id}.0/29" 83 | else: 84 | cidr = int(32 + 2 - math.log2(cors)) 85 | _tpu_id = tpu_id + 2 86 | 87 | tpu_range = f"10.{_tpu_id}.0.0/{cidr}" 88 | 89 | run_name = f"-run={repetition_idx}" 90 | run_name = "-".join([f"{key}={copy_base_config[key]}" for key in run_config_key]) + run_name 91 | run_name = run_name.replace(' ', '_').replace("'", '').replace(":", '=').replace(",", '-') 92 | run_name = run_name.replace('[', '|').replace(']', '|') 93 | 94 | copy_base_config['model_path'] = args.run_name_prefix + run_name 95 | 96 | with open(f"../buffer_configs/{tpu_id}_{run_name}.json", 'w+') as w: 97 | w.write(json.dumps(copy_base_config)) 98 | 99 | experiment_command = f"python3 ../main.py --model ../buffer_configs/" \ 100 | f"{tpu_id}_{run_name}.json --tpu {tpu_name}" 101 | delete_command = f"pu delete {tpu_name} --yes" 102 | tpu_creat_command = f"gcloud compute tpus create {tpu_name} --zone {args.zone} " \ 103 | f"--range {tpu_range} --network {args.network} --version 1.15.5 " \ 104 | f"--accelerator-type {tpu_type_str} --project {args.project}" 105 | 106 | if str2bool(args.use_preemptible): 107 | tpu_creat_command = tpu_creat_command + " --preemptible" 108 | 109 | if str2bool(args.use_manager): 110 | comm = f"python3 run_manager.py '{experiment_command}' {tpu_name} {tpu_type} {args.zone} " \ 111 | f"{args.network} {args.run_name_prefix + run_name} {str2bool(args.use_preemptible)}" 112 | else: 113 | comm = f"({tpu_creat_command} && {experiment_command}) ; {delete_command}" 114 | 115 | if len(run_name) > 66: 116 | run_name = hashlib.sha256(run_name.encode('utf-8')).hexdigest() 117 | 118 | prosses_name = f"tpu_id:{tpu_id}--{run_name}" 119 | 120 | subprocess.run(['screen', '-dmS', prosses_name, 'bash', '-c', comm]) 121 | 122 | tpu_id = tpu_id + 1 123 | 124 | print(f"Creating {prosses_name}") 125 | time.sleep(args.start_up_sleep) 126 | -------------------------------------------------------------------------------- /scripts/run_local_text2tfrecord.py: -------------------------------------------------------------------------------- 1 | import local_text2tfrecord 2 | 3 | if __name__ == '__main__': 4 | local_text2tfrecord.main() 5 | -------------------------------------------------------------------------------- /scripts/run_manager.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | import hashlib 4 | import json 5 | import math 6 | import random 7 | import time 8 | import signal 9 | import os 10 | 11 | import tensorflow as tf 12 | from tpuapi import TPUServiceAPI 13 | 14 | 15 | 16 | def str2bool(v: str) -> bool: 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 20 | return True 21 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 22 | return False 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | 26 | class GFile: 27 | 28 | def __init__(self, name, mode): 29 | self.file = tf.io.gfile.GFile(name, mode) 30 | self.write_count = 0 31 | 32 | def fileno(self): 33 | return 9 34 | 35 | def write(self, data): 36 | self.file.write(data) 37 | 38 | self.write_count = self.write_count + 1 39 | 40 | if self.write_count > 200: 41 | self.flush() 42 | 43 | return 3 44 | 45 | def write_flush(self, data): 46 | self.file.write(data) 47 | self.flush() 48 | return 3 49 | 50 | def flush(self): 51 | self.write_count = 0 52 | self.file.flush() 53 | 54 | def close(self): 55 | self.flush() 56 | self.file.close() 57 | 58 | if __name__ == '__main__': 59 | 60 | parser = argparse.ArgumentParser() 61 | 62 | parser.add_argument('run_command', type=str) 63 | parser.add_argument('tpu_name', type=str) 64 | parser.add_argument('tpu_type', type=str) 65 | parser.add_argument('zone', type=str) 66 | parser.add_argument('network', type=str) 67 | parser.add_argument('model_path', type=str) 68 | parser.add_argument('preemptible', type=str) 69 | 70 | args = parser.parse_args() 71 | 72 | run_command = args.run_command 73 | tpu_name = args.tpu_name 74 | tpu_type = args.tpu_type 75 | model_path = args.model_path 76 | preemptible = str2bool(args.preemptible) 77 | 78 | cors = int(str(tpu_type).split('-')[-1]) 79 | tpu_id = int(str(tpu_name).split('-')[-1]) 80 | 81 | if cors == 8: 82 | tpu_range = f"10.48.{tpu_id}.0/29" 83 | else: 84 | cidr = int(32 + 2 - math.log2(cors)) 85 | _tpu_id = tpu_id + 2 86 | 87 | tpu_range = f"10.{_tpu_id}.0.0/{cidr}" 88 | 89 | tpu_client = TPUServiceAPI(project='mlops-engine') 90 | 91 | out_io = GFile(f"{model_path}/run.log", 'w') 92 | 93 | 94 | def wait_for_tpu(): 95 | ready = False 96 | ready_count = 0 97 | 98 | while not ready: 99 | time.sleep(15) 100 | ready = tpu_client.is_tpu_ready(tpu_name)['healthy'] 101 | ready_count = ready_count + 1 102 | 103 | if ready_count > 15: 104 | ready_count = 0 105 | tpu_log = tpu_client.recreate(tpu_name, mesh=tpu_type, tf_version='1.15.5', 106 | zone=args.zone, cidrblock=tpu_range, 107 | preemptible=preemptible, wait=True, network=args.network) 108 | 109 | out_io.write_flush(f"\n\n\n{tpu_log}\n\n\n") 110 | 111 | try: 112 | tpu_log = tpu_client.create(tpu_name, mesh=tpu_type, tf_version='1.15.5', zone=args.zone, 113 | cidrblock=tpu_range, preemptible=preemptible, wait=True, network=args.network) 114 | 115 | out_io.write_flush(f"{tpu_log}\n\n\n") 116 | 117 | wait_for_tpu() 118 | 119 | pro = subprocess.Popen(run_command, stdout=out_io, stderr=out_io, shell=True, preexec_fn=os.setsid) 120 | 121 | done = False 122 | 123 | while not done: 124 | time.sleep(300 + random.randint(0, 300)) 125 | 126 | health = tpu_client.is_tpu_ready(tpu_name) 127 | if pro.poll() is not None: 128 | if health['healthy']: 129 | done = True 130 | 131 | if not health['healthy']: 132 | os.killpg(os.getpgid(pro.pid), signal.SIGTERM) 133 | 134 | time.sleep(60) 135 | 136 | out_io.flush() 137 | 138 | tpu_log = tpu_client.recreate(tpu_name, mesh=tpu_type, tf_version='1.15.5', 139 | zone=args.zone, cidrblock=tpu_range, 140 | preemptible=preemptible, wait=True, network=args.network) 141 | 142 | out_io.write_flush(f"\n\n\n{tpu_log}\n\n\n") 143 | 144 | wait_for_tpu() 145 | 146 | pro = subprocess.Popen(run_command, stdout=out_io, stderr=out_io, shell=True, preexec_fn=os.setsid) 147 | except Exception as e: 148 | out_io.write_flush(f"\n\n\nrun_manager has crashed\n{e}\n\n\n") 149 | 150 | out_io.write_flush(f"\n\n\nIt seams like that the run is done.") 151 | 152 | try: 153 | tpu_log = tpu_client.delete(tpu_name) 154 | out_io.write_flush(f"\n{tpu_log}") 155 | except: 156 | out_io.write_flush(f"\nFailed to Delete the TPU") 157 | 158 | out_io.close() 159 | -------------------------------------------------------------------------------- /scripts/run_train_tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | try: 3 | from train_tokenizer import main 4 | except ImportError: 5 | import os 6 | os.system("bash compile_train_tokenizer.sh") 7 | del os 8 | from train_tokenizer import main 9 | 10 | if __name__ == '__main__': 11 | main() 12 | -------------------------------------------------------------------------------- /scripts/split_video_json.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | import os 5 | 6 | 7 | # Function copy from video2tfrecords.py 8 | def split_equal(ids: list, duration: list, num: int, min_duration: int = 256): 9 | 10 | sort = sorted(zip(duration, ids))[::-1] 11 | 12 | ids_split = [[] for i in range(num)] 13 | duration_spit = [[] for i in range(num)] 14 | duration_sum = [0] * num 15 | 16 | for d, i in sort: 17 | if np.sum(d) > min_duration or min_duration <= 0: 18 | pos = np.argmin(duration_sum) 19 | 20 | ids_split[pos].append(i) 21 | duration_spit[pos].append(d) 22 | duration_sum[pos] = duration_sum[pos] + np.sum(d) 23 | 24 | return ids_split, duration_spit 25 | 26 | 27 | if __name__ == '__main__': 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument('load_path', type=str, 32 | help='The path to a json file containing video information, or a path to a folder containing ' 33 | 'json files with video information.') 34 | parser.add_argument('split', type=int, help='The number of equal splits.') 35 | parser.add_argument('-prefix', type=str, default='', help='A save file prerfix.') 36 | 37 | args = parser.parse_args() 38 | 39 | load_path = args.load_path 40 | split = args.split 41 | prefix = args.prefix 42 | 43 | if os.path.isdir(load_path): 44 | load_path = [os.path.join(load_path, p) for p in os.listdir(load_path)] 45 | else: 46 | load_path = [load_path] 47 | 48 | ids = [] 49 | duration = [] 50 | 51 | for l in load_path: 52 | json_load = json.load(open(l)) 53 | 54 | ids = ids + json_load['id'] 55 | duration = duration + json_load['duration'] 56 | 57 | if duration[0] is not list: 58 | ids = [[id] for id in ids] 59 | else: 60 | duration = [np.sum(_duration) for _duration in duration] 61 | 62 | ids, duration = split_equal(ids, duration, split, -1) 63 | 64 | split_chunk_count = 0 65 | split_video_count = 0 66 | split_video_duration = 0 67 | 68 | for i in range(len(ids)): 69 | buffer_chunk_count = len(ids[i]) 70 | buffer_video_count = int(np.sum([np.sum([len(___ids) for ___ids in __ids]) for __ids in ids[i]])) 71 | buffer_video_duration = int(np.sum([np.sum(d) for d in duration[i]])) 72 | 73 | print('split:', i, 'chunks:', buffer_chunk_count, 'videos:', 74 | buffer_video_count, 'duration:', buffer_video_duration) 75 | 76 | split_chunk_count += buffer_chunk_count 77 | split_video_count += buffer_video_count 78 | split_video_duration += buffer_video_duration 79 | 80 | print('') 81 | print('total num of chunks:', split_chunk_count, 'total num of videos:', 82 | split_video_count, 'total video duration:', split_video_duration) 83 | print('') 84 | 85 | for idx, (i, d) in enumerate(zip(ids, duration)): 86 | 87 | path = "{}work_split_{}.json".format(prefix, idx) 88 | dump = {'id': i, 'duration': d} 89 | 90 | json.dump(dump, open(path, 'w')) -------------------------------------------------------------------------------- /scripts/text2tfrecord.py: -------------------------------------------------------------------------------- 1 | """tokenization to bpe or character embeddings of text datasets""" 2 | 3 | import argparse 4 | import io 5 | import os 6 | import shutil 7 | import time 8 | import random 9 | import multiprocessing 10 | 11 | import jsonlines 12 | import requests 13 | import simdjson 14 | import tensorflow as tf 15 | import zstandard 16 | from google.cloud import storage 17 | from transformers import GPT2TokenizerFast 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--name", type=str, default="text", 21 | help="Name of output files will be name_i.tfrecords where i is the number of the file") 22 | parser.add_argument("--procs", type=int, default=2, help="Number of processes in multiprocessing") 23 | parser.add_argument("--output_dir", type=str, default="gs://jannet/the-bpe-pile/", 24 | help="Where to put tfrecords (in a bucket)") 25 | parser.add_argument("--int64", type=bool, default=True, help="Whether to encode as bytes or int64") 26 | parser.add_argument("--service_account_json_path", type=str, default="a.json", help="Service account json from gcp") 27 | parser.add_argument("--buffer_size", type=int, default=2 ** 25, help="This is a minimum size, not a maximum size. " 28 | "tfrecords will have this minimum size as well.") 29 | parser.add_argument("--separator", type=str, default=4, 30 | help="separator to place between files in chunk mode." 31 | "Default is 0 (Null) in case of byte encodings, " 32 | "50256 for tokenized texts") 33 | 34 | 35 | def file_generator(args, pid, procs): 36 | base_url = 'http://eaidata.bmk.sh/data/pile/train/%s.jsonl.zst' 37 | splits = 30 38 | parse_fn = simdjson.Parser().parse 39 | tmp_name = f".tmp.download.{pid}" 40 | 41 | def _json_parser(x): 42 | return parse_fn(x.encode()).as_dict() 43 | 44 | for i in range(pid, splits, procs): 45 | with requests.get(base_url.replace("%s", str(i).zfill(2)), stream=True) as r, open(tmp_name, 'wb') as f: 46 | shutil.copyfileobj(r.raw, f) 47 | with open(tmp_name, 'rb') as f: 48 | for item in jsonlines.Reader(io.BufferedReader(zstandard.ZstdDecompressor().stream_reader(f)), 49 | loads=_json_parser): 50 | if isinstance(item, dict): 51 | item = item['text'] 52 | if isinstance(item, list): 53 | item = chr(args.separator).join(item) 54 | yield item 55 | os.remove(tmp_name) 56 | 57 | 58 | def create_tfrecords(args, pid, procs): 59 | slash_idx = args.output_dir.find('/') 60 | bucket_name, output_dir = args.output_dir[:slash_idx], args.output_dir[slash_idx + 1:] 61 | bucket = storage.Client.from_service_account_json(args.service_account_json_path).get_bucket(bucket_name) 62 | join = chr(args.separator).join 63 | prefix = f"{'int64' if args.int64 else 'bytes'}_{args.name}_" 64 | encode = (GPT2TokenizerFast.from_pretrained('gpt2') if args.int64 else str).encode 65 | 66 | files_processed = 0 67 | tfrecord_count = 0 68 | chunk = 0 69 | buffer_size = 0 70 | tokenized_files = [] 71 | 72 | last_write = start_time = time.time() 73 | 74 | for f in file_generator(args, pid, procs): 75 | buffer_size += len(f) 76 | tokenized_files.append(f) 77 | files_processed += 1 78 | 79 | if buffer_size > chunk * args.buffer_size // 4: 80 | print(f"Worker: {pid:{len(str(procs))}d} | Buffer: {buffer_size * 2 ** -20:.1f}MB | " 81 | f"Files: {files_processed} - TFrecords: {tfrecord_count} | " 82 | f"Wrote: {time.time() - last_write:.0f}s ago - Started: {time.time() - start_time:.0f}s ago", 83 | end='') 84 | chunk += 1 85 | 86 | if buffer_size > args.buffer_size: 87 | filename = f"{prefix}{tfrecord_count:_>6d}_{files_processed}_{buffer_size}.tfrecord" 88 | 89 | joined = encode(join(tokenized_files)) 90 | tokenized_files.clear() 91 | 92 | with tf.io.TFRecordWriter(filename) as writer: 93 | if args.int64: 94 | feature = {"text": tf.train.Feature(int64_list=tf.train.Int64List(value=joined))} 95 | else: 96 | feature = {"text": tf.train.Feature(bytes_list=tf.train.BytesList(value=[joined]))} 97 | tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) 98 | writer.write(tf_example.SerializeToString()) 99 | 100 | bucket.blob(f'{output_dir}{filename}').upload_from_filename(filename) 101 | 102 | os.remove(filename) 103 | chunk = 0 104 | buffer_size = 0 105 | tfrecord_count += 1 106 | 107 | print("") 108 | 109 | last_write = time.time() 110 | 111 | 112 | def main(): 113 | args = parser.parse_args() 114 | 115 | if not args.output_dir.endswith("/"): 116 | args.output_dir = args.output_dir + "/" 117 | if not args.output_dir.startswith("gs://"): 118 | print("Output dir isn't a cloud bucket. Exiting.") 119 | return 120 | args.output_dir = args.output_dir[len('gs://'):] 121 | processes = [multiprocessing.Process(target=create_tfrecords, args=(args, pid, args.procs)) for pid in range(args.procs)] 122 | for p in processes: 123 | p.start() 124 | for p in processes: 125 | p.join() 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /scripts/train_tokenizer.pyx: -------------------------------------------------------------------------------- 1 | #!python 2 | # cython: boundscheck=False 3 | # cython: initializedcheck=False 4 | # cython: nonecheck=False 5 | # cython: wraparound=False 6 | # cython: cdivision=True 7 | # cython: profile=False 8 | # cython: linetrace=False 9 | # cython: language_level=3 10 | 11 | import datetime 12 | import io 13 | import multiprocessing 14 | import os 15 | import string 16 | import threading 17 | import time 18 | import typing 19 | from queue import Queue 20 | 21 | import jsonpickle 22 | from ftfy import fix_text 23 | from simdjson import Parser 24 | from tokenizers import Regex, Tokenizer 25 | from tokenizers.models import BPE 26 | from tokenizers.pre_tokenizers import Split 27 | from tokenizers.trainers import BpeTrainer 28 | from zstandard import ZstdDecompressor 29 | 30 | # config 31 | DEF PROCESSES = 8 32 | DEF VOCAB_SIZE = 65536UL 33 | DEF PREFETCH = 128 34 | DEF CACHE_CAPACITY = 1UL << 20 35 | DEF BASE_PATH = "/mnt/wolf/" 36 | DEF DOWNLOAD_CACHE_PATH = "/mnt/wolf/" 37 | DEF BASE_URL = 'http://eaidata.bmk.sh/data/pile/train/%s.jsonl.zst' 38 | # https://the-eye.eu/public/AI/pile/train/%s.jsonl.zst 39 | DEF PRINT_INTERVAL = 100000 40 | DEF SPLITS = 30 41 | DEF REMOVE_INTERMEDIATE = False 42 | DEF REMOVE_LAST_INTERMEDIATE = False 43 | DEF STREAM = False # if less than 2TB memory are available 44 | 45 | cdef void log(unicode text, const unsigned char pid, const unsigned char i): 46 | with open(f"{BASE_PATH}log/{pid}.txt", 'a') as f: 47 | f.write(f'Proc: {pid} | Slice: {i} | Time: {datetime.datetime.now()} | {text}\n') 48 | 49 | cdef unicode download_command(const unsigned char i, unicode tmp_name): 50 | url = "http://eaidata.bmk.sh/data/" if i % 2 else "https://the-eye.eu/public/AI/" 51 | return f"wget {url}/pile/train/{i:02d}.jsonl.zst -O {tmp_name}.tmp -t inf --timeout 15 && mv {tmp_name}.tmp {tmp_name}" 52 | 53 | cdef void sleep_till_exists(unicode file_path): 54 | while not os.path.exists(file_path): 55 | time.sleep(300) 56 | 57 | cdef void wait_for_bash(const unsigned char i, const unsigned char pid, unicode start, unicode end, unicode command): 58 | cdef unicode completion = f'{BASE_PATH}done/{pid}.txt' 59 | log(start, pid, i) 60 | os.system(f'{command} && echo 1 > {completion}') 61 | sleep_till_exists(completion) 62 | os.remove(completion) 63 | log(end, pid, i) 64 | 65 | cdef void locked_execution(const unsigned char i, const unsigned char pid, lock: threading.Semaphore, unicode start, 66 | unicode end, unicode command): 67 | lock.acquire() 68 | wait_for_bash(i, pid, start, end, command) 69 | lock.release() 70 | 71 | cdef unsigned char check_files(list paths): 72 | cdef unicode path = "" 73 | for path in paths: 74 | if os.path.exists(path): 75 | return 1 76 | return 0 77 | 78 | cdef void extract(const unsigned char pid, lock: threading.Semaphore): 79 | cdef unicode tmp_name = f"{DOWNLOAD_CACHE_PATH}{pid}" 80 | cdef unicode tmp_zstd = tmp_name + '.zst' 81 | print(f"extract {pid} sleep") 82 | if check_files([tmp_name, tmp_name + '.txt']): 83 | print(f"extract {pid} no") 84 | return 85 | sleep_till_exists(tmp_zstd) 86 | print(f"extract {pid} start") 87 | locked_execution(pid, pid, lock, "Extracting", "Finished extraction", f"unzstd {tmp_zstd} && mv {tmp_name} {tmp_name}.jsonl") 88 | if REMOVE_INTERMEDIATE: 89 | os.remove(tmp_zstd) 90 | 91 | cdef void download(const unsigned char i, const unsigned char pid, lock: threading.Semaphore): 92 | cdef unicode tmp_name = f"{DOWNLOAD_CACHE_PATH}{pid}" 93 | cdef unicode tmp_zstd = tmp_name + '.zst' 94 | if check_files([tmp_zstd] + [tmp_name, tmp_name + '.txt'] * (not STREAM)): 95 | return 96 | locked_execution(i, pid, lock, "Downloading", "Finished download", download_command(i, tmp_zstd)) 97 | 98 | cdef unicode fix_string(bytes byte_line, const unsigned short pid, const unsigned short i, const unsigned long idx, 99 | unsigned long long * total): 100 | cdef unicode out = Parser().parse(byte_line)['text'] 101 | total[0] += len(out) 102 | if idx % PRINT_INTERVAL == 0: 103 | log(f"{total[0]:15,}B", pid, i) 104 | out = fix_text(out) 105 | out = out.replace(' ', '\t') 106 | return out 107 | 108 | cdef void file_generator(queue: Queue, lock: threading.Semaphore, const unsigned char pid): 109 | cdef unicode log_path = f"{BASE_PATH}log/{pid}.txt" 110 | cdef unicode tmp_name = "" 111 | cdef bytes byte_line = b"" 112 | cdef unsigned long long total = 0 113 | cdef unsigned long idx = 0 114 | cdef unsigned char i = 0 115 | stream_reader = ZstdDecompressor().stream_reader 116 | 117 | with open(log_path, 'w') as f: 118 | f.write('') 119 | 120 | for i in range(pid, SPLITS, PROCESSES): 121 | total = 0 122 | tmp_name = f"{DOWNLOAD_CACHE_PATH}{i}.zst" 123 | log("Starting", pid, i) 124 | download(i, pid, lock) 125 | 126 | with open(tmp_name, 'rb') as f: 127 | for idx, byte_line in enumerate(io.BufferedReader(stream_reader(f))): 128 | queue.put(fix_string(byte_line, pid, i, idx, &total)) 129 | if REMOVE_LAST_INTERMEDIATE: 130 | os.remove(tmp_name) 131 | 132 | def iterator(queue: Queue, procs: typing.List[multiprocessing.Process]): 133 | die = False 134 | while True: 135 | try: 136 | yield queue.get(timeout=60) 137 | except: 138 | die = True 139 | for p in procs: 140 | if p.is_alive(): 141 | die = False 142 | break 143 | if die: 144 | break 145 | 146 | cdef jsonl_to_txt(const unsigned short i, lock: threading.Lock): 147 | cdef unicode tmp_name = f"{DOWNLOAD_CACHE_PATH}{i}" 148 | cdef unicode txt_name = tmp_name + '.txt' 149 | cdef bytes byte_line = b"" 150 | cdef unsigned long long total = 0 151 | cdef int idx = 0 152 | print(f"jsonl {i:2d} sleep") 153 | if check_files([tmp_name + '.txt']): 154 | print(f"jsonl {i:2d} no") 155 | return 156 | sleep_till_exists(tmp_name + '.jsonl') 157 | print(f"jsonl {i:2d} start") 158 | 159 | lock.acquire() 160 | with open(txt_name + '.tmp', 'w') as o: 161 | o.write('') 162 | with open(tmp_name + '.jsonl', 'rb', 2 ** 20) as f: 163 | with open(txt_name + '.tmp', 'a', 2 ** 20) as o: 164 | for idx, byte_line in enumerate(f): 165 | o.write(fix_string(byte_line, i, i, idx, &total) + chr(4)) 166 | lock.release() 167 | os.rename(txt_name + '.tmp', txt_name) 168 | if REMOVE_INTERMEDIATE: 169 | os.remove(tmp_name + '.jsonl') 170 | 171 | 172 | cpdef void main(): 173 | cdef tuple procs = tuple() 174 | for path in ('', 'download', 'log', 'done'): 175 | if not os.path.exists(BASE_PATH + path): 176 | os.mkdir(BASE_PATH + path) 177 | if not os.path.exists(DOWNLOAD_CACHE_PATH): 178 | os.mkdir(DOWNLOAD_CACHE_PATH) 179 | 180 | cdef unicode split_chars = string.digits + " \t\n\r\x0b\x0c" 181 | for c in string.punctuation: 182 | split_chars += '\\' + c 183 | regex = Regex(f"""[{split_chars}]|[^{split_chars}]+""") 184 | tokenizer = Tokenizer(BPE(unk_token='\x01', cache_capacity=CACHE_CAPACITY, merges=None, dropout=None)) 185 | tokenizer.pre_tokenizer = Split(regex, 'isolated') 186 | cdef list formatted = [f"{DOWNLOAD_CACHE_PATH}{i}.txt" for i in range(SPLITS)] 187 | trainer = BpeTrainer(special_tokens=[chr(i) for i in range(256)], vocab_size=VOCAB_SIZE) 188 | manager = multiprocessing.Manager() 189 | down_lock = manager.Semaphore(2) 190 | cdef unsigned short files_exist = 1 191 | cdef unicode file = "" 192 | for file in formatted: 193 | files_exist &= not os.path.exists(file) 194 | if STREAM and not files_exist: 195 | queue = manager.Queue(PREFETCH) 196 | 197 | procs = tuple([multiprocessing.Process(target=file_generator, args=(queue, down_lock, i)) 198 | for i in range(PROCESSES)]) 199 | for p in procs: 200 | p.start() 201 | 202 | while queue.qsize() < PREFETCH // 2: 203 | time.sleep(5) 204 | 205 | tokenizer.train_from_iterator(iterator(queue, procs), trainer) 206 | 207 | for p in procs: 208 | p.join() 209 | else: 210 | extract_lock = manager.Semaphore(PROCESSES) 211 | txt_lock = manager.Semaphore(PROCESSES) 212 | 213 | procs = tuple([multiprocessing.Process(target=download, args=(i, i, down_lock)) for i in range(SPLITS)] + 214 | [multiprocessing.Process(target=extract, args=(i, extract_lock)) for i in range(SPLITS)] + 215 | [multiprocessing.Process(target=jsonl_to_txt, args=(i, txt_lock)) for i in range(SPLITS)]) 216 | for p in procs: 217 | p.start() 218 | for p in procs: 219 | p.join() 220 | tokenizer.train(formatted, trainer) 221 | 222 | if REMOVE_LAST_INTERMEDIATE: 223 | for file in formatted: 224 | os.remove(file) 225 | tokenizer.save(".tmp.json") 226 | 227 | with open("tokenizer.json", 'w', errors='ignore') as w, open(".tmp.json", 'r', errors='ignore') as r: 228 | w.write(jsonpickle.dumps(jsonpickle.loads(r.read()), indent=4)) 229 | 230 | os.remove(".tmp.json") 231 | -------------------------------------------------------------------------------- /src/interface.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import random 3 | import time 4 | import typing 5 | 6 | import numpy as np 7 | from transformers import GPT2TokenizerFast 8 | 9 | from .dataclass import ModelParameter 10 | from .utils_core import chunks, color_print 11 | 12 | 13 | def render_video(model_output: typing.List[typing.Tuple[np.ndarray, typing.List[str]]], 14 | count: int, 15 | params: ModelParameter, 16 | save_prefix: str = "", 17 | upscale: int = 4, 18 | line_split: int = 2, 19 | text_color: typing.Tuple[int, int, int] = (255, 0, 255), 20 | text_pos: typing.Tuple[int, int] = (10, 625), 21 | text_size: float = 1.27, 22 | text_thickness: int = 3, 23 | text_line_offset: int = 50, 24 | prompt_sample_color: typing.Tuple[int, int, int] = (0, 128, 255), 25 | prompt_sample_pos: typing.Tuple[int, int] = (50, 50), 26 | ): 27 | import cv2 28 | writer = cv2.VideoWriter(f"{save_prefix}_{count}.avi", cv2.VideoWriter_fourcc(*"MJPG"), 1, 29 | (params.frame_width * upscale * len(model_output), params.frame_height * upscale)) 30 | 31 | for idx in range(len(model_output[0][0])): 32 | frame = [] 33 | for sub_idx in range(len(model_output)): 34 | 35 | sub_frame = model_output[sub_idx][0][idx] 36 | sub_frame = sub_frame * (params.color_quantization_value - 1) 37 | import scipy.ndimage 38 | sub_frame = scipy.ndimage.zoom(sub_frame, (upscale, upscale, 1), order=0) 39 | sub_frame = np.uint8(sub_frame) 40 | cv2.cvtColor(sub_frame, cv2.COLOR_RGB2BGR) 41 | 42 | text = model_output[sub_idx][1] 43 | if text is not None: 44 | for i, _text in enumerate(chunks(text[idx], params.language_token_per_frame // line_split)): 45 | cv2.putText(sub_frame, _text, (text_pos[0], text_pos[1] + text_line_offset * i), 46 | cv2.FONT_HERSHEY_SIMPLEX, text_size, text_color, text_thickness) 47 | 48 | if params.use_autoregressive_sampling: 49 | prompt_sample_text = 'prompt' if idx < params.initial_autoregressive_position else 'sample' 50 | cv2.putText(sub_frame, prompt_sample_text, prompt_sample_pos, cv2.FONT_HERSHEY_SIMPLEX, 51 | text_size, prompt_sample_color, text_thickness) 52 | 53 | frame.append(sub_frame) 54 | 55 | frame = np.concatenate(frame, axis=1) 56 | writer.write(frame) 57 | 58 | writer.release() 59 | 60 | 61 | def process_token_output(token_out: np.ndarray, padding_token: int = -1, do_argmax: bool = True, 62 | bpe_tokenizer: GPT2TokenizerFast = None) -> typing.List[str]: 63 | _shape = token_out.shape 64 | if do_argmax: 65 | voc_size = _shape[3] * _shape[4] if len(_shape) > 4 else _shape[3] 66 | token_out = np.reshape(token_out, newshape=(_shape[0], _shape[1] * _shape[2], voc_size)) 67 | token_out = np.argmax(token_out, axis=2) 68 | else: 69 | token_out = np.reshape(token_out, newshape=(_shape[0], _shape[1] * _shape[2])) 70 | 71 | token_out_str = [] 72 | 73 | for token in token_out: 74 | if padding_token > -1 and padding_token in token: 75 | token = token[:token.tolist().index(padding_token)] 76 | 77 | if bpe_tokenizer is None: 78 | token_out_str.append( 79 | "".join( 80 | chr(tok) if tok > 31 and tok != 127 and tok != 10 else " " 81 | for tok in token 82 | ) 83 | ) 84 | 85 | else: 86 | token_out_str.append(bpe_tokenizer.decode([int(tok) for tok in token])) 87 | 88 | return token_out_str 89 | 90 | 91 | def process_video_output(out_frame: np.ndarray, params: ModelParameter) -> np.ndarray: 92 | out_frame = np.reshape(out_frame, (params.time_patch_size, params.frame_height_patch, params.frame_width_patch, 93 | params.time_patch, params.patch_size, params.patch_size, params.color_channels)) 94 | 95 | out_frame = np.transpose(out_frame, [0, 3, 1, 4, 2, 5, 6]) 96 | out_frame = np.reshape(out_frame, (params.sequence_length, params.frame_height, params.frame_width, 3)) 97 | 98 | return out_frame 99 | 100 | 101 | def gen_sample_fn(params: ModelParameter): 102 | state = {'sample_index': 0} 103 | 104 | def _video_fn(out): 105 | print('sample_idx:', state['sample_index']) 106 | 107 | token_inp = None 108 | token_out = None 109 | render_input = [] 110 | 111 | frame_out = out[0][0] 112 | if params.use_autoregressive_sampling: 113 | frame_out = frame_out[:-1] 114 | 115 | frame_out = process_video_output(frame_out, params) 116 | 117 | if params.use_language: 118 | token_out = process_token_output(out[2][0], params.padding_token, not params.use_autoregressive_sampling) 119 | 120 | if not params.use_autoregressive_sampling: 121 | frame_inp = out[1][0] 122 | frame_inp = frame_inp[1:params.time_patch_size + 1] 123 | frame_inp = process_video_output(frame_inp, params) 124 | 125 | if params.use_language: 126 | token_inp = process_token_output(out[3][0], params.padding_token, False) 127 | 128 | render_input.append((frame_inp, token_inp)) 129 | 130 | render_input.append((frame_out, token_out)) 131 | 132 | render_video(render_input, state['sample_index'], params) 133 | 134 | state['sample_index'] += 1 135 | if state['sample_index'] >= params.num_of_sample: 136 | exit() 137 | 138 | def _text_fn(out): 139 | print('sample_idx:', state['sample_index']) 140 | 141 | bpe_tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') if params.vocab_size != 256 and \ 142 | params.vocab_size > 256 else None 143 | 144 | if params.use_autoregressive_sampling: 145 | 146 | if params.debug_sample: 147 | score = np.int(np.mean(np.equal(out[0][0], out[0][1])) * 100) 148 | print(f"similarity score: {score}%\n") 149 | print([process_token_output(out[0], do_argmax=False, bpe_tokenizer=bpe_tokenizer)[0]]) 150 | print([process_token_output(out[0], do_argmax=False, bpe_tokenizer=bpe_tokenizer)[1]]) 151 | print([process_token_output(out[1], do_argmax=False, bpe_tokenizer=bpe_tokenizer)[0]]) 152 | print('') 153 | 154 | print('\n------\n') 155 | color_print(params, 'Prompt:') 156 | assert params.initial_autoregressive_position > 0 157 | print(process_token_output(out[1][:, :params.initial_autoregressive_position - 1], do_argmax=False, 158 | bpe_tokenizer=bpe_tokenizer)[0]) 159 | color_print(params, 'Output:') 160 | print(process_token_output(out[0][:, params.initial_autoregressive_position:], do_argmax=False, 161 | bpe_tokenizer=bpe_tokenizer)[0].rstrip()) 162 | else: 163 | print('target:') 164 | print(process_token_output(out[1], do_argmax=False, bpe_tokenizer=bpe_tokenizer)[0]) 165 | print('\nsample:') 166 | print(process_token_output(out[0], do_argmax=False, bpe_tokenizer=bpe_tokenizer)[0]) 167 | 168 | state['sample_index'] += 1 169 | if state['sample_index'] >= params.num_of_sample: 170 | exit() 171 | 172 | print('\n') 173 | 174 | return _video_fn if params.model_mode == 'jannet' else _text_fn 175 | 176 | 177 | def get_command_line_input_and_output_fn(params: ModelParameter): 178 | bpe_tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') if params.vocab_size != 256 and \ 179 | params.vocab_size > 256 else None 180 | 181 | samp_temp = params.sampling_temperature 182 | end_iter = params.sequence_length 183 | _iter_pos = [0] 184 | _end_iter = [0] 185 | 186 | def input_fns(): 187 | 188 | while True: 189 | color_print(params, 'Enter Quary:') 190 | query = input() 191 | 192 | if bpe_tokenizer is None: 193 | query = [ord(q) for q in query] 194 | else: 195 | query = bpe_tokenizer.encode(query) 196 | 197 | if len(query) >= params.sequence_length: 198 | color_print(params, f'Query is to long, the maximum number tokens is ' 199 | f'{params.sequence_length}, but you have {len(query)} tokens.') 200 | continue 201 | 202 | iter_pos = len(query) + 1 203 | _iter_pos[0] = iter_pos 204 | 205 | _end_iter[0] = end_iter 206 | 207 | query = query + [0] * (params.sequence_length - len(query)) 208 | query = np.reshape(np.array(query, np.int32), newshape=(1, params.sequence_length, 1)) 209 | break 210 | 211 | return query, np.array([iter_pos], np.int32), \ 212 | np.array([samp_temp], np.float32), np.array([end_iter], np.int32) 213 | 214 | def output_fn(out): 215 | color_print(params, 'Responds:') 216 | print(process_token_output(out[0][:, _iter_pos[0]:][:, :_end_iter[0]], do_argmax=False, 217 | bpe_tokenizer=bpe_tokenizer)[0].rstrip()) 218 | print('') 219 | 220 | return input_fns, output_fn 221 | 222 | 223 | class ContextExhaustedError(ValueError): 224 | pass 225 | 226 | 227 | class InvalidTokenError(ValueError): 228 | pass 229 | 230 | 231 | class InterfaceWrapper: 232 | def __init__(self, params: ModelParameter): 233 | self.params = params 234 | self.manager = multiprocessing.Manager() 235 | self.input_prompt_id = self.manager.Value(int, 0) 236 | self.tpu_input_id = self.manager.Value(int, 0) 237 | self.output_prompt_id = self.manager.Value(int, 0) 238 | self.input = self.manager.dict() 239 | self.output = self.manager.dict() 240 | 241 | def blocked_get(self, inp: dict, key: int): 242 | while key not in inp: 243 | time.sleep(self.params.default_sleep_duration) 244 | return inp.pop(key) 245 | 246 | def increment(self, idx) -> int: 247 | prompt_id = idx.get() 248 | idx.set(prompt_id + 1) 249 | return prompt_id 250 | 251 | def complete(self, query: typing.List[int], temperature: float, response_len: int, debug: bool = False, 252 | asynchronous: bool = False) -> typing.Union[typing.Callable, typing.Tuple[np.array, np.array], 253 | np.array]: 254 | iter_pos = len(query) 255 | 256 | if iter_pos >= self.params.sequence_length: 257 | raise ContextExhaustedError 258 | if query and max(query) >= self.params.vocab_size: 259 | raise InvalidTokenError 260 | 261 | prompt_id = self.increment(self.input_prompt_id) 262 | 263 | query = query + [random.randint(0, self.params.vocab_size - 1) for _ in range((self.params.sequence_length - len(query)))] 264 | query = np.reshape(np.array(query, np.int32), newshape=(1, self.params.sequence_length, 1)) 265 | 266 | self.input[prompt_id] = (query, np.array([iter_pos], np.int32), np.array([temperature], np.float32), 267 | np.array([min(response_len + len(query), self.params.sequence_length)], np.int32)) 268 | 269 | def _result(): 270 | response = self.blocked_get(self.output, prompt_id)[0].astype(np.int64) 271 | out = response[0, iter_pos:].flatten() 272 | return (out, response) if debug else out 273 | 274 | return _result if asynchronous else _result() 275 | 276 | def input_query(self): 277 | return self.blocked_get(self.input, self.increment(self.tpu_input_id)) 278 | 279 | def output_responds(self, out): 280 | self.output[self.increment(self.output_prompt_id)] = out 281 | 282 | 283 | def get_similarity_input_and_output_fn(params: ModelParameter): 284 | interface = InterfaceWrapper(params) 285 | 286 | def run(): 287 | time.sleep(10) 288 | 289 | for idx in range(params.num_of_sample): 290 | query = [random.randint(0, params.vocab_size - 1) for _ in range(min(32, params.sequence_length - 8))] 291 | 292 | out = [interface.complete(query=query, temperature=0.0, response_len=params.sequence_length, debug=True, 293 | asynchronous=True) for _ in range(params.equal_debugging_items_per_check)] 294 | base, *out = [f() for f in out] 295 | 296 | score = float(np.mean([np.mean(np.equal(base, o)) * 100 for o in out])) 297 | print(f"test:{idx} similarity score: {score:6.2f}%\n") 298 | 299 | run = multiprocessing.Process(target=run, daemon=True) 300 | run.start() 301 | 302 | return interface.input_query, interface.output_responds 303 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | "Sub Main" that contains one function to start the training loop. 3 | """ 4 | 5 | import argparse 6 | import re 7 | import time 8 | 9 | import jsonpickle 10 | import mesh_tensorflow as mtf 11 | import numpy as np 12 | import tensorflow as tf2 13 | from tensorflow.python.ops import summary_ops_v2 as summary 14 | from tensorflow.python.tpu.device_assignment import device_assignment 15 | from tensorflow.python.tpu.topology import Topology 16 | from tensorflow_estimator.python.estimator import estimator as estimator_lib 17 | 18 | from .dataclass import ModelParameter 19 | from .inputs import dataset, gpt_neo_input 20 | from .interface import gen_sample_fn, get_command_line_input_and_output_fn, get_similarity_input_and_output_fn 21 | from .rest_api import get_api_input_and_output_fn 22 | from .run.run import computation_func 23 | 24 | tf = tf2.compat.v1 25 | tpu = tf.tpu 26 | 27 | 28 | def sample_output_fn(params: ModelParameter): 29 | return None, gen_sample_fn(params) 30 | 31 | 32 | def raise_str(arg: str): 33 | raise ValueError(arg) 34 | 35 | 36 | RUN_MODE_FNS = {'debug_old': sample_output_fn, 37 | 'sample': sample_output_fn, 38 | 'web_api': get_api_input_and_output_fn, 39 | 'debug': get_similarity_input_and_output_fn, 40 | 'query': get_command_line_input_and_output_fn, 41 | 'train': lambda x: raise_str("Train should've been caught by code above. Something is wrong.")} 42 | 43 | 44 | def main(args: argparse.Namespace) -> None: 45 | """ 46 | Given previously captured arguments, this function runs the following steps (in order): 47 | * Load given session_config 48 | * initialize data loader 49 | * create model graph 50 | * start training loop. 51 | :param args: argparse arguments from the parent main function 52 | :return: None 53 | """ 54 | # Setup logging 55 | model_path = args.model if args.model.endswith(".json") else f"session_configs/{args.model}.json" 56 | with open(model_path) as f: 57 | _params = f.read() 58 | _params = jsonpickle.loads(_params) 59 | params = ModelParameter(_params) 60 | params.web_workers = args.workers 61 | params.train = args.run_mode == 'train' 62 | params.debug_sample = args.run_mode == 'debug_old' 63 | params.debug_gradients = args.debug_grad is not None 64 | 65 | # Read params of model 66 | if params.train: 67 | param_dump = jsonpickle.dumps(_params, indent=4) 68 | with tf.io.gfile.GFile(f"{params.model_path}/run_config_{int(time.time())}.json", 'w') as f: 69 | f.write(param_dump) 70 | 71 | params.current_step = int(estimator_lib._load_global_step_from_checkpoint_dir(params.model_path)) 72 | 73 | # If run mode == sample, set the batch size to one 74 | if not params.train: 75 | if params.debug_sample: 76 | params.train_batch_size = 2 77 | params.use_autoregressive_sampling = True 78 | params.sampling_temperature = 0 79 | else: 80 | params.train_batch_size = 1 81 | 82 | params = ModelParameter(params) 83 | 84 | # Fetch appropriate input functions 85 | if params.model_mode == 'jannet': 86 | input_fn = dataset 87 | elif params.model_mode == 'gpt': 88 | input_fn = gpt_neo_input 89 | 90 | # Set params for text only GPT mode. 91 | params.use_language = True 92 | params.use_video = False 93 | params = ModelParameter(params) 94 | 95 | else: 96 | raise ValueError(f"model_mode need to be 'jannet' or 'gpt' {params.model_mode}, " 97 | "is a not supported option.") 98 | 99 | # Add to params: auto_layout, auto_layout_and_mesh_shape, use_tpu, num_cores 100 | mesh_shape = mtf.convert_to_shape(params.mesh_shape) 101 | params.num_cores = mesh_shape.size 102 | # Expand attention types param 103 | 104 | mtf_mesh_shape = mtf.convert_to_shape(params.mesh_shape) 105 | params.layout_rules = mtf.convert_to_layout_rules(params.layout) 106 | 107 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(args.tpu) 108 | session_config = tf.ConfigProto() 109 | session_config.allow_soft_placement = True 110 | tpu_cluster_spec = tpu_cluster_resolver.cluster_spec() 111 | 112 | if tpu_cluster_spec: 113 | session_config.cluster_def.CopyFrom(tpu_cluster_spec.as_cluster_def()) 114 | 115 | with tf.Graph().as_default(): 116 | 117 | with tf.Session(target=tpu_cluster_resolver.master(), config=session_config) as sess: 118 | tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) 119 | 120 | all_devices = sess.list_devices() 121 | 122 | cpus = [] 123 | for d in all_devices: 124 | if d.device_type == 'CPU': 125 | cpus += [re.sub('device:CPU', 'cpu', d.name)] 126 | 127 | cpu_devices = [] 128 | for c in cpus: 129 | m = re.match('/job:(.*)/replica:(.*)/task:(.*)/.*', c) 130 | cpu_devices.append((m.group(1), int(m.group(2)), int(m.group(3)), c)) 131 | 132 | cpu_devices = [_[3] for _ in sorted(cpu_devices)] 133 | params.cpu_devices = [n for n in cpu_devices if 'coordinator' not in n] 134 | 135 | topology = sess.run(tpu.initialize_system()) 136 | topo_object = Topology(serialized=topology) 137 | 138 | params.num_cores = int(np.prod(topo_object.mesh_shape)) 139 | params.num_hosts = int(topo_object.num_tasks) 140 | params.num_cores_per_host = int(params.num_cores // params.num_hosts) 141 | if params.num_cores_per_host != int(topo_object.num_tpus_per_task): 142 | raise ValueError 143 | 144 | params.d_assignment = device_assignment(topology, num_replicas=params.num_cores, 145 | computation_shape=[1, ] * mtf.utils.topology_rank(topology)) 146 | params.mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mtf_mesh_shape, params.layout_rules, 147 | None, params.d_assignment) 148 | 149 | if params.train: 150 | summary_writer = summary.create_file_writer(params.model_path) 151 | with summary_writer.as_default(), (summary.always_record_summaries()): 152 | computation_func(params, 153 | input_fn, 154 | session_config, 155 | tpu_cluster_resolver, 156 | [lambda x: print(f"Current step: {x}")] * params.debug_train_step) 157 | return 158 | 159 | input_fns, output_fn = RUN_MODE_FNS[args.run_mode](params) 160 | 161 | computation_func(params, 162 | input_fn, 163 | session_config, 164 | tpu_cluster_resolver, 165 | [output_fn], 166 | input_fns) 167 | -------------------------------------------------------------------------------- /src/model/activation.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .. import tf_wrapper as tfw 6 | from ..dataclass import BlockArgs 7 | from ..mtf_wrapper import relu as _relu, multiply, einsum, constant, sigmoid as _sigmoid, tanh as _tanh, softplus 8 | from ..utils_core import random_name, scoped 9 | 10 | tf1 = tf.compat.v1 11 | 12 | 13 | class MishForward(mtf.Operation): 14 | def __init__(self, x: mtf.Tensor): 15 | super().__init__([x], name=random_name("mish_forward")) 16 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 17 | 18 | def gradient(self, grad_ys): 19 | return MishBackward(self.inputs[0], grad_ys[0]).outputs 20 | 21 | def lower(self, lowering): 22 | mesh_impl = lowering.mesh_impl(self) 23 | 24 | def slicewise_fn(x): 25 | return tfw.multiply(x, tfw.tanh(tfw.softplus(x))) 26 | 27 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]]) 28 | lowering.set_tensor_lowering(self.outputs[0], y) 29 | 30 | 31 | class MishBackward(mtf.Operation): 32 | def __init__(self, x: mtf.Tensor, dy: mtf.Tensor): 33 | super().__init__([x, dy], name=random_name("mish_backward")) 34 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 35 | 36 | def lower(self, lowering): 37 | mesh_impl = lowering.mesh_impl(self) 38 | 39 | def slicewise_fn(x, dy): 40 | gte = tfw.tanh(tfw.softplus(x)) 41 | gte += 1. - tfw.square(gte) * x * tfw.sigmoid(x) 42 | return tfw.multiply(dy, gte) 43 | 44 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]], lowering.tensors[self.inputs[1]]) 45 | lowering.set_tensor_lowering(self.outputs[0], y) 46 | 47 | 48 | class SiluForward(mtf.Operation): 49 | def __init__(self, x: mtf.Tensor): 50 | super().__init__([x], name=random_name("silu_forward")) 51 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 52 | 53 | def gradient(self, grad_ys): 54 | return SiluBackward(self.inputs[0], grad_ys[0]).outputs 55 | 56 | def lower(self, lowering): 57 | mesh_impl = lowering.mesh_impl(self) 58 | 59 | def slicewise_fn(x): 60 | return tfw.multiply(x, tfw.sigmoid(x)) 61 | 62 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]]) 63 | lowering.set_tensor_lowering(self.outputs[0], y) 64 | 65 | 66 | class SiluBackward(mtf.Operation): 67 | def __init__(self, x: mtf.Tensor, dy: mtf.Tensor): 68 | super().__init__([x, dy], name=random_name("silu_backward")) 69 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 70 | 71 | def lower(self, lowering): 72 | mesh_impl = lowering.mesh_impl(self) 73 | 74 | def slicewise_fn(x, dy): 75 | gte = tfw.sigmoid(x) 76 | return dy * ((x - 1) * gte + 1) 77 | 78 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]], lowering.tensors[self.inputs[1]]) 79 | lowering.set_tensor_lowering(self.outputs[0], y) 80 | 81 | 82 | class LeCunTanhForward(mtf.Operation): 83 | def __init__(self, x: mtf.Tensor): 84 | super().__init__([x], name=random_name("lecun_tanh_forward")) 85 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 86 | 87 | def gradient(self, grad_ys): 88 | return LeCunTanhBackward(self.inputs[0], grad_ys[0]).outputs 89 | 90 | def lower(self, lowering): 91 | mesh_impl = lowering.mesh_impl(self) 92 | 93 | def slicewise_fn(x): 94 | return tfw.tanh(x) + x * 0.1 95 | 96 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]]) 97 | lowering.set_tensor_lowering(self.outputs[0], y) 98 | 99 | 100 | class LeCunTanhBackward(mtf.Operation): 101 | def __init__(self, x: mtf.Tensor, dy: mtf.Tensor): 102 | super().__init__([x, dy], name=random_name("lecun_tanh_backward")) 103 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 104 | 105 | def lower(self, lowering): 106 | mesh_impl = lowering.mesh_impl(self) 107 | 108 | def slicewise_fn(x, dy): 109 | return tfw.multiply(dy, tfw.subtract(1.1, tfw.square(tfw.tanh(x)))) 110 | 111 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]], lowering.tensors[self.inputs[1]]) 112 | lowering.set_tensor_lowering(self.outputs[0], y) 113 | 114 | 115 | class SoftsignForward(mtf.Operation): 116 | def __init__(self, x: mtf.Tensor): 117 | super().__init__([x], name=random_name("softsign_forward")) 118 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 119 | 120 | def gradient(self, grad_ys): 121 | return SoftsignBackward(self.inputs[0], grad_ys[0]).outputs 122 | 123 | def lower(self, lowering): 124 | mesh_impl = lowering.mesh_impl(self) 125 | 126 | def slicewise_fn(x): 127 | return x / (1. + tfw.abs(x)) 128 | 129 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]]) 130 | lowering.set_tensor_lowering(self.outputs[0], y) 131 | 132 | 133 | class SoftsignBackward(mtf.Operation): 134 | def __init__(self, x: mtf.Tensor, dy: mtf.Tensor): 135 | super().__init__([x, dy], name=random_name("softsign_backward")) 136 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 137 | 138 | def lower(self, lowering): 139 | mesh_impl = lowering.mesh_impl(self) 140 | 141 | def slicewise_fn(x, dy): 142 | return dy / tfw.square(1. + tfw.abs(x)) 143 | 144 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]], lowering.tensors[self.inputs[1]]) 145 | lowering.set_tensor_lowering(self.outputs[0], y) 146 | 147 | 148 | def _output0(op): 149 | if not issubclass(op, mtf.Operation): 150 | raise ValueError 151 | 152 | def _wrapped(args: BlockArgs): 153 | return op(args.tensor).outputs[0] 154 | 155 | return _wrapped 156 | 157 | 158 | def _gelu(params, tensor: mtf.Tensor): 159 | return einsum([tensor, _tanh(einsum([tensor, tensor, tensor, constant(params, 0.044715)], 160 | output_shape=tensor.shape) + tensor * np.sqrt(2 / np.pi)) + 1.0, 161 | constant(params, 0.5)], output_shape=tensor.shape) 162 | 163 | 164 | def gelu(args: BlockArgs): 165 | return scoped("gelu", _gelu, args.params, args.tensor) 166 | 167 | 168 | def relu(args: BlockArgs): 169 | return _relu(args.tensor) 170 | 171 | 172 | def sigmoid(args: BlockArgs): 173 | return _sigmoid(args.tensor) 174 | 175 | 176 | def tanh(args: BlockArgs): 177 | return _tanh(args.tensor) 178 | 179 | 180 | def _mtf_mish(tensor: mtf.Tensor): 181 | return multiply(_tanh(softplus(tensor)), tensor) 182 | 183 | 184 | def mtf_mish(args: BlockArgs): 185 | return scoped("mtf_mish", _mtf_mish, args.tensor) 186 | 187 | 188 | ACTIVATIONS = {'relu': relu, 189 | 'sigmoid': sigmoid, 190 | 'tanh': tanh, 191 | 'gelu': gelu, 192 | 'lecun_tanh': _output0(LeCunTanhForward), 193 | 'silu': _output0(SiluForward), 194 | 'mish': _output0(MishForward), 195 | "mtf_mish": mtf_mish, 196 | 'softsign': _output0(SoftsignForward) 197 | } 198 | 199 | 200 | def activate(args: BlockArgs) -> mtf.Tensor: 201 | """ 202 | Call activation function on mtf.Tensor. 203 | """ 204 | for fn_name in args: 205 | if fn_name not in ACTIVATIONS: 206 | continue 207 | return scoped(fn_name, ACTIVATIONS[fn_name], args) 208 | print(f'No activation function found for "{args.name_extras}". Falling back to identity. ' 209 | f'Known functions: {list(ACTIVATIONS.keys())}') 210 | return args.tensor 211 | -------------------------------------------------------------------------------- /src/model/backend.py: -------------------------------------------------------------------------------- 1 | import random 2 | import typing 3 | 4 | import mesh_tensorflow as mtf 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.python.ops import array_ops, gen_linalg_ops, math_ops, random_ops 8 | from tensorflow.python.ops.init_ops import Initializer 9 | 10 | from ..dataclass import BlockArgs, ModelParameter 11 | from ..mtf_wrapper import einsum 12 | from ..utils_core import random_name, scoped 13 | from ..utils_mtf import OPT_DIMS, SHAPE, deduplicate, non_replicated_variable, get_fan_in 14 | 15 | tf1 = tf.compat.v1 16 | 17 | 18 | class OrthogonalInit(Initializer): 19 | def __init__(self, params: ModelParameter, shape: SHAPE, is_last: bool, fan_in_dims: OPT_DIMS = None): 20 | if fan_in_dims is None: 21 | fan_in_dims = [] 22 | self.params = params 23 | self.sizes = [d.size for d in shape] 24 | self.seed = random.randint(0, 2 ** 32) 25 | sizes = [d.size for d in mtf.Shape(shape)] 26 | fan_in = int(np.prod([d.size for d in (get_fan_in(params, shape) if fan_in_dims is None else fan_in_dims)])) 27 | fan_out = np.prod(sizes) // fan_in 28 | self.transpose = transpose = fan_out > fan_in 29 | self.shape = (fan_out, fan_in) if transpose else (fan_in, fan_out) 30 | self.scale_by_depth = self.params.scale_by_depth and is_last 31 | 32 | def __call__(self, shape, dtype=None, partition_info=None): 33 | q, r = gen_linalg_ops.qr(random_ops.random_normal(self.shape, dtype=tf.float32, seed=self.seed)) 34 | q *= math_ops.sign(array_ops.diag_part(r)) 35 | if self.transpose: 36 | q = array_ops.matrix_transpose(q) 37 | out = array_ops.reshape(q, self.sizes) 38 | if self.scale_by_depth: 39 | out /= self.params.depth ** 0.5 40 | return tf.cast(out, dtype) 41 | 42 | 43 | def get_var(args: BlockArgs, shape: SHAPE, initializer: Initializer) -> mtf.Tensor: 44 | params: ModelParameter = args.params 45 | 46 | def _var(): 47 | return non_replicated_variable(params, random_name("get_variable"), shape, initializer, True, 48 | params.variable_dtype) 49 | 50 | if "shared" not in args: 51 | return _var() 52 | 53 | name = tf1.get_variable_scope().name 54 | scope = name.split('/') 55 | body_idx = scope.index("body0v") + 1 56 | block, full_fn_name = scope[body_idx:body_idx + 2] 57 | block, config = block.split('_') 58 | first_block = block == '0' 59 | fn_name = ''.join(c for c in full_fn_name if not c.isdigit()) 60 | 61 | cache = params.cached_parameters 62 | for idx in (config, fn_name): 63 | if idx not in cache: 64 | cache[idx] = {} 65 | cache = cache[idx] 66 | 67 | if "counter" not in cache: 68 | cache["counter"] = 0 69 | cache["index"] = 0 70 | cache["seen"] = set() 71 | if idx not in cache["seen"]: 72 | cache["index"] += 1 73 | cache["counter"] += first_block 74 | cache["seen"].add(full_fn_name) 75 | cache["index"] %= cache["counter"] 76 | fn_id = cache["index"] 77 | 78 | if fn_id not in cache: 79 | cache[fn_id] = {} 80 | cache = cache[fn_id] 81 | if "counter" not in cache: 82 | cache["counter"] = 0 83 | 84 | if first_block: 85 | var = _var() 86 | cache[cache["counter"]] = var 87 | cache["counter"] += 1 88 | return var 89 | 90 | if len(cache) == cache['counter'] + 1: 91 | cache["counter"] = 0 92 | var = cache[cache["counter"]] 93 | cache["counter"] += 1 94 | return var 95 | 96 | 97 | def orthogonal_var(args: BlockArgs, shape: typing.Union[typing.List[mtf.Dimension], mtf.Shape], 98 | fan_in_dims: OPT_DIMS = None) -> mtf.Tensor: 99 | shape = deduplicate(shape) 100 | return scoped("orthogonal_var", get_var, args, shape, OrthogonalInit(args.params, shape, args.is_last, fan_in_dims)) 101 | 102 | 103 | def normal_var(args: BlockArgs, shape: SHAPE, stddev: float = 0.02, mean: float = 0.) -> mtf.Tensor: 104 | shape = deduplicate(shape) 105 | return scoped("normal_var", get_var, args, shape, tf.random_normal_initializer(stddev=stddev, mean=mean)) 106 | 107 | 108 | def linear(args: BlockArgs, old: typing.List[mtf.Dimension], new: typing.List[mtf.Dimension]) -> mtf.Tensor: 109 | return einsum([args.tensor, orthogonal_var(args, old + new, old)], 110 | deduplicate((args.tensor.shape - old).dims + new)) 111 | 112 | 113 | def linear_to_features(args: BlockArgs, old: typing.Optional[typing.List[mtf.Dimension]] = None) -> mtf.Tensor: 114 | return linear(args, old, args.params.feature_dims) 115 | 116 | 117 | def linear_from_features(args: BlockArgs, new: typing.Optional[typing.List[mtf.Dimension]] = None) -> mtf.Tensor: 118 | return linear(args, args.params.feature_dims, new) 119 | -------------------------------------------------------------------------------- /src/model/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing 3 | 4 | import mesh_tensorflow as mtf 5 | import tensorflow as tf 6 | 7 | from .activation import activate 8 | from .backend import get_var, linear, orthogonal_var 9 | from .embedding import gather_embed 10 | from .normalization import norm 11 | from ..dataclass import BlockArgs 12 | from ..mtf_wrapper import (dropout as utils_dropout, sigmoid, exp, reduce_max, reduce_sum, einsum, reciprocal, reshape, 13 | multiply, stop_gradient) 14 | from ..utils_mtf import linear_shapes, anonymize_shape, unbind, replace_dim, anonymize_dim 15 | 16 | ATTENTION_DIM = typing.NamedTuple("AttentionDim", (('index', int), ('dim', mtf.Dimension))) 17 | 18 | tf1 = tf.compat.v1 19 | 20 | 21 | def rezero(args: BlockArgs) -> mtf.Tensor: 22 | return args.tensor * get_var(args, [], tf.constant_initializer(0)) 23 | 24 | 25 | def dropout(args: BlockArgs) -> mtf.Tensor: 26 | keep = 1 27 | for extra in args.name_extras: 28 | if extra.startswith('dropout_rate'): 29 | keep = 1 - float(extra[len('dropout_rate'):]) 30 | return utils_dropout(args.tensor, args.params.train, keep) 31 | 32 | 33 | def wrapped_linear(args: BlockArgs) -> mtf.Tensor: 34 | return linear(args, *linear_shapes(args)) 35 | 36 | 37 | def mixture_of_experts(args: BlockArgs) -> mtf.Tensor: 38 | old, new = linear_shapes(args) 39 | gate = linear(args, old, [args.params.expert_dim]) 40 | gate -= mtf.stop_gradient(reduce_max(gate, reduced_dim=args.params.expert_dim)) 41 | gate = exp(gate) 42 | return einsum([reciprocal(reduce_sum(gate, reduced_dim=args.params.expert_dim)), args.tensor, gate, 43 | orthogonal_var(args, old + new + [args.params.expert_dim])], 44 | output_shape=args.tensor.shape - old + new) 45 | 46 | 47 | def activated_linear(args: BlockArgs, prefix: str) -> mtf.Tensor: 48 | args = args([a[len(prefix):] for a in args if a.startswith(prefix)]) 49 | feed_forward_fn = mixture_of_experts if 'mixture_of_experts' in args else wrapped_linear 50 | out = dropout(args(activate(args(feed_forward_fn(args))))) 51 | if 'glu' in args or 'glu_add' in args: 52 | out = multiply(out, sigmoid(feed_forward_fn(args))) 53 | if 'glu_add' in args: 54 | out += activate(args(feed_forward_fn(args))) 55 | if 'norm' in args: 56 | out = norm(args(out)) 57 | return out 58 | 59 | 60 | def activated_linear_in(args: BlockArgs) -> mtf.Tensor: 61 | return activated_linear(args, 'in:') 62 | 63 | 64 | def activated_linear_out(args: BlockArgs) -> mtf.Tensor: 65 | return activated_linear(args, 'out:') 66 | 67 | 68 | def feed_forward(args: BlockArgs) -> mtf.Tensor: 69 | return activated_linear_out(args(activated_linear_in(args))) 70 | 71 | 72 | def group_linear(args: BlockArgs) -> mtf.Tensor: 73 | anonymous_key = anonymize_shape(args.params.feature_dims, args.params.key_dim) 74 | return reshape(linear(args('group'), args.params.feature_dims, anonymous_key), args.tensor.shape) 75 | 76 | 77 | def sum_heads(args: BlockArgs) -> mtf.Tensor: 78 | return reduce_sum(args.tensor, reduced_dim=args.params.head_dim) 79 | 80 | 81 | def transpose_sequence_features(args: BlockArgs) -> mtf.Tensor: 82 | assert args.params.features_per_head == args.params.sequence_length, "ToDo: Support other shapes" 83 | tensor = mtf.rename_dimension(args.tensor, args.params.sequence_dim.name, "intermediate") 84 | tensor = mtf.rename_dimension(tensor, args.params.key_dim.name, args.params.sequence_dim.name) 85 | tensor = mtf.rename_dimension(tensor, "intermediate", args.params.key_dim.name) 86 | return mtf.transpose(tensor, args.tensor.shape) 87 | 88 | 89 | def reduced_half_linear(args: BlockArgs) -> mtf.Tensor: 90 | return group_linear(args(reduce_sum(args.tensor, reduced_dim=args.params.head_dim))) 91 | 92 | 93 | def product_key_memory(args: BlockArgs) -> mtf.Tensor: 94 | anonymous_key = anonymize_dim(args.params.key_dim) 95 | features = [args.params.pkm_dim, anonymous_key] 96 | assignment = linear(args, linear_shapes(args).old, [args.params.head_dim] + features) 97 | assignment = replace_dim(assignment, args.params.key_dim, anonymous_key) # No-op. Just for MTF propagation 98 | assignment = norm(args(assignment), features) 99 | assignment = mtf.cast(assignment, tf.float64) 100 | normalizer = reduce_max(assignment, reduced_dim=args.params.key_dim) 101 | normalizer = reduce_sum(normalizer, reduced_dim=args.params.pkm_dim) 102 | assignment -= stop_gradient(normalizer) 103 | assignment = exp(assignment) 104 | normalizer = reduce_sum(assignment, output_shape=assignment.shape - [args.params.key_dim]) 105 | normalizer = einsum(unbind(normalizer, args.params.pkm_dim), output_shape=normalizer.shape - args.params.pkm_dim) 106 | 107 | val, idx = mtf.top_1(assignment, args.params.key_dim) 108 | idx = mtf.einsum([mtf.cast(exp(math.log(args.params.features_per_head) * 109 | mtf.range(normalizer.mesh, args.params.pkm_dim, dtype=normalizer.dtype)), 110 | tf.int32), idx], output_shape=idx.shape - args.params.pkm_dim) 111 | val = einsum(unbind(val, args.params.pkm_dim), output_shape=val.shape - args.params.pkm_dim) / normalizer 112 | val = mtf.cast(val, args.params.variable_dtype.activation_dtypex) 113 | out = gather_embed(args(idx), [args.params.product_key_value_dim] + args.params.feature_dims, 114 | [args.params.head_dim]) 115 | return out * val 116 | 117 | 118 | def feed_forward_product_key_memory(args: BlockArgs) -> mtf.Tensor: 119 | return product_key_memory(args(activated_linear_in(args))) 120 | 121 | 122 | def bottleneck_group_linear(args: BlockArgs) -> mtf.Tensor: 123 | args = args(activated_linear_in(args)) 124 | args.name_extras.extend(['group', 'mid:group', 'out:group']) 125 | args = args(activated_linear(args, 'mid:')) 126 | return activated_linear_out(args) 127 | -------------------------------------------------------------------------------- /src/model/convolution.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .backend import OrthogonalInit, get_var 6 | from ..dataclass import BlockArgs 7 | from ..utils_core import random_name 8 | from ..utils_mtf import get_attention_dim 9 | 10 | tf1 = tf.compat.v1 11 | 12 | 13 | class ConvolutionForward(mtf.Operation): 14 | def __init__(self, args: BlockArgs, dim: mtf.Dimension, kernel_size: int, masked: bool): 15 | params = args.params 16 | x = args.tensor 17 | shape: mtf.Shape = x.shape 18 | batch = shape.dims[0].size 19 | self.sizes = sizes = [d.size for d in shape] 20 | space_dims = (shape - params.intermediate - params.feature_dims).dims[1:] 21 | dim_index = shape.dims.index(dim) 22 | space_dim_index = space_dims.index(dim) 23 | features = params.mesh_impl.slice_size(mtf.Shape(params.feature_dims)) 24 | 25 | self.weight_size = [features, features] 26 | self.kwargs = {'stride': 1, 'name': random_name('conv'), 'dilations': 1} 27 | 28 | if len(space_dims) == 1: 29 | self.kwargs['data_format'] = 'NWC' 30 | self.weight_size.append(kernel_size) 31 | self.input_size = [s.size for s in shape if s not in params.feature_dims] + [features] 32 | self.conv = tf.nn.conv1d 33 | input2d = self.input_size.copy() 34 | weight2d = self.weight_size.copy() 35 | input2d.insert(2, 1) 36 | weight2d.append(1) 37 | 38 | def back_filter(x, w, dy, **kwargs): 39 | x = tf.reshape(x, input2d) 40 | w = tf.reshape(w, weight2d) 41 | dy = tf.reshape(dy, input2d) 42 | out = tf1.nn.conv2d_backprop_filter(x, w, dy, **kwargs) 43 | return tf.reshape(out, self.input_size) 44 | 45 | def back_input(dy, w, **kwargs): 46 | w = tf.reshape(w, weight2d) 47 | dy = tf.reshape(dy, input2d) 48 | out = tf1.nn.conv2d_backprop_input(dy.shape, w, dy, **kwargs) 49 | return tf.reshape(out, self.input_size) 50 | 51 | self.filter_backprop = back_filter 52 | self.input_backprop = back_input 53 | elif space_dim_index == 0: 54 | self.kwargs['data_format'] = 'NHWC' 55 | self.weight_size.extend([kernel_size, 1]) 56 | self.input_size = [batch, sizes[1], int(np.prod(sizes[2:len(space_dims)])), features] 57 | self.conv = tf.nn.conv2d 58 | self.filter_backprop = tf1.nn.conv2d_backprop_filter 59 | self.input_backprop = tf.nn.conv2d_transpose 60 | elif space_dim_index == len(space_dims) - 1: 61 | self.kwargs['data_format'] = 'NHWC' 62 | self.weight_size.extend([1, kernel_size]) 63 | self.input_size = [batch, int(np.prod(sizes[1:len(space_dims) - 1])), sizes[len(space_dims)], features] 64 | self.conv = tf.nn.conv2d 65 | self.filter_backprop = tf1.nn.conv2d_backprop_filter 66 | self.input_backprop = tf.nn.conv2d_transpose 67 | else: 68 | self.kwargs['data_format'] = 'NDHWC' 69 | self.weight_size.extend([1, kernel_size, 1]) 70 | self.input_size = [batch, int(np.prod(sizes[1:dim_index])), sizes[dim_index], 71 | int(np.prod(sizes[dim_index + 1:len(space_dims)])), features] 72 | self.conv = tf.nn.conv3d 73 | self.filter_backprop = tf1.nn.conv3d_backprop_filter_v2 74 | self.input_backprop = tf.nn.conv3d_transpose 75 | self.kwargs['padding'] = 'SAME' 76 | if masked: 77 | self.kwargs['padding'] = [[w - 1, 0] for w in self.weight_size[2:]] 78 | 79 | fan_in = [mtf.Dimension(chr(i + ord('a')), w) for i, w in enumerate(self.weight_size[1:]) if w != 1] 80 | mtf_weight_size = params.feature_dims 81 | mtf_weight_size.extend(fan_in) 82 | weight = OrthogonalInit(params, mtf_weight_size, args.is_last, fan_in) 83 | super().__init__([x, get_var(args, mtf_weight_size, weight)], name=random_name("conv_forward")) 84 | self._outputs = [mtf.Tensor(self, x.shape, x.dtype)] 85 | self.params = params 86 | 87 | def gradient(self, grad_ys): 88 | return ConvolutionFilterBackward(self).outputs 89 | 90 | def lower(self, lowering): 91 | mesh_impl = lowering.mesh_impl(self) 92 | 93 | def slicewise_fn(x, w): 94 | x = tf.reshape(x, self.input_size) 95 | w = tf.reshape(w, self.weight_size) 96 | out = self.conv(x, w, **self.kwargs) 97 | return tf.reshape(out, self.sizes) 98 | 99 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[self.inputs[0]], lowering.tensors[self.inputs[1]]) 100 | lowering.set_tensor_lowering(self.outputs[0], y) 101 | 102 | 103 | class ConvolutionFilterBackward(mtf.Operation): 104 | def __init__(self, conv: ConvolutionForward): 105 | super().__init__(conv.inputs + conv.outputs, name=random_name("conv_backward")) 106 | self._outputs = conv.inputs 107 | self.conv = conv 108 | 109 | def lower(self, lowering): 110 | mesh_impl = lowering.mesh_impl(self) 111 | conv = self.conv 112 | 113 | def slicewise_fn(x, w, dy): 114 | x = tf.reshape(x, conv.input_size) 115 | w = tf.reshape(w, conv.weight_size) 116 | dy = tf.reshape(dy, conv.input_size) 117 | back_filter = conv.filter_backprop(x, w, dy, **conv.kwargs) 118 | back_input = conv.input_backprop(dy, w, **conv.kwargs) 119 | back_filter = tf.reshape(back_filter, conv.input_size) 120 | back_input = tf.reshape(back_input, conv.input_size) 121 | return back_input, back_filter 122 | 123 | dx, dw = mesh_impl.slicewise(slicewise_fn, *[lowering.tensors[self.inputs[i]] for i in range(3)]) 124 | lowering.set_tensor_lowering(self.outputs[0], dx) 125 | lowering.set_tensor_lowering(self.outputs[1], dw) 126 | 127 | 128 | def convolution(args: BlockArgs): 129 | raise ValueError("Convolution is currently broken") 130 | idx, dim = get_attention_dim(args) 131 | convolution_size = 16 132 | if len(args) > 0 and args[-1].isdigit(): 133 | convolution_size = int(args[-1]) 134 | return ConvolutionForward(args, dim, convolution_size, idx in args.params.masked_attention_dimensions).outputs[0] 135 | -------------------------------------------------------------------------------- /src/model/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing 3 | 4 | import mesh_tensorflow as mtf 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from .backend import normal_var, orthogonal_var 9 | from .. import tf_wrapper as tfw 10 | from ..dataclass import BlockArgs, ModelParameter 11 | from ..mtf_wrapper import einsum, reshape, multiply, zeros_like 12 | from ..utils_core import random_name, scoped 13 | from ..utils_mtf import DIM_LIST, SHAPE, linear_shapes, shape_size 14 | 15 | 16 | def _multi_dim_range_tf(params: ModelParameter, dims: DIM_LIST) -> mtf.Tensor: 17 | out, *items = [tfw.reshape(tfw.tf_range(0, dim.size * size, size), 18 | [1] * idx + [dim.size] + [1] * (len(dims) - idx - 1)) 19 | for idx, (dim, size) in enumerate(zip(dims, np.cumprod([1] + [d.size for d in dims[:-1]])))] 20 | for i in items: 21 | out += i 22 | return tfw.cast(out, params.variable_dtype.activation_dtype) 23 | 24 | 25 | def squeeze(tensor: tf.Tensor, removed_dims: typing.List[int]): 26 | shape = tensor.shape.as_list() 27 | for i, d in enumerate(sorted(removed_dims)): 28 | del shape[d - i] 29 | return tf.reshape(tensor, shape) 30 | 31 | 32 | def unsqueeze(tensor: tf.Tensor, added_dims: typing.List[int]): 33 | shape = tensor.shape.as_list() 34 | for d in sorted(added_dims): 35 | shape.insert(d, 1) 36 | return tf.reshape(tensor, shape) 37 | 38 | 39 | class ScatterAdd(mtf.Operation): 40 | """Assign to one or more variables.""" 41 | 42 | def __init__(self, out: mtf.Tensor, indices: mtf.Tensor, gradient: mtf.Tensor, 43 | squeeze_dims: typing.Optional[SHAPE]): 44 | super().__init__([out, indices, gradient], out.mesh, random_name("sparse_assign")) 45 | if isinstance(squeeze_dims, mtf.Shape): 46 | squeeze_dims = squeeze_dims.dims 47 | if squeeze_dims is None: 48 | squeeze_dims = [] 49 | self.squeeze_dims = squeeze_dims 50 | self.index_dims = [indices.shape.dims.index(dim) for dim in squeeze_dims if dim in indices.shape.dims] 51 | self.embed_dims = [out.shape.dims.index(dim) for dim in squeeze_dims if dim in out.shape.dims] 52 | self.grad_dims = [gradient.shape.dims.index(dim) for dim in squeeze_dims if dim in gradient.shape.dims] 53 | 54 | self.indices = indices 55 | self.grad = gradient 56 | self._outputs = [mtf.Tensor(self, out.shape, out.dtype)] 57 | 58 | def lower(self, lowering): 59 | mesh_impl = lowering.mesh_impl(self) 60 | flattened_dims = 0 61 | 62 | def assign_fn(val: tf.Tensor, indices: tf.Tensor, gradient: tf.Tensor) -> tf.Tensor: 63 | # Val: [1 (Heads), Keys, Features] 64 | # Indices: [Batch, Sequence, 1 (Heads), 1 (Keys)] 65 | # Gradient: [Batch, Sequence, 1 (Heads), Features] 66 | shape = val.shape 67 | val = squeeze(val, self.embed_dims) 68 | indices = squeeze(indices, self.index_dims) 69 | gradient = squeeze(gradient, self.grad_dims) 70 | indices = tf.reshape(indices, indices.shape.as_list() + [1]) 71 | val = tf.reshape(val, val.shape.as_list()[:-flattened_dims] + [-1]) 72 | gradient = tf.cast(tf.reshape(gradient, gradient.shape.as_list()[:-flattened_dims] + [-1]), val.dtype) 73 | return tf.reshape(tf.tensor_scatter_nd_add(val, indices, gradient), shape) 74 | 75 | out, indices, gradients = self.inputs 76 | for flattened_dims, (dim0, dim1) in enumerate(zip((out.shape - self.squeeze_dims).dims[::-1], 77 | (gradients.shape - self.squeeze_dims).dims[::-1]), 0): 78 | if dim0 != dim1: 79 | break 80 | flattened_dims = max(flattened_dims, 1) 81 | y = mesh_impl.slicewise(assign_fn, lowering.tensors[out], lowering.tensors[indices], 82 | lowering.tensors[gradients]) 83 | lowering.set_tensor_lowering(self.outputs[0], y) 84 | 85 | 86 | def scatter_add(out: mtf.Tensor, indices: mtf.Tensor, gradient: mtf.Tensor, squeeze_dims: typing.Optional[SHAPE] = None 87 | ) -> mtf.Tensor: 88 | return ScatterAdd(out, indices, gradient, squeeze_dims).outputs[0] 89 | 90 | 91 | class Gather(mtf.Operation): 92 | def __init__(self, args: BlockArgs, embedding: mtf.Tensor, squeeze_dims: typing.Optional[SHAPE]): 93 | super().__init__([args.tensor, embedding], args.params.mesh, name=random_name("gather")) 94 | if isinstance(squeeze_dims, mtf.Shape): 95 | squeeze_dims = squeeze_dims.dims 96 | if squeeze_dims is None: 97 | squeeze_dims = [] 98 | self.squeeze_dims = squeeze_dims 99 | self.squeezed_index_dims = [args.tensor.shape.dims.index(dim) for dim in squeeze_dims 100 | if dim in args.tensor.shape.dims] 101 | self.squeezed_embed_dims = [embedding.shape.dims.index(dim) for dim in squeeze_dims 102 | if dim in embedding.shape.dims] 103 | out_shape = args.tensor.shape - squeeze_dims + embedding.shape.dims[1:] 104 | self.args = args 105 | self.unsqueezed_dims = [out_shape.dims.index(dim) for dim in squeeze_dims if dim in out_shape.dims] 106 | self._outputs = [mtf.Tensor(self, out_shape, 107 | args.params.variable_dtype.activation_dtype)] 108 | 109 | def gradient(self, grad_ys: typing.List[mtf.Tensor]) -> typing.Tuple[None, mtf.Tensor]: 110 | indices, embedding = self.inputs 111 | return None, scatter_add(zeros_like(embedding), indices, grad_ys[0], self.squeeze_dims) 112 | 113 | def lower(self, lowering: mtf.Lowering): 114 | mesh_impl: mtf.simd_mesh_impl.SimdMeshImpl = lowering.mesh_impl(self) 115 | 116 | indices, embeddings = self.inputs 117 | 118 | def slicewise_fn(idx: tf.Tensor, embd: tf.Tensor) -> tf.Tensor: 119 | idx = squeeze(idx, self.squeezed_index_dims) 120 | embd = squeeze(embd, self.squeezed_embed_dims) 121 | out = tf.gather(embd, idx, axis=0) 122 | return unsqueeze(out, self.unsqueezed_dims) 123 | 124 | y = mesh_impl.slicewise(slicewise_fn, lowering.tensors[indices], lowering.tensors[embeddings]) 125 | lowering.set_tensor_lowering(self.outputs[0], y) 126 | 127 | 128 | class RelativeEmbeddingForward(mtf.Operation): 129 | def __init__(self, args: BlockArgs, shape: SHAPE): 130 | super().__init__([], args.params.mesh, name=random_name("rel_embed")) 131 | if isinstance(shape, list): 132 | shape = mtf.Shape(shape) 133 | self.args = args 134 | self.shape = shape 135 | self._outputs = [mtf.Tensor(self, shape, args.params.variable_dtype.activation_dtype)] 136 | 137 | def has_gradient(self): 138 | return False 139 | 140 | def lower(self, lowering: mtf.Lowering): 141 | mesh_impl: mtf.simd_mesh_impl.SimdMeshImpl = lowering.mesh_impl(self) 142 | 143 | params = self.args.params 144 | shape = self.shape 145 | 146 | position_dims: SHAPE = (shape - params.feature_dims) - params.intermediate 147 | feature_dims = linear_shapes(self.args).old 148 | position_count = shape_size(position_dims) 149 | 150 | cosine = 'cosine' in params.position_embedding 151 | 152 | shape_formula = ''.join(chr(ord('a') + i) for i in range(shape.ndims)) 153 | position_formula = ''.join(shape_formula[shape.dims.index(d)] for d in position_dims) 154 | feature_formula = ''.join(shape_formula[shape.dims.index(d)] for d in feature_dims) 155 | 156 | positions = _multi_dim_range_tf(params, position_dims) 157 | features = _multi_dim_range_tf(params, feature_dims) 158 | additive = 0 159 | feature_count = shape_size(feature_dims) 160 | 161 | if cosine: 162 | additive = tfw.mod(features, 2) 163 | features = (features - additive) / 2 164 | additive = additive * math.pi 165 | feature_count /= 2 166 | 167 | features += 4 / feature_count 168 | features -= math.log(position_count / 2 / math.pi) 169 | features = tfw.exp(features) + additive 170 | out = tfw.einsum(f'{position_formula},{feature_formula}->{shape_formula}', positions, features) 171 | out = multiply(tfw.sin(out), params.embedding_stddev) 172 | lowering.set_tensor_lowering(self.outputs[0], mesh_impl.import_tf_tensor(self.outputs[0], out)) 173 | 174 | 175 | def _embed_var(args: BlockArgs, shape: SHAPE) -> mtf.Tensor: 176 | if 'orthogonal' in args: 177 | return orthogonal_var(args, shape) 178 | return normal_var(args, shape, args.params.embedding_stddev) 179 | 180 | 181 | def _embed(args: BlockArgs, shape: SHAPE) -> mtf.Tensor: 182 | if isinstance(shape, (list, tuple)): 183 | shape = mtf.Shape(shape) 184 | 185 | variables = [] 186 | position_dims: mtf.Shape = (shape - args.params.feature_dims) - args.params.intermediate 187 | feature_dims = linear_shapes(args).old 188 | 189 | if 'absolute' in args: 190 | out = _embed_var(args, shape) 191 | elif 'axial' in args: 192 | splits = 2 193 | for a in args: 194 | if a.isdigit(): 195 | splits = int(a) 196 | break 197 | tmp_dims = [] 198 | variables = [] 199 | 200 | def _new_part(size: int): 201 | tmp = mtf.Dimension(f'_{len(tmp_dims)}', size) 202 | tmp_dims.append(tmp) 203 | variables.append(_embed_var(args, [tmp] + feature_dims)) 204 | 205 | for dim in position_dims: 206 | base = int(dim.size ** (1 / splits)) 207 | while dim.size % base != 0: 208 | base -= 1 209 | final = dim.size // base ** (splits - 1) 210 | _new_part(final) 211 | for i in range(1, splits): 212 | _new_part(base) 213 | out = reshape(einsum(variables, output_shape=tmp_dims + feature_dims), shape) 214 | 215 | elif 'relative' in args: 216 | out = RelativeEmbeddingForward(args, shape).outputs[0] 217 | if 'learned' in args: 218 | out = multiply(out, _embed_var(args, feature_dims)) 219 | else: 220 | raise ValueError("The following embeddings are supported:" 221 | " relative(-learned) or absolute(-split) or axial(-split) are supported") 222 | 223 | return out 224 | 225 | 226 | def embed(args: BlockArgs, shape: SHAPE) -> mtf.Tensor: 227 | return scoped('embed', _embed, args, shape) 228 | 229 | 230 | def gather_embed(args: BlockArgs, shape: SHAPE, squeezed_dims: typing.Optional[SHAPE] = None) -> mtf.Tensor: 231 | return Gather(args, scoped("gather", embed, args, shape), squeezed_dims).outputs[0] 232 | -------------------------------------------------------------------------------- /src/model/frontend.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf 5 | 6 | from .activation import activate 7 | from .basic import dropout, feed_forward, rezero, group_linear, feed_forward_product_key_memory, reduced_half_linear, \ 8 | transpose_sequence_features, sum_heads, bottleneck_group_linear, product_key_memory 9 | from .convolution import convolution 10 | from .normalization import norm 11 | from .spatial import attention, cummean, cumsum 12 | from ..dataclass import BlockArgs, BlockConfig, ModelParameter 13 | from ..mtf_wrapper import add, multiply 14 | from ..utils_core import scoped 15 | 16 | ATTENTION_DIM = typing.NamedTuple("AttentionDim", (('index', int), ('dim', mtf.Dimension))) 17 | 18 | tf1 = tf.compat.v1 19 | 20 | 21 | def _get_block_part(block_part_config: BlockConfig, params: ModelParameter, block_input: mtf.Tensor) -> mtf.Tensor: 22 | out = block_input 23 | 24 | for idx, layer in enumerate(block_part_config.layer, 1): 25 | name, *extras = layer.split('-') 26 | args = BlockArgs(params, out, extras, idx == len(block_part_config.layer)) 27 | out = scoped(name + '_', LAYER_FUNCTIONS[name], args) 28 | 29 | if block_part_config.skip and block_part_config.memory_reduction_strategy in ("none", "checkpoint"): 30 | out += block_input 31 | return out 32 | 33 | 34 | def block_part_fn(params: ModelParameter, block_part_config: BlockConfig, block_input: mtf.Tensor, 35 | name_prefix: str = 'block') -> mtf.Tensor: 36 | return scoped(f"{name_prefix}_", _get_block_part, block_part_config, params, block_input) 37 | 38 | 39 | def split_path(args: BlockArgs) -> mtf.Tensor: 40 | base, *name_extras = [path for path in '-'.join(args.name_extras).split(';')] 41 | base = base.split('-') 42 | if 'add' in base: 43 | out = 0 44 | fn = add 45 | elif 'multiply' in base: 46 | out = 1 47 | fn = multiply 48 | else: 49 | raise ValueError 50 | 51 | for idx, conf in enumerate(name_extras): 52 | out = fn(out, _get_block_part(BlockConfig({'skip': False, 'layer': conf.split(',')}, ''), 53 | args.params, args.tensor)) 54 | 55 | return out 56 | 57 | 58 | LAYER_FUNCTIONS = {'feed_forward': feed_forward, 59 | 'attention': attention, 60 | 'cummean': cummean, 61 | 'cumsum': cumsum, 62 | 'norm': norm, 63 | 'rezero': rezero, 64 | 'activation': activate, 65 | 'convolution': convolution, 66 | 'dropout': dropout, 67 | 'group_linear': group_linear, 68 | 'split_path': split_path, 69 | 'feed_forward_product_key_memory': feed_forward_product_key_memory, 70 | 'product_key_memory': product_key_memory, 71 | 'reduced_half_linear': reduced_half_linear, 72 | 'transpose_sequence_features': transpose_sequence_features, 73 | 'bottleneck_group_linear': bottleneck_group_linear, 74 | 'sum_heads': sum_heads 75 | } 76 | -------------------------------------------------------------------------------- /src/model/momentumnet.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf 5 | 6 | from .frontend import block_part_fn 7 | from ..dataclass import ModelParameter 8 | from ..utils_core import random_name 9 | from ..utils_mtf import gradient_iterator 10 | 11 | tf1 = tf.compat.v1 12 | 13 | 14 | class MomentumOperation(mtf.Operation): 15 | """Operation to implement custom gradients. 16 | 17 | See comments on custom_gradient() below. 18 | """ 19 | 20 | def __init__(self, params, block_config, x, x_backwards, v, v_backwards, index): 21 | graph: mtf.Graph = x.graph 22 | prev_ops = len(graph.operations) 23 | new_v = v * params.momentumnet_alpha 24 | fx = block_part_fn(params, block_config, x, index) 25 | new_v += fx * (1 - params.momentumnet_alpha) 26 | new_x = x + new_v 27 | fn_outputs = [new_x, x_backwards, new_v, v_backwards] 28 | forward_operations = graph.operations[prev_ops:] 29 | new_outputs = set() 30 | new_inputs = set() 31 | for op in forward_operations: 32 | new_inputs.update(set(op.inputs)) 33 | if not isinstance(op, mtf.Variable): 34 | new_outputs.update(set(op.outputs)) 35 | explicit_inputs = [x, x_backwards, v, v_backwards] 36 | variables = [t for t in list(new_inputs - new_outputs - set(explicit_inputs)) if t.dtype.is_floating] 37 | super(MomentumOperation, self).__init__(explicit_inputs + variables + fn_outputs, params.mesh, 38 | random_name("custom_gradient")) 39 | # Make sure no one uses the internals of this function, since the gradients 40 | # will probably not work correctly. 41 | for t in new_outputs - set(fn_outputs): 42 | t.usable = False 43 | 44 | self._graph: mtf.Graph = x.graph 45 | self._x: mtf.Tensor = x 46 | self.params: ModelParameter = params 47 | self._v: mtf.Tensor = v 48 | self._fx: mtf.Tensor = fx 49 | self._variables: typing.List[mtf.Variable] = variables 50 | self._fn_outputs: typing.List[mtf.Tensor] = fn_outputs 51 | self._outputs: typing.List[mtf.Tensor] = [mtf.Tensor(self, x.shape, x.dtype, index=i) 52 | for i, x in enumerate(fn_outputs)] 53 | self._forward_operations = forward_operations 54 | 55 | def lower(self, lowering): 56 | for fn_output, output in zip(self._fn_outputs, self._outputs): 57 | lowering.set_tensor_lowering(output, lowering.tensors[fn_output]) 58 | 59 | def gradient(self, grad_ys: typing.List[mtf.Tensor], 60 | params: typing.Optional[typing.List[mtf.Operation]] = None): 61 | dx, x_backwards, dv, v_backwards = grad_ys 62 | x: mtf.Tensor = self._x if x_backwards is None else x_backwards 63 | v: mtf.Tensor = self._v if v_backwards is None else v_backwards 64 | f_again_ops, mapping = self._graph.clone_operations(self._forward_operations, {self._x: x}) 65 | # figure out what Tensors are downstream of xs 66 | downstream = set([x] + self._variables) 67 | for op in f_again_ops: 68 | if op.has_gradient and set(op.inputs) & downstream: 69 | downstream |= set(op.outputs) 70 | fx = mapping[self._fx] 71 | tensor_to_gradient = {mapping[self.outputs[0]]: dx, mapping[self.outputs[2]]: dv} 72 | if params is None: 73 | with tf1.variable_scope(fx.graph.captured_variable_scope): 74 | for op in f_again_ops[::-1]: 75 | grad_outputs = [tensor_to_gradient.get(out) for out in op.outputs] 76 | if not op.has_gradient or not any(grad_outputs) or not set(op.inputs) & downstream: 77 | continue 78 | with tf1.variable_scope(op.name + "/momentumnet/gradients"): 79 | for inner_op, inp, grad in gradient_iterator(self.params, op, grad_outputs): 80 | if inp not in downstream or grad is None: 81 | continue 82 | if inp in tensor_to_gradient: 83 | tensor_to_gradient[inp] += grad 84 | else: 85 | tensor_to_gradient[inp] = grad 86 | yield dx + (1 - self.params.momentumnet_alpha) * tensor_to_gradient[x] 87 | yield x - v 88 | yield (dx + dv) * self.params.momentumnet_alpha 89 | yield (v - (1 - self.params.momentumnet_alpha) * fx) / self.params.momentumnet_alpha 90 | yield from (tensor_to_gradient.get(x) for x in self._variables) 91 | return 92 | tensor_to_gradient = {mapping[self.outputs[0]]: [0, 0, dx], mapping[self.outputs[2]]: [0, 0, dv]} 93 | yield self, params[1], x - v 94 | yield self, params[2], (dx + dv) * self.params.momentumnet_alpha 95 | yield self, params[3], (v - (1 - self.params.momentumnet_alpha) * fx) / self.params.momentumnet_alpha 96 | with tf1.variable_scope(fx.graph.captured_variable_scope): 97 | for op in f_again_ops[::-1]: 98 | grad_outputs = [] 99 | for out in op.outputs: 100 | grad = tensor_to_gradient.get(out) 101 | if grad is None: 102 | grad_outputs.append(None) 103 | continue 104 | grad_outputs.append(grad[2]) 105 | grad[0] += 1 106 | if grad[0] == len(grad[2].operation.inputs): 107 | del tensor_to_gradient[out] 108 | if not op.has_gradient or not any(grad_outputs) or not set(op.inputs) & downstream: 109 | continue 110 | for inner_op, inp, grad in gradient_iterator(self.params, op, grad_outputs): 111 | if inp not in downstream or grad is None: 112 | continue 113 | if inp in tensor_to_gradient: 114 | grad_list = tensor_to_gradient[inp] 115 | grad_list[1] += 1 116 | with tf1.variable_scope(op.name + "/momentumnet/gradients"): 117 | grad_list[2] += grad 118 | else: 119 | tensor_to_gradient[inp] = grad_list = [0, 1, grad] 120 | if len(inp.operation.outputs) != grad_list[1]: 121 | continue 122 | if inp not in self._variables: 123 | continue 124 | yield inner_op, params[4 + self._variables.index(inp)], grad_list[2] 125 | yield self, params[2], dx + (1 - self.params.momentumnet_alpha) * tensor_to_gradient[x] 126 | -------------------------------------------------------------------------------- /src/model/normalization.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import numpy as np 5 | import tensorflow as tf 6 | from scipy.special import erfinv 7 | 8 | from .backend import normal_var, SHAPE 9 | from ..dataclass import BlockArgs 10 | from ..mtf_wrapper import einsum, reduce_mean, rsqrt_eps, square 11 | from ..utils_mtf import linear_shapes 12 | 13 | tf1 = tf.compat.v1 14 | 15 | 16 | def uniformly_sampled_gaussian(num_rand, dtype): 17 | rand = 2 * (np.arange(num_rand) + 0.5) / float(num_rand) - 1 18 | rand = np.sqrt(2) * erfinv(rand) 19 | return tf.constant(rand, dtype=dtype) 20 | 21 | 22 | def norm(args: BlockArgs, feature_shape: typing.Optional[SHAPE] = None) -> mtf.Tensor: 23 | block_input = args.tensor 24 | feature_shape = mtf.Shape(linear_shapes(args).old if feature_shape is None else feature_shape) 25 | normalized_shape = block_input.shape - (feature_shape - [args.params.head_dim] * ('group' in args)) 26 | 27 | block_input -= reduce_mean(block_input, output_shape=normalized_shape) 28 | scale = [rsqrt_eps(reduce_mean(square(block_input), output_shape=normalized_shape), 1e-5), block_input] 29 | if 'scale' in args: 30 | scale.append(normal_var(args, feature_shape, mean=1)) 31 | block_input = einsum(scale, output_shape=block_input.shape) 32 | if 'shift' in args: 33 | block_input += normal_var(args, feature_shape, mean=0) 34 | return block_input 35 | -------------------------------------------------------------------------------- /src/model/revnet.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf 5 | 6 | from .frontend import block_part_fn 7 | from ..mtf_wrapper import add 8 | from ..utils_core import random_name 9 | from ..utils_mtf import gradient_iterator 10 | 11 | tf1 = tf.compat.v1 12 | 13 | 14 | class RevGradOp(mtf.Operation): 15 | """Operation to implement custom gradients. 16 | 17 | See comments on custom_gradient() below. 18 | """ 19 | 20 | def __init__(self, params, block_config, x1, x1_backwards, x2, x2_backwards, index): 21 | graph: mtf.Graph = x1.graph 22 | prev_ops = len(graph.operations) 23 | y1 = x1 + block_part_fn(params, block_config, x2, index) 24 | fn_outputs = [x2, x2_backwards, y1, x1_backwards] 25 | forward_operations = graph.operations[prev_ops:] 26 | new_outputs = set() 27 | new_inputs = set() 28 | for op in forward_operations: 29 | new_inputs.update(set(op.inputs)) 30 | if not isinstance(op, mtf.Variable): 31 | new_outputs.update(set(op.outputs)) 32 | explicit_inputs = [x1, x1_backwards, x2, x2_backwards] 33 | variables = [t for t in list(new_inputs - new_outputs - set(explicit_inputs)) if t.dtype.is_floating] 34 | super(RevGradOp, self).__init__(explicit_inputs + variables + fn_outputs, x1.mesh, 35 | random_name("custom_gradient")) 36 | # Make sure no one uses the internals of this function, since the gradients 37 | # will probably not work correctly. 38 | for t in new_outputs - set(fn_outputs): 39 | t.usable = False 40 | 41 | self._graph: mtf.Graph = x1.graph 42 | self.params = params 43 | self._x2: mtf.Tensor = x2 44 | self._y1: mtf.Tensor = y1 45 | self._variables: typing.List[mtf.Variable] = variables 46 | self._fn_outputs: typing.List[mtf.Tensor] = fn_outputs 47 | self._outputs: typing.List[mtf.Tensor] = [mtf.Tensor(self, x.shape, x.dtype, index=i) 48 | for i, x in enumerate(fn_outputs)] 49 | self._forward_operations = forward_operations[:-1] 50 | 51 | def lower(self, lowering): 52 | for fn_output, output in zip(self._fn_outputs, self._outputs): 53 | lowering.set_tensor_lowering(output, lowering.tensors[fn_output]) 54 | 55 | def gradient(self, grad_ys, params: typing.Optional[typing.List[mtf.Operation]] = None 56 | ) -> typing.Iterable[typing.Tuple[mtf.Operation, mtf.Tensor, mtf.Tensor]]: 57 | dy2, dy2_backwards, dy1, dy1_backwards = grad_ys 58 | x2 = self._x2 if dy2_backwards is None else dy2_backwards 59 | f_again_ops, mapping = self._graph.clone_operations(self._forward_operations, {self._x2: x2}) 60 | fx2 = mapping[self._forward_operations[-1].outputs[0]] 61 | # figure out what Tensors are downstream of xs 62 | downstream = set([x2] + self._variables) 63 | for op in f_again_ops: 64 | if op.has_gradient and set(op.inputs) & downstream: 65 | downstream |= set(op.outputs) 66 | tensor_to_gradient = {fx2: dy1} 67 | if params is None: 68 | yield dy1 69 | yield (self._y1 if dy1_backwards is None else dy1_backwards) - fx2 70 | with tf1.variable_scope(fx2.graph.captured_variable_scope): 71 | for op in f_again_ops[::-1]: 72 | grad_outputs = [tensor_to_gradient.get(out) for out in op.outputs] 73 | if not op.has_gradient or not any(grad_outputs) or not set(op.inputs) & downstream: 74 | continue 75 | with tf1.variable_scope(op.name + "/revnet/gradients"): 76 | for inner_op, inp, grad in gradient_iterator(self.params, op, grad_outputs): 77 | if inp not in downstream or grad is None: 78 | continue 79 | if inp in tensor_to_gradient: 80 | tensor_to_gradient[inp] += grad 81 | else: 82 | tensor_to_gradient[inp] = grad 83 | yield add(dy2, tensor_to_gradient[x2]) 84 | yield x2 85 | yield from (tensor_to_gradient.get(x) for x in self._variables) 86 | return 87 | tensor_to_gradient = {fx2: [0, 0, dy1]} 88 | yield self, params[0], dy1 89 | yield self, params[1], (self._y1 if dy1_backwards is None else dy1_backwards) - fx2 90 | yield self, params[3], x2 91 | with tf1.variable_scope(fx2.graph.captured_variable_scope): 92 | for op in f_again_ops[::-1]: 93 | grad_outputs = [] 94 | for out in op.outputs: 95 | grad = tensor_to_gradient.get(out) 96 | if grad is None: 97 | grad_outputs.append(None) 98 | continue 99 | grad_outputs.append(grad[2]) 100 | grad[0] += 1 101 | if grad[0] == len(grad[2].operation.inputs): 102 | del tensor_to_gradient[out] 103 | if not op.has_gradient or not any(grad_outputs) or not set(op.inputs) & downstream: 104 | continue 105 | for inner_op, inp, grad in gradient_iterator(self.params, op, grad_outputs): 106 | if inp not in downstream or grad is None: 107 | continue 108 | if inp in tensor_to_gradient: 109 | grad_list = tensor_to_gradient[inp] 110 | grad_list[1] += 1 111 | with tf1.variable_scope(op.name + "/revnet/gradients"): 112 | grad_list[2] = add(grad_list[2], grad) 113 | else: 114 | tensor_to_gradient[inp] = grad_list = [0, 1, grad] 115 | if len(inp.operation.outputs) != grad_list[1]: 116 | continue 117 | if inp not in self._variables: 118 | continue 119 | yield inner_op, params[4 + self._variables.index(inp)], grad_list[2] 120 | yield self, params[2], add(dy2, tensor_to_gradient[x2][2]) 121 | -------------------------------------------------------------------------------- /src/model/spatial.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf 5 | 6 | from .basic import activated_linear_in, activated_linear_out 7 | from .embedding import embed 8 | from ..dataclass import BlockArgs 9 | from ..mtf_wrapper import einsum, greater_equal, multiply, less, exp, reduce_max, reduce_sum 10 | from ..utils_mtf import (anonymize, anonymize_dim, compare_range, get_attention_dim, is_masked, linear_shapes, 11 | random_name) 12 | 13 | ATTENTION_DIM = typing.NamedTuple("AttentionDim", (('index', int), ('dim', mtf.Dimension))) 14 | 15 | tf1 = tf.compat.v1 16 | 17 | 18 | def _masked_map(args: BlockArgs) -> typing.Tuple[mtf.Tensor, typing.Union[mtf.Tensor, int]]: 19 | dim = get_attention_dim(args).dim 20 | tmp = anonymize_dim(dim) 21 | bias = embed(args, [args.params.head_dim, dim, tmp]) 22 | return bias, compare_range(args.params, dim, tmp, greater_equal) if is_masked(args) else 1 23 | 24 | 25 | def _cumsum_grad(dy: mtf.Tensor, dim: int) -> mtf.Tensor: 26 | return mtf.cwise(lambda x: tf.cumsum(x, dim, reverse=True), [dy], 27 | name=random_name('cumsum_grad')) 28 | 29 | 30 | def cumsum(args: BlockArgs) -> mtf.Tensor: 31 | dim = args.tensor.shape.dims.index(get_attention_dim(args).dim) 32 | return mtf.cwise(lambda x: tf.cumsum(x, dim), [args.tensor], name=random_name("cumsum"), 33 | grad_function=lambda _, dy: [_cumsum_grad(dy, dim)]) 34 | 35 | 36 | def cummean(args: BlockArgs) -> mtf.Tensor: 37 | return cumsum(args) / (1 + mtf.range(args.tensor.mesh, get_attention_dim(args).dim, dtype=args.tensor.dtype, 38 | name=random_name("cummean"))) 39 | 40 | 41 | def attention(args: BlockArgs) -> mtf.Tensor: 42 | args.params.attention_idx += 1 43 | if "dot_product" in args or "input_as_value" not in args: 44 | base = args(activated_linear_in(args)) 45 | 46 | dim = get_attention_dim(args).dim 47 | tmp = anonymize_dim(dim) 48 | shape = args.tensor.shape 49 | 50 | logit = 0 51 | val = 0 52 | key = 0 53 | if 'dot_product' in args: 54 | if 'embedded' in args or 'context' in args: 55 | key = activated_linear_out(base) 56 | if 'embedded' in args or 'positional' in args: 57 | key += embed(args, [dim] + args.params.feature_dims) 58 | qry = activated_linear_out(base) 59 | qry *= dim.size ** -0.5 60 | logit_shape = shape - (mtf.Shape(linear_shapes(args).old) - [args.params.head_dim]) + tmp 61 | logit = einsum([qry, anonymize(key, dim)], output_shape=logit_shape) 62 | if "shared_key_value" in args: 63 | val = key 64 | if 'biased_softmax' in args: 65 | logit += multiply(*_masked_map(args)) 66 | if logit != 0: 67 | logit += (compare_range(args.params, dim, tmp, less) * 1e38) * -2 68 | logit -= mtf.stop_gradient(reduce_max(logit, reduced_dim=tmp)) 69 | logit = exp(logit) 70 | logit /= reduce_sum(logit, reduced_dim=tmp) 71 | if 'biased_attention_map' in args: 72 | logit += multiply(*_masked_map(args)) 73 | if 'scale_attention_map' in args: 74 | logit *= multiply(*_masked_map(args)) 75 | if val == 0: 76 | val = anonymize(args.tensor if "input_as_value" in args else activated_linear_out(base), dim) 77 | if not logit: 78 | raise UserWarning(f"WARNING: There is no spatial mixing happening with the following attention parameters: " 79 | f"{args.name_extras}.") 80 | return einsum([logit, val], shape) 81 | -------------------------------------------------------------------------------- /src/mtf_wrapper.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf 5 | 6 | from .dataclass import ModelParameter 7 | from .utils_core import scoped 8 | 9 | tf1 = tf.compat.v1 10 | _NAME_INDEX = [0] 11 | 12 | DIM = typing.Union[mtf.Dimension, str] 13 | DIM_LIST = typing.List[mtf.Dimension] 14 | SHAPE = typing.Union[mtf.Shape, DIM_LIST] 15 | TENSORS = typing.List[mtf.Tensor] 16 | OPT_SHAPE = typing.Optional[SHAPE] 17 | OPT_DIMS = typing.Optional[DIM_LIST] 18 | OPT_DIM = typing.Optional[mtf.Dimension] 19 | 20 | 21 | def get_variable_for_tensor(tensor: typing.Union[mtf.Tensor, mtf.Variable]) -> mtf.Variable: 22 | if isinstance(tensor, mtf.Variable): 23 | return tensor 24 | op: mtf.Operation = tensor.operation 25 | while not isinstance(op, mtf.Variable): 26 | value: mtf.Tensor = op.inputs[0] 27 | op: mtf.Variable = value.operation 28 | return op 29 | 30 | 31 | def einsum(xs: TENSORS, output_shape: OPT_SHAPE = None, reduced_dims: OPT_DIMS = None) -> mtf.Tensor: 32 | return scoped("einsum", mtf.einsum, xs, output_shape, reduced_dims) 33 | 34 | 35 | def one_hot(indices: mtf.Tensor, output_dim: mtf.Dimension, on_value: float = 1.0, off_value: float = 0.0, 36 | dtype: tf.dtypes = tf.float32) -> mtf.Tensor: 37 | return scoped("one_hot", mtf.one_hot, indices, output_dim, on_value, off_value, dtype) 38 | 39 | 40 | def reduce_mean(tensor: mtf.Tensor, output_shape: OPT_SHAPE = None, reduced_dim: OPT_DIM = None) -> mtf.Tensor: 41 | return scoped("reduce_mean", mtf.reduce_mean, tensor, None, output_shape, reduced_dim) 42 | 43 | 44 | def reduce_sum(tensor: mtf.Tensor, output_shape: OPT_SHAPE = None, reduced_dim: OPT_DIM = None) -> mtf.Tensor: 45 | return scoped("reduce_sum", mtf.reduce_sum, tensor, None, output_shape, reduced_dim) 46 | 47 | 48 | def reduce_max(tensor: mtf.Tensor, output_shape: OPT_SHAPE = None, reduced_dim: OPT_DIM = None) -> mtf.Tensor: 49 | return scoped("reduce_max", mtf.reduce_max, tensor, None, output_shape, reduced_dim) 50 | 51 | 52 | def reduce_logsumexp(tensor: mtf.Tensor, reduced_dim: OPT_DIM = None) -> mtf.Tensor: 53 | return scoped("reduce_logsumexp", mtf.reduce_logsumexp, tensor, reduced_dim) 54 | 55 | 56 | def recompute_grad(fn: typing.Callable, explicit_inputs: typing.List[mtf.Tensor]) -> mtf.Tensor: 57 | return scoped("recompute_grad", mtf.recompute_grad, fn, explicit_inputs) 58 | 59 | 60 | def stop_gradient(tensor: mtf.Tensor): 61 | return scoped("stop_gradient", mtf.stop_gradient, tensor) 62 | 63 | 64 | def _softmax_cross_entropy_with_logits(params: ModelParameter, logits: mtf.Tensor, targets: mtf.Tensor): 65 | max_logit = reduce_max(stop_gradient(logits), reduced_dim=params.vocab_dim) 66 | log_z = log(reduce_sum(exp(logits - max_logit), reduced_dim=params.vocab_dim)) + max_logit 67 | loss = einsum([logits - log_z, one_hot(targets, params.vocab_dim, dtype=logits.dtype), 68 | constant_scalar(params, -1 / targets.size)], output_shape=[]) 69 | if params.z_loss: 70 | loss += einsum([log_z, log_z, constant_scalar(params, params.z_loss / targets.size)], output_shape=[]) 71 | return loss 72 | 73 | 74 | def softmax_cross_entropy_with_logits(params: ModelParameter, logits: mtf.Tensor, targets: mtf.Tensor) -> mtf.Tensor: 75 | return scoped("softmax_cross_entropy_with_logits", _softmax_cross_entropy_with_logits, params, logits, targets) 76 | 77 | 78 | def import_laid_out_tensor(params: ModelParameter, laid_out_tensor: object, shape: SHAPE, 79 | name: typing.Optional[str] = None): 80 | return scoped("import_laid_out_tensor", mtf.import_laid_out_tensor, params.mesh, laid_out_tensor, shape, name) 81 | 82 | 83 | def import_fully_replicated(params: ModelParameter, laid_out_tensor: tf.Tensor, shape: SHAPE, 84 | name: typing.Optional[str] = None): 85 | return scoped("import_fully_replicated", mtf.import_fully_replicated, params.mesh, laid_out_tensor, shape, name) 86 | 87 | 88 | def logical_not(tensor: mtf.Tensor): 89 | return scoped("logical_not", mtf.logical_not, tensor) 90 | 91 | 92 | def logical_and(x1: mtf.Tensor, x2: mtf.Tensor): 93 | return scoped("logical_and", mtf.logical_and, x1, x2) 94 | 95 | 96 | def identity(tensor: mtf.Tensor): 97 | return scoped("identity", mtf.identity, tensor) 98 | 99 | 100 | def while_loop(cond_fn: typing.Callable, body_fn: typing.Callable, inputs: TENSORS, 101 | num_loop_vars: typing.Optional[int] = None, has_accumulators: bool = False): 102 | return scoped("while_loop", mtf.while_loop, cond_fn, body_fn, inputs, num_loop_vars, has_accumulators) 103 | 104 | 105 | def anonymize(tensor: mtf.Tensor): 106 | return scoped("anonymize", mtf.anonymize, tensor) 107 | 108 | 109 | def random_uniform(params: ModelParameter, shape: SHAPE, dtype: typing.Optional[tf.DType] = None, maxval: float = 0, 110 | minval: float = 0): 111 | return scoped("random_uniform", mtf.random_uniform, params.mesh, shape, dtype=dtype, maxval=maxval, minval=minval) 112 | 113 | 114 | def relu(tensor: mtf.Tensor): 115 | return scoped("relu", mtf.relu, tensor) 116 | 117 | 118 | def tanh(tensor: mtf.Tensor): 119 | return scoped("tanh", mtf.tanh, tensor) 120 | 121 | 122 | def assign(var: mtf.Variable, new_val: mtf.Tensor): 123 | return scoped("assign", mtf.assign, get_variable_for_tensor(var), new_val) 124 | 125 | 126 | def assign_add(var: mtf.Variable, new_val: mtf.Tensor): 127 | return scoped("assign_add", mtf.assign_add, get_variable_for_tensor(var), new_val) 128 | 129 | 130 | def assign_sub(var: mtf.Variable, new_val: mtf.Tensor): 131 | return scoped("assign_sub", mtf.assign_sub, get_variable_for_tensor(var), new_val) 132 | 133 | 134 | def concat(tensors: typing.List[mtf.Tensor], concat_dim_name: str) -> mtf.Tensor: 135 | return scoped("concat", mtf.concat, tensors, concat_dim_name) 136 | 137 | 138 | def pad(tensor: mtf.Tensor, padding: typing.Tuple[int, int], dim_name: str) -> mtf.Tensor: 139 | return scoped("concat", mtf.pad, tensor, padding, dim_name) 140 | 141 | 142 | def constant(params: ModelParameter, value: typing.Union[int, float], shape: OPT_SHAPE = None, 143 | dtype: typing.Union[None, mtf.VariableDType, tf.DType] = None) -> mtf.Tensor: 144 | return scoped("constant", mtf.constant, params.mesh, value, shape, 145 | params.variable_dtype.activation_dtype if dtype is None else dtype) 146 | 147 | 148 | def constant_float(params: ModelParameter, value: typing.Union[int, float], shape: OPT_SHAPE = None) -> mtf.Tensor: 149 | return scoped("constant_float", mtf.constant, params.mesh, value, shape, tf.float32) 150 | 151 | 152 | def constant_int(params: ModelParameter, value: typing.Union[int, float], shape: OPT_SHAPE = None) -> mtf.Tensor: 153 | return scoped("constant_int", mtf.constant, params.mesh, value, shape, tf.int32) 154 | 155 | 156 | def constant_scalar(params: ModelParameter, value: typing.Union[int, float], dtype: tf.DType = None) -> mtf.Tensor: 157 | dtype = params.variable_dtype.activation_dtype if dtype is None else dtype 158 | return scoped("constant_scalar", mtf.constant, params.mesh, value, [], dtype) 159 | 160 | 161 | def optimizer_scalar(params: ModelParameter, value: typing.Union[int, float]) -> mtf.Tensor: 162 | return scoped("optimizer_scalar", mtf.constant, params.mesh, value, [], params.optimizer_calculation_dtype) 163 | 164 | 165 | def greater_equal(x1: mtf.Tensor, x2: mtf.Tensor, output_shape: OPT_SHAPE = None) -> mtf.Tensor: 166 | return scoped("greater_equal", mtf.greater_equal, x1, x2, output_shape) 167 | 168 | 169 | def greater(x1: mtf.Tensor, x2: mtf.Tensor, output_shape: OPT_SHAPE = None) -> mtf.Tensor: 170 | return scoped("greater", mtf.greater, x1, x2, output_shape) 171 | 172 | 173 | def less(x1: mtf.Tensor, x2: mtf.Tensor, output_shape: OPT_SHAPE = None) -> mtf.Tensor: 174 | return scoped("less", mtf.less, x1, x2, output_shape) 175 | 176 | 177 | def less_equal(x1: mtf.Tensor, x2: mtf.Tensor, output_shape: OPT_SHAPE = None) -> mtf.Tensor: 178 | return scoped("less_equal", mtf.less_equal, x1, x2, output_shape) 179 | 180 | 181 | def equal(x1: mtf.Tensor, x2: mtf.Tensor, output_shape: OPT_SHAPE = None) -> mtf.Tensor: 182 | return scoped("equal", mtf.equal, x1, x2, output_shape) 183 | 184 | 185 | def mod(x1: mtf.Tensor, x2: typing.Union[mtf.Tensor, int]) -> mtf.Tensor: 186 | return scoped("mod", lambda x, y: x % y, x1, x2) 187 | 188 | 189 | def sin(x: mtf.Tensor): 190 | return scoped("sin", mtf.sin, x) 191 | 192 | 193 | def negative(tensor: mtf.Tensor): 194 | return scoped("negative", lambda x: -x, tensor) 195 | 196 | 197 | def floordiv(x1: mtf.Tensor, x2: mtf.Tensor) -> mtf.Tensor: 198 | return scoped("floordiv", lambda x, y: x // y, x1, x2) 199 | 200 | 201 | def mtf_range(mesh: mtf.Mesh, dim: DIM, dtype: tf.DType) -> mtf.Tensor: 202 | return scoped("range", mtf.range, mesh, dim, dtype) 203 | 204 | 205 | def cast(tensor: mtf.Tensor, dtype: tf.DType) -> mtf.Tensor: 206 | return scoped("cast", mtf.cast, tensor, dtype) 207 | 208 | 209 | def exp(tensor: mtf.Tensor) -> mtf.Tensor: 210 | return scoped("exp", mtf.exp, tensor) 211 | 212 | 213 | def reciprocal(tensor: mtf.Tensor) -> mtf.Tensor: 214 | return scoped("reciprocal", mtf.reciprocal, tensor) 215 | 216 | 217 | def log(tensor: mtf.Tensor) -> mtf.Tensor: 218 | return scoped("log", mtf.log, tensor) 219 | 220 | 221 | def reshape(tensor: mtf.Tensor, new_shape: SHAPE): 222 | return scoped("reshape", mtf.reshape, tensor, new_shape) 223 | 224 | 225 | def argmax(tensor: mtf.Tensor, reduced_dim: mtf.Dimension): 226 | return scoped("argmax", mtf.argmax, tensor, reduced_dim) 227 | 228 | 229 | def sigmoid(tensor: mtf.Tensor) -> mtf.Tensor: 230 | return scoped("sigmoid", mtf.sigmoid, tensor) 231 | 232 | 233 | def sqrt(tensor: mtf.Tensor) -> mtf.Tensor: 234 | return scoped("sqrt", mtf.sqrt, tensor) 235 | 236 | 237 | def sqrt_eps(tensor: mtf.Tensor, epsilon: float = 1e-6) -> mtf.Tensor: 238 | return scoped("sqrt_eps", lambda x: sqrt(add(x, epsilon)), tensor) 239 | 240 | 241 | def rsqrt(tensor: mtf.Tensor) -> mtf.Tensor: 242 | return scoped("rsqrt", mtf.rsqrt, tensor) 243 | 244 | 245 | def rsqrt_eps(tensor: mtf.Tensor, epsilon: float = 1e-6) -> mtf.Tensor: 246 | return scoped("rsqrt_eps", lambda x: rsqrt(add(x, epsilon)), tensor) 247 | 248 | 249 | def softplus(tensor: mtf.Tensor) -> mtf.Tensor: 250 | return scoped("softplus", mtf.softplus, tensor) 251 | 252 | 253 | def square(tensor: mtf.Tensor) -> mtf.Tensor: 254 | return scoped("square", mtf.square, tensor) 255 | 256 | 257 | def broadcast(tensor: mtf.Tensor, new_shape: SHAPE) -> mtf.Tensor: 258 | return scoped("broadcast", mtf.broadcast, tensor, new_shape) 259 | 260 | 261 | def sign(tensor: mtf.Tensor) -> mtf.Tensor: 262 | return scoped("sign", mtf.sign, tensor) 263 | 264 | 265 | def shift(tensor: mtf.Tensor, offset: int, dim: DIM, wrap: bool) -> mtf.Tensor: 266 | return scoped("shift", mtf.shift, tensor, offset, dim, wrap) 267 | 268 | 269 | def maximum(x1: mtf.Tensor, x2: typing.Union[mtf.Tensor, int, float], output_shape: OPT_SHAPE = None) -> mtf.Tensor: 270 | return scoped("maximum", mtf.maximum, x1, x2, output_shape) 271 | 272 | 273 | def minimum(x1: mtf.Tensor, x2: typing.Union[mtf.Tensor, int, float], output_shape: OPT_SHAPE = None) -> mtf.Tensor: 274 | return scoped("minimum", mtf.minimum, x1, x2, output_shape) 275 | 276 | 277 | def add_n(*xs: typing.Union[typing.List[TENSORS], TENSORS]) -> mtf.Tensor: 278 | if len(xs) == 1 and not isinstance(xs[0], mtf.Tensor): 279 | xs = xs[0] 280 | return scoped("add_n", mtf.add_n, xs) 281 | 282 | 283 | def mtf_slice(tensor: mtf.Tensor, begin: int, size: int, dim_name: str): 284 | return scoped("slice", mtf.slice, tensor, begin, size, dim_name) 285 | 286 | 287 | def add(x1: typing.Union[mtf.Variable, mtf.Tensor], x2: mtf.Tensor): 288 | if isinstance(x1, mtf.Variable): 289 | x1 = x1.value 290 | return scoped("add", lambda x, y: x + y, x1, x2) 291 | 292 | 293 | def multiply(x1: typing.Union[mtf.Variable, mtf.Tensor], x2: mtf.Tensor): 294 | if isinstance(x1, mtf.Variable): 295 | x1 = x1.value 296 | return scoped("multiply", lambda x, y: x * y, x1, x2) 297 | 298 | 299 | def divide(x1: mtf.Tensor, x2: float): 300 | return scoped("divide", lambda x, y: x / y, x1, x2) 301 | 302 | 303 | def subtract(x1: mtf.Tensor, x2: mtf.Tensor): 304 | return scoped("subtract", add, x1, -x2) 305 | 306 | 307 | def ones(mesh: mtf.Mesh, shape: SHAPE, dtype: tf.DType) -> mtf.Tensor: 308 | return scoped("ones", mtf.ones, mesh, shape, dtype) 309 | 310 | 311 | def zeros(mesh: mtf.Mesh, shape: SHAPE, dtype: tf.DType) -> mtf.Tensor: 312 | return scoped("zeros", mtf.zeros, mesh, shape, dtype) 313 | 314 | 315 | def pow(x1: mtf.Tensor, x2: mtf.Tensor) -> mtf.Tensor: 316 | return scoped("pow", mtf.pow, x1, x2) 317 | 318 | 319 | def zeros_like(tensor: mtf.Tensor) -> mtf.Tensor: 320 | return scoped("zeros_like", mtf.zeros_like, tensor) 321 | 322 | 323 | def ones_like(tensor: mtf.Tensor) -> mtf.Tensor: 324 | return scoped("ones_like", mtf.ones_like, tensor) 325 | 326 | 327 | def dropout(tensor: mtf.Tensor, is_training: bool, keep_prob: typing.Optional[float] = None, 328 | rate: typing.Optional[float] = None, noise_shape: OPT_SHAPE = None) -> mtf.Tensor: 329 | return scoped("dropout", mtf.dropout, tensor, is_training, keep_prob, rate, noise_shape) 330 | -------------------------------------------------------------------------------- /src/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stores custom optimizer classes as well as a custom optimizer creation utility as a handy wrapper 3 | b""" 4 | 5 | import typing 6 | 7 | import mesh_tensorflow as mtf 8 | import tensorflow as tf2 9 | 10 | from .backend import import_mtf, import_float, get_var, variable 11 | from .backend import variable 12 | from .context import OptimizerCtx 13 | from .context import OptimizerCtx 14 | from .gradients import MULTI_LOSS_GRADIENTS 15 | from .learning_rate import get_learning_rate 16 | from .optimizers import OPTIMIZERS 17 | from ..dataclass import ModelParameter 18 | from ..mtf_wrapper import (cast, constant_float, constant_scalar, einsum, equal, greater_equal, mod, reduce_sum, assign, 19 | add, multiply, scoped, identity, zeros_like, negative, optimizer_scalar, reciprocal, 20 | reduce_mean, broadcast, assign_sub, assign_add, subtract, get_variable_for_tensor) 21 | from ..utils_mtf import feature_dims_used, to_fp32, gradient_iterator 22 | 23 | tf = tf2.compat.v1 24 | zeros = tf.zeros_initializer() 25 | 26 | 27 | def gradient_accumulation(ctx: OptimizerCtx): 28 | ctx.update_ops.append((assign_add if ctx.assign else add)(ctx.grad_buffer, ctx.grad)) 29 | 30 | 31 | def update(ctx: OptimizerCtx): 32 | params = ctx.params 33 | update_ops = ctx.update_ops 34 | learning_rate = ctx.learning_rate 35 | 36 | var = ctx.var 37 | if ctx.grad_buffer is not None: 38 | ctx.grad = reduce_mean(broadcast(identity(ctx.grad_buffer.value), [params.batch_dim] + ctx.grad.shape.dims), 39 | params.batch_dim) 40 | ctx.update_ops.append((assign if ctx.assign else mtf.multiply)(ctx.grad_buffer, zeros_like(ctx.grad))) 41 | 42 | for opt in params.optimizer.split('-'): 43 | opt, *args = opt.split(':') 44 | ctx.grad = scoped(opt, OPTIMIZERS[opt], ctx, *args) 45 | 46 | if 'rezero' in var.name: 47 | ctx.grad *= params.rezero_lr_multiplier 48 | 49 | features_used = feature_dims_used(params, var) 50 | large_tensor = features_used and len(var.shape.dims) > len(params.feature_dims) 51 | large_tensor |= not features_used and len(var.shape.dims) >= 2 # not norm or rezero + scalable catch-all 52 | large_tensor &= var.shape.size > 1 # not rezero 53 | large_tensor &= "norm" not in var.name # not norm 54 | large_tensor &= "rezero" not in var.name # not norm 55 | large_tensor &= "embed" not in var.name # not input/output embedding, position embedding, attention map bias 56 | large_tensor &= "input" not in var.name or "lang_in" in var.name or "vid_in" in var.name # not input 57 | large_tensor &= "output" not in var.name or "lang_out" in var.name or "vid_out" in var.name # not output 58 | 59 | if large_tensor and params.weight_decay > 0: 60 | ctx.grad += einsum([cast(var.value, params.optimizer_calculation_dtype), learning_rate, 61 | optimizer_scalar(params, params.weight_decay)], output_shape=var.shape) 62 | 63 | if ctx.assign: 64 | update_ops.append(assign_sub(ctx.var, ctx.grad)) 65 | else: 66 | update_ops.append(subtract(ctx.var, mtf.cast(ctx.grad, params.calculation_dtype))) 67 | 68 | 69 | def get_optimizer(loss_list: typing.List[mtf.Tensor], params: ModelParameter, manual_step: mtf.Tensor, fn: str 70 | ) -> typing.Tuple[typing.List[mtf.Assign], mtf.Tensor]: 71 | """ 72 | Creates optimizing and update/training operations. 73 | :param loss_list: Final scalar loss of the model 74 | :param params: ModelParameter instance 75 | :param manual_step: manually incremented global_step variable to account for grad accumulation 76 | :param fn: whether to "accumulate" gradients or "update" parameters. 77 | :return: scalar learning rate, update operations, gradients 78 | 79 | there is no check for "update". you can just call it "oijhiojio" and it'll still work. just make sure it's not 80 | called "accumulate". 81 | """ 82 | 83 | dtype = params.optimizer_calculation_dtype 84 | update_ops = [] 85 | 86 | learning_rate_ctx = get_learning_rate(params, loss_list, update_ops) 87 | learning_rate = import_mtf(params, learning_rate_ctx.learning_rate, "learning_rate") 88 | 89 | step = cast(equal(mod(cast(manual_step + optimizer_scalar(params, 1), dtype), 90 | import_mtf(params, params.grad_accumulation * 1., "grad_accum")), 91 | import_mtf(params, 0., "zero")), dtype) 92 | neg_step = -step 93 | mstep = 1 + neg_step 94 | beta1 = 1 + neg_step * import_mtf(params, 1 - params.opt_beta1, "beta1") 95 | beta2 = 1 + neg_step * import_mtf(params, 1 - params.opt_beta2, "beta2") 96 | step_count = cast(learning_rate_ctx.global_steps_mtf, step.dtype) * step + mstep * 10 ** 9 + 1 97 | 98 | first_grad = {} 99 | loss_1__loss_1 = loss_1__loss_2 = loss_2__loss_2 = 0 100 | mgda = params.multi_loss_strategy == "mgda" 101 | 102 | if mgda: 103 | loss_1__loss_1 = constant_float(params, 0, shape=[params.head_dim]) 104 | loss_1__loss_2 = constant_float(params, 0, shape=[params.head_dim]) 105 | loss_2__loss_2 = constant_float(params, 0, shape=[params.head_dim]) 106 | 107 | tensor_to_gradient: typing.Dict[mtf.Tensor, typing.List[int, int, mtf.Tensor, mtf.Operation]] = {} 108 | tensor_to_var = {} 109 | 110 | for loss_idx, loss in enumerate(loss_list): 111 | if mgda and loss_idx == 2: 112 | v1v1 = reduce_sum(loss_1__loss_1, output_shape=[]) 113 | v1v2 = reduce_sum(loss_1__loss_2, output_shape=[]) 114 | v2v2 = reduce_sum(loss_2__loss_2, output_shape=[]) 115 | min_gamma = 0.001 116 | gamma = multiply(constant_float(params, value=(1 - min_gamma), shape=[]), 117 | to_fp32(greater_equal(v1v2, v1v1))) 118 | gamma += einsum([constant_float(params, value=min_gamma, shape=[]), to_fp32(greater_equal(v1v2, v2v2)), 119 | to_fp32(equal(gamma, 0))], output_shape=[]) 120 | gamma += einsum([optimizer_scalar(params, -1), 121 | to_fp32(equal(gamma, 0)), 122 | add(v1v2, negative(v2v2)), 123 | reciprocal(add(v1v1, v2v2) - multiply(-2, v1v2))], 124 | output_shape=[]) 125 | 126 | loss = loss_list[0] * gamma + loss_list[1] * (1 - gamma) 127 | 128 | operations = loss.graph.operations 129 | xs = [] 130 | tensor_to_var = {} 131 | for x in params.variable_cache.values(): 132 | x: typing.Union[mtf.AddOperation, mtf.Variable] = x 133 | value: mtf.Tensor = x.outputs[0] 134 | op = get_variable_for_tensor(value) 135 | if not op.trainable: 136 | continue 137 | xs.append(value) 138 | tensor_to_var[value] = x 139 | 140 | loss_grad = constant_scalar(params, 1.0) 141 | downstream = set(xs) 142 | 143 | for op in operations: 144 | if op.has_gradient and (set(op.inputs) & downstream): 145 | downstream |= set(op.outputs) 146 | 147 | tensor_to_gradient: typing.Dict[mtf.Tensor, typing.List[int, int, mtf.Tensor, 148 | mtf.Operation]] = {loss: [0, 0, loss_grad, None]} 149 | 150 | with tf.variable_scope(loss.graph.captured_variable_scope): 151 | for op in operations[::-1]: 152 | grad_outputs = [] 153 | for out in op.outputs: 154 | if out not in tensor_to_gradient: 155 | grad_outputs.append(None) 156 | continue 157 | 158 | grad_list: typing.Tuple[int, int, mtf.Tensor, mtf.Operation] = tensor_to_gradient[out] 159 | grad_outputs.append(grad_list[2]) 160 | grad_list[0] += 1 161 | 162 | if not op.has_gradient or not any(grad_outputs) or not (set(op.inputs) & downstream): 163 | continue 164 | for inner_op, inp, grad in gradient_iterator(params, op, grad_outputs): 165 | if inp not in downstream or grad is None: 166 | continue 167 | 168 | if inp in tensor_to_gradient: 169 | grad_list = tensor_to_gradient[inp] 170 | grad_list[1] += 1 171 | grad_list[2] += grad 172 | grad_list[3] = inner_op 173 | else: 174 | tensor_to_gradient[inp] = [0, 1, grad, inner_op] 175 | 176 | ctx = OptimizerCtx(op, grad_outputs, downstream, tensor_to_gradient, tensor_to_var, params, 177 | loss_idx, update_ops, {}, loss_list, first_grad, 178 | loss_1__loss_1, loss_1__loss_2, loss_2__loss_2, mstep, step, neg_step, dtype, 179 | beta1, beta2, learning_rate, step_count, fn == 'update') 180 | ctx.variable_to_gradient = {var: cast(tensor_to_gradient[tensor][2], params.optimizer_calculation_dtype) 181 | for tensor, var in tensor_to_var.items()} 182 | for tensor, var in tensor_to_var.items(): 183 | update(ctx(tensor, get_variable_for_tensor(tensor), ctx.variable_to_gradient[var], var.outputs[0])) 184 | if params.combine_assignments: 185 | ctx.update_ops = step.graph.combine_assignments(ctx.update_ops) 186 | return ctx.update_ops, learning_rate 187 | -------------------------------------------------------------------------------- /src/optimizer/backend.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf2 5 | from tensorflow.python.ops.init_ops import Initializer 6 | 7 | from ..dataclass import ModelParameter 8 | from ..mtf_wrapper import (import_fully_replicated) 9 | from ..utils_mtf import SHAPE, get_variable 10 | 11 | tf = tf2.compat.v1 12 | zeros = tf.zeros_initializer() 13 | 14 | 15 | def import_float(imported): 16 | return tf.constant(imported * 1.0, dtype=tf.float32, shape=[]) 17 | 18 | 19 | def get_var(params: ModelParameter, name: str, shape: SHAPE, initializer: Initializer = zeros): 20 | return get_variable(params, name, shape, initializer, False, params.optimizer_dtype) 21 | 22 | 23 | def variable(params: ModelParameter, base: mtf.Variable, name: str, shape: SHAPE): 24 | return get_variable(params, f"{base.name}/{params.optimizer.replace(':', '_')}/{name}", shape, zeros, False, 25 | params.optimizer_dtype) 26 | 27 | 28 | def import_mtf(params: ModelParameter, imported: typing.Union[tf.Tensor, float], name: str): 29 | return import_fully_replicated(params, tf.cast(imported, params.optimizer_calculation_dtype), [], name) 30 | -------------------------------------------------------------------------------- /src/optimizer/context.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | 5 | from ..dataclass import ModelParameter 6 | 7 | 8 | class OptimizerCtx: 9 | def __init__(self, op: mtf.Operation, grad_outputs: typing.List[mtf.Tensor], downstream: typing.Set[mtf.Operation], 10 | tensor_to_gradient: dict, tensor_to_var: dict, params: ModelParameter, loss_idx: int, update_ops: list, 11 | debug_gradients_dict: dict, loss_list: list, first_grad: dict, 12 | loss_1__loss_1: typing.Optional[mtf.Tensor], loss_1__loss_2: typing.Optional[mtf.Tensor], 13 | loss_2__loss_2: typing.Optional[mtf.Tensor], mstep: mtf.Tensor, step: mtf.Tensor, neg_step, 14 | dtype: mtf.VariableDType, beta1: mtf.Tensor, beta2: mtf.Tensor, learning_rate: mtf.Tensor, 15 | step_count: mtf.Tensor, assign: bool): 16 | self.step_count = step_count 17 | self.op = op 18 | self.grad_outputs = grad_outputs 19 | self.tensor_to_gradient = tensor_to_gradient 20 | self.tensor_to_var = tensor_to_var 21 | self.params = params 22 | self.loss_idx = loss_idx 23 | self.update_ops = update_ops 24 | self.debug_gradients_dict = debug_gradients_dict 25 | self.loss_list = loss_list 26 | self.first_grad = first_grad 27 | self.loss_1__loss_1 = loss_1__loss_1 28 | self.loss_1__loss_2 = loss_1__loss_2 29 | self.loss_2__loss_2 = loss_2__loss_2 30 | self.mstep = mstep 31 | self.step = step 32 | self.dtype = dtype 33 | self.neg_step = neg_step 34 | self.beta1 = beta1 35 | self.beta2 = beta2 36 | self.learning_rate = learning_rate 37 | self.assign = assign 38 | self.args = [op, grad_outputs, downstream, tensor_to_gradient, tensor_to_var, params, loss_idx, update_ops, 39 | debug_gradients_dict, loss_list, first_grad, loss_1__loss_1, loss_1__loss_2, loss_2__loss_2, mstep, 40 | step, dtype, beta1, beta2, learning_rate] 41 | 42 | self.var: typing.Optional[mtf.Variable] = None 43 | self.value: typing.Optional[mtf.Tensor] = None 44 | self.tensor: typing.Optional[mtf.Tensor] = None 45 | self.grad_buffer: typing.Optional[mtf.Variable] = None 46 | self.grad: typing.Optional[mtf.Tensor] = None 47 | self.original_grad: typing.Optional[mtf.Tensor] = None 48 | self.variable_to_gradient: typing.Optional[typing.Dict[mtf.Variable:mtf.Tensor]] = {} 49 | 50 | self.global_norm_reciprocal: typing.Optional[mtf.Tensor] = None 51 | 52 | def __call__(self, tensor: mtf.Tensor, var: mtf.Variable, grad: mtf.Tensor, value: mtf.Tensor): 53 | self.var = var 54 | self.value = value 55 | self.tensor = tensor 56 | self.grad = self.original_grad = grad 57 | self.op = self.tensor_to_gradient[tensor][3] 58 | return self 59 | -------------------------------------------------------------------------------- /src/optimizer/gradients.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import tensorflow as tf2 3 | 4 | from .context import OptimizerCtx 5 | from ..mtf_wrapper import add_n, einsum, minimum 6 | 7 | tf = tf2.compat.v1 8 | zeros = tf.zeros_initializer() 9 | 10 | 11 | def pcgrad(ctx: OptimizerCtx, grad: mtf.Tensor): 12 | op = ctx.op 13 | loss_idx = ctx.loss_idx 14 | loss_list = ctx.loss_list 15 | first_grad = ctx.first_grad 16 | 17 | if 'body' not in op.name: 18 | return grad 19 | 20 | if loss_idx < len(loss_list) - 1: 21 | first_grad[op.name] = grad 22 | return None 23 | 24 | all_grads = [grad, first_grad[op.name]] 25 | g_square = [1e-8 + einsum([g, g], output_shape=[]) for g in all_grads[1:]] 26 | 27 | for i in range(len(all_grads)): 28 | grad = all_grads.pop(0) 29 | for g, sq in zip(all_grads, g_square): 30 | grad -= einsum([g, minimum(einsum([grad, g], output_shape=[]), 0), sq], output_shape=grad.shape) 31 | 32 | all_grads.append(grad) 33 | g_square.append(einsum([g, g], output_shape=[])) 34 | 35 | return add_n(all_grads) 36 | 37 | 38 | def mgda(ctx: OptimizerCtx, grad: mtf.Tensor): 39 | op = ctx.op 40 | loss_idx = ctx.loss_idx 41 | first_grad = ctx.first_grad 42 | params = ctx.params 43 | 44 | if loss_idx == 2: 45 | return None 46 | 47 | if 'body' not in op.name: 48 | return grad 49 | 50 | if loss_idx == 0: 51 | first_grad[op.name] = grad 52 | return None 53 | 54 | elif loss_idx == 1: 55 | ctx.loss_1__loss_1 += einsum([first_grad[op.name], first_grad[op.name]], [params.head_dim]) 56 | ctx.loss_1__loss_2 += einsum([first_grad[op.name], grad], [params.head_dim]) 57 | ctx.loss_2__loss_2 += einsum([grad, grad], [params.head_dim]) 58 | 59 | del first_grad[op.name] 60 | return None 61 | 62 | return grad 63 | 64 | 65 | MULTI_LOSS_GRADIENTS = {'mgda': mgda, 66 | 'pcgrad': pcgrad} 67 | -------------------------------------------------------------------------------- /src/optimizer/learning_rate.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf2 5 | 6 | from .backend import import_float, import_fully_replicated 7 | from .. import tf_wrapper as tfw 8 | from ..dataclass import ModelParameter, LearningRateConfig 9 | from ..utils_mtf import weighted_add 10 | 11 | tf = tf2.compat.v1 12 | 13 | 14 | class LearningRateCtx: 15 | def __init__(self, params: ModelParameter, loss_list: typing.List[mtf.Tensor], 16 | update_ops: typing.List[mtf.Operation]): 17 | global_step = tf.train.get_or_create_global_step() 18 | self.params = params 19 | self.learning_rate = tf.constant(value=params.learning_rate, shape=[], dtype=tf.float32) 20 | self.global_steps_float = tf.cast(global_step, tf.float32) 21 | self.loss_list = loss_list 22 | self.global_steps_mtf = import_fully_replicated(params, global_step, [], "mtf_learning_rate") 23 | self.update_ops = update_ops 24 | self.config: typing.Optional[LearningRateConfig] = None 25 | 26 | 27 | def linear_warmup(ctx: LearningRateCtx): 28 | warmup_steps_float = import_float(ctx.config.final_step) 29 | is_warmup = tfw.cast(ctx.global_steps_float < warmup_steps_float, tf.float32) 30 | warmup_factor = weighted_add(tfw.divide(ctx.global_steps_float, warmup_steps_float), 1, is_warmup) 31 | ctx.learning_rate = tfw.multiply(ctx.learning_rate, warmup_factor) 32 | 33 | 34 | def exponential_decay(ctx: LearningRateCtx): 35 | base = import_float(ctx.config.factor) 36 | exp = tfw.maximum(tfw.subtract(ctx.global_steps_float, import_float(ctx.config.start_step), import_float(0))) 37 | decay = tfw.pow(base, exp) 38 | ctx.learning_rate = tfw.multiply(ctx.learning_rate, decay) 39 | 40 | 41 | def linear_decay(ctx: LearningRateCtx): 42 | start_step = import_float(ctx.config.start_step) 43 | final_step = import_float(ctx.config.final_step) 44 | current_step = tfw.subtract(ctx.global_steps_float, start_step) 45 | final_step = tfw.subtract(final_step, start_step) 46 | decay = tfw.subtract(1, tfw.divide(current_step, final_step)) 47 | decay = tfw.maximum(tfw.minimum(decay, 1), 0) 48 | ctx.learning_rate = tfw.multiply(ctx.learning_rate, decay) 49 | 50 | 51 | def lower_bound(ctx: LearningRateCtx): 52 | ctx.learning_rate = tfw.maximum(ctx.learning_rate, ctx.config.factor) 53 | 54 | 55 | def upper_bound(ctx: LearningRateCtx): 56 | ctx.learning_rate = tfw.minimum(ctx.learning_rate, ctx.config.factor) 57 | 58 | 59 | MODULES = {"linear_warmup": linear_warmup, 60 | "exponential_decay": exponential_decay, 61 | "linear_decay": linear_decay, 62 | "lower_bound": lower_bound, 63 | "upper_bound": upper_bound} 64 | 65 | 66 | def get_learning_rate(params: ModelParameter, loss_list: typing.List[mtf.Tensor], 67 | update_ops: typing.List[mtf.Operation]) -> LearningRateCtx: 68 | ctx = LearningRateCtx(params, loss_list, update_ops) 69 | for name, keys in params.learning_rate_config.items(): 70 | ctx.config = keys 71 | tfw.scoped(name, MODULES[name], ctx) 72 | return ctx 73 | -------------------------------------------------------------------------------- /src/optimizer/optimizers.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | 5 | from .backend import variable 6 | from .context import OptimizerCtx 7 | from ..mtf_wrapper import (cast, optimizer_scalar, einsum, minimum, 8 | reduce_mean, reduce_sum, assign, add, multiply, maximum, reciprocal, square, 9 | reduce_max, rsqrt, sqrt, add_n, negative, pow as mtf_pow) 10 | from ..utils_mtf import weighted_add 11 | 12 | 13 | def opt_rsqrt(tensor: mtf.Tensor) -> mtf.Tensor: 14 | return reciprocal(maximum(sqrt(tensor), 1e-5)) 15 | 16 | 17 | def debias_momentum(ctx: OptimizerCtx, momentum: mtf.Tensor) -> mtf.Tensor: 18 | return reciprocal(add(1, negative(mtf_pow(momentum, ctx.step_count)))) 19 | 20 | 21 | def debias(ctx: OptimizerCtx, tensor: mtf.Tensor, momentum: mtf.Tensor) -> mtf.Tensor: 22 | return multiply(tensor, debias_momentum(ctx, momentum)) 23 | 24 | 25 | def assign_ctx(ctx: OptimizerCtx, var: typing.Union[mtf.Variable, mtf.Tensor], value: mtf.Tensor 26 | ) -> typing.Union[mtf.Assign, mtf.AddOperation]: 27 | if ctx.assign: 28 | return assign(var, value) 29 | return add(multiply(var, 0), value) 30 | 31 | 32 | def adam(ctx: OptimizerCtx) -> mtf.Tensor: 33 | exp_avg_p2_ptr = variable(ctx.params, ctx.var, 'exp_avg_p2', ctx.var.shape) 34 | exp_avg_p1_ptr = variable(ctx.params, ctx.var, 'exp_avg_p1', ctx.var.shape) 35 | 36 | exp_avg_p2 = weighted_add(exp_avg_p2_ptr, square(ctx.grad), ctx.beta2) 37 | grad = weighted_add(exp_avg_p1_ptr, ctx.grad, ctx.beta1) 38 | 39 | ctx.update_ops.append(assign_ctx(ctx, exp_avg_p2_ptr, exp_avg_p2)) 40 | ctx.update_ops.append(assign_ctx(ctx, exp_avg_p1_ptr, grad)) 41 | return einsum([opt_rsqrt(debias(ctx, exp_avg_p2, ctx.beta2)), grad, 42 | debias_momentum(ctx, ctx.beta1)], output_shape=grad.shape) 43 | 44 | 45 | def novograd(ctx: OptimizerCtx) -> mtf.Tensor: 46 | if ctx.var.shape.ndims == 0: 47 | return adam(ctx) 48 | 49 | exp_avg_p1 = exp_avg_p1_ptr = variable(ctx.params, ctx.var, "exp_avg_p1", ctx.var.shape) 50 | exp_avg_p2 = exp_avg_p2_ptr = variable(ctx.params, ctx.var, "exp_avg_p2", []) 51 | 52 | exp_avg_p1 = add(multiply(ctx.beta1, exp_avg_p1), multiply(ctx.grad, opt_rsqrt(exp_avg_p2))) 53 | exp_avg_p2 = weighted_add(exp_avg_p2, reduce_sum(square(ctx.grad)), ctx.beta2) 54 | ctx.update_ops.extend([assign_ctx(ctx, exp_avg_p1_ptr, exp_avg_p1), 55 | assign_ctx(ctx, exp_avg_p2_ptr, exp_avg_p2)]) 56 | return add(multiply(ctx.beta1, exp_avg_p1), 57 | multiply(ctx.grad, opt_rsqrt(debias(ctx, exp_avg_p2, ctx.beta2)))) 58 | 59 | 60 | def sm3(ctx: OptimizerCtx) -> mtf.Tensor: 61 | if ctx.var.shape.ndims == 0: 62 | return adam(ctx) 63 | 64 | weight_update = variable(ctx.params, ctx.var, "dim0", [ctx.var.shape.dims[0]]) 65 | buffer = [weight_update] 66 | 67 | for i in range(1, ctx.var.shape.ndims): 68 | buffer.append(variable(ctx.params, ctx.var, f"dim{i}", [ctx.var.shape.dims[i]])) 69 | weight_update = minimum(weight_update, buffer[-1]) 70 | 71 | weight_update = add(weight_update, square(ctx.grad)) 72 | 73 | ctx.update_ops.extend([assign_ctx(ctx, buf_ptr, reduce_max(weight_update, output_shape=[dim])) 74 | for buf_ptr, dim in zip(buffer, weight_update.shape.dims)]) 75 | 76 | return multiply(ctx.grad, opt_rsqrt(weight_update)) 77 | 78 | 79 | def adaptive_gradient_clipping(ctx: OptimizerCtx, gradient_clip: str) -> mtf.Tensor: 80 | gradient_clip = float(gradient_clip) 81 | var = cast(ctx.value, ctx.params.optimizer_calculation_dtype) 82 | grd_norm = minimum(rsqrt(einsum([ctx.grad] * 2, output_shape=[])), 1e6) 83 | wgt_norm = maximum(sqrt(einsum([var] * 2, output_shape=[])), 1e-3) 84 | return ctx.grad * minimum(wgt_norm * grd_norm * gradient_clip, 1) 85 | 86 | 87 | def l2norm_gradient_clipping(ctx: OptimizerCtx, gradient_clip: str) -> mtf.Tensor: 88 | gradient_clip = float(gradient_clip) 89 | return einsum([ctx.grad, optimizer_scalar(ctx.params, gradient_clip), 90 | rsqrt(maximum(einsum([ctx.grad, ctx.grad], []), gradient_clip ** -2))]) 91 | 92 | 93 | def global_l2norm_gradient_clipping(ctx: OptimizerCtx, gradient_clip: str) -> mtf.Tensor: 94 | gradient_clip = float(gradient_clip) 95 | if ctx.global_norm_reciprocal is None: 96 | global_sum = add_n([reduce_sum(square(grad)) for grad in ctx.variable_to_gradient.values()]) 97 | ctx.global_norm_reciprocal = rsqrt(maximum(global_sum, gradient_clip ** -2)) 98 | return einsum([ctx.grad, optimizer_scalar(ctx.params, gradient_clip), ctx.global_norm_reciprocal]) 99 | 100 | 101 | def value_gradient_clipping(ctx: OptimizerCtx, gradient_clip: str) -> mtf.Tensor: 102 | gradient_clip = float(gradient_clip) 103 | return maximum(minimum(ctx.grad, gradient_clip), -gradient_clip) 104 | 105 | 106 | def gradient_centralisation(ctx: OptimizerCtx) -> mtf.Tensor: 107 | return ctx.grad - reduce_mean(ctx.grad) 108 | 109 | 110 | def weight_centralisation(ctx: OptimizerCtx) -> mtf.Tensor: 111 | return ctx.grad + reduce_mean(ctx.value) 112 | 113 | 114 | def multiply_learning_rate(ctx: OptimizerCtx) -> mtf.Tensor: 115 | return multiply(ctx.grad, ctx.learning_rate) 116 | 117 | 118 | def momentum(ctx: OptimizerCtx, momentum_multiplier: str, gradient_multiplier: str, nesterov: str) -> mtf.Tensor: 119 | nesterov = bool(int(nesterov)) 120 | momentum_multiplier = float(momentum_multiplier) 121 | gradient_multiplier = float(gradient_multiplier) 122 | 123 | state = variable(ctx.params, ctx.var, 'momentum', ctx.var.shape) 124 | new_state = momentum_multiplier * state + ctx.grad * gradient_multiplier 125 | ctx.update_ops.append(assign_ctx(ctx, state, new_state)) 126 | if not nesterov: 127 | return new_state 128 | return ctx.grad + momentum_multiplier * new_state 129 | 130 | 131 | OPTIMIZERS = {"adam": adam, 132 | "sm3": sm3, 133 | "novograd": novograd, 134 | "adaptive_clip": adaptive_gradient_clipping, 135 | "l2norm_clip": l2norm_gradient_clipping, 136 | "value_clip": value_gradient_clipping, 137 | "gradient_centralisation": gradient_centralisation, 138 | "weight_centralisation": weight_centralisation, 139 | "learning_rate": multiply_learning_rate, 140 | "global_l2norm_clip": global_l2norm_gradient_clipping, 141 | "momentum": momentum 142 | } 143 | 144 | 145 | def graft(ctx: OptimizerCtx, optimizer: str, *params: str) -> mtf.Tensor: 146 | return einsum([ctx.grad, rsqrt(reduce_sum(square(ctx.grad))), 147 | sqrt(reduce_sum(square(OPTIMIZERS[optimizer](ctx, *params))))], 148 | output_shape=ctx.grad.shape) 149 | 150 | 151 | OPTIMIZERS['graft'] = graft 152 | -------------------------------------------------------------------------------- /src/rest_api.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import typing 3 | 4 | import uvicorn 5 | from fastapi import FastAPI, HTTPException 6 | from pydantic import BaseModel 7 | from transformers import GPT2TokenizerFast 8 | 9 | from .dataclass import ModelParameter 10 | from .interface import InterfaceWrapper 11 | 12 | 13 | class Tokens(BaseModel): 14 | tokens: typing.List[int] 15 | 16 | 17 | class TokenCompletion(BaseModel): 18 | token_completion: typing.List[int] 19 | 20 | 21 | class Completion(BaseModel): 22 | completion: str 23 | 24 | 25 | class SanitizedTokens(BaseModel): 26 | tokens: typing.List[int] 27 | 28 | 29 | class CompletionInput(BaseModel): 30 | prompt: str = "" 31 | max_tokens: int = 16 32 | temperature: float = 1. 33 | error: bool = True 34 | 35 | 36 | class RestAPI: 37 | def __init__(self, params: ModelParameter): 38 | self._interface = InterfaceWrapper(params) 39 | self._params = params 40 | self._tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 41 | 42 | async def check_tokens(self, tokens: typing.List[int], error: bool = True) -> SanitizedTokens: 43 | if tokens and max(tokens) > self._params.vocab_size: 44 | if error: 45 | raise HTTPException(status_code=400, detail=f"Invalid tokens sent. Tokens go up to " 46 | f"{self._params.vocab_size} but received {max(tokens)}.") 47 | tokens = [t for t in tokens if t < self._params.vocab_size] 48 | if len(tokens) > self._params.sequence_length: 49 | if error: 50 | raise HTTPException(status_code=400, detail=f"Context too big. The model supports up to " 51 | f"{self._params.sequence_length} tokens but received {len(tokens)}.") 52 | tokens = tokens[:self._params.sequence_length] 53 | return SanitizedTokens(tokens=tokens) 54 | 55 | async def encode(self, prompt: str) -> Tokens: 56 | out = list(prompt.encode()) if self._params.vocab_size == 256 else self._tokenizer.encode(prompt) 57 | return Tokens(tokens=out) 58 | 59 | async def decode(self, prompt: typing.List[int]) -> Completion: 60 | out = ''.join(chr(c) for c in prompt) if self._params.vocab_size == 256 else self._tokenizer.encode(prompt) 61 | return Completion(completion=out) 62 | 63 | async def token_completion(self, params: CompletionInput) -> TokenCompletion: 64 | tokens = (await self.encode(params.prompt)).tokens 65 | tokens = (await self.check_tokens(tokens, params.error)).tokens 66 | out = self._interface.complete(tokens, params.temperature, len(tokens) + params.max_tokens) 67 | out = out.tolist()[:params.max_tokens] 68 | return TokenCompletion(token_completion=out) 69 | 70 | async def completion(self, params: CompletionInput) -> Completion: 71 | return await self.decode((await self.token_completion(params)).token_completion) 72 | 73 | 74 | def get_api_input_and_output_fn(params: ModelParameter): 75 | rest_api = RestAPI(params) 76 | fast_api = FastAPI() 77 | 78 | for key in dir(rest_api): 79 | if key.startswith('_') or key.endswith('_'): 80 | continue 81 | fn = getattr(rest_api, key) 82 | fast_api.post('/' + key, response_model=typing.get_type_hints(fn)["return"])(fn) 83 | 84 | run = multiprocessing.Process(target=uvicorn.run, daemon=True, args=(fast_api,), 85 | kwargs={'host': '0.0.0.0', 'port': 62220, 'log_level': 'info', 86 | 'workers': params.web_workers}) 87 | run.start() 88 | 89 | return rest_api._interface.input_query, rest_api._interface.output_responds 90 | -------------------------------------------------------------------------------- /src/run/dataloader_placement.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | 4 | import jsonpickle 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.python.framework import ops 8 | from tensorflow.python.tpu import tpu_feed 9 | 10 | from .. import tf_wrapper as tfw 11 | from ..dataclass import ModelParameter 12 | 13 | tf1 = tf.compat.v1 14 | Dataset = tf1.data.Dataset 15 | 16 | 17 | def place_dataloader(params: ModelParameter, input_fn): 18 | num_cores = params.mesh_impl.device_assignment.num_replicas 19 | 20 | ordered_ordinals = [] 21 | ordered_hosts = [] 22 | ordered_host_ids = [] 23 | host_id_to_its_pnums = collections.defaultdict(list) 24 | d_assignment = params.mesh_impl.device_assignment 25 | 26 | for pnum in range(num_cores): 27 | physical_pnum = params.mesh_impl.l2p(pnum) 28 | 29 | # For MTF, there's always 1 core per replica. So logical_core=0. 30 | ordered_ordinals.append(d_assignment.tpu_ordinal(replica=physical_pnum, logical_core=0)) 31 | host_device = d_assignment.host_device(replica=physical_pnum) 32 | host_id = int(host_device.lower().split("/task:")[1].split("/device:")[0]) 33 | ordered_hosts.append(host_device) 34 | ordered_host_ids.append(host_id) 35 | host_id_to_its_pnums[host_id].append(pnum) 36 | 37 | num_hosts = len(set(ordered_hosts)) 38 | 39 | pnum_maps = [] 40 | macro_batching_multi = params.macro_batching if params.train else 1 41 | batch_size = params.input_pipeline_shape[0].to_integer_list[0] * macro_batching_multi 42 | for mtf_shape in params.input_pipeline_shape: 43 | # Make sure that the batch size is the same across all input tensors. 44 | assert batch_size == mtf_shape.to_integer_list[0] * macro_batching_multi 45 | 46 | s_shape = params.mesh_impl.slice_shape(mtf_shape) 47 | shape_list = [dim_size // s_dim_size for dim_size, s_dim_size in zip(mtf_shape.to_integer_list, s_shape)] 48 | 49 | pnum_map_shape = shape_list + [num_cores // np.prod(shape_list)] 50 | assert np.prod(pnum_map_shape) == num_cores 51 | 52 | # Initialize the pnum_map to None. 53 | pnum_map = np.empty(pnum_map_shape, dtype=object) 54 | pnum_map[:] = None 55 | 56 | for pnum in range(num_cores): 57 | s_begin = params.mesh_impl.slice_begin(mtf_shape, pnum) 58 | coord = [dim_size // s_dim_size for dim_size, s_dim_size in zip(s_begin, s_shape)] 59 | # put pnum in pnum_map[coord] 60 | pnum_array_ref = pnum_map[tuple(coord)] 61 | for idx, value in enumerate(pnum_array_ref): 62 | if value is None: 63 | pnum_array_ref[idx] = pnum 64 | break 65 | 66 | pnum_maps.append(pnum_map) 67 | 68 | # For each sub-batch, we need to know which host should read it. 69 | if params.train: 70 | 71 | # This records how many datasets (ds) are already stored on each host. 72 | num_dss_per_host = [0] * num_hosts 73 | 74 | # A list of host_ids that holds datasets (ds). 75 | hosts_to_hold_ds = [] 76 | 77 | for sub_batch_pnum_map in pnum_maps[0]: 78 | 79 | num_pnums_per_host = [0] * num_hosts 80 | for pnum in sub_batch_pnum_map.flatten(): 81 | num_pnums_per_host[ordered_host_ids[pnum]] += 1 82 | 83 | host_metrics = [(host_id, num_pnums_per_host[host_id], num_dss_per_host[host_id]) for host_id in 84 | range(num_hosts)] 85 | host_id, _, _ = max(host_metrics, key=lambda keys: (keys[1], -keys[2])) 86 | 87 | num_dss_per_host[host_id] += 1 88 | hosts_to_hold_ds.append(host_id) 89 | 90 | else: 91 | # There should be just one dataset-holding host. Make the last host do it. 92 | hosts_to_hold_ds = [num_hosts - 1] 93 | 94 | sub_batch_size = batch_size // len(hosts_to_hold_ds) 95 | tf1.logging.info("MTF sub_batch_size: {}".format(sub_batch_size)) 96 | assert sub_batch_size * len(hosts_to_hold_ds) == batch_size 97 | 98 | # Slots for all laidout tensors. 99 | all_laidout_tensors = [[None] * len(params.input_pipeline_shape) for _ in range(num_cores)] 100 | 101 | log_path = params.model_path + "/DataLog.log" 102 | _run_log = [] 103 | run_log = None 104 | 105 | if params.use_checkpointing: 106 | if tf.io.gfile.exists(log_path): 107 | _run_log = json.load(tf.io.gfile.GFile(log_path, 'r')) 108 | 109 | curran_stats = {'steps': params.current_step, 'ctx': params.sequence_length, 110 | 'slice_count': len(hosts_to_hold_ds), 111 | 'interleave_size': params.interleaved_datasets, 112 | 'batch_size': params.train_batch_size, 113 | 'grad_accumulation': params.grad_accumulation, 114 | 'token_patch_size': params.token_patch_size 115 | } 116 | 117 | size_dump = jsonpickle.dumps(_run_log + [curran_stats], indent=4) 118 | with tf.io.gfile.GFile(f"{params.model_path}/model_size.info", 'w') as f: 119 | f.write(size_dump) 120 | 121 | if len(_run_log) > 0 and not params.use_random_dataloader: 122 | _run_log = [r for r in _run_log if r['steps'] != params.current_step] 123 | if len(_run_log) > 0: 124 | run_log = [_run_log.pop(-1)] 125 | for r in _run_log[::-1]: 126 | if run_log[-1]['steps'] != r['steps'] and r['steps'] != params.current_step: 127 | run_log.append(r) 128 | run_log = run_log[::-1] 129 | 130 | for run_idx in range(len(run_log) - 1): 131 | run_log[run_idx]['steps'] = run_log[run_idx + 1]['steps'] - run_log[run_idx]['steps'] 132 | 133 | run_log[-1]['steps'] = params.current_step - run_log[-1]['steps'] 134 | 135 | if run_log[-1]['steps'] <= 0: 136 | run_log = None 137 | 138 | ds_iterator = [] 139 | # For each sub-batch, create a SubBatchSlicer object. 140 | for sub_batch_i, host_id in enumerate(hosts_to_hold_ds): 141 | # Get the list of pnums for each input. 142 | if params.train: 143 | 144 | all_sub_batch_pnums = [] 145 | for pnum_map in pnum_maps: 146 | sub_batch_pnums = pnum_map[sub_batch_i, ...].flatten().tolist() 147 | all_sub_batch_pnums.append(sub_batch_pnums) 148 | 149 | else: 150 | 151 | all_sub_batch_pnums = [pnum_map.flatten().tolist() for pnum_map in pnum_maps] 152 | 153 | with ops.device(f"/job:worker/task:{host_id}/device:CPU:0"): 154 | dataset = input_fn(params, sub_batch_size, sub_batch_i, len(hosts_to_hold_ds), run_log) 155 | if not params.use_random_dataloader and params.train and params.use_video: 156 | dataset = dataset.skip(params.current_step // params.macro_batching) 157 | dataset = dataset.prefetch(params.buffer_size) 158 | options = tf.data.Options() 159 | # options.autotune.enabled = True 160 | # options.deterministic = not params.train 161 | options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO 162 | options.experimental_optimization.filter_fusion = True 163 | options.experimental_optimization.apply_default_optimizations = True 164 | options.experimental_optimization.map_and_batch_fusion = True 165 | options.experimental_optimization.map_and_filter_fusion = True 166 | options.experimental_optimization.map_fusion = True 167 | options.experimental_optimization.map_parallelization = True 168 | options.experimental_optimization.noop_elimination = True 169 | options.experimental_optimization.parallel_batch = True 170 | options.experimental_optimization.shuffle_and_repeat_fusion = True 171 | options.experimental_slack = True 172 | options.threading.private_threadpool_size = 96 173 | options.threading.max_intra_op_parallelism = 1 174 | dataset: Dataset = dataset.with_options(options) 175 | _ds_iterator = tf1.data.make_initializable_iterator(dataset) 176 | ds_iterator.append(_ds_iterator) 177 | all_input_tensors = _ds_iterator.get_next() 178 | 179 | if isinstance(all_input_tensors, tf.Tensor): 180 | all_input_tensors = [all_input_tensors] 181 | assert len(all_input_tensors) == len(all_sub_batch_pnums) 182 | 183 | for input_i in range(len(all_input_tensors)): 184 | input_tensor = all_input_tensors[input_i] 185 | sub_batch_pnums = all_sub_batch_pnums[input_i] 186 | mtf_input_shape = params.input_pipeline_shape[input_i] 187 | 188 | # Initialize the cache for each input_i 189 | _slice_dict = collections.defaultdict(list) 190 | 191 | for idx, pnum in enumerate(sub_batch_pnums): 192 | 193 | s_begin = params.mesh_impl.slice_begin(mtf_input_shape, pnum) 194 | if not not params.train: 195 | # Always slice from 0 in the first dimension (batch dimension), since 196 | # input_tensor a sub-batch tensor. 197 | s_begin[0] = 0 198 | if tuple(s_begin) in _slice_dict: 199 | input_slice = _slice_dict[tuple(s_begin)] 200 | else: 201 | s_shape = params.mesh_impl.slice_shape(mtf_input_shape) 202 | s_shape[0] = s_shape[0] * macro_batching_multi 203 | input_slice = tfw.slice(input_tensor, s_begin, s_shape) 204 | 205 | all_laidout_tensors[pnum][input_i] = input_slice 206 | 207 | # Make sure that there are no Nones in all_laidout_tensors. 208 | for laidout_tensors in all_laidout_tensors: 209 | assert None not in laidout_tensors 210 | 211 | with ops.device(f"/job:worker/task:{hosts_to_hold_ds[0]}/device:CPU:0"): 212 | 213 | def _tpu_ordinal_function_impl(pnum): 214 | return ordered_ordinals[pnum] 215 | 216 | def _placement_function_impl(pnum): 217 | return ordered_hosts[pnum] 218 | 219 | laidout_tensors0 = all_laidout_tensors[0] 220 | infeed_queue = tpu_feed.InfeedQueue( 221 | number_of_tuple_elements=len(laidout_tensors0), 222 | tuple_types=[x.dtype for x in laidout_tensors0], 223 | tuple_shapes=[x.shape for x in laidout_tensors0]) 224 | enqueue_ops = infeed_queue.generate_enqueue_ops( 225 | all_laidout_tensors, 226 | tpu_ordinal_function=_tpu_ordinal_function_impl, 227 | placement_function=_placement_function_impl) 228 | 229 | input_initializers = [ds.initializer for ds in ds_iterator] 230 | 231 | return input_initializers, enqueue_ops, infeed_queue 232 | 233 | 234 | def infeed_from_session(params: ModelParameter): 235 | num_cores = params.mesh_impl.device_assignment.num_replicas 236 | d_assignment = params.mesh_impl.device_assignment 237 | ordered_ordinals = [] 238 | ordered_hosts = [] 239 | 240 | for pnum in range(num_cores): 241 | physical_pnum = params.mesh_impl.l2p(pnum) 242 | host_device = d_assignment.host_device(replica=physical_pnum) 243 | ordered_hosts.append(host_device) 244 | 245 | # For MTF, there's always 1 core per replica. So logical_core=0. 246 | ordered_ordinals.append(d_assignment.tpu_ordinal(replica=physical_pnum, logical_core=0)) 247 | 248 | def _tpu_ordinal_function_impl(pnum): 249 | return ordered_ordinals[pnum] 250 | 251 | def _placement_function_impl(pnum): 252 | return ordered_hosts[pnum] 253 | 254 | prompt = tf1.placeholder(dtype=tf.int32, shape=[t.size for t in params.token_dim_shape]) 255 | iter_pos = tf1.placeholder(dtype=tf.int32, shape=[1]) 256 | samp_temp = tf1.placeholder(dtype=tf.float32, shape=[1]) 257 | end_iter = tf1.placeholder(dtype=tf.int32, shape=[1]) 258 | 259 | all_laidout_tensors = [[prompt, prompt, iter_pos, samp_temp, end_iter] for _ in range(params.num_cores)] 260 | 261 | laidout_tensors0 = all_laidout_tensors[0] 262 | infeed_queue = tpu_feed.InfeedQueue( 263 | number_of_tuple_elements=len(laidout_tensors0), 264 | tuple_types=[x.dtype for x in laidout_tensors0], 265 | tuple_shapes=[x.shape for x in laidout_tensors0]) 266 | enqueue_ops = infeed_queue.generate_enqueue_ops(all_laidout_tensors, 267 | tpu_ordinal_function=_tpu_ordinal_function_impl, 268 | placement_function=_placement_function_impl) 269 | 270 | place_holders = [prompt, iter_pos, samp_temp, end_iter] 271 | return enqueue_ops, infeed_queue, place_holders 272 | -------------------------------------------------------------------------------- /src/run/inference.py: -------------------------------------------------------------------------------- 1 | import mesh_tensorflow as mtf 2 | import tensorflow as tf 3 | 4 | from ..dataclass import ModelParameter 5 | from ..model import build 6 | from ..mtf_wrapper import (constant_scalar, log, argmax, reshape, one_hot, equal, less_equal, mtf_range, greater, 7 | reduce_sum, cast, shift, ones, zeros, constant, random_uniform, greater_equal, logical_not, 8 | anonymize, add, negative, multiply) 9 | from ..utils_mtf import concat, pad, utils_slice, to_fp32, weighted_add 10 | 11 | tf1 = tf.compat.v1 12 | Dataset = tf1.data.Dataset 13 | 14 | 15 | def autoregressive_model(params: ModelParameter, 16 | frame_input=None, token_x_input=None, token_y_input=None, 17 | frame_mask_src=None, frame_mask_tag=None, token_mask=None, 18 | initial_pos=None, sampling_temperature=None, end_iterations=None): 19 | if params.use_video: 20 | # todo: fix token shift for video (Jan). 21 | tkn_per_frame = mtf.Dimension("language_token_per_frame", 22 | params.language_token_per_frame) 23 | shape = [params.batch_dim, params.sequence_dim, tkn_per_frame, params.vocab_dim] 24 | 25 | def body_fn(position, token_x_input, token_y_input, frame_input, 26 | frame_mask_src, frame_mask_tag, token_mask, *states): 27 | 28 | _, _, _, _, _, frame_out, token_out = build(params, 29 | frame_input, 30 | ones(params.mesh, [], tf.float32), 31 | ones(params.mesh, [], tf.float32), 32 | token_x_input, 33 | token_y_input, 34 | frame_mask_src, 35 | frame_mask_tag, 36 | token_mask) 37 | 38 | frame_input = weighted_add(pad(frame_out, params.sequence_dim, (0, 1)), frame_input, 39 | one_hot(position, params.frame_input_sequence, dtype=tf.float32)) 40 | 41 | if params.use_language: 42 | one_hot_sequence = one_hot(position, params.sequence_dim, dtype=tf.float32) 43 | token_out = argmax(reshape(token_out, new_shape=shape), params.vocab_dim) 44 | padding_token = to_fp32(equal(token_out, params.padding_token)) 45 | 46 | token_x_input = weighted_add(reshape(token_out, new_shape=params.token_dim_shape), 47 | token_x_input, 48 | one_hot(position, params.sequence_dim, dtype=tf.int32)) 49 | 50 | token_pad = less_equal(mtf_range(params.mesh, tkn_per_frame, dtype=tf.float32), 51 | to_fp32(argmax(padding_token, reduced_dim=tkn_per_frame)), 52 | output_shape=token_out.shape) 53 | 54 | token_mask = weighted_add(reshape(to_fp32(token_pad), new_shape=params.token_dim_shape), 55 | to_fp32(token_mask), one_hot_sequence) 56 | 57 | frame_pad = to_fp32(greater(reduce_sum(padding_token, reduced_dim=tkn_per_frame), 0)) 58 | token_x_input = weighted_add(frame_pad, to_fp32(token_x_input), one_hot_sequence) 59 | 60 | token_x_input = cast(token_x_input, dtype=tf.int32) 61 | 62 | return position + 1, token_x_input, token_y_input, frame_input, frame_mask_src, frame_mask_tag, token_mask 63 | 64 | if token_mask is not None: 65 | token_mask = to_fp32(token_mask) 66 | if frame_mask_src is not None: 67 | frame_mask_src = to_fp32(frame_mask_src) 68 | if frame_mask_tag is not None: 69 | frame_mask_tag = to_fp32(frame_mask_tag) 70 | 71 | while_loop_inputs = [zeros(params.mesh, [], tf.int32) + params.initial_autoregressive_position, 72 | token_x_input, token_y_input, frame_input, frame_mask_src, frame_mask_tag, 73 | token_mask] 74 | 75 | else: # -> params.use_language 76 | def body_fn(position, token_x, token_y, sampling_temperature, *states): 77 | _, _, _, _, _, _, token_out = build(params, 78 | ones(params.mesh, [], tf.float32), 79 | ones(params.mesh, [], tf.float32), 80 | ones(params.mesh, [], tf.float32), 81 | token_x, 82 | token_y, 83 | ones(params.mesh, [], tf.float32), 84 | ones(params.mesh, [], tf.float32), 85 | ones(params.mesh, [], tf.float32)) 86 | 87 | one_hot_mask = one_hot(position, output_dim=params.sequence_dim, dtype=tf.int32) 88 | token_out = add(cast(token_out, dtype=tf.float32), 89 | multiply(log(negative(log(random_uniform(params, token_out.shape, 90 | maxval=1, minval=1e-9, dtype=tf.float32)))), 91 | negative(sampling_temperature))) 92 | token_out = argmax(token_out, params.vocab_dim) 93 | 94 | token_out = shift(token_out, offset=1, dim=params.sequence_dim, wrap=False) 95 | 96 | return (add(position, 1), weighted_add(token_out, token_x, one_hot_mask), 97 | token_y, cast(sampling_temperature, dtype=tf.float32)) 98 | 99 | if initial_pos is None: 100 | initial_pos = constant(params, value=params.initial_autoregressive_position, dtype=tf.int32) 101 | 102 | if params.debug_sample: 103 | token_initial_pos_mask = less_equal(mtf_range(params.mesh, params.sequence_dim, dtype=tf.int32), 104 | initial_pos) 105 | token_initial_pos_mask = cast(token_initial_pos_mask, tf.int32) 106 | token_x_input_a = utils_slice(token_x_input, 0, 1, dim=params.batch_dim) 107 | token_x_input_b = multiply(token_x_input_a, token_initial_pos_mask) 108 | token_x_input = concat([token_x_input_a, token_x_input_b], dim=token_x_input_a.shape[0]) 109 | 110 | if sampling_temperature is None: 111 | sampling_temperature = constant_scalar(params, params.sampling_temperature, dtype=tf.float32) 112 | 113 | if end_iterations is None: 114 | end_iterations = constant(params, value=params.sequence_length, dtype=tf.int32) 115 | 116 | while_loop_inputs = [initial_pos, token_x_input, token_y_input, sampling_temperature] 117 | 118 | def cond_fn(position, *states): 119 | is_done = greater_equal(position, end_iterations) 120 | is_done = reduce_sum(is_done) 121 | 122 | return logical_not(is_done) 123 | 124 | loop_out = mtf.while_loop(cond_fn=cond_fn, body_fn=body_fn, inputs=while_loop_inputs) 125 | 126 | token_out = None 127 | frame_out = None 128 | if params.use_language: 129 | token_out = loop_out[1] 130 | if params.use_video: 131 | frame_out = loop_out[3] 132 | 133 | return token_out, frame_out 134 | 135 | 136 | def get_infrence_model(params: ModelParameter): 137 | def infrence_model(frame_input, cat_mask_src, cat_mask_tag, token_x_input, token_y_input, frame_mask_src, 138 | frame_mask_tag, 139 | token_mask, initial_pos, sampling_temperature, end_iterations): 140 | 141 | if params.use_autoregressive_sampling: 142 | token_out, frame_out = autoregressive_model(params, 143 | frame_input, 144 | token_x_input, 145 | token_y_input, 146 | frame_mask_src, 147 | frame_mask_tag, 148 | token_mask, 149 | initial_pos, 150 | sampling_temperature, 151 | end_iterations) 152 | else: 153 | _, _, _, _, _, frame_out, token_out = build(params, 154 | frame_input, 155 | cat_mask_src, 156 | cat_mask_tag, 157 | token_x_input, 158 | token_y_input, 159 | frame_mask_src, 160 | frame_mask_tag, 161 | token_mask) 162 | 163 | if params.use_language: 164 | token_out = anonymize(token_out) 165 | if params.use_video: 166 | frame_out = anonymize(frame_out) 167 | 168 | return token_out, frame_out 169 | 170 | return infrence_model 171 | -------------------------------------------------------------------------------- /src/run/train.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import tensorflow as tf 5 | 6 | from ..dataclass import ModelParameter 7 | from ..model import build 8 | from ..mtf_wrapper import constant_scalar, get_variable_for_tensor 9 | from ..optimizer import get_optimizer 10 | from ..utils_core import NAME_INDICES 11 | from ..utils_mtf import unbind, deduplicate 12 | 13 | 14 | def none_cast(x: typing.Optional[mtf.Tensor]): 15 | if x is not None: 16 | return mtf.cast(x, tf.float64) 17 | 18 | 19 | def get_train_model(params: ModelParameter): 20 | def train_model(frame_input, cat_mask_src, cat_mask_tag, token_x_input, token_y_input, 21 | frame_mask_src, frame_mask_tag, token_mask, manual_global_step): 22 | slice_dim = mtf.Dimension("batch_slice", params.macro_batching) 23 | 24 | def inp_slice_fn(x: typing.Optional[mtf.Tensor]): 25 | if x is None: 26 | return [None] * params.macro_batching 27 | x = mtf.replace_dimensions(x, params.macro_batch_dim, [slice_dim, params.batch_dim]) 28 | return unbind(x, slice_dim) 29 | 30 | inputs = (frame_input, cat_mask_src, cat_mask_tag, token_x_input, token_y_input, frame_mask_src, frame_mask_tag, 31 | token_mask) 32 | inputs = list(zip(*map(inp_slice_fn, inputs))) 33 | idx = constant_scalar(params, 0, dtype=params.optimizer_calculation_dtype) 34 | all_ops = [] 35 | for i, args in enumerate(inputs, 1): 36 | params.is_last_mbatch = i == len(inputs) 37 | params.macro_batch_index = i 38 | NAME_INDICES.clear() 39 | params.cached_parameters.clear() 40 | loss, loss_list, video_loss, accuracy, token_loss, frame_out, token_out = build(params, *args) 41 | loss = none_cast(loss) 42 | video_loss = none_cast(video_loss) 43 | token_loss = none_cast(token_loss) 44 | if params.multi_loss_strategy == "linear": 45 | loss_list = [loss] 46 | elif params.multi_loss_strategy == "mgda": 47 | loss_list = [none_cast(x) for x in loss_list] + [None] 48 | graph: mtf.Graph = params.mesh.graph 49 | update_ops, learning_rate = get_optimizer(loss_list, params, idx, 50 | "update" if params.is_last_mbatch else "add") 51 | if params.is_last_mbatch: 52 | ops = graph.operations.copy() 53 | graph.operations.clear() 54 | graph.operations.extend(all_ops) 55 | graph.operations.extend(ops) 56 | graph._all_variables = deduplicate(graph.all_variables) 57 | graph._trainable_variables = deduplicate(graph.trainable_variables) 58 | graph._operations = deduplicate(graph.operations) 59 | else: 60 | idx += 1 61 | for tensor in update_ops: 62 | tensor: mtf.Tensor = tensor 63 | op = get_variable_for_tensor(tensor) 64 | tensor = mtf.stop_gradient(mtf.cast(tensor, op.activation_dtype)) 65 | params.variable_cache[op.full_name] = tensor.operation 66 | ops = graph.operations.copy() 67 | all_ops.extend([op for op in ops if not isinstance(op, mtf.Assign)]) 68 | graph.operations.clear() 69 | graph.operations.extend([op for op in ops if isinstance(op, mtf.Variable)]) 70 | 71 | return frame_out, token_out, learning_rate, loss, video_loss, token_loss, accuracy, update_ops, {} 72 | 73 | return train_model 74 | -------------------------------------------------------------------------------- /src/run/utils_run.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import jsonpickle 4 | import mesh_tensorflow as mtf 5 | import tensorflow as tf 6 | from tensorflow.python.ops import summary_ops_v2 as summary 7 | from tensorflow.python.tpu import tpu 8 | 9 | from .. import tf_wrapper as tfw 10 | from ..dataclass import ModelParameter 11 | from ..mtf_wrapper import import_laid_out_tensor 12 | from ..utils_core import color_print 13 | 14 | tf1 = tf.compat.v1 15 | Dataset = tf1.data.Dataset 16 | 17 | 18 | class CheckpointLoaderHook(tf.estimator.SessionRunHook): 19 | """Load checkpoint right after the session started.""" 20 | 21 | def __init__(self, checkpoint_dir): 22 | self.checkpoint_dir = checkpoint_dir 23 | 24 | def after_create_session(self, session, coord): 25 | saver_collection = tf1.get_collection(tf1.GraphKeys.SAVERS) 26 | if saver_collection: 27 | check_point = tf.train.latest_checkpoint(self.checkpoint_dir) 28 | if check_point: 29 | saver_collection[0].restore(session, check_point) 30 | 31 | 32 | def add_summary(tf_loss, value, global_step): 33 | """Add all summaries.""" 34 | 35 | def _host_loss_summary(local_tf_loss, local_value, local_global_step): 36 | """Add summary.scalar in host side.""" 37 | gs = tfw.cast(local_global_step, tf.int64) 38 | with tfw.control_dependencies([summary.scalar(key, local_value[key], step=gs) for key in local_value.keys()]): 39 | return tfw.identity(local_tf_loss) 40 | 41 | # Cast the global step to tf.int32, since 42 | # outside_compilation does not support tf.int64. 43 | return tpu.outside_compilation(_host_loss_summary, tf_loss, value, tfw.cast(global_step, tf.int32)) 44 | 45 | 46 | def add_histogram(tf_loss, value, global_step): 47 | """Add all summaries.""" 48 | 49 | def _host_loss_summary(local_tf_loss, local_value, local_global_step): 50 | """Add summary.scalar in host side.""" 51 | gs = tfw.cast(local_global_step, tf.int64) 52 | with tfw.control_dependencies([summary.histogram(key, local_value[key], step=gs) 53 | for key in local_value.keys()]): 54 | return tfw.identity(local_tf_loss) 55 | 56 | # Cast the global step to tf.int32, since 57 | # outside_compilation does not support tf.int64. 58 | return tpu.outside_compilation(_host_loss_summary, tf_loss, value, tfw.cast(global_step, tf.int32)) 59 | 60 | 61 | def _import_tensor(params: ModelParameter, tensor, shape, name): 62 | return import_laid_out_tensor(params, params.mesh_impl.LaidOutTensor([tensor]), shape, name) 63 | 64 | 65 | def analyze_model(params: ModelParameter, time_to_build: float, graph: mtf.Graph): 66 | color_print(params, f"Built in {time_to_build:.1f}s") 67 | param_count = int(sum([variable.size for variable in graph.trainable_variables])) 68 | var_count = int(sum([variable.size for variable in graph.all_variables])) 69 | embed_param_count = int(sum([variable.size for variable in 70 | graph.trainable_variables if 'embed' in variable.name])) 71 | gather_param_count = int(sum([variable.size for variable in 72 | graph.trainable_variables if 'gather' in variable.name])) 73 | body_param_count = int(sum([variable.size for variable in 74 | graph.trainable_variables if 'body' in variable.name])) 75 | 76 | print('') 77 | 78 | constant = ' variables: ' 79 | variable_mapping = [('Core', param_count - embed_param_count), 80 | ('Embedding', embed_param_count - gather_param_count), 81 | ('Sparse', gather_param_count), 82 | ('Full Model', body_param_count), 83 | ('Untrainable', var_count - param_count), 84 | ('', 0), 85 | ('Total trainable', param_count), 86 | ('Total', var_count)] 87 | variable_mapping = [(name, f'{int(count):,}') for name, count in variable_mapping] 88 | max_str = max(len(name) for name, _ in variable_mapping) 89 | max_int = max(len(count) for _, count in variable_mapping) 90 | for name, count in variable_mapping: 91 | if not name: 92 | color_print(params, '-' * (max_str + max_int + len(constant))) 93 | continue 94 | color_print(params, f'{name:<{max_str}s}{constant}{count:>{max_int}s}') 95 | 96 | color_print(params, "\nDimensions:") 97 | for dim_name in sorted(list(set([item for variable in graph.all_variables 98 | for item in variable.shape.dimension_names]))): 99 | color_print(params, dim_name) 100 | print('') 101 | 102 | model_size = {'model_variables': int(param_count - embed_param_count), 103 | 'embedding_variables': int(embed_param_count), 104 | 'body_variables': int(body_param_count), 105 | 'untrainable_variables': int(var_count - param_count), 106 | 'total_trainable_variables': int(param_count), 107 | 'total_variables': int(var_count) 108 | } 109 | 110 | if params.train: 111 | size_dump = jsonpickle.dumps(model_size, indent=4) 112 | with tf.io.gfile.GFile(f"{params.model_path}/model_size.info", 'w') as f: 113 | f.write(size_dump) 114 | 115 | 116 | def rep_batch(params: ModelParameter, shape: [mtf.Shape, typing.List[mtf.Dimension]]): 117 | if params.macro_batching > 1 and params.train: 118 | return mtf.replace_dimensions(shape, params.batch_dim, params.macro_batch_dim) 119 | return shape 120 | -------------------------------------------------------------------------------- /src/tf_wrapper.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import tensorflow 4 | import tensorflow as tf 5 | 6 | from .utils_core import scoped as general_scoped 7 | 8 | tf1 = tf.compat.v1 9 | 10 | 11 | def scoped(name: str, fn: typing.Callable, *args, **kwargs): 12 | return general_scoped(f"tf_{name}", fn, *args, **kwargs) 13 | 14 | 15 | def softplus(tensor: tf.Tensor) -> tf.Tensor: 16 | return scoped("softplus", tf.math.softplus, tensor) 17 | 18 | 19 | def divide(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 20 | return scoped("divide", lambda x, y: x / y, x1, x2) 21 | 22 | 23 | def multiply(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 24 | return scoped("multiply", lambda x, y: x * y, x1, x2) 25 | 26 | 27 | def add(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 28 | return scoped("add", lambda x, y: x + y, x1, x2) 29 | 30 | 31 | def subtract(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 32 | return scoped("subtract", lambda x, y: x - y, x1, x2) 33 | 34 | 35 | def pow(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 36 | return scoped("pow", lambda x, y: x ** y, x1, x2) 37 | 38 | 39 | def maximum(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 40 | return scoped("maximum", tf.maximum, x1, x2) 41 | 42 | 43 | def equal(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 44 | return scoped("equal", tf.equal, x1, x2) 45 | 46 | 47 | def greater(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 48 | return scoped("greater", tf.greater, x1, x2) 49 | 50 | 51 | def less(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 52 | return scoped("less", tf.less, x1, x2) 53 | 54 | 55 | def less_equal(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 56 | return scoped("less_equal", tf.less_equal, x1, x2) 57 | 58 | 59 | def greater_equal(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 60 | return scoped("greater_equal", tf.greater_equal, x1, x2) 61 | 62 | 63 | def minimum(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: 64 | return scoped("minimum", tf.minimum, x1, x2) 65 | 66 | 67 | def assign(var: tf.Tensor, val: tf.Tensor) -> tf.Tensor: 68 | return scoped("minimum", tf1.assign, var, val) 69 | 70 | 71 | def assign_add(var: tf.Tensor, val: tf.Tensor) -> tf.Tensor: 72 | return scoped("assign_add", tf1.assign_add, var, val) 73 | 74 | 75 | def assign_sub(var: tf.Tensor, val: tf.Tensor) -> tf.Tensor: 76 | return scoped("assign_sub", tf1.assign_sub, var, val) 77 | 78 | 79 | def group(ops: list): 80 | return scoped("group", tf.group, ops) 81 | 82 | 83 | def identity(tensor: tf.Tensor): 84 | return scoped("identity", tf.identity, tensor) 85 | 86 | 87 | def control_dependencies(ops: list): 88 | return scoped("control_dependencies", tf.control_dependencies, ops) 89 | 90 | 91 | def slice(tensor: tf.Tensor, start: int, end: int) -> tf.Tensor: 92 | return scoped("slice", tf.slice, tensor, start, end) 93 | 94 | 95 | def constant(value: float, dtype: tensorflow.DType): 96 | return scoped("constant", tf.constant, value, dtype, []) 97 | 98 | 99 | def tanh(tensor: tf.Tensor) -> tf.Tensor: 100 | return scoped("tanh", tf.math.tanh, tensor) 101 | 102 | 103 | def square(tensor: tf.Tensor) -> tf.Tensor: 104 | return scoped("square", tf.math.square, tensor) 105 | 106 | 107 | def sigmoid(tensor: tf.Tensor) -> tf.Tensor: 108 | return scoped("sigmoid", tf.math.sigmoid, tensor) 109 | 110 | 111 | def abs(tensor: tf.Tensor) -> tf.Tensor: 112 | return scoped("abs", tf.math.abs, tensor) 113 | 114 | 115 | def exp(tensor: tf.Tensor) -> tf.Tensor: 116 | return scoped("exp", tf.math.exp, tensor) 117 | 118 | 119 | def sin(tensor: tf.Tensor) -> tf.Tensor: 120 | return scoped("sin", tf.math.sin, tensor) 121 | 122 | 123 | def einsum(equation: str, *inputs: tf.Tensor) -> tf.Tensor: 124 | return scoped("einsum", tf.einsum, equation, *inputs) 125 | 126 | 127 | def mod(tensor: tf.Tensor, modulo: int) -> tf.Tensor: 128 | return scoped("mod", lambda x, y: (x % y), tensor, modulo) 129 | 130 | 131 | def reshape(tensor: tf.Tensor, new_shape: typing.List[int]): 132 | return scoped("reshape", tf.reshape, tensor, new_shape) 133 | 134 | 135 | def tf_range(start: int, end: int, step: int): 136 | return scoped("range", tf.range, start, end, step) 137 | 138 | 139 | def cast(tensor: tf.cast, dtype: tf.DType): 140 | return scoped("reshape", tf.cast, tensor, dtype) 141 | -------------------------------------------------------------------------------- /src/utils_core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic utility functions that are called frequently across modules. 3 | """ 4 | import functools 5 | import typing 6 | from datetime import datetime, timezone 7 | 8 | import tensorflow as tf 9 | 10 | from .dataclass import ModelParameter 11 | 12 | NAME_INDICES = {} 13 | tf1 = tf.compat.v1 14 | 15 | 16 | def scoped(name: str, fn: typing.Callable, *args, **kwargs): 17 | name = random_name(name) 18 | with tf1.variable_scope(f'{name}v'), tf1.name_scope(f'{name}n'): 19 | return fn(*args, **kwargs) 20 | 21 | 22 | def default(value: typing.Any, default_value: typing.Any) -> typing.Any: 23 | """ 24 | Return a default value if a given value is None. 25 | This is merely a comfort function to avoid typing out "x if x is None else y" over and over again. 26 | :param value: value that can be None 27 | :param default_value: default if value is None 28 | :return: value or default_value 29 | """ 30 | return default_value if value is None else value 31 | 32 | 33 | def chunks(lst: typing.List, n: int): 34 | """ 35 | Yield successive n-sized chunks from lst. 36 | :param lst: the list to be split. 37 | :param n: the chunk size. 38 | """ 39 | for i in range(0, len(lst), n): 40 | yield lst[i:i + n] 41 | 42 | 43 | def timestamp(): 44 | return "{}".format(datetime.now(timezone.utc).isoformat()) 45 | 46 | 47 | def color_print(params: ModelParameter, string): 48 | print(f"{params.own_color}{timestamp()} {string}{params.other_color}", flush=True) 49 | 50 | 51 | def int_reduce_mul(*integers: typing.Union[typing.List[typing.Iterable[int]], typing.List[int]]) -> int: 52 | if isinstance(integers[0], typing.Iterable): 53 | integers = integers[0] 54 | return functools.reduce(int.__mul__, integers) 55 | 56 | 57 | def random_name(prefix="") -> str: 58 | """ 59 | Generates a random name based on the globally set seed using python's random module. 60 | Each name has 256 bits of entropy and a final length of 44 base64 encoded characters. 61 | For the sake of convenience, special characters are removed from the final string. 62 | :return: random string 63 | """ 64 | if prefix not in NAME_INDICES: 65 | NAME_INDICES[prefix] = -1 66 | NAME_INDICES[prefix] += 1 67 | return f'{prefix}{NAME_INDICES[prefix]}' 68 | -------------------------------------------------------------------------------- /tests/backend.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from src.dataclass import BlockArgs, ModelParameter 8 | 9 | tf1 = tf.compat.v1 10 | 11 | tf1.disable_v2_behavior() 12 | 13 | RELU_STD = 1 / 1.42 14 | 15 | 16 | class BaseTest: 17 | def __init__(self, 18 | *args, 19 | mesh_shape: typing.Union[None, list, str] = None, 20 | layout_rules: typing.Union[None, list, str] = None, 21 | devices: typing.Union[None, typing.List[str]] = None, 22 | **kwargs): 23 | self.mesh_shape = [] if mesh_shape is None else mesh_shape 24 | self.layout_rules = [] if layout_rules is None else layout_rules 25 | self.devices = ["cpu:0"] if devices is None else devices 26 | 27 | self.session_config = tf1.ConfigProto() 28 | self.session_config.allow_soft_placement = True 29 | 30 | def _close_session(self): 31 | default_session = tf1.get_default_session() 32 | if default_session is not None: 33 | default_session.close() 34 | 35 | def build(self, graph: mtf.Graph, mesh: mtf.Mesh, 36 | *args, **kwargs) -> typing.Tuple[typing.List[mtf.Tensor], typing.Any]: 37 | pass 38 | 39 | def run(self, sess: tf1.Session, outputs: typing.List[tf.Tensor], args: typing.Any) -> None: 40 | pass 41 | 42 | def __call__(self, *args, **kwargs) -> None: 43 | self._close_session() 44 | 45 | with tf.Graph().as_default() as tf_graph, tf1.Session(config=self.session_config, graph=tf_graph) as sess: 46 | graph = mtf.Graph() 47 | mesh = mtf.Mesh(graph, "MESH") 48 | 49 | outputs, args = self.build(graph, mesh, *args, **kwargs) 50 | 51 | mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(self.mesh_shape, self.layout_rules, self.devices) 52 | lowering = mtf.Lowering(graph, {mesh: mesh_impl}) 53 | 54 | outputs = [lowering.export_to_tf_tensor(output) for output in outputs] 55 | 56 | sess.run(tf1.global_variables_initializer()) 57 | sess.run(lowering.copy_masters_to_slices()) 58 | 59 | self.run(sess, outputs, args) 60 | 61 | 62 | class OperationTest(BaseTest): 63 | def __init__(self, **kwargs): 64 | super(OperationTest, self).__init__(**kwargs) 65 | params = ModelParameter(kwargs) 66 | self.fp16 = "16" in (kwargs['calculation_dtype'] + kwargs['slice_dtype'] + kwargs['storage_dtype']) 67 | self.args = BlockArgs(params, None, ['']) 68 | self.args.params.layout = self.layout_rules 69 | self.args.params.mesh_shape = self.mesh_shape 70 | self.tolerance = 1 / (params.train_batch_size * params.sequence_length * params.features) ** (0.05 if self.fp16 else 1 / 3) 71 | 72 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 73 | pass 74 | 75 | def _run(self, out: np.array) -> None: 76 | pass 77 | 78 | def _is_close(self, x: np.array, y: np.array, rtol: float = 1e-3): 79 | assert np.isclose(x, y, rtol, self.tolerance) 80 | 81 | def build(self, graph: mtf.Graph, mesh: mtf.Mesh, 82 | *args, **kwargs) -> typing.Tuple[typing.List[mtf.Tensor], typing.Any]: 83 | params = self.args.params 84 | params.mesh = mesh 85 | params.graph = graph 86 | inp = mtf.random_normal(mesh, [params.batch_dim, params.sequence_dim] + params.feature_dims, 87 | dtype=params.variable_dtype.activation_dtype) 88 | 89 | return [self._build(inp)], None 90 | 91 | def run(self, sess: tf1.Session, outputs: typing.List[tf.Tensor], args: typing.Any) -> None: 92 | self._run(sess.run(outputs)[0]) 93 | 94 | 95 | def curry_class(base: typing.Type, **kwargs) -> typing.Callable: 96 | def _fn(**kw): 97 | return base(**kw, **kwargs) 98 | 99 | _fn.__name__ = f'{base.__name__}({",".join(f"{k}={v}" for k, v in kwargs.items())})' 100 | return _fn 101 | -------------------------------------------------------------------------------- /tests/basic_linear_square_test.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from backend import OperationTest, RELU_STD 9 | from src.model import basic 10 | 11 | tf1 = tf.compat.v1 12 | 13 | 14 | class Linear(OperationTest): 15 | @staticmethod 16 | def _activation() -> str: 17 | return '' 18 | 19 | @staticmethod 20 | def _target_std() -> float: 21 | return 1 22 | 23 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 24 | for _ in range(self.args.params.train_steps): 25 | inp = basic.activate(self.args(basic.wrapped_linear(self.args(inp)))([self._activation()])) 26 | return inp 27 | 28 | def _run(self, out: np.array) -> None: 29 | self.tolerance *= self.args.params.train_steps 30 | target_std = self._target_std() 31 | if self.args.params.scale_by_depth: 32 | target_std /= self.args.params.depth ** 0.5 33 | self._is_close(np.mean(np.std(out, -1)), target_std, 0.2) 34 | 35 | 36 | class ReLULinear(Linear): 37 | @staticmethod 38 | def _activation() -> str: 39 | return 'relu' 40 | 41 | def _target_std(self) -> float: 42 | return RELU_STD ** self.args.params.train_steps 43 | 44 | 45 | @pytest.mark.parametrize("test", [Linear, ReLULinear]) 46 | @pytest.mark.parametrize("calculation_dtype", ["bfloat16", "float32"]) 47 | @pytest.mark.parametrize("storage_dtype", ["bfloat16", "float32"]) 48 | @pytest.mark.parametrize("slice_dtype", ["bfloat16", "float32"]) 49 | @pytest.mark.parametrize("embd_per_head", [16, 256]) 50 | @pytest.mark.parametrize("heads", [1, 2, 8]) 51 | @pytest.mark.parametrize("scale_by_depth", [True, False]) 52 | @pytest.mark.parametrize("train_steps", [1, 2, 8]) 53 | def square_matmul_std_test(test: typing.Type, calculation_dtype: str, storage_dtype: str, slice_dtype: str, 54 | embd_per_head: int, heads: int, scale_by_depth: bool, train_steps: int): 55 | test(calculation_dtype=calculation_dtype, storage_dtype=storage_dtype, slice_dtype=slice_dtype, 56 | features_per_head=embd_per_head, heads=heads, batch_size=1, sequence_length=1, group_linear_factor=heads, 57 | scale_by_depth=scale_by_depth, train_steps=train_steps)() 58 | -------------------------------------------------------------------------------- /tests/basic_pointwise_test.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from backend import OperationTest, RELU_STD 9 | from src.model import basic 10 | 11 | tf1 = tf.compat.v1 12 | 13 | 14 | class ReZero(OperationTest): 15 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 16 | return basic.rezero(self.args(inp)) 17 | 18 | @staticmethod 19 | def _run(out: np.array) -> None: 20 | assert np.all(out == 0) 21 | 22 | 23 | class Dropout(OperationTest): 24 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 25 | return basic.dropout(self.args(inp)([f'dropout_rate{self.args.params.input_dropout}'])) 26 | 27 | def _run(self, out: np.array) -> None: 28 | self._is_close(np.sum(out == 0) / out.size, self.args.params.input_dropout, 0.2) 29 | 30 | 31 | class Activate(OperationTest): 32 | @staticmethod 33 | def _activation() -> str: 34 | return '' 35 | 36 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 37 | return basic.activate(self.args(inp)([self._activation()])) 38 | 39 | 40 | class Identity(Activate): 41 | def _run(self, out: np.array) -> None: 42 | self._is_close(np.std(out), 1) 43 | 44 | 45 | class ReLU(Activate): 46 | @staticmethod 47 | def _activation() -> str: 48 | return 'relu' 49 | 50 | def _run(self, out: np.array) -> None: 51 | self._is_close(np.std(out), RELU_STD, 0.2) 52 | 53 | 54 | @pytest.mark.parametrize("test", [ReZero, Dropout, Identity, ReLU]) 55 | @pytest.mark.parametrize("calculation_dtype", ["bfloat16", "float32"]) 56 | @pytest.mark.parametrize("storage_dtype", ["bfloat16", "float32"]) 57 | @pytest.mark.parametrize("slice_dtype", ["bfloat16", "float32"]) 58 | @pytest.mark.parametrize("embd_per_head", [64, 256, 1024]) 59 | @pytest.mark.parametrize("batch_size", [16, 256]) 60 | def pointwise_test(test: typing.Type, calculation_dtype: str, storage_dtype: str, slice_dtype: str, embd_per_head: int, 61 | batch_size: int): 62 | test(calculation_dtype=calculation_dtype, storage_dtype=storage_dtype, slice_dtype=slice_dtype, 63 | features_per_head=embd_per_head, heads=1, batch_size=batch_size, sequence_length=1)() 64 | -------------------------------------------------------------------------------- /tests/variable_test.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import mesh_tensorflow as mtf 4 | import numpy as np 5 | import pytest 6 | import tensorflow as tf 7 | 8 | from backend import OperationTest, curry_class 9 | from src.model import backend 10 | from src.utils_mtf import get_intermediate, deduplicate 11 | 12 | tf1 = tf.compat.v1 13 | 14 | 15 | class VariableCheck(OperationTest): 16 | def _in_dims(self) -> typing.List[mtf.Dimension]: 17 | return [] 18 | 19 | def _out_dims(self) -> typing.List[mtf.Dimension]: 20 | return [] 21 | 22 | def _shape(self) -> typing.List[mtf.Dimension]: 23 | return deduplicate(self._in_dims() + self._out_dims()) 24 | 25 | @staticmethod 26 | def _target_std() -> float: 27 | return 0 28 | 29 | @staticmethod 30 | def _target_mean() -> float: 31 | return 0 32 | 33 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 34 | return mtf.zeros(inp.mesh, self._shape()) 35 | 36 | def _run(self, out: np.array) -> None: 37 | self._is_close(np.std(out), self._target_std()) 38 | self._is_close(np.mean(out), self._target_mean()) 39 | 40 | 41 | class NormalCheck(VariableCheck): 42 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 43 | return backend.normal_var(self.args, self._shape(), self._target_std(), self._target_mean()) 44 | 45 | 46 | class NormShiftCheck(NormalCheck): 47 | @staticmethod 48 | def _target_std() -> float: 49 | return 0.02 50 | 51 | @staticmethod 52 | def _target_mean() -> float: 53 | return 0 54 | 55 | 56 | class NormScaleCheck(NormalCheck): 57 | @staticmethod 58 | def _target_std() -> float: 59 | return 0.02 60 | 61 | @staticmethod 62 | def _target_mean() -> float: 63 | return 1 64 | 65 | 66 | class EmbeddingCheck(NormalCheck): 67 | def _target_std(self) -> float: 68 | return self.args.params.embedding_stddev 69 | 70 | @staticmethod 71 | def _target_mean() -> float: 72 | return 0 73 | 74 | 75 | class OrthogonalCheck(VariableCheck): 76 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 77 | return backend.orthogonal_var(self.args, self._shape()) 78 | 79 | def _target_std(self) -> float: 80 | size = np.prod([d.size for d in self._shape()]) 81 | feature_dims = self.args.params.feature_dims 82 | inp = feature_dims if self._in_dims() == feature_dims or self._out_dims() == feature_dims else self._in_dims() 83 | intermediate = np.prod([d.size for d in inp]) 84 | min_fan = min(size // intermediate, intermediate) 85 | std = ((min_fan * (1 - min_fan / size) ** 2 + (size - min_fan) * (min_fan / size) ** 2) / size) ** 0.5 86 | if not self.args.params.scale_by_depth: 87 | return std 88 | return std / self.args.params.depth ** 0.5 89 | 90 | 91 | class AllSumFeedForwardIn(OrthogonalCheck): 92 | def _in_dims(self) -> typing.List[mtf.Dimension]: 93 | return self.args.params.feature_dims 94 | 95 | def _out_dims(self) -> typing.List[mtf.Dimension]: 96 | return self.args.params.intermediate 97 | 98 | 99 | class AllSumFeedForwardOut(OrthogonalCheck): 100 | def _in_dims(self) -> typing.List[mtf.Dimension]: 101 | return self.args.params.intermediate 102 | 103 | def _out_dims(self) -> typing.List[mtf.Dimension]: 104 | return self.args.params.feature_dims 105 | 106 | 107 | class GroupFeedForwardIn(OrthogonalCheck): 108 | def _in_dims(self) -> typing.List[mtf.Dimension]: 109 | return get_intermediate(self.args(['group'])) 110 | 111 | def _out_dims(self) -> typing.List[mtf.Dimension]: 112 | return self.args.params.feature_dims 113 | 114 | 115 | class GroupFeedForwardOut(OrthogonalCheck): 116 | def _in_dims(self) -> typing.List[mtf.Dimension]: 117 | return self.args.params.feature_dims 118 | 119 | def _out_dims(self) -> typing.List[mtf.Dimension]: 120 | return get_intermediate(self.args(['group'])) 121 | 122 | 123 | class SharedOrthogonalVariable(GroupFeedForwardIn): 124 | def _get_shared_var(self, idx: int) -> mtf.Tensor: 125 | with tf1.variable_scope(f"gpt/body/{self.args.params.attention_idx}_0/feed_forward_{idx}/"): 126 | out = backend.orthogonal_var(self.args(['shared']), self._shape()) 127 | self.args.params.attention_idx += idx == 0 128 | return out 129 | 130 | def _run(self, out: np.array) -> None: 131 | assert all(np.array_equal(out[0], out[i]) for i in range(1, out.shape[0])) 132 | 133 | 134 | class SingleSharedVariable(SharedOrthogonalVariable): 135 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 136 | return mtf.stack([self._get_shared_var(0) for _ in range(self.args.params.depth)], "items") 137 | 138 | 139 | class DoubleSharedVariable(SharedOrthogonalVariable): 140 | def _build(self, inp: mtf.Tensor) -> mtf.Tensor: 141 | return mtf.stack([mtf.stack([self._get_shared_var(0), self._get_shared_var(1)], "non_shared") 142 | for _ in range(self.args.params.depth)], "items") 143 | 144 | 145 | @pytest.mark.parametrize("test", 146 | [curry_class(AllSumFeedForwardIn, scale_by_depth=True), 147 | curry_class(AllSumFeedForwardOut, scale_by_depth=True), 148 | curry_class(GroupFeedForwardIn, scale_by_depth=True), 149 | curry_class(GroupFeedForwardOut, scale_by_depth=True), 150 | curry_class(AllSumFeedForwardIn, scale_by_depth=False), 151 | curry_class(AllSumFeedForwardOut, scale_by_depth=False), 152 | curry_class(GroupFeedForwardIn, scale_by_depth=False), 153 | curry_class(GroupFeedForwardOut, scale_by_depth=False), 154 | NormShiftCheck, NormScaleCheck, EmbeddingCheck, SingleSharedVariable, DoubleSharedVariable]) 155 | @pytest.mark.parametrize("calculation_dtype", ["bfloat16", "float32"]) 156 | @pytest.mark.parametrize("storage_dtype", ["bfloat16", "float32"]) 157 | @pytest.mark.parametrize("slice_dtype", ["bfloat16", "float32"]) 158 | @pytest.mark.parametrize("embd_per_head", [1, 16, 256]) 159 | @pytest.mark.parametrize("heads", [1, 4]) 160 | def op_test(test: typing.Type, calculation_dtype: str, storage_dtype: str, slice_dtype: str, embd_per_head: int, 161 | heads: int): 162 | test(calculation_dtype=calculation_dtype, storage_dtype=storage_dtype, slice_dtype=slice_dtype, 163 | features_per_head=embd_per_head, heads=heads, batch_size=1, sequence_length=1)() 164 | --------------------------------------------------------------------------------