├── .gitignore ├── README.md ├── assets └── weights_urls.txt ├── minimal20b ├── __init__.py ├── constants.py ├── create.py ├── generate.py ├── model.py └── rotary.py ├── minimal20b_flax ├── __init__.py ├── create.py ├── generate.py ├── layernorm.py ├── model.py ├── model_xmap.py └── utils.py ├── requirements.txt ├── requirements_flax.txt └── scripts ├── eval ├── eval_harness.py └── requirements.txt └── eval_flax ├── eval_harness.py ├── eval_harness_xmap.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal GPT-NeoX-20B 2 | 3 | This is a fairly minimal implementation of GPT-NeoX-20B in PyTorch. It is meant primarily as an educational/reference implementation, rather than an optimized or feature-full implementation. 4 | 5 | GPT-NeoX-20B is a 20B-parameter autoregressive Transformer model developed by [EleutherAI](https://www.eleuther.ai/) with the support of [CoreWeave](https://www.coreweave.com/), trained using the [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) library. 6 | 7 | Some notes about the model: 8 | 9 | * The model weights and activations come in half-precision (fp16). 10 | * In fp16, loading the model weights requires about 40GB of GPU memory. Running inference on a single batch requires some more. 11 | * The model supports up to a maximum sequence length of 2048 tokens. 12 | 13 | ## Setup 14 | 15 | ### Installation 16 | 17 | Install PyTorch with your appropriate CUDA version, and then install from the `requirements.txt` (basically just `tokenizers`). 18 | 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ### Download weights 24 | 25 | Following the [NeoX guide](https://github.com/EleutherAI/gpt-neox#download-links), download the model weights and tokenizer JSON file with the following command: 26 | 27 | ```bash 28 | wget --cut-dirs=5 -nH -r --no-parent --reject "index.html*" https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/ -P 20B_checkpoints 29 | ``` 30 | 31 | You can also manually down them from [here](https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/). Because of the size of the model, the model weights are broken into multiple files, based on the DeepSpeed save format. 32 | 33 | #### Generate text 34 | 35 | Here is some sample code to generate text. Note that since we are greedily decoding with no fancy tricks, there tends to be quite some repetitiion in generations. 36 | 37 | ```python 38 | import minimal20b 39 | import torch 40 | model = minimal20b.create_model( 41 | "/path/to/20B_checkpoints/global_step150000", 42 | use_cache=True, 43 | device="cuda:0", 44 | ) 45 | tokenizer = minimal20b.create_tokenizer( 46 | "/path/to/20B_checkpoints/20B_tokenizer.json", 47 | ) 48 | with torch.inference_mode(): 49 | minimal20b.greedy_generate_text( 50 | model, tokenizer, 51 | "GPTNeoX20B is a 20B-parameter autoregressive Transformer model developed by EleutherAI.", 52 | max_seq_len=100, 53 | ) 54 | ``` 55 | 56 | #### Evaluation 57 | 58 | To run evaluation with the LM-eval-harness, you will need to install some additional dependencies (mostly just the eval harness library): 59 | 60 | ```bash 61 | pip install -r scripts/eval/requirements.txt 62 | ``` 63 | 64 | Most datasets are automatically downloaded via Hugging Face `datasets`, but if you are evaluating on lambada, you will need to separately download the data. 65 | 66 | ```bash 67 | mkdir -p data/lambada 68 | wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl 69 | ``` 70 | 71 | Then, you can run the following command. 72 | 73 | ```bash 74 | python scripts/eval/eval_harness.py \ 75 | --model_path /path/to/20B_checkpoints/global_step150000 \ 76 | --tokenizer_path /path/to/20B_checkpoints/20B_tokenizer.json \ 77 | --tasks lambada,anli_r1,anli_r2,anli_r3,wsc,winogrande,hellaswag,piqa 78 | ``` 79 | 80 | | Task | Metric | NeoX Impl (2 GPU) | This Repo (1 GPU) | 81 | |------------|-----------------|-------------------|-------------------| 82 | | anli_r1 | acc | 0.3270 | 0.3300 | 83 | | | acc_stderr | 0.0148 | 0.0149 | 84 | | anli_r2 | acc | 0.3410 | 0.3420 | 85 | | | acc_stderr | 0.0150 | 0.0150 | 86 | | anli_r3 | acc | 0.3567 | 0.3617 | 87 | | | acc_stderr | 0.0138 | 0.0139 | 88 | | hellaswag | acc | 0.5351 | 0.5335 | 89 | | | acc_stderr | 0.0050 | 0.0050 | 90 | | | acc_norm | 0.7140 | 0.7126 | 91 | | | acc_norm_stderr | 0.0045 | 0.0045 | 92 | | lambada | acc | 0.7211 | 0.7223 | 93 | | | acc_stderr | 0.0062 | 0.0062 | 94 | | | ppl | 3.6760 | 3.6559 | 95 | | | ppl_stderr | 0.0760 | 0.0757 | 96 | | piqa | acc | 0.7748 | 0.7758 | 97 | | | acc_stderr | 0.0097 | 0.0097 | 98 | | | acc_norm | 0.7786 | 0.7856 | 99 | | | acc_norm_stderr | 0.0097 | 0.0096 | 100 | | winogrande | acc | 0.6598 | 0.6598 | 101 | | | acc_stderr | 0.0133 | 0.0133 | 102 | | wsc | acc | 0.5096 | 0.4808 | 103 | | | acc_stderr | 0.0493 | 0.0492 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /assets/weights_urls.txt: -------------------------------------------------------------------------------- 1 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json 2 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_00-model_00-model_states.pt 3 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_00-model_01-model_states.pt 4 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_02-model_00-model_states.pt 5 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_02-model_01-model_states.pt 6 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_03-model_00-model_states.pt 7 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_03-model_01-model_states.pt 8 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_04-model_00-model_states.pt 9 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_04-model_01-model_states.pt 10 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_05-model_00-model_states.pt 11 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_05-model_01-model_states.pt 12 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_06-model_00-model_states.pt 13 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_06-model_01-model_states.pt 14 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_07-model_00-model_states.pt 15 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_07-model_01-model_states.pt 16 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_08-model_00-model_states.pt 17 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_08-model_01-model_states.pt 18 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_09-model_00-model_states.pt 19 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_09-model_01-model_states.pt 20 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_10-model_00-model_states.pt 21 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_10-model_01-model_states.pt 22 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_11-model_00-model_states.pt 23 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_11-model_01-model_states.pt 24 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_12-model_00-model_states.pt 25 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_12-model_01-model_states.pt 26 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_13-model_00-model_states.pt 27 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_13-model_01-model_states.pt 28 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_14-model_00-model_states.pt 29 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_14-model_01-model_states.pt 30 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_15-model_00-model_states.pt 31 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_15-model_01-model_states.pt 32 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_16-model_00-model_states.pt 33 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_16-model_01-model_states.pt 34 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_17-model_00-model_states.pt 35 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_17-model_01-model_states.pt 36 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_18-model_00-model_states.pt 37 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_18-model_01-model_states.pt 38 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_19-model_00-model_states.pt 39 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_19-model_01-model_states.pt 40 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_20-model_00-model_states.pt 41 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_20-model_01-model_states.pt 42 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_21-model_00-model_states.pt 43 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_21-model_01-model_states.pt 44 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_22-model_00-model_states.pt 45 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_22-model_01-model_states.pt 46 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_23-model_00-model_states.pt 47 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_23-model_01-model_states.pt 48 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_24-model_00-model_states.pt 49 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_24-model_01-model_states.pt 50 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_25-model_00-model_states.pt 51 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_25-model_01-model_states.pt 52 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_26-model_00-model_states.pt 53 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_26-model_01-model_states.pt 54 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_27-model_00-model_states.pt 55 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_27-model_01-model_states.pt 56 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_28-model_00-model_states.pt 57 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_28-model_01-model_states.pt 58 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_29-model_00-model_states.pt 59 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_29-model_01-model_states.pt 60 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_30-model_00-model_states.pt 61 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_30-model_01-model_states.pt 62 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_31-model_00-model_states.pt 63 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_31-model_01-model_states.pt 64 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_32-model_00-model_states.pt 65 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_32-model_01-model_states.pt 66 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_33-model_00-model_states.pt 67 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_33-model_01-model_states.pt 68 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_34-model_00-model_states.pt 69 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_34-model_01-model_states.pt 70 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_35-model_00-model_states.pt 71 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_35-model_01-model_states.pt 72 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_36-model_00-model_states.pt 73 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_36-model_01-model_states.pt 74 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_37-model_00-model_states.pt 75 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_37-model_01-model_states.pt 76 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_38-model_00-model_states.pt 77 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_38-model_01-model_states.pt 78 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_39-model_00-model_states.pt 79 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_39-model_01-model_states.pt 80 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_40-model_00-model_states.pt 81 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_40-model_01-model_states.pt 82 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_41-model_00-model_states.pt 83 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_41-model_01-model_states.pt 84 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_42-model_00-model_states.pt 85 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_42-model_01-model_states.pt 86 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_43-model_00-model_states.pt 87 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_43-model_01-model_states.pt 88 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_44-model_00-model_states.pt 89 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_44-model_01-model_states.pt 90 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_45-model_00-model_states.pt 91 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_45-model_01-model_states.pt 92 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_47-model_00-model_states.pt 93 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_47-model_01-model_states.pt 94 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_48-model_00-model_states.pt 95 | https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/global_step150000/layer_48-model_01-model_states.pt -------------------------------------------------------------------------------- /minimal20b/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import NeoX20BModel, generate_mask 2 | from .constants import Args20b, ArgsDummy 3 | from .create import create_model, create_dummy_model, create_tokenizer 4 | from .generate import greedy_generate, greedy_generate_text 5 | -------------------------------------------------------------------------------- /minimal20b/constants.py: -------------------------------------------------------------------------------- 1 | class Args20b: 2 | vocab_size = 50432 3 | hidden_size = 6144 4 | num_attention_heads = 64 5 | rotary_pct = 0.25 6 | rotary_emb_base = 10000 7 | layernorm_epsilon = 1e-5 8 | num_layers = 44 9 | 10 | 11 | class ArgsDummy: 12 | vocab_size = 50432 13 | hidden_size = 64 14 | num_attention_heads = 4 15 | rotary_pct = 0.25 16 | rotary_emb_base = 10000 17 | layernorm_epsilon = 1e-5 18 | num_layers = 2 19 | -------------------------------------------------------------------------------- /minimal20b/create.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import auto as tqdm_lib 3 | 4 | import torch 5 | import tokenizers 6 | 7 | import minimal20b.model as model20b 8 | from minimal20b.constants import Args20b, ArgsDummy 9 | 10 | 11 | def create_model(checkpoint_path, use_cache=False, device=torch.device("cuda:0")): 12 | """ 13 | To prevent allocation memory on CPU, we initialize on 'meta' and individually 14 | port each module over to 'device' as we load each state dict. 15 | 16 | :param checkpoint_path: Path to the checkpoint folder 17 | :param use_cache: whether to use cache (i.e. for efficient generation) 18 | :param device: device that you want the model to end up on 19 | :return: model 20 | """ 21 | # Instantiate model 22 | pbar = tqdm_lib.tqdm(total=48) 23 | pbar.set_description("Instantiating model (~1 min)") 24 | model = model20b.NeoX20BModel(Args20b, use_cache=use_cache, device="meta") 25 | model = model.half().to_empty(device=device) 26 | pbar.update(1) 27 | 28 | # Load transformer layers 29 | for layer_i in range(Args20b.num_layers): 30 | pbar.set_description(f"Loading layer {layer_i}") 31 | filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt" 32 | filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt" 33 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, filename_tp1)) 34 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, filename_tp2)) 35 | state_dict = {} 36 | # Good 37 | # Keys where we concatenate on the second dim 38 | for key in [ 39 | "attention.dense.weight", 40 | "mlp.dense_4h_to_h.weight", 41 | ]: 42 | state_dict[key] = torch.cat([loaded_tp1[key], loaded_tp2[key]], dim=1) 43 | # Mapping individual split weights to custom split implementations 44 | # Layer Norms 45 | # Choose 1 46 | state_dict["input_layernorm.weight"] = ( 47 | loaded_tp1["input_layernorm.weight"] + loaded_tp2["input_layernorm.weight"]) / 2 48 | state_dict["input_layernorm.bias"] = ( 49 | loaded_tp1["input_layernorm.bias"] + loaded_tp2["input_layernorm.bias"]) / 2 50 | state_dict["post_attention_layernorm.weight"] = ( 51 | loaded_tp1["post_attention_layernorm.weight"] + loaded_tp2["post_attention_layernorm.weight"]) / 2 52 | state_dict["post_attention_layernorm.bias"] = ( 53 | loaded_tp1["post_attention_layernorm.bias"] + loaded_tp2["post_attention_layernorm.bias"]) / 2 54 | # LinearWithTPMerge 55 | state_dict["mlp.dense_h_to_4h.weight"] = torch.cat([ 56 | loaded_tp1["mlp.dense_h_to_4h.weight"], 57 | loaded_tp2["mlp.dense_h_to_4h.weight"], 58 | ], dim=0) 59 | state_dict["mlp.dense_h_to_4h.bias"] = torch.cat([ 60 | loaded_tp1["mlp.dense_h_to_4h.bias"], 61 | loaded_tp2["mlp.dense_h_to_4h.bias"], 62 | ], dim=0) 63 | state_dict["attention.query_key_value.weight"] = torch.cat([ 64 | loaded_tp1["attention.query_key_value.weight"], 65 | loaded_tp2["attention.query_key_value.weight"], 66 | ], dim=0) 67 | state_dict["attention.query_key_value.bias"] = torch.cat([ 68 | loaded_tp1["attention.query_key_value.bias"], 69 | loaded_tp2["attention.query_key_value.bias"], 70 | ], dim=0) 71 | # LinearWithTPSplitBias 72 | state_dict["mlp.dense_4h_to_h.bias"] = ( 73 | loaded_tp1["mlp.dense_4h_to_h.bias"] 74 | + loaded_tp2["mlp.dense_4h_to_h.bias"] 75 | ) 76 | state_dict["attention.dense.bias"] = ( 77 | loaded_tp1["attention.dense.bias"] 78 | + loaded_tp2["attention.dense.bias"] 79 | ) 80 | # Just take one 81 | state_dict["attention.rotary_emb.inv_freq"] = loaded_tp1["attention.rotary_emb.inv_freq"] 82 | model.layer_list[layer_i].load_state_dict(state_dict) 83 | del loaded_tp1 84 | del loaded_tp2 85 | pbar.update(1) 86 | 87 | # Load input embedding 88 | pbar.set_description(f"Loading input embedding") 89 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) 90 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) 91 | model.embed_in.load_state_dict({"weight": torch.cat([ 92 | loaded_tp1["word_embeddings.weight"], 93 | loaded_tp2["word_embeddings.weight"], 94 | ], dim=0)}) 95 | del loaded_tp1 96 | del loaded_tp2 97 | pbar.update(1) 98 | 99 | # Load final layer norm 100 | pbar.set_description(f"Loading final layer norm") 101 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt")) 102 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt")) 103 | model.final_layer_norm.load_state_dict({ 104 | "weight": (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"])/2, 105 | "bias": (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"])/2, 106 | }) 107 | del loaded_tp1 108 | del loaded_tp2 109 | pbar.update(1) 110 | 111 | # Load output embedding 112 | pbar.set_description(f"Loading output embedding") 113 | loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt")) 114 | loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt")) 115 | model.logits_out.load_state_dict({ 116 | "weight": torch.cat([ 117 | loaded_tp1["final_linear.weight"], 118 | loaded_tp2["final_linear.weight"], 119 | ], dim=0), 120 | }) 121 | del loaded_tp1 122 | del loaded_tp2 123 | pbar.update(1) 124 | pbar.set_description("Done.") 125 | 126 | return model 127 | 128 | 129 | def create_dummy_model(use_cache=False, device=torch.device("cpu")): 130 | model = model20b.NeoX20BModel(ArgsDummy, use_cache=use_cache).half().to(device) 131 | return model 132 | 133 | 134 | def create_tokenizer(tokenizer_path): 135 | return tokenizers.Tokenizer.from_file(tokenizer_path) 136 | -------------------------------------------------------------------------------- /minimal20b/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from tqdm import auto as tqdm_lib 4 | 5 | 6 | def greedy_generate(model: nn.Module, input_ids: torch.Tensor, max_seq_len: int, 7 | verbose=True): 8 | """Generate greedily from 20B. 9 | 10 | :param model: NeoX20BModel 11 | :param input_ids: token IDs [batch_size, seq_len] 12 | :param max_seq_len: max sequence length to generate up to (includes input_ids) 13 | :param verbose: whether to print progress 14 | 15 | :return: List of token IDs 16 | """ 17 | initial_input_length = input_ids.shape[1] 18 | current_input_ids = input_ids 19 | layer_past = None 20 | layer_past_length = 0 21 | all_token_ids = input_ids.tolist() 22 | batch_size = len(all_token_ids) 23 | 24 | if verbose: 25 | trange = tqdm_lib.trange(initial_input_length, max_seq_len) 26 | else: 27 | trange = range(initial_input_length, max_seq_len) 28 | 29 | for _ in trange: 30 | input_length = current_input_ids.shape[1] 31 | model_out, layer_past = model( 32 | current_input_ids, 33 | layer_past=layer_past, 34 | ) 35 | greedy_predicted_token_ids = model_out[:, -1].argmax(-1) 36 | current_input_ids = greedy_predicted_token_ids[:, None] 37 | for i in range(batch_size): 38 | all_token_ids[i].append(greedy_predicted_token_ids[i]) 39 | layer_past_length += input_length 40 | return all_token_ids 41 | 42 | 43 | def greedy_generate_text(model: nn.Module, 44 | tokenizer, 45 | initial_str: str, 46 | max_seq_len: int, 47 | device=torch.device("cuda:0"), 48 | verbose=True): 49 | """Generate greedily from 20B. 50 | 51 | :param model: NeoX20BModel 52 | :param tokenizer: NeoX20B tokenizer 53 | :param initial_str: initial string to start generation from 54 | :param max_seq_len: max sequence length to generate up to (includes input_ids) 55 | :param device: device to use 56 | :param verbose: whether to print progress 57 | 58 | :return: List of token IDs 59 | """ 60 | tokenized = tokenizer.encode(initial_str) 61 | input_ids = torch.LongTensor([tokenized.ids]).to(device) 62 | all_token_ids = greedy_generate(model=model, input_ids=input_ids, max_seq_len=max_seq_len, verbose=verbose) 63 | return tokenizer.decode(all_token_ids[0]) 64 | -------------------------------------------------------------------------------- /minimal20b/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | import minimal20b.rotary as rotary 6 | 7 | 8 | class NeoX20BModel(nn.Module): 9 | def __init__(self, args, use_cache=False, device=None): 10 | super().__init__() 11 | self.use_cache = use_cache 12 | self.embed_in = nn.Embedding(args.vocab_size, args.hidden_size, device=device) 13 | self.layer_list = nn.ModuleList([]) 14 | for layer_i in range(args.num_layers): 15 | self.layer_list.append(TransformerLayer(args, use_cache, device=device)) 16 | self.final_layer_norm = nn.LayerNorm( 17 | args.hidden_size, 18 | eps=args.layernorm_epsilon, 19 | device=device, 20 | ) 21 | self.logits_out = nn.Linear( 22 | args.hidden_size, 23 | args.vocab_size, 24 | bias=False, 25 | device=device, 26 | ) 27 | 28 | def forward(self, x, attention_mask=None, layer_past=None): 29 | if attention_mask is None: 30 | attention_mask = generate_mask(x.shape[1]).to(x.device) 31 | if self.use_cache: 32 | if layer_past is None: 33 | kv_length = x.shape[1] 34 | else: 35 | kv_length = layer_past[0].shape[1] + 1 36 | attention_mask = attention_mask[..., :x.shape[1], :kv_length] 37 | 38 | if layer_past is None: 39 | layer_past = [None] * len(self.layer_list) 40 | kv_cache_list = [] 41 | hidden_states = self.embed_in(x) 42 | hidden_states = self.pre_transformer_transpose(hidden_states) 43 | 44 | for layer_i, layer in enumerate(self.layer_list): 45 | hidden_states, kv_cache = layer( 46 | x=hidden_states, 47 | attention_mask=attention_mask, 48 | layer_past=layer_past[layer_i], 49 | ) 50 | kv_cache_list.append(kv_cache) 51 | hidden_states = self.post_transformer_transpose(hidden_states) 52 | hidden_states = self.final_layer_norm(hidden_states) 53 | logits = self.logits_out(hidden_states) 54 | if self.use_cache: 55 | return logits, kv_cache_list 56 | else: 57 | return logits 58 | 59 | @classmethod 60 | def pre_transformer_transpose(cls, x): 61 | return x.transpose(0, 1).contiguous() 62 | 63 | @classmethod 64 | def post_transformer_transpose(cls, x): 65 | return x.transpose(0, 1).contiguous() 66 | 67 | 68 | class TransformerLayer(nn.Module): 69 | def __init__(self, args, use_cache, device=None): 70 | super().__init__() 71 | self.use_cache = use_cache 72 | self.input_layernorm = nn.LayerNorm( 73 | args.hidden_size, 74 | eps=args.layernorm_epsilon, 75 | device=device, 76 | ) 77 | self.post_attention_layernorm = nn.LayerNorm( 78 | args.hidden_size, 79 | eps=args.layernorm_epsilon, 80 | device=device, 81 | ) 82 | self.attention = SelfAttention(args, self.use_cache, device=device) 83 | self.mlp = MLP(args) 84 | 85 | def forward(self, x, attention_mask, layer_past=None): 86 | residual = x 87 | ln_output = self.input_layernorm(x) 88 | attention_output, kv_cache = self.attention( 89 | ln_output, 90 | attention_mask, 91 | layer_past=layer_past, 92 | ) 93 | post_attn_ln = self.post_attention_layernorm(x) 94 | mlp_output = self.mlp(hidden_states=post_attn_ln) 95 | output = residual + mlp_output + attention_output 96 | return output, kv_cache 97 | 98 | 99 | class SelfAttention(nn.Module): 100 | def __init__(self, args, use_cache=False, device=None): 101 | super().__init__() 102 | self.hidden_size = args.hidden_size 103 | self.use_cache = use_cache 104 | self.num_attention_heads = args.num_attention_heads 105 | self.hidden_size_per_attention_head = args.hidden_size // args.num_attention_heads 106 | self.rotary_ndims = int(self.hidden_size_per_attention_head * args.rotary_pct) 107 | self.rotary_emb = rotary.RotaryEmbedding( 108 | self.rotary_ndims, 109 | base=args.rotary_emb_base, 110 | device=device, 111 | ) 112 | self.query_key_value = nn.Linear( 113 | args.hidden_size, 114 | 3 * args.hidden_size, 115 | device=device, 116 | ) 117 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) 118 | self.dense = nn.Linear( 119 | args.hidden_size, 120 | args.hidden_size, 121 | device=device, 122 | ) 123 | 124 | def forward(self, hidden_states, attention_mask, layer_past=None): 125 | has_layer_past = layer_past is not None and layer_past.numel() > 0 126 | 127 | # Compute QKV 128 | # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] 129 | qkv = self.query_key_value(hidden_states) 130 | 131 | # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] 132 | new_qkv_shape = qkv.size()[:-1] + ( 133 | self.num_attention_heads, 134 | 3 * self.hidden_size_per_attention_head, 135 | ) 136 | qkv = qkv.view(*new_qkv_shape) 137 | 138 | # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] 139 | query_layer = qkv[..., :self.hidden_size_per_attention_head] 140 | key_layer = qkv[..., self.hidden_size_per_attention_head: 2 * self.hidden_size_per_attention_head] 141 | value_layer = qkv[..., 2 * self.hidden_size_per_attention_head:] 142 | 143 | # Compute rotary embeddings 144 | query_rot, query_pass = ( 145 | query_layer[..., : self.rotary_ndims], 146 | query_layer[..., self.rotary_ndims:], 147 | ) 148 | key_rot, key_pass = ( 149 | key_layer[..., : self.rotary_ndims], 150 | key_layer[..., self.rotary_ndims:], 151 | ) 152 | seq_len = key_layer.shape[0] 153 | offset = 0 154 | if has_layer_past: 155 | offset = layer_past[0].shape[0] 156 | seq_len += offset 157 | cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) 158 | query_layer, key_layer = rotary.apply_rotary_pos_emb( 159 | query_rot, key_rot, cos, sin, offset=offset, 160 | ) 161 | query_layer = torch.cat((query_layer, query_pass), dim=-1) 162 | key_layer = torch.cat((key_layer, key_pass), dim=-1) 163 | 164 | # Cache QKV values 165 | if has_layer_past: 166 | past_key, past_value = layer_past 167 | key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) 168 | value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) 169 | if self.use_cache: 170 | kv_cache = torch.stack((key_layer, value_layer)) 171 | else: 172 | kv_cache = None 173 | 174 | # Compute attention 175 | # noinspection PyTypeChecker 176 | context_layer = self.attention( 177 | query_layer, key_layer, value_layer, attention_mask 178 | ) 179 | 180 | # Reshape outputs 181 | # [b, np, sq, hn] --> [sq, b, np, hn] 182 | context_layer = context_layer.permute(2, 0, 1, 3).contiguous() 183 | 184 | # [sq, b, np, hn] --> [sq, b, hp] 185 | new_context_layer_shape = context_layer.size()[:-2] + ( 186 | self.hidden_size, 187 | ) 188 | context_layer = context_layer.view(*new_context_layer_shape) 189 | 190 | # ================= 191 | # Output. [sq, b, h] 192 | # ================= 193 | output = self.dense(context_layer) 194 | 195 | return output, kv_cache 196 | 197 | def attention(self, query_layer, key_layer, value_layer, attention_mask): 198 | # =================================== 199 | # Raw attention scores. [b, np, s, s] 200 | # =================================== 201 | 202 | # [b, np, sq, sk] 203 | output_size = ( 204 | query_layer.size(1), 205 | query_layer.size(2), 206 | query_layer.size(0), 207 | key_layer.size(0), 208 | ) 209 | 210 | # [sq, b, np, hn] -> [sq, b * np, hn] 211 | query_layer = query_layer.view( 212 | output_size[2], output_size[0] * output_size[1], -1 213 | ) 214 | key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) 215 | 216 | # preallocating result tensor: [b * np, sq, sk] 217 | matmul_result = torch.empty( 218 | output_size[0] * output_size[1], 219 | output_size[2], 220 | output_size[3], 221 | dtype=query_layer.dtype, 222 | device=query_layer.device, 223 | ) 224 | 225 | # Raw attention scores. [b * np, sq, sk] 226 | matmul_result = torch.baddbmm( 227 | matmul_result, 228 | query_layer.transpose(0, 1), # [b * np, sq, hn] 229 | key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] 230 | beta=0.0, 231 | alpha=(1.0 / self.norm_factor), 232 | ) 233 | 234 | # change view to [b, np, sq, sk] 235 | attention_scores = matmul_result.view(*output_size) 236 | 237 | # ================================================== 238 | # Update attention mask for inference. [b, np, sq, sk] 239 | # ================================================== 240 | 241 | # =========================== 242 | # Attention probs and dropout 243 | # =========================== 244 | 245 | # attention scores and attention mask [b, np, sq, sk] 246 | masked_scores = attention_mask_func(attention_scores, attention_mask) \ 247 | if attention_mask is not None else attention_scores 248 | attention_probs = torch.nn.Softmax(dim=-1)(masked_scores) 249 | 250 | # # This is actually dropping out entire tokens to attend to, which might 251 | # # seem a bit unusual, but is taken from the original Transformer paper. 252 | # attention_probs = self.attention_dropout(attention_probs) 253 | 254 | # ========================= 255 | # Context layer. [sq, b, hp] 256 | # ========================= 257 | 258 | # value_layer -> context layer. 259 | # [sk, b, np, hn] --> [b, np, sq, hn] 260 | 261 | # context layer shape: [b, np, sq, hn] 262 | output_size = ( 263 | value_layer.size(1), 264 | value_layer.size(2), 265 | query_layer.size(0), 266 | value_layer.size(3), 267 | ) 268 | 269 | # change view [sk, b * np, hn] 270 | value_layer = value_layer.view( 271 | value_layer.size(0), output_size[0] * output_size[1], -1 272 | ) 273 | 274 | # change view [b * np, sq, sk] 275 | attention_probs = attention_probs.view( 276 | output_size[0] * output_size[1], output_size[2], -1 277 | ) 278 | 279 | # matmul: [b * np, sq, hn] 280 | context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) 281 | 282 | # change view [b, np, sq, hn] 283 | context_layer = context_layer.view(*output_size) 284 | return context_layer 285 | 286 | 287 | class MLP(nn.Module): 288 | def __init__(self, args, device=None): 289 | super().__init__() 290 | ff_dim = 4 * args.hidden_size 291 | self.dense_h_to_4h = nn.Linear(args.hidden_size, ff_dim, device=device) 292 | self.dense_4h_to_h = nn.Linear(ff_dim, args.hidden_size, device=device) 293 | 294 | def forward(self, hidden_states): 295 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 296 | intermediate_parallel = bias_gelu_impl(intermediate_parallel) 297 | output = self.dense_4h_to_h(intermediate_parallel) 298 | return output 299 | 300 | 301 | # noinspection PyAbstractClass 302 | class GeLUFunction(torch.autograd.Function): 303 | # noinspection PyMethodOverriding 304 | @staticmethod 305 | # bias is an optional argument 306 | def forward(ctx, inputs): 307 | ctx.save_for_backward(inputs) 308 | return gelu(inputs) 309 | 310 | # noinspection PyMethodOverriding 311 | @staticmethod 312 | def backward(ctx, grad_output): 313 | inputs = ctx.saved_tensors 314 | tmp = gelu_back(grad_output, inputs) 315 | return tmp, tmp 316 | 317 | 318 | bias_gelu_impl = GeLUFunction.apply 319 | 320 | 321 | def generate_mask(seq_len): 322 | return torch.tril(torch.ones((1, 1, seq_len, seq_len), dtype=torch.bool)) 323 | 324 | 325 | def attention_mask_func(attention_scores, ltor_mask): 326 | """Assign -10000.0 to False cells in ltor_mask""" 327 | attention_scores.masked_fill_(~ltor_mask, -10000.0) 328 | return attention_scores 329 | 330 | 331 | @torch.jit.script 332 | def gelu(x): 333 | return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) 334 | 335 | 336 | # gradient of tanh approximation of gelu 337 | # gradient of actual gelu is: 338 | # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) 339 | @torch.jit.script 340 | def gelu_back(g, x): 341 | tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) 342 | # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 343 | ff = 0.5 * x * ( 344 | (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) 345 | ) + 0.5 * (1 + tanh_out) 346 | return ff * g 347 | -------------------------------------------------------------------------------- /minimal20b/rotary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class RotaryEmbedding(torch.nn.Module): 5 | 6 | def __init__(self, dim, base=10000, device=None): 7 | super().__init__() 8 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 9 | self.register_buffer('inv_freq', inv_freq) 10 | # Delay initialization until first forward call, because initial model on the 'meta' device 11 | self.cos_cached = None 12 | self.sin_cached = None 13 | 14 | def forward(self, x, seq_dim=1, seq_len=None): 15 | if seq_len is None: 16 | seq_len = x.shape[seq_dim] 17 | if self.cos_cached is None: 18 | t = torch.arange(2048, device=x.device, dtype=self.inv_freq.dtype) 19 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 20 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 21 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 22 | # [sx, 1 (b * np), hn] 23 | self.cos_cached = emb.cos()[:, None, None, :] 24 | self.sin_cached = emb.sin()[:, None, None, :] 25 | return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] 26 | 27 | 28 | def rotate_half(x): 29 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 30 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions 31 | 32 | 33 | # @torch.jit.script 34 | def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): 35 | cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] 36 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 37 | -------------------------------------------------------------------------------- /minimal20b_flax/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import GPTNeoX20BModel 2 | from .create import load_model_weights 3 | -------------------------------------------------------------------------------- /minimal20b_flax/create.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from multiprocessing import Pool 4 | 5 | import jax 6 | from tqdm import auto as tqdm_lib 7 | 8 | import numpy as np 9 | 10 | # noinspection PyPep8Naming 11 | from jax.experimental import PartitionSpec as P 12 | from jax.experimental.pjit import pjit 13 | from jax.experimental import maps 14 | import jax.numpy as jnp 15 | from flax.core import frozen_dict 16 | from flax import traverse_util 17 | 18 | import torch 19 | import tokenizers 20 | 21 | import minimal20b_flax.utils as utils 22 | import minimal20b_flax.model as model 23 | 24 | 25 | def load_model_weights(checkpoint_path, config: model.NeoX20BConfig = model.default_neox20b_config): 26 | """Loads the weights from a checkpoint and shard to 8 TPU devices.""" 27 | pbar = tqdm_lib.tqdm(total=47) 28 | 29 | # 1. Load embed_in 30 | pbar.set_description("Loading embed_in") 31 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) 32 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) 33 | shared_embedding = np.concatenate([ 34 | loaded_tp1["word_embeddings.weight"], 35 | loaded_tp2["word_embeddings.weight"], 36 | ], axis=0) 37 | del loaded_tp1 38 | del loaded_tp2 39 | # 1.1. Shard to device 40 | embed_in_params = traverse_util.unflatten_dict({ 41 | ("embed", "kernel"): utils.shard_to_devices(shared_embedding, axis=0), 42 | }) 43 | pbar.update(1) 44 | 45 | # 2. Load layer weights 46 | # These are stacked because we will later run a jax.lax.scan over them to iterate 47 | # over layers. 48 | # Note: this next line loads all the layers into CPU memory, which is a lot. 49 | layer_params_list = [] 50 | for i in range(config.num_layers): 51 | pbar.set_description(f"Loading layer {i}") 52 | layer_params_list.append(traverse_util.flatten_dict(frozen_dict.unfreeze( 53 | load_single_layer_params(checkpoint_path, i) 54 | ))) 55 | pbar.update(1) 56 | # 2.1. Shard to device 57 | sharding = model.GPTNeoX20BModel.get_sharding() 58 | flat_stacked_layers_sharding = traverse_util.flatten_dict(frozen_dict.unfreeze( 59 | sharding["transformer"])) 60 | pbar.set_description(f"Sharding transformer layers to TPUs") 61 | stacked_layer_params = {} 62 | for k, v in layer_params_list[0].items(): 63 | stacked = np.stack([ 64 | layer_params[k] 65 | for layer_params in layer_params_list 66 | ], axis=0) 67 | shard_strategy = flat_stacked_layers_sharding[k] 68 | if shard_strategy == P(None, None): 69 | stacked = utils.replicate_to_devices(stacked) 70 | elif shard_strategy == P(None, None, "tp"): 71 | stacked = utils.shard_to_devices(stacked, axis=2) 72 | elif shard_strategy == P(None, "tp", None): 73 | stacked = utils.shard_to_devices(stacked, axis=1) 74 | else: 75 | raise RuntimeError() 76 | stacked_layer_params[k] = stacked 77 | stacked_layer_params = frozen_dict.freeze(traverse_util.unflatten_dict( 78 | stacked_layer_params 79 | )) 80 | pbar.update(1) 81 | 82 | # 3. Load final layer norm and embed_out (jointly "embed_out") 83 | pbar.set_description(f"Load embed_out") 84 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt")) 85 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt")) 86 | # noinspection PyDictCreation 87 | embed_out_params = { 88 | ("norm", "bias"): (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"]) / 2, 89 | ("norm", "scale"): (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"]) / 2, 90 | } 91 | del loaded_tp1 92 | del loaded_tp2 93 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt")) 94 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt")) 95 | embed_out_params["embed_out", "kernel"] = np.concatenate([ 96 | loaded_tp1["final_linear.weight"].T, 97 | loaded_tp2["final_linear.weight"].T, 98 | ], axis=1) 99 | del loaded_tp1 100 | del loaded_tp2 101 | # 3.1. Shard to device 102 | embed_out_params["norm", "bias"] = utils.replicate_to_devices( 103 | embed_out_params["norm", "bias"]) 104 | embed_out_params["norm", "scale"] = utils.replicate_to_devices( 105 | embed_out_params["norm", "scale"]) 106 | embed_out_params["embed_out", "kernel"] = utils.shard_to_devices( 107 | embed_out_params["embed_out", "kernel"], axis=1) 108 | embed_out_params = frozen_dict.freeze(traverse_util.unflatten_dict(embed_out_params)) 109 | pbar.update(1) 110 | pbar.set_description("Done.") 111 | 112 | # 4. Combine 113 | all_params = frozen_dict.freeze({ 114 | "embed_in": embed_in_params, 115 | "transformer": stacked_layer_params, 116 | "embed_out": embed_out_params 117 | }) 118 | return all_params 119 | 120 | 121 | def load_single_layer_params(checkpoint_path, layer_i): 122 | filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt" 123 | filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt" 124 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, filename_tp1)) 125 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, filename_tp2)) 126 | # noinspection PyDictCreation 127 | layer_params = {} 128 | layer_params["attn_norm", "bias"] = ( 129 | loaded_tp1["input_layernorm.bias"] 130 | + loaded_tp2["input_layernorm.bias"] 131 | ) / 2 132 | layer_params["attn_norm", "scale"] = ( 133 | loaded_tp1["input_layernorm.weight"] 134 | + loaded_tp2["input_layernorm.weight"] 135 | ) / 2 136 | layer_params["qkv_proj", "kernel"] = np.concatenate([ 137 | loaded_tp1["attention.query_key_value.weight"].T, 138 | loaded_tp2["attention.query_key_value.weight"].T, 139 | ], axis=1).reshape((6144, 8, 8, 3, 96)).swapaxes(2, 3).reshape((6144, 18432)) 140 | # input_dim, num_heads1(tp), numheads2(heads per device), qkv, dim_per_head 141 | layer_params["qkv_proj", "bias"] = np.concatenate([ 142 | loaded_tp1["attention.query_key_value.bias"], 143 | loaded_tp2["attention.query_key_value.bias"], 144 | ]).reshape((8, 8, 3, 96)).swapaxes(1, 2).reshape(18432) 145 | layer_params["output_proj", "kernel"] = np.concatenate([ 146 | loaded_tp1["attention.dense.weight"].T, 147 | loaded_tp2["attention.dense.weight"].T, 148 | ], axis=0) 149 | layer_params["output_proj", "bias"] = ( 150 | loaded_tp1["attention.dense.bias"] 151 | + loaded_tp2["attention.dense.bias"] 152 | ) 153 | layer_params["ff_norm", "bias"] = ( 154 | loaded_tp1["post_attention_layernorm.bias"] 155 | + loaded_tp2["post_attention_layernorm.bias"] 156 | ) / 2 157 | layer_params["ff_norm", "scale"] = ( 158 | loaded_tp1["post_attention_layernorm.weight"] 159 | + loaded_tp2["post_attention_layernorm.weight"] 160 | ) / 2 161 | layer_params["ff_up_proj", "kernel"] = np.concatenate([ 162 | loaded_tp1["mlp.dense_h_to_4h.weight"].T, 163 | loaded_tp2["mlp.dense_h_to_4h.weight"].T, 164 | ], axis=1) 165 | layer_params["ff_up_proj", "bias"] = np.concatenate([ 166 | loaded_tp1["mlp.dense_h_to_4h.bias"], 167 | loaded_tp2["mlp.dense_h_to_4h.bias"], 168 | ]) 169 | layer_params["ff_down_proj", "kernel"] = np.concatenate([ 170 | loaded_tp1["mlp.dense_4h_to_h.weight"].T, 171 | loaded_tp2["mlp.dense_4h_to_h.weight"].T, 172 | ], axis=0) 173 | layer_params["ff_down_proj", "bias"] = ( 174 | loaded_tp1["mlp.dense_4h_to_h.bias"] 175 | + loaded_tp2["mlp.dense_4h_to_h.bias"] 176 | ) 177 | layer_params = frozen_dict.freeze(traverse_util.unflatten_dict(layer_params)) 178 | del loaded_tp1 179 | del loaded_tp2 180 | return layer_params 181 | 182 | 183 | def load_to_numpy(path, **kwargs): 184 | return {k: v.numpy() for k, v in torch.load(path, **kwargs).items()} 185 | 186 | 187 | def create_tokenizer(tokenizer_path): 188 | return tokenizers.Tokenizer.from_file(tokenizer_path) 189 | 190 | 191 | # === Colab specific === 192 | 193 | def colab_load_model_weights(checkpoint_path, config: model.NeoX20BConfig = model.default_neox20b_config): 194 | """Loads the weights from a checkpoint and shard to 8 TPU devices.""" 195 | pbar = tqdm_lib.tqdm(total=311) 196 | 197 | # 1. Load embed_in 198 | pbar.set_description("Loading embed_in") 199 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) 200 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) 201 | shared_embedding = np.concatenate([ 202 | loaded_tp1["word_embeddings.weight"], 203 | loaded_tp2["word_embeddings.weight"], 204 | ], axis=0) 205 | del loaded_tp1 206 | del loaded_tp2 207 | # 1.1. Shard to device 208 | embed_in_params = traverse_util.unflatten_dict({ 209 | ("embed", "kernel"): utils.shard_to_devices(shared_embedding, axis=0), 210 | }) 211 | pbar.update(1) 212 | 213 | stacked_layer_params = {} 214 | sharding = model.GPTNeoX20BModel.get_sharding() 215 | flat_stacked_layers_sharding = traverse_util.flatten_dict(frozen_dict.unfreeze( 216 | sharding["transformer"])) 217 | 218 | # 2.1 Preallocate 219 | def initialize_layer_params(): 220 | shape_dict = { 221 | ('attn_norm', 'scale'): (44, 6144,), 222 | ('attn_norm', 'bias'): (44, 6144,), 223 | ('qkv_proj', 'kernel'): (44, 6144, 18432), 224 | ('qkv_proj', 'bias'): (44, 18432,), 225 | ('output_proj', 'kernel'): (44, 6144, 6144), 226 | ('output_proj', 'bias'): (44, 6144,), 227 | ('ff_norm', 'scale'): (44, 6144,), 228 | ('ff_norm', 'bias'): (44, 6144,), 229 | ('ff_up_proj', 'kernel'): (44, 6144, 24576), 230 | ('ff_up_proj', 'bias'): (44, 24576,), 231 | ('ff_down_proj', 'kernel'): (44, 24576, 6144), 232 | ('ff_down_proj', 'bias'): (44, 6144,), 233 | } 234 | layer_params = {} 235 | for k, v in shape_dict.items(): 236 | layer_params[k] = jnp.zeros(v, dtype=jnp.float16) 237 | return layer_params 238 | 239 | initialize_layer_params_pjit = pjit( 240 | initialize_layer_params, 241 | in_axis_resources=None, 242 | out_axis_resources=flat_stacked_layers_sharding, 243 | ) 244 | mesh = utils.get_default_mesh() 245 | with maps.mesh(mesh.devices, mesh.axis_names): 246 | pbar.set_description(f"Initializing layer params on device") 247 | stacked_layer_params = initialize_layer_params_pjit() 248 | pbar.update(1) 249 | 250 | def assign_to_sharded_device_array(old_state, new_layer, layer_idx): 251 | new_state = old_state.at[layer_idx].set(new_layer) 252 | return new_state 253 | 254 | assign_funcs_dict = {} 255 | for k in stacked_layer_params: 256 | assign_funcs_dict[k] = pjit( 257 | assign_to_sharded_device_array, 258 | in_axis_resources=( 259 | flat_stacked_layers_sharding[k], 260 | P(*flat_stacked_layers_sharding[k][1:]), 261 | ), 262 | out_axis_resources=flat_stacked_layers_sharding[k], 263 | donate_argnums=(0,), 264 | static_argnums=(2,), 265 | ) 266 | 267 | for layer_i in range(config.num_layers): 268 | pbar.set_description(f"Loading layer {layer_i}") 269 | single_layer_params = load_single_layer_params(checkpoint_path, layer_i) 270 | flattened_layer_params = traverse_util.flatten_dict(frozen_dict.unfreeze(single_layer_params)) 271 | with maps.mesh(mesh.devices, mesh.axis_names): 272 | for k in stacked_layer_params: 273 | stacked_layer_params[k] = assign_funcs_dict[k]( 274 | stacked_layer_params[k], 275 | flattened_layer_params[k], 276 | layer_i, 277 | ) 278 | pbar.update(1) 279 | 280 | stacked_layer_params = frozen_dict.freeze(traverse_util.unflatten_dict(stacked_layer_params)) 281 | 282 | # 3. Load final layer norm and embed_out (jointly "embed_out") 283 | pbar.set_description(f"Load embed_out") 284 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt")) 285 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt")) 286 | # noinspection PyDictCreation 287 | embed_out_params = { 288 | ("norm", "bias"): (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"]) / 2, 289 | ("norm", "scale"): (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"]) / 2, 290 | } 291 | del loaded_tp1 292 | del loaded_tp2 293 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt")) 294 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt")) 295 | embed_out_params["embed_out", "kernel"] = np.concatenate([ 296 | loaded_tp1["final_linear.weight"].T, 297 | loaded_tp2["final_linear.weight"].T, 298 | ], axis=1) 299 | del loaded_tp1 300 | del loaded_tp2 301 | # 3.1. Shard to device 302 | embed_out_params["norm", "bias"] = utils.replicate_to_devices( 303 | embed_out_params["norm", "bias"]) 304 | embed_out_params["norm", "scale"] = utils.replicate_to_devices( 305 | embed_out_params["norm", "scale"]) 306 | embed_out_params["embed_out", "kernel"] = utils.shard_to_devices( 307 | embed_out_params["embed_out", "kernel"], axis=1) 308 | embed_out_params = frozen_dict.freeze(traverse_util.unflatten_dict(embed_out_params)) 309 | pbar.update(1) 310 | pbar.set_description("Done.") 311 | 312 | # 4. Combine 313 | all_params = frozen_dict.freeze({ 314 | "embed_in": embed_in_params, 315 | "transformer": stacked_layer_params, 316 | "embed_out": embed_out_params 317 | }) 318 | return all_params 319 | 320 | 321 | def colab_load_single_layer_params(checkpoint_path, layer_i): 322 | filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt" 323 | filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt" 324 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, filename_tp1)) 325 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, filename_tp2)) 326 | # noinspection PyDictCreation 327 | layer_params = {} 328 | layer_params["attn_norm", "bias"] = ( 329 | loaded_tp1["input_layernorm.bias"] 330 | + loaded_tp2["input_layernorm.bias"] 331 | ) / 2 332 | layer_params["attn_norm", "scale"] = ( 333 | loaded_tp1["input_layernorm.weight"] 334 | + loaded_tp2["input_layernorm.weight"] 335 | ) / 2 336 | layer_params["qkv_proj", "kernel"] = np.concatenate([ 337 | loaded_tp1["attention.query_key_value.weight"].T, 338 | loaded_tp2["attention.query_key_value.weight"].T, 339 | ], axis=1).reshape((6144, 8, 8, 3, 96)).swapaxes(2, 3).reshape((6144, 18432)) 340 | # input_dim, num_heads1(tp), numheads2(heads per device), qkv, dim_per_head 341 | layer_params["qkv_proj", "bias"] = np.concatenate([ 342 | loaded_tp1["attention.query_key_value.bias"], 343 | loaded_tp2["attention.query_key_value.bias"], 344 | ]).reshape((8, 8, 3, 96)).swapaxes(1, 2).reshape(18432) 345 | layer_params["output_proj", "kernel"] = np.concatenate([ 346 | loaded_tp1["attention.dense.weight"].T, 347 | loaded_tp2["attention.dense.weight"].T, 348 | ], axis=0) 349 | layer_params["output_proj", "bias"] = ( 350 | loaded_tp1["attention.dense.bias"] 351 | + loaded_tp2["attention.dense.bias"] 352 | ) 353 | layer_params["ff_norm", "bias"] = ( 354 | loaded_tp1["post_attention_layernorm.bias"] 355 | + loaded_tp2["post_attention_layernorm.bias"] 356 | ) / 2 357 | layer_params["ff_norm", "scale"] = ( 358 | loaded_tp1["post_attention_layernorm.weight"] 359 | + loaded_tp2["post_attention_layernorm.weight"] 360 | ) / 2 361 | layer_params["ff_up_proj", "bias"] = np.concatenate([ 362 | loaded_tp1["mlp.dense_h_to_4h.bias"], 363 | loaded_tp2["mlp.dense_h_to_4h.bias"], 364 | ]) 365 | layer_params["ff_down_proj", "bias"] = ( 366 | loaded_tp1["mlp.dense_4h_to_h.bias"] 367 | + loaded_tp2["mlp.dense_4h_to_h.bias"] 368 | ) 369 | layer_params = frozen_dict.freeze(traverse_util.unflatten_dict(layer_params)) 370 | del loaded_tp1 371 | del loaded_tp2 372 | return layer_params 373 | 374 | 375 | def colab_load_single_layer_qkv_kernel_params(checkpoint_path, layer_i, original_shard: int): 376 | filename = f"layer_{layer_i + 2:02d}-model_{original_shard:02d}-model_states.pt" 377 | loaded = load_to_numpy(os.path.join(checkpoint_path, filename)) 378 | return loaded["attention.query_key_value.weight"].T.reshape( 379 | (6144, 4, 8, 3, 96) 380 | ).swapaxes(2, 3).reshape((6144, 9216)) 381 | 382 | 383 | def colab_load_single_layer_ff_up_kernel_params(checkpoint_path, layer_i, original_shard: int): 384 | filename = f"layer_{layer_i + 2:02d}-model_{original_shard:02d}-model_states.pt" 385 | loaded = load_to_numpy(os.path.join(checkpoint_path, filename)) 386 | return loaded["mlp.dense_h_to_4h.weight"].T 387 | 388 | 389 | def colab_load_single_layer_ff_down_kernel_params(checkpoint_path, layer_i, original_shard: int): 390 | filename = f"layer_{layer_i + 2:02d}-model_{original_shard:02d}-model_states.pt" 391 | loaded = load_to_numpy(os.path.join(checkpoint_path, filename)) 392 | return loaded["mlp.dense_4h_to_h.weight"].T 393 | 394 | 395 | # === Xmap specific === 396 | 397 | def load_single_layer_params_for_xmap(checkpoint_path, layer_i): 398 | num_shards = 8 399 | transformer_layer_params = load_single_layer_params(checkpoint_path=checkpoint_path, layer_i=layer_i) 400 | fparams = frozen_dict.unfreeze(traverse_util.flatten_dict(transformer_layer_params)) 401 | # noinspection PyDictCreation 402 | new_fparams = {} 403 | new_fparams[('attn_norm', 'bias')] = \ 404 | stack_copies(fparams[('attn_norm', 'bias')], num_shards, axis=0) 405 | new_fparams[('attn_norm', 'scale')] = \ 406 | stack_copies(fparams[('attn_norm', 'scale')], num_shards, axis=0) 407 | new_fparams[('ff_down_proj', 'bias')] = \ 408 | stack_copies(fparams[('ff_down_proj', 'bias')], num_shards, axis=0) / 8 409 | new_fparams[('ff_down_proj', 'kernel')] = \ 410 | fparams[('ff_down_proj', 'kernel')].reshape(8, 3072, 6144) 411 | new_fparams[('ff_norm', 'bias')] = \ 412 | stack_copies(fparams[('ff_norm', 'bias')], num_shards, axis=0) 413 | new_fparams[('ff_norm', 'scale')] = \ 414 | stack_copies(fparams[('ff_norm', 'scale')], num_shards, axis=0) 415 | new_fparams[('ff_up_proj', 'bias')] = \ 416 | fparams[('ff_up_proj', 'bias')].reshape(8, 3072) 417 | new_fparams[('ff_up_proj', 'kernel')] = \ 418 | fparams[('ff_up_proj', 'kernel')].reshape(6144, 8, 3072).swapaxes(0, 1) 419 | new_fparams[('output_proj', 'bias')] = \ 420 | stack_copies(fparams[('output_proj', 'bias')], num_shards, axis=0) / 8 421 | new_fparams[('output_proj', 'kernel')] = \ 422 | fparams[('output_proj', 'kernel')].reshape(8, 768, 6144) 423 | new_fparams[('qkv_proj', 'bias')] = \ 424 | fparams[('qkv_proj', 'bias')].reshape(8, 2304) 425 | new_fparams[('qkv_proj', 'kernel')] = \ 426 | fparams[('qkv_proj', 'kernel')].reshape(6144, 8, 2304).swapaxes(0, 1) 427 | return new_fparams 428 | 429 | 430 | def load_model_weights_for_xmap( 431 | checkpoint_path, 432 | config: model.NeoX20BConfig = model.default_neox20b_config, 433 | pool_size=None): 434 | flattened_params = {} 435 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) 436 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) 437 | embed_in = np.concatenate([ 438 | loaded_tp1["word_embeddings.weight"], 439 | loaded_tp2["word_embeddings.weight"], 440 | ], axis=0) 441 | flattened_params[('embed_in', 'embed', 'kernel')] = embed_in.reshape(8, 6304, 6144) 442 | del loaded_tp1 443 | del loaded_tp2 444 | 445 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt")) 446 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt")) 447 | # noinspection PyDictCreation 448 | final_norm_bias = (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"]) / 2 449 | final_norm_scale = (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"]) / 2 450 | num_shards = 8 451 | flattened_params[('embed_out', 'norm', 'bias')] = \ 452 | stack_copies(final_norm_bias, num_shards, axis=0) 453 | flattened_params[('embed_out', 'norm', 'scale')] = \ 454 | stack_copies(final_norm_scale, num_shards, axis=0) 455 | del loaded_tp1 456 | del loaded_tp2 457 | 458 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt")) 459 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt")) 460 | embed_out = np.concatenate([ 461 | loaded_tp1["final_linear.weight"].T, 462 | loaded_tp2["final_linear.weight"].T, 463 | ], axis=1) 464 | flattened_params[('embed_out', 'embed_out', 'kernel')] = \ 465 | embed_out.reshape(6144, 8, 6304).swapaxes(0, 1) 466 | del loaded_tp1 467 | del loaded_tp2 468 | 469 | if pool_size is None: 470 | for layer_i in tqdm_lib.trange(config.num_layers): 471 | layer_params = load_single_layer_params_for_xmap( 472 | checkpoint_path=checkpoint_path, 473 | layer_i=layer_i, 474 | ) 475 | for k, v in layer_params.items(): 476 | flattened_params[(f"layer_{layer_i:02d}",) + k] = v 477 | else: 478 | pool = Pool(processes=pool_size) 479 | 480 | pool_args = [(checkpoint_path, layer_i) for layer_i in range(config.num_layers)] 481 | for layer_i, layer_params in tqdm_lib.tqdm(pool.imap(pool_func, pool_args), total=len(pool_args)): 482 | for k, v in layer_params.items(): 483 | flattened_params[(f"layer_{layer_i:02d}",) + k] = v 484 | 485 | params = traverse_util.unflatten_dict(flattened_params) 486 | return params 487 | 488 | 489 | def colab_load_model_weights_for_xmap( 490 | checkpoint_path, 491 | config: model.NeoX20BConfig = model.default_neox20b_config): 492 | mesh = get_colab_mesh() 493 | pbar = tqdm_lib.tqdm(total=47) 494 | 495 | # Embedding 496 | pbar.set_description("Loading embed_in") 497 | flattened_params = {} 498 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt")) 499 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt")) 500 | embed_in = np.concatenate([ 501 | loaded_tp1["word_embeddings.weight"], 502 | loaded_tp2["word_embeddings.weight"], 503 | ], axis=0) 504 | flattened_params[('embed_in', 'embed', 'kernel')] = shard_to_devices_v2( 505 | embed_in.reshape(8, 6304, 6144), mesh=mesh) 506 | del loaded_tp1 507 | del loaded_tp2 508 | pbar.update(1) 509 | 510 | # Load layers 511 | for layer_i in range(config.num_layers): 512 | pbar.set_description(f"Loading layer {layer_i}") 513 | layer_params = load_single_layer_params_for_xmap( 514 | checkpoint_path=checkpoint_path, 515 | layer_i=layer_i, 516 | ) 517 | for k, v in layer_params.items(): 518 | flattened_params[(f"layer_{layer_i:02d}",) + k] = shard_to_devices_v2(v, mesh=mesh) 519 | del layer_params 520 | pbar.update(1) 521 | 522 | # Final layer norm 523 | pbar.set_description(f"Load final layer norm") 524 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt")) 525 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt")) 526 | # noinspection PyDictCreation 527 | final_norm_bias = (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"]) / 2 528 | final_norm_scale = (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"]) / 2 529 | num_shards = 8 530 | flattened_params[('embed_out', 'norm', 'bias')] = shard_to_devices_v2( 531 | stack_copies(final_norm_bias, num_shards, axis=0), mesh=mesh) 532 | flattened_params[('embed_out', 'norm', 'scale')] = shard_to_devices_v2( 533 | stack_copies(final_norm_scale, num_shards, axis=0), mesh=mesh) 534 | del loaded_tp1 535 | del loaded_tp2 536 | pbar.update(1) 537 | 538 | # Output embeddings 539 | pbar.set_description(f"Load embed_out") 540 | loaded_tp1 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt")) 541 | loaded_tp2 = load_to_numpy(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt")) 542 | embed_out = np.concatenate([ 543 | loaded_tp1["final_linear.weight"].T, 544 | loaded_tp2["final_linear.weight"].T, 545 | ], axis=1) 546 | flattened_params[('embed_out', 'embed_out', 'kernel')] = shard_to_devices_v2( 547 | embed_out.reshape(6144, 8, 6304).swapaxes(0, 1), mesh=mesh) 548 | del loaded_tp1 549 | del loaded_tp2 550 | pbar.update(1) 551 | 552 | params = traverse_util.unflatten_dict(flattened_params) 553 | return params 554 | 555 | 556 | def get_colab_mesh(): 557 | return maps.Mesh(np.asarray(jax.local_devices()).reshape(1, 8), ('dp', 'tp')) 558 | 559 | 560 | def identity(x): 561 | return x 562 | 563 | 564 | def shard_to_devices_v2(x, mesh): 565 | shard_to = jax.experimental.maps.xmap( 566 | identity, 567 | in_axes=["shard", ...], 568 | out_axes=["shard", ...], 569 | axis_resources={'shard': 'tp', 'batch': 'dp'}, 570 | ) 571 | with mesh: 572 | return shard_to(x) 573 | 574 | 575 | def shard_to_devices_v3(x): 576 | assert x.shape[0] == 8 577 | return jax.device_put_sharded(list(x), devices=jax.local_devices()) 578 | 579 | 580 | def pool_func(arg): 581 | checkpoint_path, layer_i = arg 582 | loaded_layer = load_single_layer_params_for_xmap( 583 | checkpoint_path=checkpoint_path, 584 | layer_i=layer_i, 585 | ) 586 | return layer_i, loaded_layer 587 | 588 | 589 | def stack_copies(x, num, axis=0): 590 | return np.stack([x] * num, axis=axis) -------------------------------------------------------------------------------- /minimal20b_flax/generate.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import numpy as np 3 | import minimal20b_flax.model_xmap as model_xmap 4 | import minimal20b_flax.create as create 5 | 6 | 7 | # I have no idea if this helps 8 | CACHED_FUNCS = {} 9 | 10 | 11 | def generate(input_string: str, 12 | neox_model: model_xmap.GPTNeoX20BModel, 13 | params, 14 | tokenizer, 15 | maximum_context_length: int = None, 16 | rng: jax.random.PRNGKey = None, 17 | mesh=None): 18 | input_ids = tokenizer.encode(input_string).ids 19 | input_ctx_length = len(input_ids) 20 | # Specify a maximum_context_length to prevent re-jit-ing 21 | # Set it to None for the fastest inference for a fixed token length 22 | if maximum_context_length is not None: 23 | assert input_ctx_length < maximum_context_length 24 | padded_input_ids = np.zeros(maximum_context_length, dtype=int) 25 | padded_input_ids[-input_ctx_length:] = input_ids 26 | else: 27 | padded_input_ids = np.array([0] * neox_model.generate_length + input_ids) 28 | 29 | if rng is None: 30 | rng = jax.random.PRNGKey(np.random.randint(100000000)) 31 | elif isinstance(rng, int): 32 | rng = jax.random.PRNGKey(rng) 33 | 34 | if "generate" not in CACHED_FUNCS: 35 | CACHED_FUNCS["generate"] = jax.experimental.maps.xmap( 36 | neox_model.generate, 37 | in_axes=( 38 | ["shard", ...], 39 | [...], 40 | [...], 41 | [...], 42 | ), 43 | out_axes={ 44 | "generated_logits": [...], 45 | "generated_tokens": [...], 46 | 47 | }, 48 | axis_resources={'shard': 'tp', 'batch': 'dp'}, 49 | ) 50 | if mesh is None: 51 | mesh = create.get_colab_mesh() 52 | with mesh: 53 | output = CACHED_FUNCS["generate"]( 54 | params, 55 | padded_input_ids, 56 | input_ctx_length, 57 | rng, 58 | ) 59 | return { 60 | "generated_string": tokenizer.decode(output["generated_tokens"]), 61 | "generated_tokens": np.array(output["generated_tokens"]), 62 | "generated_logits": np.array(output["generated_logits"]), 63 | } 64 | -------------------------------------------------------------------------------- /minimal20b_flax/layernorm.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Iterable, Tuple, Union 2 | 3 | import flax.linen as nn 4 | from jax.nn import initializers 5 | import jax.numpy as jnp 6 | import jax.lax as lax 7 | 8 | PRNGKey = Any 9 | Array = Any 10 | Shape = Tuple[int, ...] 11 | Dtype = Any # this could be a real type? 12 | 13 | Axes = Union[int, Iterable[int]] 14 | 15 | 16 | class LayerNorm(nn.Module): 17 | epsilon: float = 1e-6 18 | dtype: Any = jnp.float32 19 | param_dtype: Dtype = jnp.float32 20 | use_bias: bool = True 21 | use_scale: bool = True 22 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 23 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 24 | 25 | @nn.compact 26 | def __call__(self, x): 27 | """Applies layer normalization on the input. 28 | 29 | Args: 30 | x: the inputs 31 | 32 | Returns: 33 | Normalized inputs (the same shape as inputs). 34 | """ 35 | reduction_axes = (-1,) 36 | feature_axes = (-1,) 37 | 38 | # TODO(jheek) suport axis_name for model parallelism? 39 | mean, var = _compute_stats(x, reduction_axes) 40 | 41 | return _normalize( 42 | self, x, mean, var, reduction_axes, feature_axes, 43 | self.dtype, self.param_dtype, self.epsilon, 44 | self.use_bias, self.use_scale, 45 | self.bias_init, self.scale_init) 46 | 47 | 48 | def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: 49 | """Returns a tuple of deduplicated, sorted, and positive axes.""" 50 | if not isinstance(axes, Iterable): 51 | axes = (axes,) 52 | return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) 53 | 54 | 55 | def _abs_sq(x): 56 | """Computes the elementwise square of the absolute value |x|^2.""" 57 | if jnp.iscomplexobj(x): 58 | return lax.square(lax.real(x)) + lax.square(lax.imag(x)) 59 | else: 60 | return lax.square(x) 61 | 62 | 63 | def _compute_stats(x: Array, axes: Axes): 64 | """Computes mean and variance statistics. 65 | 66 | This implementation takes care of a few important details: 67 | - Computes in float32 precision for half precision inputs 68 | - mean and variance is computable in a single XLA fusion, 69 | by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). 70 | - Clips negative variances to zero which can happen due to 71 | roundoff errors. This avoids downstream NaNs. 72 | - Supports averaging across a parallel axis and subgroups of a parallel axis 73 | with a single `lax.pmean` call to avoid latency. 74 | 75 | Arguments: 76 | x: Input array. 77 | axes: The axes in ``x`` to compute mean and variance statistics for. 78 | 79 | Returns: 80 | A pair ``(mean, var)``. 81 | """ 82 | # promote x to at least float32, this avoids half precision computation 83 | # but preserves double or complex floating points 84 | x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) 85 | mean = jnp.mean(x, axes) 86 | diff = x - mean[..., None] 87 | var = jnp.mean(_abs_sq(diff), axes) 88 | return mean, var 89 | 90 | 91 | def _normalize(mdl: nn.Module, x: Array, mean: Array, var: Array, 92 | reduction_axes: Axes, feature_axes: Axes, 93 | dtype: Dtype, param_dtype: Dtype, 94 | epsilon: float, 95 | use_bias: bool, use_scale: bool, 96 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array], 97 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array]): 98 | """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. 99 | 100 | Arguments: 101 | mdl: Module to apply the normalization in (normalization params will reside 102 | in this module). 103 | x: The input. 104 | mean: Mean to use for normalization. 105 | var: Variance to use for normalization. 106 | reduction_axes: The axes in ``x`` to reduce. 107 | feature_axes: Axes containing features. A separate bias and scale is learned 108 | for each specified feature. 109 | dtype: Dtype of the returned result. 110 | param_dtype: Dtype of the parameters. 111 | epsilon: Normalization epsilon. 112 | use_bias: If true, add a bias term to the output. 113 | use_scale: If true, scale the output. 114 | bias_init: Initialization function for the bias term. 115 | scale_init: Initialization function for the scaling function. 116 | 117 | Returns: 118 | The normalized input. 119 | """ 120 | reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) 121 | feature_axes = _canonicalize_axes(x.ndim, feature_axes) 122 | stats_shape = list(x.shape) 123 | for axis in reduction_axes: 124 | stats_shape[axis] = 1 125 | mean = mean.reshape(stats_shape) 126 | var = var.reshape(stats_shape) 127 | feature_shape = [1] * x.ndim 128 | reduced_feature_shape = [] 129 | for ax in feature_axes: 130 | feature_shape[ax] = x.shape[ax] 131 | reduced_feature_shape.append(x.shape[ax]) 132 | y = x - mean 133 | mul = lax.rsqrt(var + epsilon) 134 | if use_scale: 135 | scale = mdl.param('scale', scale_init, reduced_feature_shape, 136 | param_dtype).reshape(feature_shape) 137 | mul *= scale 138 | y *= mul 139 | if use_bias: 140 | bias = mdl.param('bias', bias_init, reduced_feature_shape, 141 | param_dtype).reshape(feature_shape) 142 | y += bias 143 | return jnp.asarray(y, dtype) 144 | -------------------------------------------------------------------------------- /minimal20b_flax/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from einops import repeat 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | # noinspection PyPep8Naming 7 | from jax.experimental import PartitionSpec as P 8 | from jax.experimental.pjit import pjit 9 | 10 | import flax.linen as nn 11 | from flax.linen.partitioning import with_sharding_constraint as shard_to 12 | from flax import struct 13 | from flax.core import frozen_dict 14 | from flax import traverse_util 15 | 16 | 17 | @struct.dataclass 18 | class NeoX20BConfig: 19 | vocab_size: int = 50432 20 | hidden_size: int = 6144 21 | num_attention_heads: int = 64 22 | rotary_pct: float = 0.25 23 | rotary_emb_base: int = 10000 24 | layernorm_epsilon: float = 1e-5 25 | num_layers: int = 44 26 | tp_num: int = 8 27 | 28 | 29 | default_neox20b_config = NeoX20BConfig() 30 | 31 | 32 | @struct.dataclass 33 | class GPTNeoX20BModel: 34 | 35 | config: NeoX20BConfig = default_neox20b_config 36 | 37 | def _eval_apply_fn(self, params, x, mask): 38 | embedded = ShardedEmbedIn(config=self.config).apply({"params": params["embed_in"]}, x) 39 | 40 | def _transformer_layer_scan_fn(layer_in, layer_params): 41 | h_, mask_ = layer_in["h"], layer_in["mask"] 42 | h_out = h_ + ShardedTransformerLayer(config=self.config).apply( 43 | {"params": layer_params}, 44 | h_, mask_, 45 | ) 46 | return {"h": h_out, "mask": mask_}, None 47 | 48 | layers_out, _ = jax.lax.scan( 49 | f=_transformer_layer_scan_fn, 50 | init={"h": embedded, "mask": mask}, 51 | xs=params["transformer"] 52 | ) 53 | layers_out = layers_out["h"] 54 | return ShardedEmbedOut(config=self.config).apply({"params": params["embed_out"]}, layers_out) 55 | 56 | def eval_apply_fn_pjit(self): 57 | return pjit( 58 | self._eval_apply_fn, 59 | in_axis_resources=( 60 | self.get_sharding(), # params 61 | P("dp", None), # input [batch, seq_len] 62 | P("dp", None, None), # mask [batch, seq_len, seq_len] 63 | ), 64 | out_axis_resources=P("dp", None, "tp"), # [batch, seq_len, hidden] 65 | ) 66 | 67 | def _get_initial_decode_state(self, params, ctx, ctx_length): 68 | # Embed initial context 69 | embedded = ShardedEmbedIn(config=self.config).apply( 70 | {"params": params["embed_in"]}, ctx) 71 | 72 | # Set up scan function for creating decode_states for each layer 73 | def _transformer_layer_init_decode_scan_fn(layer_in, layer_params): 74 | h_, ctx_length_ = layer_in["h"], layer_in["ctx_length"] 75 | new_residual, decode_state = ShardedTransformerLayer(config=self.config).apply( 76 | {"params": layer_params}, 77 | h_, ctx_length, 78 | method=ShardedTransformerLayer.get_init_decode_state, 79 | ) 80 | h_out = h_ + new_residual 81 | return {"h": h_out, "ctx_length": ctx_length_}, decode_state 82 | 83 | # Run scan over transformer layers 84 | layers_out, init_state = jax.lax.scan( 85 | f=_transformer_layer_init_decode_scan_fn, 86 | init={"h": embedded, "ctx_length": ctx_length}, 87 | xs=params["transformer"], 88 | ) 89 | final_logit = ShardedEmbedOut(config=self.config).apply( 90 | {"params": params["embed_out"]}, 91 | layers_out["h"][:, -1:, :], 92 | ) 93 | 94 | return {"logits": final_logit, "decode_state": init_state} 95 | 96 | def get_initial_decode_state_pjit(self): 97 | return pjit( 98 | self._get_initial_decode_state, 99 | in_axis_resources=( 100 | self.get_sharding(), 101 | P("dp", None), # input_ids [batch, seq_len] 102 | P("dp"), # ctx_length [batch] 103 | ), 104 | out_axis_resources={ 105 | "logits": P("dp", None, "tp"), 106 | "decode_state": self.get_decode_state_sharding(), 107 | } 108 | ) 109 | 110 | def _decode_once(self, params, single_x, decode_state): 111 | assert single_x.shape[1] == 1 112 | # Embed single token 113 | embedded = ShardedEmbedIn(config=self.config).apply( 114 | {"params": params["embed_in"]}, single_x) 115 | 116 | # Set up scan function for doing a single decoding step for each layer 117 | def _transformer_layer_decode_once_scan_fn(h, layer_params_and_decode_state): 118 | layer_params, layer_decode_state = layer_params_and_decode_state 119 | new_residual, new_layer_decode_state = ShardedTransformerLayer(config=self.config).apply( 120 | {"params": layer_params}, 121 | layer_decode_state, h, 122 | method=ShardedTransformerLayer.decode_once, 123 | ) 124 | h_out = h + new_residual 125 | return h_out, new_layer_decode_state 126 | 127 | # Run scan over transformer layers 128 | layers_out, new_decode_state = jax.lax.scan( 129 | f=_transformer_layer_decode_once_scan_fn, 130 | init=embedded, 131 | xs=(params["transformer"], decode_state), 132 | ) 133 | 134 | # Project to logits 135 | logits = ShardedEmbedOut(config=self.config).apply( 136 | {"params": params["embed_out"]}, 137 | layers_out, 138 | ) 139 | return { 140 | "logits": logits, 141 | "new_decode_state": new_decode_state, 142 | } 143 | 144 | def decode_once_pjit(self): 145 | decode_state_sharding = self.get_decode_state_sharding() 146 | return pjit( 147 | self._decode_once, 148 | in_axis_resources=( 149 | self.get_sharding(), 150 | P("dp"), # input_ids [batch, seq_len] 151 | decode_state_sharding, # decode_state 152 | ), 153 | out_axis_resources={ 154 | "logits": P("dp", None, "tp"), 155 | "new_decode_state": decode_state_sharding, 156 | } 157 | ) 158 | 159 | def _generate(self, params, ctx, ctx_length, rng, generate_length, sampler_args): 160 | init_out = self._get_initial_decode_state( 161 | params=params, 162 | ctx=ctx, 163 | ctx_length=ctx_length 164 | ) 165 | # Add sampling logic here 166 | initial_token = init_out["logits"].argmax(-1) 167 | 168 | init_carry = { 169 | "single_x": initial_token, 170 | "decode_state": init_out["decode_state"], 171 | } 172 | 173 | def _decode_once_scan_fn(decode_carry, step_rng): 174 | decode_out = self._decode_once( 175 | params=params, 176 | single_x=decode_carry["single_x"], 177 | decode_state=decode_carry["decode_state"], 178 | ) 179 | 180 | # Add sampling logic here 181 | next_token = decode_out["logits"].argmax(-1) 182 | # next_token = temperature_sample( 183 | # key=step_rng, 184 | # logits=decode_out["logits"], 185 | # **sampler_args, 186 | # ) 187 | 188 | next_carry = { 189 | "single_x": next_token, 190 | "decode_state": decode_out["new_decode_state"] 191 | } 192 | outputs = { 193 | "logits": decode_out["logits"], 194 | "next_token": next_token, 195 | } 196 | return next_carry, outputs 197 | 198 | final_state, generation_outputs = jax.lax.scan( 199 | f=_decode_once_scan_fn, 200 | init=init_carry, 201 | xs=jax.random.split(rng, generate_length), 202 | ) 203 | tokens = generation_outputs["next_token"].swapaxes(0, 1)[:, :, 0] 204 | logits = generation_outputs["logits"].swapaxes(0, 1)[:, :, 0] 205 | return { 206 | # "final_state": final_state, 207 | "generated_logits": jnp.concatenate((init_out["logits"], logits), axis=1), 208 | "generated_tokens": jnp.concatenate((initial_token, tokens), axis=1), 209 | "final_state": final_state, 210 | "initial_state": init_out["decode_state"], 211 | } 212 | 213 | def generate_pjit(self): 214 | return pjit( 215 | self._generate, 216 | in_axis_resources=( 217 | self.get_sharding(), 218 | P("dp", None), # ctx [batch, seq_len] 219 | P("dp"), # ctx_length [batch] 220 | None, 221 | ), 222 | out_axis_resources={ 223 | "generated_logits": P("dp", None, "tp"), 224 | "generated_tokens": P("dp"), 225 | "final_state": P("dp"), 226 | "initial_state": P("dp"), 227 | }, 228 | static_argnums=(4, 5), 229 | ) 230 | 231 | @staticmethod 232 | def get_decode_state_sharding(): 233 | return { 234 | "tokens_decoded": P(None, "dp"), # [num_layers, batch] 235 | "k": P(None, "dp", None, "tp", None), # [num_layers, batch, seq_len, heads, dim_per_head] 236 | "v": P(None, "dp", None, "tp", None), # [num_layers, batch, seq_len, heads, dim_per_head] 237 | } 238 | 239 | @staticmethod 240 | def get_sharding(): 241 | # 1. embed_in sharding 242 | embed_in_sharding = frozen_dict.freeze(traverse_util.unflatten_dict({ 243 | ("embed", "kernel"): P("tp", None), 244 | })) 245 | 246 | # 2. layer_sharding 247 | flat_stacked_layers_sharding = { 248 | ('attn_norm', 'bias'): P(None, None, ), 249 | ('attn_norm', 'scale'): P(None, None, ), 250 | ('qkv_proj', 'bias'): P(None, None, ), 251 | ('qkv_proj', 'kernel'): P(None, None, 'tp'), 252 | ('output_proj', 'bias'): P(None, None, ), 253 | ('output_proj', 'kernel'): P(None, 'tp', None), 254 | ('ff_norm', 'bias'): P(None, None, ), 255 | ('ff_norm', 'scale'): P(None, None, ), 256 | ('ff_up_proj', 'bias'): P(None, None, ), 257 | ('ff_up_proj', 'kernel'): P(None, None, 'tp'), 258 | ('ff_down_proj', 'bias'): P(None, None), 259 | ('ff_down_proj', 'kernel'): P(None, 'tp', None), 260 | } 261 | stacked_layers_sharding = frozen_dict.freeze(traverse_util.unflatten_dict( 262 | flat_stacked_layers_sharding)) 263 | 264 | # 3. embed_out sharding 265 | embed_out_sharding = { 266 | ('norm', 'bias'): P(None), 267 | ('norm', 'scale'): P(None), 268 | ('embed_out', 'kernel'): P(None, "tp"), 269 | } 270 | embed_out_sharding = frozen_dict.freeze(traverse_util.unflatten_dict(embed_out_sharding)) 271 | 272 | # 4. Combine 273 | all_sharding = frozen_dict.freeze({ 274 | "embed_in": embed_in_sharding, 275 | "transformer": stacked_layers_sharding, 276 | "embed_out": embed_out_sharding, 277 | }) 278 | return all_sharding 279 | 280 | 281 | class ShardedEmbedIn(nn.Module): 282 | 283 | config: NeoX20BConfig = default_neox20b_config 284 | 285 | @nn.compact 286 | def __call__(self, input_ids): 287 | onehot_inputs = jax.nn.one_hot(input_ids, self.config.vocab_size, dtype=jnp.float16) 288 | onehot_inputs = shard_to(onehot_inputs, P("dp", None, "tp")) 289 | embedded = nn.Dense( 290 | features=self.config.hidden_size, 291 | use_bias=False, 292 | name="embed", 293 | dtype=jnp.float16, 294 | )(onehot_inputs) 295 | return embedded 296 | 297 | 298 | class ShardedTransformerLayer(nn.Module): 299 | """Sharded Transformer Layer. 300 | 301 | Note: This doesn't compute the full residual connection x + r(x), only r(x). 302 | The residual connection will be computed downstream. 303 | """ 304 | config: NeoX20BConfig = default_neox20b_config 305 | 306 | # noinspection PyAttributeOutsideInit 307 | def setup(self): 308 | config = self.config 309 | self.attn_norm = nn.LayerNorm(epsilon=config.layernorm_epsilon, dtype=jnp.float16) 310 | self.ff_norm = nn.LayerNorm(epsilon=config.layernorm_epsilon, dtype=jnp.float16) 311 | self.qkv_proj = nn.Dense( 312 | config.hidden_size * 3, 313 | name="qkv_proj", 314 | dtype=jnp.float16, 315 | ) 316 | self.output_proj = nn.Dense( 317 | config.hidden_size, 318 | name="output_proj", 319 | dtype=jnp.float16, 320 | ) 321 | self.ff_up_proj = nn.Dense( 322 | config.hidden_size * 4, 323 | name="ff_up_proj", 324 | dtype=jnp.float16, 325 | ) 326 | self.ff_down_proj = nn.Dense( 327 | config.hidden_size, 328 | name="ff_down_proj", 329 | dtype=jnp.float16, 330 | ) 331 | 332 | def __call__(self, x, attn_bias): 333 | """ 334 | :param x: [batch, seq_len, hidden_size] 335 | :param attn_bias: [*, seq_len, seq_len] 336 | :return: [batch, seq_len, hidden_size] 337 | """ 338 | attn_in = self.attn_norm(x) 339 | # -> [batch, seq_len, hidden_size] 340 | 341 | q, k, v = self.compute_qkv(attn_in) 342 | # -> 3 x [batch, seq_len, heads, dims_per_head] 343 | 344 | seq_len = attn_in.shape[1] 345 | causal_mask = np.tril(np.ones((seq_len, seq_len)))[None, :, :] # NumPy array gets cached 346 | # -> [1, seq_len, seq_len] 347 | 348 | bias = -1e4 * (1. - causal_mask) 349 | bias += attn_bias 350 | # -> [1, seq_len, seq_len] 351 | 352 | attn_out = self.compute_self_attn(q, k, v, bias) 353 | # -> [batch, seq_len, hidden] 354 | 355 | ff_out = self.compute_ff(x) 356 | # -> [batch, seq_len, hidden] 357 | 358 | return attn_out + ff_out 359 | 360 | def split_heads(self, x): 361 | config = self.config 362 | dims_per_head = config.hidden_size // config.num_attention_heads 363 | # reshaped = x.reshape(x.shape[:-1] + (heads_per_device, dims_per_head)) 364 | reshaped = x.reshape(x.shape[:-2] + (config.num_attention_heads, dims_per_head)) 365 | # reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + x.shape[-1:]) 366 | return shard_to(reshaped, P("dp", None, "tp", None)) 367 | 368 | def compute_qkv(self, x): 369 | config = self.config 370 | # [batch, seq, qkv_dims] 371 | qkv_arr = self.qkv_proj(x) 372 | 373 | # [batch, seq, mp, dim//mp] 374 | qkv_arr = shard_to(qkv_arr, P("dp", None, "tp")) 375 | mp_split = jnp.reshape(qkv_arr, qkv_arr.shape[:-1] + (config.tp_num, -1)) 376 | mp_split = shard_to(mp_split, P("dp", None, "tp", None)) 377 | 378 | local_dim = config.hidden_size // config.tp_num 379 | 380 | q, k, v = jnp.split(mp_split, [local_dim, local_dim * 2], axis=-1) 381 | 382 | q = self.split_heads(q) 383 | k = self.split_heads(k) 384 | v = self.split_heads(v) 385 | # -> 386 | 387 | return q, k, v 388 | 389 | def compute_self_attn(self, q, k, v, attn_bias): 390 | """ 391 | :param q: [batch, q_len, heads, dims_per_head] 392 | :param k: [batch, kv_len, heads, dims_per_head] 393 | :param v: [batch, kv_len, heads, dims_per_head] 394 | :param attn_bias: [*, q_len, kv_len] 395 | :return: [batch, q_len, hidden] 396 | """ 397 | config = self.config 398 | rotary_dims = int(config.hidden_size // config.num_attention_heads * config.rotary_pct) 399 | k_rot = k[:, :, :, :rotary_dims] 400 | k_pass = k[:, :, :, rotary_dims:] 401 | 402 | q_rot = q[:, :, :, :rotary_dims] 403 | q_pass = q[:, :, :, rotary_dims:] 404 | 405 | sincos = fixed_pos_embedding(k_rot, seq_dim=1) 406 | # return sincos 407 | q_rot = apply_rotary_pos_emb(q_rot, sincos) 408 | k_rot = apply_rotary_pos_emb(k_rot, sincos) 409 | q_rot = shard_to(q_rot, P("dp", None, "tp", None)) 410 | k_rot = shard_to(k_rot, P("dp", None, "tp", None)) 411 | 412 | k = jnp.concatenate([k_rot, k_pass], axis=-1) 413 | q = jnp.concatenate([q_rot, q_pass], axis=-1) 414 | 415 | k = shard_to(k, P("dp", None, "tp", None)) 416 | q = shard_to(q, P("dp", None, "tp", None)) 417 | 418 | attention_logits = jnp.einsum("bthd,bThd->bhtT", q, k) 419 | attention_logits = shard_to(attention_logits, P("dp", "tp", None, None)) 420 | # -> [batch, heads, q_len, kv_len] 421 | 422 | sqrt_key_size = np.sqrt(config.hidden_size // config.num_attention_heads).astype(k.dtype) 423 | attention_logits = attention_logits / sqrt_key_size 424 | attention_logits += attn_bias 425 | attention_logits = shard_to(attention_logits, P("dp", "tp", None, None)) 426 | # -> [batch, heads, q_len, kv_len] 427 | 428 | attention_weights = jax.nn.softmax(attention_logits) 429 | attention_weights = shard_to(attention_weights, P("dp", "tp", None, None)) 430 | # -> [batch, heads, q_len, kv_len] 431 | 432 | attention_vec = jnp.einsum("bhtT,bThd->bthd", attention_weights, v) 433 | attention_vec = shard_to(attention_vec, P("dp", None, "tp", None)) 434 | # -> [batch, q_len, heads, dims_per_head] 435 | 436 | attention_vec = attention_vec.reshape(attention_vec.shape[:2] + (-1,)) 437 | attention_vec = shard_to(attention_vec, P("dp", None, "tp")) 438 | # -> [batch, q_len, hidden] 439 | 440 | attn_out = self.output_proj(attention_vec) 441 | attn_out = shard_to(attn_out, P("dp", None, "tp")) 442 | # -> [batch, q_len, hidden] 443 | 444 | return attn_out 445 | 446 | def compute_ff(self, x): 447 | ff_out = self.ff_norm(x) 448 | ff_out = self.ff_up_proj(ff_out) 449 | ff_out = shard_to(ff_out, P("dp", None, "tp")) 450 | ff_out = jax.nn.gelu(ff_out) 451 | ff_out = self.ff_down_proj(ff_out) 452 | ff_out = shard_to(ff_out, P("dp", None, "tp")) 453 | return ff_out 454 | 455 | # iterate the decoding process by a single token 456 | def decode_once(self, decode_state, x): 457 | """ 458 | :param decode_state: 459 | :param x: [batch, q_len=1, hidden_size] 460 | """ 461 | attn_in = self.attn_norm(x) 462 | # -> [batch, q_len=1, hidden_size] 463 | q, v, k = self.compute_qkv(attn_in) 464 | # -> 3 x [batch, q_len=1, heads, dims_per_head] 465 | 466 | # add new kv to end, clip off the start 467 | v = jnp.concatenate((decode_state["v"], v), axis=1)[:, 1:] 468 | # -> [batch, kv_len+1, heads, dims_per_head] 469 | k = jnp.concatenate((decode_state["k"], k), axis=1)[:, 1:] 470 | # -> [batch, kv_len+1, heads, dims_per_head] 471 | 472 | tokens_decoded = decode_state["tokens_decoded"] + 1 473 | # -> [batch] 474 | 475 | length = v.shape[1] 476 | masked_tokens = (length - tokens_decoded)[:, None] 477 | # -> [batch, 1] 478 | attention_mask = (jnp.arange(0, length)[None, :] < masked_tokens)[:, None, :] 479 | # -> [batch, q_len=1, seq_len] 480 | 481 | bias = (-1e4 * attention_mask) 482 | # -> [batch, q_len=1, seq_len] 483 | 484 | attn_out = self.compute_self_attn(q, v, k, bias[:, None, :, :]) 485 | # -> 3 x [batch, q_len=1, hidden] 486 | 487 | ff_out = self.compute_ff(x) 488 | # -> 3 x [batch, q_len=1, hidden] 489 | 490 | return (attn_out + ff_out), { 491 | "tokens_decoded": tokens_decoded, 492 | "k": k, 493 | "v": v 494 | } 495 | 496 | # take in right aligned context tokens and generate an initial state 497 | def get_init_decode_state(self, x, given_length): 498 | """ 499 | :param x: [batch, seq_len, hidden_size] 500 | :param given_length: [batch] 501 | """ 502 | attn_in = self.attn_norm(x) 503 | # -> [batch, seq_len, hidden_size] 504 | q, v, k = self.compute_qkv(attn_in) 505 | # -> 3 x [batch, seq_len, heads, dims_per_head] 506 | 507 | batch_size, full_length = x.shape[0], x.shape[1] 508 | masked_tokens = (full_length - given_length)[:, None] 509 | # -> [batch, 1] 510 | 511 | causal_mask = np.tril(np.ones((full_length, full_length))) 512 | # -> [seq_len, seq_len] 513 | 514 | bias = -1e4 * (1. - causal_mask) # regular AR masking 515 | bias = jnp.repeat(bias[None, :], repeats=batch_size, axis=0) 516 | # -> [batch, seq_len, seq_len] 517 | 518 | context_length_mask = (jnp.arange(0, full_length)[None, :] < masked_tokens)[:, None, :] 519 | # -> [batch, 1, seq_len] 520 | 521 | bias -= 1e4 * context_length_mask # mask out zero tokens before context starts 522 | # -> [batch, seq_len, seq_len] 523 | 524 | attn_out = self.compute_self_attn(q, v, k, bias[:, None, :, :]) 525 | # -> [batch, seq_len, hidden] 526 | 527 | ff_out = self.compute_ff(x) 528 | # -> [batch, seq_len, hidden] 529 | 530 | return (attn_out + ff_out), { 531 | "tokens_decoded": given_length.astype(jnp.uint32), 532 | "k": k, 533 | "v": v, 534 | } 535 | 536 | 537 | class ShardedEmbedOut(nn.Module): 538 | 539 | config: NeoX20BConfig = default_neox20b_config 540 | 541 | # noinspection PyAttributeOutsideInit 542 | def setup(self): 543 | config = self.config 544 | self.norm = nn.LayerNorm(epsilon=config.layernorm_epsilon, dtype=jnp.float16) 545 | self.embed_out = nn.Dense(config.vocab_size, use_bias=False, dtype=jnp.float16) 546 | 547 | def __call__(self, x): 548 | return self.predict(x) 549 | 550 | def predict(self, x): 551 | x = self.norm(x) 552 | x = shard_to(x, P("dp", None, None)) 553 | out = self.embed_out(x) 554 | out = shard_to(out, P("dp", None, "tp")) 555 | return out 556 | 557 | def loss(self, x, targets, z_loss=1): 558 | logits = self.predict(x) 559 | targets_onehot = jax.nn.one_hot(targets, self.dim, dtype=jnp.float16) 560 | logits_for_targets = jnp.sum(jnp.multiply(targets_onehot, logits), axis=-1) 561 | 562 | # softmax denominator 563 | exp_logits = jnp.exp(logits) 564 | sum_exp_logits = exp_logits.sum(axis=-1) 565 | 566 | # compute loss 567 | loss = jnp.log(sum_exp_logits) - logits_for_targets 568 | loss += (1e-4 * jnp.square(jnp.log(sum_exp_logits)) * z_loss).mean() 569 | correct = (0.0 == logits_for_targets) 570 | return loss, correct 571 | 572 | 573 | def fixed_pos_embedding(x, seq_dim=0): 574 | dim = x.shape[-1] 575 | inv_freq = 1. / (10000 ** (np.arange(0, dim, 2) / dim)) .astype(np.float16) 576 | sinusoid_inp = np.einsum('i , j -> i j', np.arange(x.shape[seq_dim]), inv_freq) .astype(np.float16) 577 | return np.sin(sinusoid_inp).astype(np.float16), np.cos(sinusoid_inp).astype(np.float16) 578 | 579 | 580 | def apply_rotary_pos_emb(x, sincos): 581 | sin, cos = map(lambda t: repeat(t, '... b n -> ... b (j n)', j=2)[-x.shape[-3]:, None, :], sincos) 582 | return (x * cos) + (rotate_half(x) * sin) 583 | 584 | 585 | def rotate_half(x): 586 | half_dim = x.shape[-1] // 2 587 | x1 = x[:, :, :, :half_dim] 588 | x2 = x[:, :, :, half_dim:] 589 | return jnp.concatenate((-x2, x1), axis=-1) 590 | 591 | 592 | def temperature_sample(key, logits, temp=1): 593 | return jax.random.categorical(key, logits/temp, -1).astype(jnp.int32) 594 | -------------------------------------------------------------------------------- /minimal20b_flax/model_xmap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from einops import repeat 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | # noinspection PyPep8Naming 7 | from jax.nn.initializers import zeros 8 | 9 | import flax.linen as nn 10 | from flax import struct 11 | from flax.core import frozen_dict 12 | from minimal20b_flax.utils import f_psum, g_psum 13 | 14 | 15 | @struct.dataclass 16 | class NeoX20BConfig: 17 | vocab_size: int = 50432 18 | hidden_size: int = 6144 19 | num_attention_heads: int = 64 20 | rotary_pct: float = 0.25 21 | rotary_emb_base: int = 10000 22 | layernorm_epsilon: float = 1e-5 23 | num_layers: int = 44 24 | tp_num: int = 8 25 | 26 | 27 | default_neox20b_config = NeoX20BConfig() 28 | 29 | 30 | @struct.dataclass 31 | class GPTNeoX20BModel: 32 | 33 | config: NeoX20BConfig = default_neox20b_config 34 | generate_length: int = 2048 35 | sampler_args: dict = frozen_dict.FrozenDict({"temp": 1.}) 36 | 37 | def eval(self, params, x, mask): 38 | embedded = ShardedEmbedIn(config=self.config).apply({"params": params["embed_in"]}, x) 39 | h = embedded 40 | for layer_i in range(self.config.num_layers): 41 | residual = ShardedTransformerLayer(config=self.config).apply( 42 | {"params": params[f"layer_{layer_i:02d}"]}, h, mask 43 | ) 44 | h = h + residual 45 | return ShardedEmbedOut(config=self.config).apply({"params": params["embed_out"]}, h) 46 | 47 | def get_batch_eval_fn(self): 48 | return jax.vmap(self.eval, in_axes=[None, 0, 0]) 49 | 50 | def get_initial_decode_state(self, params, ctx, ctx_length): 51 | # Embed initial context 52 | embedded = ShardedEmbedIn(config=self.config).apply( 53 | {"params": params["embed_in"]}, ctx) 54 | 55 | # Set up scan function for creating decode_states for each layer 56 | decode_state_list = [] 57 | h = embedded 58 | for layer_i in range(self.config.num_layers): 59 | new_residual, layer_decode_state = ShardedTransformerLayer(config=self.config).apply( 60 | {"params": params[f"layer_{layer_i:02d}"]}, 61 | h, ctx_length, 62 | method=ShardedTransformerLayer.get_init_decode_state, 63 | ) 64 | decode_state_list.append(layer_decode_state) 65 | h = h + new_residual 66 | final_logit = ShardedEmbedOut(config=self.config).apply( 67 | {"params": params["embed_out"]}, 68 | h[-1:, :], 69 | ) 70 | return {"logits": final_logit, "decode_state": decode_state_list} 71 | 72 | def decode_once(self, params, single_x, decode_state): 73 | assert single_x.shape[0] == 1 74 | # Embed single token 75 | embedded = ShardedEmbedIn(config=self.config).apply( 76 | {"params": params["embed_in"]}, single_x) 77 | h = embedded 78 | new_decode_state_list = [] 79 | for layer_i in range(self.config.num_layers): 80 | new_residual, new_layer_decode_state = ShardedTransformerLayer(config=self.config).apply( 81 | {"params": params[f"layer_{layer_i:02d}"]}, 82 | decode_state[layer_i], h, 83 | method=ShardedTransformerLayer.decode_once, 84 | ) 85 | h = h + new_residual 86 | new_decode_state_list.append(new_layer_decode_state) 87 | # Project to logits 88 | logits = ShardedEmbedOut(config=self.config).apply( 89 | {"params": params["embed_out"]}, 90 | h, 91 | ) 92 | return { 93 | "logits": logits, 94 | "new_decode_state": new_decode_state_list, 95 | } 96 | 97 | def generate(self, params, ctx, ctx_length, rng): 98 | init_out = self.get_initial_decode_state( 99 | params=params, 100 | ctx=ctx, 101 | ctx_length=ctx_length 102 | ) 103 | # Add sampling logic here 104 | init_logits = init_out["logits"].swapaxes(0, 1).reshape(1, self.config.vocab_size) 105 | initial_token = init_logits.argmax(-1) 106 | 107 | init_carry = { 108 | "single_x": initial_token, 109 | "decode_state": init_out["decode_state"], 110 | } 111 | 112 | def _decode_once_scan_fn(decode_carry, step_rng): 113 | decode_out = self.decode_once( 114 | params=params, 115 | single_x=decode_carry["single_x"], 116 | decode_state=decode_carry["decode_state"], 117 | ) 118 | 119 | # Add sampling logic here 120 | decode_out_logits = decode_out["logits"].swapaxes(0, 1).reshape(1, self.config.vocab_size) 121 | next_token = temperature_sample( 122 | key=step_rng, 123 | logits=decode_out["logits"].swapaxes(0, 1).reshape(1, self.config.vocab_size), 124 | **self.sampler_args, 125 | ) 126 | 127 | next_carry = { 128 | "single_x": next_token, 129 | "decode_state": decode_out["new_decode_state"] 130 | } 131 | outputs = { 132 | "logits": decode_out_logits, 133 | "next_token": next_token, 134 | } 135 | return next_carry, outputs 136 | 137 | final_state, generation_outputs = jax.lax.scan( 138 | f=_decode_once_scan_fn, 139 | init=init_carry, 140 | xs=jax.random.split(rng, self.generate_length), 141 | ) 142 | 143 | tokens = generation_outputs["next_token"][:, 0] 144 | logits = generation_outputs["logits"][:, 0] 145 | return { 146 | # "final_state": final_state, 147 | "generated_logits": jnp.concatenate((init_logits, logits), axis=0), 148 | "generated_tokens": jnp.concatenate((initial_token, tokens), axis=0), 149 | } 150 | 151 | 152 | class ShardedEmbedIn(nn.Module): 153 | 154 | config: NeoX20BConfig = default_neox20b_config 155 | 156 | @nn.compact 157 | def __call__(self, input_ids): 158 | config = self.config 159 | dims_per_shard = self.config.vocab_size // config.tp_num 160 | shard_start_index = jax.lax.axis_index('shard') * dims_per_shard 161 | 162 | # TODO: Check if this still works 163 | input_onehot = jax.nn.one_hot(input_ids - shard_start_index, dims_per_shard, dtype=jnp.float16) 164 | embedded = nn.Dense( 165 | features=config.hidden_size, 166 | kernel_init=zero_init_fp16(), 167 | use_bias=False, 168 | name="embed", 169 | dtype=jnp.float16, 170 | )(input_onehot) 171 | embedded = g_psum(embedded) 172 | return embedded 173 | 174 | 175 | class ShardedTransformerLayer(nn.Module): 176 | """Sharded Transformer Layer. 177 | 178 | Note: This doesn't compute the full residual connection x + r(x), only r(x). 179 | The residual connection will be computed downstream. 180 | """ 181 | config: NeoX20BConfig = default_neox20b_config 182 | 183 | # noinspection PyAttributeOutsideInit 184 | def setup(self): 185 | config = self.config 186 | self.dims_per_head = config.hidden_size // config.num_attention_heads 187 | self.heads_per_shard = config.num_attention_heads // config.tp_num 188 | self.dims_per_shard = config.hidden_size // config.tp_num 189 | self.attn_norm = nn.LayerNorm(epsilon=config.layernorm_epsilon, dtype=jnp.float16) 190 | self.ff_norm = nn.LayerNorm(epsilon=config.layernorm_epsilon, dtype=jnp.float16) 191 | self.qkv_proj = nn.Dense( 192 | self.dims_per_shard * 3, 193 | name="qkv_proj", 194 | dtype=jnp.float16, 195 | kernel_init=zero_init_fp16(), 196 | bias_init=zero_init_fp16(), 197 | ) 198 | self.output_proj = nn.Dense( 199 | config.hidden_size, 200 | name="output_proj", 201 | dtype=jnp.float16, 202 | kernel_init=zero_init_fp16(), 203 | bias_init=zero_init_fp16(), 204 | ) 205 | self.ff_up_proj = nn.Dense( 206 | self.dims_per_shard * 4, 207 | name="ff_up_proj", 208 | dtype=jnp.float16, 209 | kernel_init=zero_init_fp16(), 210 | bias_init=zero_init_fp16(), 211 | ) 212 | self.ff_down_proj = nn.Dense( 213 | config.hidden_size, 214 | name="ff_down_proj", 215 | dtype=jnp.float16, 216 | kernel_init=zero_init_fp16(), 217 | bias_init=zero_init_fp16(), 218 | ) 219 | 220 | def __call__(self, x, attn_bias): 221 | """ 222 | :param x: [seq_len, hidden_size] 223 | :param attn_bias: [*, seq_len, seq_len] 224 | :return: [seq_len, hidden_size] 225 | """ 226 | attn_in = self.attn_norm(x) 227 | # -> [seq_len, hidden_size] 228 | 229 | q, k, v = self.compute_qkv(attn_in) 230 | # -> 3 x [seq_len, heads, dims_per_head] 231 | 232 | seq_len = attn_in.shape[0] 233 | causal_mask = np.tril(np.ones((seq_len, seq_len)))[None, :, :] # NumPy array gets cached 234 | # -> [1, seq_len, seq_len] 235 | 236 | bias = -1e4 * (1. - causal_mask) 237 | bias += attn_bias 238 | # -> [1, seq_len, seq_len] 239 | 240 | attn_out = self.compute_self_attn(q, k, v, bias) 241 | # -> [seq_len, hidden] 242 | 243 | ff_out = self.compute_ff(x) 244 | # -> [seq_len, hidden] 245 | 246 | return attn_out + ff_out 247 | 248 | def split_heads(self, x): 249 | reshaped = x.reshape(x.shape[:-1] + (self.heads_per_shard, self.dims_per_head)) 250 | return reshaped 251 | 252 | def compute_qkv(self, x): 253 | # [seq, 3*dims_per_shard] 254 | qkv_arr = self.qkv_proj(x) 255 | q, k, v = jnp.split(qkv_arr, [self.dims_per_shard, self.dims_per_shard * 2], axis=-1) 256 | 257 | # [seq, heads, dims_per_head] 258 | q = self.split_heads(q) 259 | k = self.split_heads(k) 260 | v = self.split_heads(v) 261 | 262 | return q, k, v 263 | 264 | def compute_self_attn(self, q, k, v, attn_bias): 265 | """ 266 | :param q: [q_len, heads, dims_per_head] 267 | :param k: [kv_len, heads, dims_per_head] 268 | :param v: [kv_len, heads, dims_per_head] 269 | :param attn_bias: [*, q_len, kv_len] 270 | :return: [q_len, hidden] 271 | """ 272 | config = self.config 273 | rotary_dims = int(config.hidden_size // config.num_attention_heads * config.rotary_pct) 274 | k_rot = k[..., :rotary_dims] 275 | k_pass = k[..., rotary_dims:] 276 | 277 | q_rot = q[..., :rotary_dims] 278 | q_pass = q[..., rotary_dims:] 279 | 280 | sincos = fixed_pos_embedding(k_rot, seq_dim=0) 281 | # return sincos 282 | q_rot = apply_rotary_pos_emb(q_rot, sincos) 283 | k_rot = apply_rotary_pos_emb(k_rot, sincos) 284 | 285 | k = jnp.concatenate([k_rot, k_pass], axis=-1) 286 | q = jnp.concatenate([q_rot, q_pass], axis=-1) 287 | 288 | attention_logits = jnp.einsum("thd,Thd->htT", q.astype(jnp.float32), k.astype(jnp.float32)) 289 | # -> [heads, q_len, kv_len] 290 | 291 | sqrt_key_size = np.sqrt(config.hidden_size // config.num_attention_heads).astype(k.dtype) 292 | attention_logits = attention_logits / sqrt_key_size 293 | attention_logits += attn_bias 294 | # -> [heads, q_len, kv_len] 295 | 296 | attention_weights = jax.nn.softmax(attention_logits) 297 | # -> [heads, q_len, kv_len] 298 | 299 | attention_vec = jnp.einsum("htT,Thd->thd", attention_weights, v) 300 | # -> [q_len, heads, dims_per_head] 301 | 302 | attention_vec = attention_vec.reshape(-1, self.dims_per_shard).astype(jnp.float16) 303 | # -> [q_len, hidden] 304 | 305 | attn_out = g_psum(self.output_proj(attention_vec)) 306 | # -> [q_len, hidden] 307 | 308 | return attn_out 309 | 310 | def compute_ff(self, x): 311 | ff_out = self.ff_norm(x) 312 | ff_out = self.ff_up_proj(ff_out) 313 | ff_out = jax.nn.gelu(ff_out) 314 | ff_out = g_psum(self.ff_down_proj(ff_out)) 315 | return ff_out 316 | 317 | # iterate the decoding process by a single token 318 | def decode_once(self, decode_state, x): 319 | """ 320 | :param decode_state: 321 | :param x: [q_len=1, hidden_size] 322 | """ 323 | attn_in = self.attn_norm(x) 324 | # -> [q_len=1, hidden_size] 325 | q, v, k = self.compute_qkv(attn_in) 326 | # -> 3 x [q_len=1, heads, dims_per_head] 327 | 328 | # add new kv to end, clip off the start 329 | v = jnp.concatenate((decode_state["v"], v), axis=0)[1:] 330 | # -> [kv_len+1, heads, dims_per_head] 331 | k = jnp.concatenate((decode_state["k"], k), axis=0)[1:] 332 | # -> [kv_len+1, heads, dims_per_head] 333 | 334 | tokens_decoded = decode_state["tokens_decoded"] + 1 335 | 336 | length = v.shape[0] 337 | masked_tokens = length - tokens_decoded 338 | attention_mask = (jnp.arange(0, length) < masked_tokens)[None, :] 339 | # -> [q_len=1, seq_len] 340 | 341 | bias = (-1e4 * attention_mask) 342 | # -> [q_len=1, seq_len] 343 | 344 | attn_out = self.compute_self_attn(q, v, k, bias) 345 | # -> 3 x [q_len=1, hidden] 346 | 347 | ff_out = self.compute_ff(x) 348 | # -> 3 x [q_len=1, hidden] 349 | 350 | return (attn_out + ff_out), { 351 | "tokens_decoded": tokens_decoded, 352 | "k": k, 353 | "v": v 354 | } 355 | 356 | # take in right aligned context tokens and generate an initial state 357 | def get_init_decode_state(self, x, given_length): 358 | """ 359 | :param x: [batch, seq_len, hidden_size] 360 | :param given_length: [batch] 361 | """ 362 | x = f_psum(x) 363 | attn_in = self.attn_norm(x) 364 | # -> [seq_len, hidden_size] 365 | q, v, k = self.compute_qkv(attn_in) 366 | # -> 3 x [seq_len, heads, dims_per_head] 367 | 368 | full_length = x.shape[0] 369 | masked_tokens = full_length - given_length 370 | 371 | causal_mask = np.tril(np.ones((full_length, full_length))) 372 | # -> [seq_len, seq_len] 373 | 374 | bias = -1e4 * (1. - causal_mask) # regular AR masking 375 | # -> [seq_len, seq_len] 376 | 377 | context_length_mask = (jnp.arange(0, full_length) < masked_tokens)[None, :] 378 | # -> [1, seq_len] 379 | 380 | bias -= 1e4 * context_length_mask # mask out zero tokens before context starts 381 | # -> [seq_len, seq_len] 382 | 383 | attn_out = self.compute_self_attn(q, v, k, bias) 384 | # -> [seq_len, hidden] 385 | 386 | ff_out = self.compute_ff(x) 387 | # -> [seq_len, hidden] 388 | 389 | return (attn_out + ff_out), { 390 | "tokens_decoded": given_length.astype(jnp.uint32), 391 | "k": k, 392 | "v": v, 393 | } 394 | 395 | 396 | class ShardedEmbedOut(nn.Module): 397 | 398 | config: NeoX20BConfig = default_neox20b_config 399 | 400 | # noinspection PyAttributeOutsideInit 401 | def setup(self): 402 | config = self.config 403 | self.vocab_per_shard = config.vocab_size // config.tp_num 404 | self.norm = ReplicatedLayerNorm(epsilon=config.layernorm_epsilon, dtype=jnp.float16) 405 | self.embed_out = nn.Dense( 406 | self.vocab_per_shard, use_bias=False, dtype=jnp.float16, 407 | kernel_init=zero_init_fp16(), 408 | bias_init=zero_init_fp16(), 409 | ) 410 | 411 | def __call__(self, x): 412 | logits = self.predict(x) 413 | logits = jax.lax.all_gather(logits, 'shard') 414 | # Transpose? 415 | return logits 416 | 417 | def predict(self, x): 418 | x = self.norm(x) 419 | out = self.embed_out(x) 420 | return out 421 | 422 | def loss(self, x, targets, z_loss=1): 423 | x = f_psum(x) 424 | logits = self.predict(x) 425 | 426 | shard_start_index = jax.lax.axis_index('shard') * self.dim_per_shard 427 | global_max = jax.lax.pmax(jax.lax.stop_gradient(logits.max(-1, keepdims=True)), "shard") 428 | logits -= jax.lax.stop_gradient(global_max) 429 | 430 | gt_onehot = jax.nn.one_hot(targets - shard_start_index, self.dim_per_shard) 431 | predicted_logits = jnp.sum(jnp.multiply(gt_onehot, logits), axis=-1) 432 | predicted_logits = g_psum(predicted_logits) 433 | 434 | exp_logits = jnp.exp(logits) 435 | 436 | sum_exp_logits = exp_logits.sum(axis=-1) 437 | sum_exp_logits = g_psum(sum_exp_logits) 438 | 439 | loss = jnp.log(sum_exp_logits) - predicted_logits 440 | 441 | loss += (1e-4 * jnp.square(jnp.log(sum_exp_logits)) * z_loss).mean() 442 | 443 | correct = (0.0 == predicted_logits) 444 | 445 | return loss, correct 446 | 447 | 448 | class ReplicatedLayerNorm(nn.Module): 449 | 450 | def __init__(self, epsilon, dtype): 451 | super().__init__() 452 | self.epsilon = epsilon 453 | self.dtype = dtype 454 | 455 | @nn.compact 456 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 457 | mean = jnp.mean(inputs, axis=-1, keepdims=True) 458 | variance = jnp.var(inputs, axis=-1, keepdims=True) 459 | 460 | param_shape = inputs.shape[-1:] 461 | scale = self.param("scale", zeros, param_shape, self.dtype) 462 | scale = jax.lax.all_gather(scale, "shard")[0] 463 | 464 | offset = self.param("bias", zeros, param_shape, self.dtype) 465 | offset = jax.lax.all_gather(offset, "shard")[0] 466 | 467 | scale = jnp.broadcast_to(scale, inputs.shape) 468 | offset = jnp.broadcast_to(offset, inputs.shape) 469 | mean = jnp.broadcast_to(mean, inputs.shape) 470 | 471 | inv = scale * jax.lax.rsqrt(variance + self.epsilon) 472 | return inv * (inputs - mean) + offset 473 | 474 | 475 | def fixed_pos_embedding(x, seq_dim=0): 476 | # x: [seq_len, head, head_dim//4] 477 | dim = x.shape[-1] 478 | inv_freq = 1. / (10000 ** (np.arange(0, dim, 2) / dim)).astype(np.float16) 479 | sinusoid_inp = np.einsum('i , j -> i j', np.arange(x.shape[seq_dim]), inv_freq).astype(np.float16) 480 | return np.sin(sinusoid_inp).astype(np.float16), np.cos(sinusoid_inp).astype(np.float16) 481 | 482 | 483 | def apply_rotary_pos_emb(x, sincos): 484 | # x: [seq_len, head, head_dim//4] 485 | sin, cos = map(lambda t: repeat(t, '... n -> ... (j n)', j=2)[-x.shape[-3]:, None, :], sincos) 486 | # sin: [seq_len, 1, head_dim//4//2] 487 | return (x * cos) + (rotate_half(x) * sin) 488 | 489 | 490 | def rotate_half(x): 491 | # x: [seq_len, head, head_dim//4] 492 | half_dim = x.shape[-1] // 2 493 | x1 = x[:, :, :half_dim] 494 | x2 = x[:, :, half_dim:] 495 | return jnp.concatenate((-x2, x1), axis=-1) 496 | 497 | 498 | def temperature_sample(key, logits, temp=1): 499 | return jax.random.categorical(key, logits/temp, -1).astype(jnp.int32) 500 | 501 | 502 | def zero_init_fp16(): 503 | def zeros_(key, shape, dtype): 504 | return jnp.zeros(shape, jnp.float16) 505 | return zeros_ 506 | -------------------------------------------------------------------------------- /minimal20b_flax/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.interpreters.pxla 3 | import jax.numpy as jnp 4 | from jax.experimental import maps 5 | 6 | 7 | def replicate_to_devices(array, devices=None): 8 | if devices is None: 9 | devices = jax.local_devices() 10 | num_devices = len(devices) 11 | num_dims = array.ndim 12 | device_buffers = [ 13 | jax.device_put(array, device) 14 | for device in devices 15 | ] 16 | sharding = tuple([jax.interpreters.pxla.NoSharding() for _ in range(num_dims)]) 17 | # Assuming mesh=("dp", "tp") 18 | mesh_mapping = ( 19 | jax.interpreters.pxla.Replicated(1), 20 | jax.interpreters.pxla.Replicated(num_devices), 21 | ) 22 | sharding_spec = jax.interpreters.pxla.ShardingSpec(sharding, mesh_mapping) 23 | return jax.interpreters.pxla.make_sharded_device_array( 24 | aval=jax.ShapedArray(array.shape, jnp.float16), 25 | sharding_spec=sharding_spec, 26 | device_buffers=device_buffers, 27 | ) 28 | 29 | 30 | def shard_to_devices(array, axis, devices=None): 31 | if devices is None: 32 | devices = jax.local_devices() 33 | num_devices = len(devices) 34 | num_dims = array.ndim 35 | split_arrays = np.split(array, num_devices, axis=axis) 36 | device_buffers = [ 37 | jax.device_put(split_array, device) 38 | for split_array, device 39 | in zip(split_arrays, devices) 40 | ] 41 | sharding = [jax.interpreters.pxla.NoSharding() for _ in range(num_dims)] 42 | sharding[axis] = jax.interpreters.pxla.Chunked((num_devices,)) 43 | sharding = tuple(sharding) 44 | # Assuming mesh=("dp", "tp") 45 | mesh_mapping = ( 46 | jax.interpreters.pxla.Replicated(1), 47 | jax.interpreters.pxla.ShardedAxis(0), 48 | ) 49 | sharding_spec = jax.interpreters.pxla.ShardingSpec(sharding, mesh_mapping) 50 | return jax.interpreters.pxla.make_sharded_device_array( 51 | aval=jax.ShapedArray(array.shape, jnp.float16), 52 | sharding_spec=sharding_spec, 53 | device_buffers=device_buffers, 54 | ) 55 | 56 | 57 | def split_to_device_buffers(array, axis, devices=None): 58 | if devices is None: 59 | devices = jax.local_devices() 60 | num_devices = len(devices) 61 | split_arrays = np.split(array, num_devices, axis=axis) 62 | device_buffers = [ 63 | jax.device_put(split_array, device) 64 | for split_array, device 65 | in zip(split_arrays, devices) 66 | ] 67 | return device_buffers 68 | 69 | 70 | def wrap_device_buffers_in_sharded_device_array(device_buffers, array_shape, axis, devices=None): 71 | if devices is None: 72 | devices = jax.local_devices() 73 | num_devices = len(devices) 74 | num_dims = len(array_shape) 75 | sharding = [jax.interpreters.pxla.NoSharding() for _ in range(num_dims)] 76 | sharding[axis] = jax.interpreters.pxla.Chunked((num_devices,)) 77 | sharding = tuple(sharding) 78 | mesh_mapping = ( 79 | jax.interpreters.pxla.Replicated(1), 80 | jax.interpreters.pxla.ShardedAxis(0), 81 | ) 82 | sharding_spec = jax.interpreters.pxla.ShardingSpec(sharding, mesh_mapping) 83 | return jax.interpreters.pxla.make_sharded_device_array( 84 | aval=jax.ShapedArray(array_shape, jnp.float16), 85 | sharding_spec=sharding_spec, 86 | device_buffers=device_buffers, 87 | ) 88 | 89 | 90 | def jnp_sharded_zeros(array_shape, axis, devices=None): 91 | if devices is None: 92 | devices = jax.local_devices() 93 | num_devices = len(devices) 94 | buffer_shape = list(array_shape.shape) 95 | buffer_shape[axis] //= num_devices 96 | device_buffers = [ 97 | jax.device_put(jnp.zeros(...), device) 98 | for device in devices 99 | ] 100 | num_dims = len(array_shape) 101 | sharding = [jax.interpreters.pxla.NoSharding() for _ in range(num_dims)] 102 | sharding[axis] = jax.interpreters.pxla.Chunked((num_devices,)) 103 | sharding = tuple(sharding) 104 | mesh_mapping = ( 105 | jax.interpreters.pxla.Replicated(1), 106 | jax.interpreters.pxla.ShardedAxis(0), 107 | ) 108 | sharding_spec = jax.interpreters.pxla.ShardingSpec(sharding, mesh_mapping) 109 | return jax.interpreters.pxla.make_sharded_device_array( 110 | aval=jax.ShapedArray(array_shape, jnp.float16), 111 | sharding_spec=sharding_spec, 112 | device_buffers=device_buffers, 113 | ) 114 | 115 | 116 | def get_default_mesh(): 117 | devices = jax.local_devices() 118 | return maps.Mesh(np.asarray(devices).reshape(1, 8), ('dp', 'tp')) 119 | 120 | 121 | # identity in forward pass, psum in backward 122 | @jax.custom_vjp 123 | def f_psum(x): 124 | return x 125 | 126 | 127 | def f_psum_fwd(x): 128 | return f_psum(x), None 129 | 130 | 131 | def f_psum_bwd(_, g): 132 | return jax.lax.psum(g, "shard"), 133 | 134 | 135 | f_psum.defvjp(f_psum_fwd, f_psum_bwd) 136 | 137 | 138 | # identity in forward pass, pmean in backward 139 | @jax.custom_vjp 140 | def f_pmean(x): 141 | return x 142 | 143 | 144 | def f_pmean_fwd(x): 145 | return f_psum(x), None 146 | 147 | 148 | def f_pmean_bwd(_, g): 149 | return jax.lax.pmean(g, "shard"), 150 | 151 | 152 | f_pmean.defvjp(f_pmean_fwd, f_pmean_bwd) 153 | 154 | 155 | # psum in forward pass, identity in backward 156 | @jax.custom_vjp 157 | def g_psum(x): 158 | return jax.lax.psum(x, "shard") 159 | 160 | 161 | def g_psum_fwd(x): 162 | return g_psum(x), None 163 | 164 | 165 | def g_psum_bwd(_, g): 166 | return g, 167 | 168 | 169 | g_psum.defvjp(g_psum_fwd, g_psum_bwd) 170 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tokenizers==0.12.0 -------------------------------------------------------------------------------- /requirements_flax.txt: -------------------------------------------------------------------------------- 1 | tokenizers==0.12.0 2 | tqdm==4.62.3 3 | flax==0.4.1 4 | einops==0.4.1 -------------------------------------------------------------------------------- /scripts/eval/eval_harness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pprint import pprint 4 | 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from lm_eval.base import CacheHook 10 | from lm_eval.models.gpt2 import GPT2LM 11 | from lm_eval import tasks, evaluator, utils 12 | 13 | import minimal20b 14 | 15 | 16 | class TokenizerWrapper: 17 | def __init__(self, tokenizer): 18 | self.tokenizer = tokenizer 19 | 20 | def encode(self, string: str): 21 | return self.tokenizer.encode(string).ids 22 | 23 | def decode(self, tokens): 24 | return self.tokenizer.decode(tokens) 25 | 26 | 27 | class EvalHarnessAdapter(GPT2LM): 28 | """ 29 | An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks. 30 | """ 31 | 32 | def __init__(self, model, tokenizer): 33 | 34 | # self.device = torch.device(f"cuda:0") 35 | self.device = torch.device("cuda:0") 36 | self.VOCAB_SIZE = minimal20b.Args20b.vocab_size 37 | self.model = model 38 | self.tokenizer = TokenizerWrapper(tokenizer) 39 | self.EOT_TOKEN_ID = 0 40 | self.cache_hook = CacheHook(None) 41 | self.max_length = 2048 42 | self.max_gen_toks = 128 43 | 44 | self.batch_size = 4 45 | 46 | self.full_attention_mask = minimal20b.generate_mask(2048).to(self.device) 47 | 48 | def greedy_until(self, requests): 49 | raise NotImplementedError() 50 | 51 | def _loglikelihood_tokens(self, requests, disable_tqdm=False): 52 | res = [] 53 | res_len = 0 # storing the result length for later 54 | with torch.no_grad(): 55 | 56 | def _collate(x): 57 | toks = x[1] + x[2] 58 | return -len(toks), tuple(toks) 59 | 60 | reord = utils.Reorderer(requests, _collate) 61 | for chunk in utils.chunks( 62 | tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size 63 | ): 64 | inps, contlens, inplens, padding_length = [], [], [], None 65 | for _, context_enc, continuation_enc in chunk: 66 | # when too long to fit in context, truncate from the left 67 | inp = torch.tensor( 68 | (context_enc + continuation_enc)[-(self.max_length + 1):][:-1], 69 | dtype=torch.long, 70 | ).to(self.device) 71 | (inplen,) = inp.shape 72 | 73 | cont = continuation_enc 74 | 75 | # since in _collate we make sure length is descending, the longest is always the first one. 76 | padding_length = ( 77 | padding_length if padding_length is not None else inplen 78 | ) 79 | 80 | # pad to length 81 | inp = torch.cat( 82 | [ 83 | inp, # [seq] 84 | torch.zeros(padding_length - inplen, dtype=torch.long).to( 85 | inp.device 86 | ), # [padding_length - seq] 87 | ], 88 | dim=0, 89 | ) 90 | 91 | inps.append(inp.unsqueeze(0)) 92 | contlens.append(cont) 93 | inplens.append(inplen) 94 | 95 | logits = self._model_call(torch.cat(inps, dim=0)) 96 | res_len += len(chunk) 97 | 98 | if logits is not None: 99 | multi_logits = F.log_softmax(logits, dim=-1) # [batch, seq, vocab] 100 | for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( 101 | chunk, multi_logits, inps, inplens, contlens 102 | ): 103 | contlen = len(cont_toks) 104 | logits = logits[inplen - contlen:inplen].unsqueeze( 105 | 0 106 | ) # [1, seq, vocab] 107 | greedy_tokens = logits.argmax(dim=-1) 108 | # cont_toks :: [1, seq] 109 | cont_toks = ( 110 | torch.tensor(cont_toks, dtype=torch.long) 111 | .unsqueeze(0) 112 | .to(multi_logits.device) 113 | ) 114 | # noinspection PyUnresolvedReferences 115 | max_equal = (greedy_tokens == cont_toks).all() 116 | logits = torch.gather( 117 | logits, 2, cont_toks.unsqueeze(-1) 118 | ).squeeze( 119 | -1 120 | ) # [1, seq] 121 | answer = (float(logits.sum()), bool(max_equal)) 122 | res.append(answer) 123 | 124 | return reord.get_original(res) 125 | 126 | def _model_call(self, inps): 127 | length = inps.shape[1] 128 | model_out = self.model( 129 | inps.to(self.device), 130 | attention_mask=self.full_attention_mask[..., :length, :length], 131 | ) 132 | return model_out 133 | 134 | @torch.no_grad() 135 | def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2): 136 | self.model.eval() 137 | results = evaluator.evaluate( 138 | lm=self, 139 | task_dict=tasks.get_task_dict(eval_tasks), 140 | provide_description=False, 141 | num_fewshot=num_fewshot, 142 | limit=None, 143 | bootstrap_iters=bootstrap_iters, 144 | ) 145 | return results 146 | 147 | 148 | def main(): 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('--model_path', type=str, required=True) 151 | parser.add_argument('--tokenizer_path', type=str, required=True) 152 | parser.add_argument('--tasks', type=str, required=True) 153 | parser.add_argument('--output_path', type=str, default=None) 154 | args = parser.parse_args() 155 | model = minimal20b.create_model(args.model_path) 156 | tokenizer = minimal20b.create_tokenizer(args.tokenizer_path) 157 | adapter = EvalHarnessAdapter(model, tokenizer) 158 | print("Running evaluation harness...") 159 | results = adapter.run_eval( 160 | eval_tasks=args.tasks.split(","), 161 | bootstrap_iters=10000, 162 | ) 163 | pprint(results) 164 | if args.output_path: 165 | with open(args.output_path, "w") as f: 166 | f.write(json.dumps(results, indent=2)) 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /scripts/eval/requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/EleutherAI/lm-evaluation-harness.git@dc937d4b70af819c5695e09d94e59e4cdb1e40ad#egg=lm_eval -------------------------------------------------------------------------------- /scripts/eval_flax/eval_harness.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.experimental import maps 3 | import jax.numpy as jnp 4 | import numpy as np 5 | 6 | import argparse 7 | import json 8 | from pprint import pprint 9 | 10 | from tqdm import tqdm 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | from lm_eval.base import CacheHook 15 | from lm_eval.models.gpt2 import GPT2LM 16 | from lm_eval import tasks, evaluator, utils 17 | 18 | import minimal20b_flax.model as model 19 | import minimal20b_flax.create as create 20 | 21 | 22 | class TokenizerWrapper: 23 | def __init__(self, tokenizer): 24 | self.tokenizer = tokenizer 25 | 26 | def encode(self, string: str): 27 | return self.tokenizer.encode(string).ids 28 | 29 | def decode(self, tokens): 30 | return self.tokenizer.decode(tokens) 31 | 32 | 33 | class EvalHarnessAdapter(GPT2LM): 34 | """ 35 | An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks. 36 | """ 37 | 38 | def __init__(self, weights, mesh, tokenizer): 39 | 40 | config = model.NeoX20BConfig() 41 | self.VOCAB_SIZE = config.vocab_size 42 | self.weights = weights 43 | self.mesh = mesh 44 | self.tokenizer = TokenizerWrapper(tokenizer) 45 | self.EOT_TOKEN_ID = 0 46 | self.cache_hook = CacheHook(None) 47 | self.max_length = 2048 48 | self.max_gen_toks = 128 49 | 50 | self.batch_size = 4 51 | 52 | self.eval_apply_fn_pjit = model.GPTNeoX20BModel().eval_apply_fn_pjit() 53 | 54 | def greedy_until(self, requests): 55 | raise NotImplementedError() 56 | 57 | def _loglikelihood_tokens(self, requests, disable_tqdm=False): 58 | res = [] 59 | res_len = 0 # storing the result length for later 60 | with torch.no_grad(): 61 | 62 | def _collate(x): 63 | toks = x[1] + x[2] 64 | return -len(toks), tuple(toks) 65 | 66 | reord = utils.Reorderer(requests, _collate) 67 | for chunk in utils.chunks( 68 | tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size 69 | ): 70 | inps, contlens, inplens, padding_length = [], [], [], None 71 | for _, context_enc, continuation_enc in chunk: 72 | # when too long to fit in context, truncate from the left 73 | inp = torch.tensor( 74 | (context_enc + continuation_enc)[-(self.max_length + 1):][:-1], 75 | dtype=torch.long, 76 | ) 77 | (inplen,) = inp.shape 78 | 79 | cont = continuation_enc 80 | 81 | # since in _collate we make sure length is descending, the longest is always the first one. 82 | padding_length = ( 83 | padding_length if padding_length is not None else inplen 84 | ) 85 | 86 | # pad to length 87 | inp = torch.cat( 88 | [ 89 | inp, # [seq] 90 | torch.zeros(padding_length - inplen, dtype=torch.long).to( 91 | inp.device 92 | ), # [padding_length - seq] 93 | ], 94 | dim=0, 95 | ) 96 | 97 | inps.append(inp.unsqueeze(0)) 98 | contlens.append(cont) 99 | inplens.append(inplen) 100 | 101 | logits = self._model_call(torch.cat(inps, dim=0)) 102 | res_len += len(chunk) 103 | 104 | if logits is not None: 105 | multi_logits = F.log_softmax(logits, dim=-1) # [batch, seq, vocab] 106 | for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( 107 | chunk, multi_logits, inps, inplens, contlens 108 | ): 109 | contlen = len(cont_toks) 110 | logits = logits[inplen - contlen:inplen].unsqueeze( 111 | 0 112 | ) # [1, seq, vocab] 113 | greedy_tokens = logits.argmax(dim=-1) 114 | # cont_toks :: [1, seq] 115 | cont_toks = ( 116 | torch.tensor(cont_toks, dtype=torch.long) 117 | .unsqueeze(0) 118 | .to(multi_logits.device) 119 | ) 120 | # noinspection PyUnresolvedReferences 121 | max_equal = (greedy_tokens == cont_toks).all() 122 | logits = torch.gather( 123 | logits, 2, cont_toks.unsqueeze(-1) 124 | ).squeeze( 125 | -1 126 | ) # [1, seq] 127 | answer = (float(logits.sum()), bool(max_equal)) 128 | res.append(answer) 129 | 130 | return reord.get_original(res) 131 | 132 | def _model_call(self, inps): 133 | 134 | length = inps.shape[1] 135 | max_length = 2048 136 | for candidate_length in [1024, 512, 256, 128, 64]: 137 | if length > candidate_length: 138 | break 139 | max_length = candidate_length 140 | 141 | assert max_length >= length 142 | mask = jnp.zeros([1, max_length, max_length]) 143 | 144 | inps_arr = inps.numpy() 145 | padded_inps = np.concatenate([ 146 | inps_arr, 147 | np.zeros((inps_arr.shape[0], max_length - length), dtype=np.int32), 148 | ], axis=1) 149 | 150 | with maps.mesh(self.mesh.devices, self.mesh.axis_names): 151 | logits = self.eval_apply_fn_pjit( 152 | self.weights, 153 | padded_inps, 154 | mask, 155 | ) 156 | logits = logits[:, :length, :] 157 | logits = torch.tensor(np.array(logits), dtype=torch.float32) 158 | 159 | return logits 160 | 161 | @torch.no_grad() 162 | def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2): 163 | results = evaluator.evaluate( 164 | lm=self, 165 | task_dict=tasks.get_task_dict(eval_tasks), 166 | provide_description=False, 167 | num_fewshot=num_fewshot, 168 | limit=None, 169 | bootstrap_iters=bootstrap_iters, 170 | ) 171 | return results 172 | 173 | 174 | def main(): 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--model_path', type=str, required=True) 177 | parser.add_argument('--tokenizer_path', type=str, required=True) 178 | parser.add_argument('--tasks', type=str, required=True) 179 | parser.add_argument('--output_path', type=str, default=None) 180 | args = parser.parse_args() 181 | 182 | # Set up mesh for TPU 183 | devices = jax.local_devices() 184 | mesh = maps.Mesh(np.asarray(devices).reshape(1, 8), ('dp', 'tp')) 185 | 186 | tokenizer = create.create_tokenizer(args.tokenizer_path) 187 | weights = create.load_model_weights(args.model_path) 188 | 189 | adapter = EvalHarnessAdapter(weights=weights, mesh=mesh, tokenizer=tokenizer) 190 | print("Running evaluation harness...") 191 | results = adapter.run_eval( 192 | eval_tasks=args.tasks.split(","), 193 | bootstrap_iters=10000, 194 | ) 195 | pprint(results) 196 | if args.output_path: 197 | with open(args.output_path, "w") as f: 198 | f.write(json.dumps(results, indent=2)) 199 | 200 | 201 | if __name__ == "__main__": 202 | main() 203 | -------------------------------------------------------------------------------- /scripts/eval_flax/eval_harness_xmap.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.experimental import maps 3 | import jax.numpy as jnp 4 | import numpy as np 5 | from functools import partial 6 | 7 | import argparse 8 | import json 9 | from pprint import pprint 10 | 11 | from tqdm import tqdm 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from lm_eval.base import CacheHook 16 | from lm_eval.models.gpt2 import GPT2LM 17 | from lm_eval import tasks, evaluator, utils 18 | 19 | import minimal20b_flax.model_xmap as model_xmap 20 | import minimal20b_flax.create as create 21 | 22 | 23 | class TokenizerWrapper: 24 | def __init__(self, tokenizer): 25 | self.tokenizer = tokenizer 26 | 27 | def encode(self, string: str): 28 | return self.tokenizer.encode(string).ids 29 | 30 | def decode(self, tokens): 31 | return self.tokenizer.decode(tokens) 32 | 33 | 34 | class EvalHarnessAdapter(GPT2LM): 35 | """ 36 | An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks. 37 | """ 38 | 39 | def __init__(self, weights, mesh, tokenizer, batch_size): 40 | 41 | config = model_xmap.NeoX20BConfig() 42 | self.VOCAB_SIZE = config.vocab_size 43 | self.weights = weights 44 | self.mesh = mesh 45 | self.tokenizer = TokenizerWrapper(tokenizer) 46 | self.EOT_TOKEN_ID = 0 47 | self.cache_hook = CacheHook(None) 48 | self.max_length = 2048 49 | self.max_gen_toks = 128 50 | 51 | self.batch_size = batch_size 52 | 53 | neox_model = model_xmap.GPTNeoX20BModel(config=config) 54 | self.eval_apply_fn_xmap = jax.experimental.maps.xmap( 55 | neox_model.get_batch_eval_fn(), 56 | in_axes=( 57 | ["shard", ...], 58 | [...], 59 | [...], 60 | ), 61 | out_axes=[...], 62 | axis_resources={'shard': 'tp', 'batch': 'dp'}, 63 | ) 64 | 65 | def greedy_until(self, requests): 66 | raise NotImplementedError() 67 | 68 | def _loglikelihood_tokens(self, requests, disable_tqdm=False): 69 | res = [] 70 | res_len = 0 # storing the result length for later 71 | with torch.no_grad(): 72 | 73 | def _collate(x): 74 | toks = x[1] + x[2] 75 | return -len(toks), tuple(toks) 76 | 77 | reord = utils.Reorderer(requests, _collate) 78 | for chunk in utils.chunks( 79 | tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size 80 | ): 81 | inps, contlens, inplens, padding_length = [], [], [], None 82 | for _, context_enc, continuation_enc in chunk: 83 | # when too long to fit in context, truncate from the left 84 | inp = torch.tensor( 85 | (context_enc + continuation_enc)[-(self.max_length + 1):][:-1], 86 | dtype=torch.long, 87 | ) 88 | (inplen,) = inp.shape 89 | 90 | cont = continuation_enc 91 | 92 | # since in _collate we make sure length is descending, the longest is always the first one. 93 | padding_length = ( 94 | padding_length if padding_length is not None else inplen 95 | ) 96 | 97 | # pad to length 98 | inp = torch.cat( 99 | [ 100 | inp, # [seq] 101 | torch.zeros(padding_length - inplen, dtype=torch.long).to( 102 | inp.device 103 | ), # [padding_length - seq] 104 | ], 105 | dim=0, 106 | ) 107 | 108 | inps.append(inp.unsqueeze(0)) 109 | contlens.append(cont) 110 | inplens.append(inplen) 111 | 112 | logits = self._model_call(torch.cat(inps, dim=0)) 113 | res_len += len(chunk) 114 | 115 | if logits is not None: 116 | multi_logits = F.log_softmax(logits, dim=-1) # [batch, seq, vocab] 117 | for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( 118 | chunk, multi_logits, inps, inplens, contlens 119 | ): 120 | contlen = len(cont_toks) 121 | logits = logits[inplen - contlen:inplen].unsqueeze( 122 | 0 123 | ) # [1, seq, vocab] 124 | greedy_tokens = logits.argmax(dim=-1) 125 | # cont_toks :: [1, seq] 126 | cont_toks = ( 127 | torch.tensor(cont_toks, dtype=torch.long) 128 | .unsqueeze(0) 129 | .to(multi_logits.device) 130 | ) 131 | # noinspection PyUnresolvedReferences 132 | max_equal = (greedy_tokens == cont_toks).all() 133 | logits = torch.gather( 134 | logits, 2, cont_toks.unsqueeze(-1) 135 | ).squeeze( 136 | -1 137 | ) # [1, seq] 138 | answer = (float(logits.sum()), bool(max_equal)) 139 | res.append(answer) 140 | 141 | return reord.get_original(res) 142 | 143 | def _model_call(self, inps): 144 | 145 | length = inps.shape[1] 146 | max_length = 2048 147 | for candidate_length in [1024, 512, 256, 128, 64]: 148 | if length > candidate_length: 149 | break 150 | max_length = candidate_length 151 | 152 | assert max_length >= length 153 | mask = jnp.zeros([inps.shape[0], max_length, max_length]) 154 | 155 | inps_arr = inps.numpy() 156 | padded_inps = np.concatenate([ 157 | inps_arr, 158 | np.zeros((inps_arr.shape[0], max_length - length), dtype=np.int32), 159 | ], axis=1) 160 | with maps.mesh(self.mesh.devices, self.mesh.axis_names): 161 | logits = self.eval_apply_fn_xmap( 162 | self.weights, 163 | padded_inps, 164 | mask, 165 | ) 166 | logits = logits.swapaxes(1, 2).reshape(-1, max_length, self.VOCAB_SIZE) 167 | logits = torch.tensor(np.array(logits), dtype=torch.float32) 168 | 169 | return logits 170 | 171 | @torch.no_grad() 172 | def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2): 173 | results = evaluator.evaluate( 174 | lm=self, 175 | task_dict=tasks.get_task_dict(eval_tasks), 176 | provide_description=False, 177 | num_fewshot=num_fewshot, 178 | limit=None, 179 | bootstrap_iters=bootstrap_iters, 180 | ) 181 | return results 182 | 183 | 184 | def main(): 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument('--model_path', type=str, required=True) 187 | parser.add_argument('--tokenizer_path', type=str, required=True) 188 | parser.add_argument('--tasks', type=str, required=True) 189 | parser.add_argument('--output_path', type=str, default=None) 190 | parser.add_argument('--batch_size', type=int, default=4) 191 | parser.add_argument('--pool_size', type=int, default=None) 192 | args = parser.parse_args() 193 | 194 | # Set up mesh for TPU 195 | devices = jax.local_devices() 196 | mesh = maps.Mesh(np.asarray(devices).reshape(1, 8), ('dp', 'tp')) 197 | 198 | tokenizer = create.create_tokenizer(args.tokenizer_path) 199 | weights = create.load_model_weights_for_xmap(args.model_path, pool_size=args.pool_size) 200 | 201 | adapter = EvalHarnessAdapter( 202 | weights=weights, 203 | mesh=mesh, 204 | tokenizer=tokenizer, 205 | batch_size=args.batch_size, 206 | ) 207 | print("Running evaluation harness...") 208 | results = adapter.run_eval( 209 | eval_tasks=args.tasks.split(","), 210 | bootstrap_iters=10000, 211 | ) 212 | pprint(results) 213 | if args.output_path: 214 | with open(args.output_path, "w") as f: 215 | f.write(json.dumps(results, indent=2)) 216 | 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /scripts/eval_flax/requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/EleutherAI/lm-evaluation-harness.git@dc937d4b70af819c5695e09d94e59e4cdb1e40ad#egg=lm_eval --------------------------------------------------------------------------------