├── .github └── workflows │ ├── lint.yml │ ├── mypy.yml │ └── testing.yml ├── .gitignore ├── .isort.cfg ├── LICENSE ├── README.md ├── csrc └── paged_attention │ ├── README.md │ ├── attention │ ├── attention_dtypes.h │ ├── attention_generic.cuh │ ├── attention_kernels.cu │ ├── attention_utils.cuh │ ├── dtype_bfloat16.cuh │ ├── dtype_float16.cuh │ └── dtype_float32.cuh │ ├── cache.h │ ├── cache_kernels.cu │ ├── cuda_compat.h │ ├── cuda_utils.h │ ├── cuda_utils_kernels.cu │ ├── dispatch_utils.h │ ├── ops.h │ ├── pybind.cpp │ └── reduction_utils.cuh ├── fms-extras-requirements.txt ├── fms_extras ├── __init__.py ├── models │ ├── __init__.py │ ├── calico.py │ ├── hf │ │ ├── __init__.py │ │ ├── modeling_calico_hf.py │ │ └── modeling_mlp_speculator.py │ ├── paged_gpt_bigcode.py │ ├── paged_llama.py │ └── speculator.py ├── modules │ ├── __init__.py │ └── attention.py └── utils │ ├── __init__.py │ ├── cache │ ├── __init__.py │ └── paged.py │ └── generation.py ├── pyproject.toml ├── requirements-build.txt ├── requirements.txt ├── scripts └── paged_speculative_inference.py ├── setup.py ├── test-requirements.txt └── tests ├── conftest.py ├── models ├── __init__.py ├── hf │ ├── __init__.py │ └── test_mlp_speculator.py ├── hf_equivalence │ ├── __init__.py │ └── test_calico.py ├── test_calico.py ├── test_paged_llama.py └── test_speculator.py ├── resources └── expectations │ ├── models.test_calico.TestCalico.test_model_output │ └── models.test_calico.TestCalico.test_model_weight_keys └── utils ├── __init__.py ├── cache ├── __init__.py └── test_paged.py └── test_generation.py /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - uses: psf/black@stable 11 | with: 12 | options: "--check --diff --color" 13 | src: "." 14 | version: "~= 23.3.0" 15 | - uses: isort/isort-action@master 16 | with: 17 | sort-paths: fms_extras 18 | 19 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies and run MyPy 2 | 3 | name: MyPy Type Checking 4 | 5 | on: [pull_request] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 3.10 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: "3.10" 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -r test-requirements.txt 25 | 26 | - name: Test with pytest 27 | run: | 28 | mypy --exclude hf --exclude testing fms_extras 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies and run tests 2 | 3 | name: FMS Extras Testing 4 | 5 | on: [pull_request] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 3.10 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: "3.10" 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install pytest numpy 25 | pip install transformers==4.40.2 26 | pip install safetensors 27 | pip install ibm-fms 28 | pip install . 29 | 30 | - name: Test with pytest 31 | run: | 32 | pytest -vv -rP tests/ 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .DS_Store -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | ensure_newline_before_comments = True 3 | force_grid_wrap = 0 4 | include_trailing_comma = True 5 | lines_after_imports = 2 6 | multi_line_output = 3 7 | use_parentheses = True 8 | profile = black 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fms-extras 2 | 3 | This is a repo as part of the foundation-model-stack organization which is used for new features staged to be integrated 4 | with [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack). This repo is the home 5 | for extensions, research and/or in-development work, and fms-based models trained by IBM. 6 | 7 | ## Installation 8 | 9 | ### Local 10 | 11 | ```bash 12 | pip install -e . 13 | ``` 14 | 15 | ## Notable Features 16 | 17 | 1. `MLPSpeculator`: a lightweight speculator model that can be used along-side a generative model to speed up inference (currently deployed in IBM TGIS with training in [fms-fsdp](https://github.com/foundation-model-stack/fms-fsdp)) 18 | 2. `PagedKVCacheManager`: an implementation of kv-cache management that provides a user with the proper input to use paged-attention with their own models (currently deployed in IBM TGIS) 19 | 3. `PagedLLaMA`: a LLaMA implementation that uses paged-attention in Multi-Head Attention. This model is compilable without graph breaks. 20 | 4. `speculative generation`: a reference implementation of speculative generate using PagedKVCacheManager and MLPSpeculator 21 | 22 | ## Structure and contents of this Repository 23 | 24 | This repo follows a similar structure to that of [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) 25 | 26 | * `fms_extras/models/` - Pure pytorch implementations of popular model architectures, without requiring any specific common interface beyond `nn.Module`. Each model configuration is registered with `fms.models.register_model()` so that instances can be obtained through `fms.models.get_model('architecture', 'variant', '/path/to/data')`. Each model can also register sources/formats/versions of data to load (e.g. checkpoints provided by meta, HF, or trained from this repo). 27 | * `fms_extras/models/hf/` - Adapters that compose our native PyTorch FMS model architecture implementations in HF-compatible wrapper interfaces. Each FMS model implements an adapter, and adapted instances are obtained via `fms.models.hf.to_hf_api(model)` 28 | * `fms_extras/utils/` - Other operators useful in working with LLMs. These include a `speculative_generate()` function, `PagedKVCacheManager` class for easy-to-use kv-cache management with paged attention kernels, etc. 29 | * `scripts/` - Various scripts for inference (paged generation and speculative generation) 30 | * `csrc/` - Custom kernels used in fms-extra, currently related to paged-attention 31 | 32 | ## References 33 | 34 | - Huggingface TGI: https://github.com/huggingface/text-generation-inference 35 | - IBM TGIS: https://github.com/IBM/text-generation-inference 36 | -------------------------------------------------------------------------------- /csrc/paged_attention/README.md: -------------------------------------------------------------------------------- 1 | Note: Current version of paged attention kernels adapted from https://github.com/vllm-project/vllm 0.2.7 2 | 3 | For any changes from vLLM, please mark with `//`(start) and `//<\fms>`(end) and explain the changes in this README. -------------------------------------------------------------------------------- /csrc/paged_attention/attention/attention_dtypes.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "attention_generic.cuh" 4 | #include "dtype_float16.cuh" 5 | #include "dtype_float32.cuh" 6 | #include "dtype_bfloat16.cuh" -------------------------------------------------------------------------------- /csrc/paged_attention/attention/attention_generic.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h 3 | * Copyright (c) 2023, The vLLM team. 4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include 21 | 22 | namespace vllm { 23 | 24 | // A vector type to store Q, K, V elements. 25 | template 26 | struct Vec {}; 27 | 28 | // A vector type to store FP32 accumulators. 29 | template 30 | struct FloatVec {}; 31 | 32 | // Template vector operations. 33 | template 34 | inline __device__ Acc mul(A a, B b); 35 | 36 | template 37 | inline __device__ float sum(T v); 38 | 39 | template 40 | inline __device__ float dot(T a, T b) { 41 | return sum(mul(a, b)); 42 | } 43 | 44 | template 45 | inline __device__ float dot(T a, T b) { 46 | return sum(mul(a, b)); 47 | } 48 | 49 | template 50 | inline __device__ void zero(T& dst) { 51 | constexpr int WORDS = sizeof(T) / 4; 52 | union { 53 | T raw; 54 | uint32_t words[WORDS]; 55 | } tmp; 56 | 57 | #pragma unroll 58 | for (int ii = 0; ii < WORDS; ++ii) { 59 | tmp.words[ii] = 0u; 60 | } 61 | dst = tmp.raw; 62 | } 63 | 64 | } // namespace vllm -------------------------------------------------------------------------------- /csrc/paged_attention/attention/attention_utils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp 3 | * Copyright (c) 2023, The vLLM team. 4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | #include "../cuda_compat.h" 21 | #include "attention_dtypes.h" 22 | 23 | #include 24 | #include 25 | 26 | namespace vllm { 27 | 28 | // Q*K^T operation. 29 | template 30 | inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { 31 | using A_vec = typename FloatVec::Type; 32 | // Compute the parallel products for Q*K^T (treat vector lanes separately). 33 | A_vec qk_vec = mul(q[0], k[0]); 34 | #pragma unroll 35 | for (int ii = 1; ii < N; ++ii) { 36 | qk_vec = fma(q[ii], k[ii], qk_vec); 37 | } 38 | 39 | // Finalize the reduction across lanes. 40 | float qk = sum(qk_vec); 41 | #pragma unroll 42 | for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { 43 | qk += VLLM_SHFL_XOR_SYNC(qk, mask); 44 | } 45 | return qk; 46 | } 47 | 48 | template 49 | struct Qk_dot { 50 | template 51 | static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { 52 | return qk_dot_(q, k); 53 | } 54 | }; 55 | 56 | } // namespace vllm 57 | -------------------------------------------------------------------------------- /csrc/paged_attention/attention/dtype_bfloat16.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp 3 | * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h 4 | * Copyright (c) 2023, The vLLM team. 5 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 6 | * 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | #pragma once 20 | 21 | #include "attention_generic.cuh" 22 | #include "dtype_float32.cuh" 23 | 24 | #ifndef USE_ROCM 25 | #include 26 | #include 27 | #else 28 | #include 29 | #include 30 | 31 | typedef __hip_bfloat162 __nv_bfloat162; 32 | typedef __hip_bfloat16 __nv_bfloat16; 33 | #endif 34 | 35 | #include 36 | 37 | namespace vllm { 38 | 39 | // Define custom BF16 vector data types. 40 | struct bf16_4_t { 41 | __nv_bfloat162 x; 42 | __nv_bfloat162 y; 43 | }; 44 | 45 | struct bf16_8_t { 46 | __nv_bfloat162 x; 47 | __nv_bfloat162 y; 48 | __nv_bfloat162 z; 49 | __nv_bfloat162 w; 50 | }; 51 | 52 | // BF16 vector types for Q, K, V. 53 | template<> 54 | struct Vec<__nv_bfloat16, 1> { 55 | using Type = __nv_bfloat16; 56 | }; 57 | template<> 58 | struct Vec<__nv_bfloat16, 2> { 59 | using Type = __nv_bfloat162; 60 | }; 61 | template<> 62 | struct Vec<__nv_bfloat16, 4> { 63 | using Type = bf16_4_t; 64 | }; 65 | template<> 66 | struct Vec<__nv_bfloat16, 8> { 67 | using Type = bf16_8_t; 68 | }; 69 | 70 | // FP32 accumulator vector types corresponding to Vec. 71 | template<> 72 | struct FloatVec<__nv_bfloat16> { 73 | using Type = float; 74 | }; 75 | template<> 76 | struct FloatVec<__nv_bfloat162> { 77 | using Type = float2; 78 | }; 79 | template<> 80 | struct FloatVec { 81 | using Type = Float4_; 82 | }; 83 | template<> 84 | struct FloatVec { 85 | using Type = Float8_; 86 | }; 87 | 88 | // Utility functions for type conversions. 89 | inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { 90 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 91 | assert(false); 92 | #else 93 | return __bfloat1622float2(val); 94 | #endif 95 | } 96 | 97 | inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { 98 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 99 | assert(false); 100 | #else 101 | return __bfloat162bfloat162(val); 102 | #endif 103 | } 104 | 105 | // Vector addition. 106 | inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { 107 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 108 | assert(false); 109 | #else 110 | #ifndef USE_ROCM 111 | return a + b; 112 | #else 113 | return __hadd(a, b); 114 | #endif 115 | #endif 116 | } 117 | 118 | inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { 119 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 120 | assert(false); 121 | #else 122 | return __hadd2(a, b); 123 | #endif 124 | } 125 | 126 | inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { 127 | bf16_4_t c; 128 | c.x = add(a.x, b.x); 129 | c.y = add(a.y, b.y); 130 | return c; 131 | } 132 | 133 | inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { 134 | bf16_8_t c; 135 | c.x = add(a.x, b.x); 136 | c.y = add(a.y, b.y); 137 | c.z = add(a.z, b.z); 138 | c.w = add(a.w, b.w); 139 | return c; 140 | } 141 | 142 | inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { 143 | float2 fa = bf1622float2(a); 144 | return add(fa, fb); 145 | } 146 | 147 | inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { 148 | Float4_ fc; 149 | fc.x = add(a.x, fb.x); 150 | fc.y = add(a.y, fb.y); 151 | return fc; 152 | } 153 | 154 | inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { 155 | Float8_ fc; 156 | fc.x = add(a.x, fb.x); 157 | fc.y = add(a.y, fb.y); 158 | fc.z = add(a.z, fb.z); 159 | fc.w = add(a.w, fb.w); 160 | return fc; 161 | } 162 | 163 | // Vector multiplication. 164 | template<> 165 | inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { 166 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 167 | assert(false); 168 | #else 169 | return __hmul(a, b); 170 | #endif 171 | } 172 | 173 | template<> 174 | inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { 175 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 176 | assert(false); 177 | #else 178 | return __hmul2(a, b); 179 | #endif 180 | } 181 | 182 | template<> 183 | inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { 184 | return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); 185 | } 186 | 187 | template<> 188 | inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { 189 | bf16_4_t c; 190 | c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); 191 | c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); 192 | return c; 193 | } 194 | 195 | template<> 196 | inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { 197 | __nv_bfloat162 s = bf162bf162(a); 198 | bf16_4_t c; 199 | c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); 200 | c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); 201 | return c; 202 | } 203 | 204 | template<> 205 | inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { 206 | bf16_8_t c; 207 | c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); 208 | c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); 209 | c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); 210 | c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); 211 | return c; 212 | } 213 | 214 | template<> 215 | inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { 216 | __nv_bfloat162 s = bf162bf162(a); 217 | bf16_8_t c; 218 | c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); 219 | c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); 220 | c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); 221 | c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); 222 | return c; 223 | } 224 | 225 | template<> 226 | inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { 227 | float fa = __bfloat162float(a); 228 | float fb = __bfloat162float(b); 229 | return fa * fb; 230 | } 231 | 232 | template<> 233 | inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { 234 | float2 fa = bf1622float2(a); 235 | float2 fb = bf1622float2(b); 236 | return mul(fa, fb); 237 | } 238 | 239 | template<> 240 | inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { 241 | return mul(bf162bf162(a), b); 242 | } 243 | 244 | template<> 245 | inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { 246 | Float4_ fc; 247 | fc.x = mul(a.x, b.x); 248 | fc.y = mul(a.y, b.y); 249 | return fc; 250 | } 251 | 252 | template<> 253 | inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { 254 | __nv_bfloat162 s = bf162bf162(a); 255 | Float4_ fc; 256 | fc.x = mul(s, b.x); 257 | fc.y = mul(s, b.y); 258 | return fc; 259 | } 260 | 261 | template<> 262 | inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { 263 | Float8_ fc; 264 | fc.x = mul(a.x, b.x); 265 | fc.y = mul(a.y, b.y); 266 | fc.z = mul(a.z, b.z); 267 | fc.w = mul(a.w, b.w); 268 | return fc; 269 | } 270 | 271 | template<> 272 | inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { 273 | __nv_bfloat162 s = bf162bf162(a); 274 | Float8_ fc; 275 | fc.x = mul(s, b.x); 276 | fc.y = mul(s, b.y); 277 | fc.z = mul(s, b.z); 278 | fc.w = mul(s, b.w); 279 | return fc; 280 | } 281 | 282 | // Vector fused multiply-add. 283 | inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { 284 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 285 | assert(false); 286 | #else 287 | return __hfma2(a, b, c); 288 | #endif 289 | } 290 | 291 | inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { 292 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 293 | assert(false); 294 | #else 295 | return __hfma2(bf162bf162(a), b, c); 296 | #endif 297 | } 298 | 299 | inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { 300 | bf16_4_t d; 301 | d.x = fma(a.x, b.x, c.x); 302 | d.y = fma(a.y, b.y, c.y); 303 | return d; 304 | } 305 | 306 | inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { 307 | __nv_bfloat162 s = bf162bf162(a); 308 | bf16_4_t d; 309 | d.x = fma(s, b.x, c.x); 310 | d.y = fma(s, b.y, c.y); 311 | return d; 312 | } 313 | 314 | inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { 315 | bf16_8_t d; 316 | d.x = fma(a.x, b.x, c.x); 317 | d.y = fma(a.y, b.y, c.y); 318 | d.z = fma(a.z, b.z, c.z); 319 | d.w = fma(a.w, b.w, c.w); 320 | return d; 321 | } 322 | 323 | inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { 324 | __nv_bfloat162 s = bf162bf162(a); 325 | bf16_8_t d; 326 | d.x = fma(s, b.x, c.x); 327 | d.y = fma(s, b.y, c.y); 328 | d.z = fma(s, b.z, c.z); 329 | d.w = fma(s, b.w, c.w); 330 | return d; 331 | } 332 | 333 | inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { 334 | return __bfloat162float(a) * __bfloat162float(b) + fc; 335 | } 336 | 337 | inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { 338 | float2 fa = bf1622float2(a); 339 | float2 fb = bf1622float2(b); 340 | return fma(fa, fb, fc); 341 | } 342 | 343 | inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { 344 | return fma(bf162bf162(a), b, fc); 345 | } 346 | 347 | inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { 348 | Float4_ fd; 349 | fd.x = fma(a.x, b.x, fc.x); 350 | fd.y = fma(a.y, b.y, fc.y); 351 | return fd; 352 | } 353 | 354 | inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { 355 | __nv_bfloat162 s = bf162bf162(a); 356 | Float4_ fd; 357 | fd.x = fma(s, b.x, fc.x); 358 | fd.y = fma(s, b.y, fc.y); 359 | return fd; 360 | } 361 | 362 | inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { 363 | Float8_ fd; 364 | fd.x = fma(a.x, b.x, fc.x); 365 | fd.y = fma(a.y, b.y, fc.y); 366 | fd.z = fma(a.z, b.z, fc.z); 367 | fd.w = fma(a.w, b.w, fc.w); 368 | return fd; 369 | } 370 | 371 | inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { 372 | __nv_bfloat162 s = bf162bf162(a); 373 | Float8_ fd; 374 | fd.x = fma(s, b.x, fc.x); 375 | fd.y = fma(s, b.y, fc.y); 376 | fd.z = fma(s, b.z, fc.z); 377 | fd.w = fma(s, b.w, fc.w); 378 | return fd; 379 | } 380 | 381 | // Vector sum. 382 | template<> 383 | inline __device__ float sum(__nv_bfloat16 v) { 384 | return __bfloat162float(v); 385 | } 386 | 387 | template<> 388 | inline __device__ float sum(__nv_bfloat162 v) { 389 | float2 vf = bf1622float2(v); 390 | return vf.x + vf.y; 391 | } 392 | 393 | template<> 394 | inline __device__ float sum(bf16_4_t v) { 395 | return sum(v.x) + sum(v.y); 396 | } 397 | 398 | template<> 399 | inline __device__ float sum(bf16_8_t v) { 400 | return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); 401 | } 402 | 403 | // From float32 to bfloat16. 404 | inline __device__ void from_float(__nv_bfloat16& dst, float src) { 405 | dst = __float2bfloat16(src); 406 | } 407 | 408 | inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { 409 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 410 | assert(false); 411 | #else 412 | dst = __float22bfloat162_rn(src); 413 | #endif 414 | } 415 | 416 | inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { 417 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 418 | assert(false); 419 | #else 420 | dst.x = __float22bfloat162_rn(src.x); 421 | dst.y = __float22bfloat162_rn(src.y); 422 | #endif 423 | } 424 | 425 | inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { 426 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 427 | assert(false); 428 | #else 429 | dst.x = __float22bfloat162_rn(src.x); 430 | dst.y = __float22bfloat162_rn(src.y); 431 | dst.z = __float22bfloat162_rn(src.z); 432 | dst.w = __float22bfloat162_rn(src.w); 433 | #endif 434 | } 435 | 436 | // From bfloat16 to float32. 437 | inline __device__ float to_float(__nv_bfloat16 u) { 438 | return __bfloat162float(u); 439 | } 440 | 441 | // Zero-out a variable. 442 | inline __device__ void zero(__nv_bfloat16& dst) { 443 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 444 | assert(false); 445 | #else 446 | // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. 447 | dst = __ushort_as_bfloat16((unsigned short)0x0000U); 448 | #endif 449 | } 450 | 451 | } // namespace vllm -------------------------------------------------------------------------------- /csrc/paged_attention/attention/dtype_float16.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp 3 | * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h 4 | * Copyright (c) 2023, The vLLM team. 5 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 6 | * 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | #pragma once 20 | 21 | #include "attention_generic.cuh" 22 | #include "dtype_float32.cuh" 23 | 24 | #ifdef USE_ROCM 25 | #include 26 | #endif 27 | 28 | #include 29 | 30 | namespace vllm { 31 | 32 | // FP16 vector types for Q, K, V. 33 | template<> 34 | struct Vec { 35 | using Type = uint16_t; 36 | }; 37 | template<> 38 | struct Vec { 39 | using Type = uint32_t; 40 | }; 41 | template<> 42 | struct Vec { 43 | using Type = uint2; 44 | }; 45 | template<> 46 | struct Vec { 47 | using Type = uint4; 48 | }; 49 | 50 | // FP32 accumulator vector types corresponding to Vec. 51 | template<> 52 | struct FloatVec { 53 | using Type = float; 54 | }; 55 | template<> 56 | struct FloatVec { 57 | using Type = float2; 58 | }; 59 | template<> 60 | struct FloatVec { 61 | using Type = Float4_; 62 | }; 63 | template<> 64 | struct FloatVec { 65 | using Type = Float8_; 66 | }; 67 | 68 | // Utility functions for type conversions. 69 | inline __device__ uint32_t h0_h0(uint16_t a) { 70 | #ifndef USE_ROCM 71 | uint32_t b; 72 | asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); 73 | return b; 74 | #else 75 | union { 76 | uint32_t u32; 77 | uint16_t u16[2]; 78 | } tmp; 79 | tmp.u16[0] = a; 80 | tmp.u16[1] = a; 81 | return tmp.u32; 82 | #endif 83 | } 84 | 85 | inline __device__ float half_to_float(uint16_t h) { 86 | float f; 87 | #ifndef USE_ROCM 88 | asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); 89 | #else 90 | asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); 91 | #endif 92 | return f; 93 | } 94 | 95 | inline __device__ float2 half2_to_float2(uint32_t v) { 96 | #ifndef USE_ROCM 97 | uint16_t lo, hi; 98 | asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); 99 | return make_float2(half_to_float(lo), half_to_float(hi)); 100 | #else 101 | union { 102 | uint32_t u32; 103 | uint16_t u16[2]; 104 | } tmp; 105 | tmp.u32 = v; 106 | float2 ret; 107 | ret.x = half_to_float(tmp.u16[0]); 108 | ret.y = half_to_float(tmp.u16[1]); 109 | return ret; 110 | #endif 111 | } 112 | 113 | inline __device__ uint16_t float_to_half(float f) { 114 | union { 115 | uint32_t u32; 116 | uint16_t u16[2]; 117 | } tmp; 118 | #ifndef USE_ROCM 119 | asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); 120 | #else 121 | asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); 122 | #endif 123 | return tmp.u16[0]; 124 | } 125 | 126 | inline __device__ uint32_t float2_to_half2(float2 f) { 127 | union { 128 | uint32_t u32; 129 | uint16_t u16[2]; 130 | } tmp; 131 | #ifndef USE_ROCM 132 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 133 | asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); 134 | #else 135 | asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); 136 | asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); 137 | #endif 138 | #else 139 | tmp.u16[0] = float_to_half(f.x); 140 | tmp.u16[1] = float_to_half(f.y); 141 | #endif 142 | return tmp.u32; 143 | } 144 | 145 | // Vector addition. 146 | inline __device__ uint16_t add(uint16_t a, uint16_t b) { 147 | uint16_t c; 148 | #ifndef USE_ROCM 149 | asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); 150 | #else 151 | asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); 152 | #endif 153 | return c; 154 | } 155 | 156 | inline __device__ uint32_t add(uint32_t a, uint32_t b) { 157 | uint32_t c; 158 | #ifndef USE_ROCM 159 | asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); 160 | #else 161 | asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); 162 | #endif 163 | return c; 164 | } 165 | 166 | inline __device__ uint2 add(uint2 a, uint2 b) { 167 | uint2 c; 168 | c.x = add(a.x, b.x); 169 | c.y = add(a.y, b.y); 170 | return c; 171 | } 172 | 173 | inline __device__ uint4 add(uint4 a, uint4 b) { 174 | uint4 c; 175 | c.x = add(a.x, b.x); 176 | c.y = add(a.y, b.y); 177 | c.z = add(a.z, b.z); 178 | c.w = add(a.w, b.w); 179 | return c; 180 | } 181 | 182 | inline __device__ float2 add(uint32_t a, float2 fb) { 183 | float2 fa = half2_to_float2(a); 184 | return add(fa, fb); 185 | } 186 | 187 | inline __device__ Float4_ add(uint2 a, Float4_ fb) { 188 | Float4_ fc; 189 | fc.x = add(a.x, fb.x); 190 | fc.y = add(a.y, fb.y); 191 | return fc; 192 | } 193 | 194 | inline __device__ Float8_ add(uint4 a, Float8_ fb) { 195 | Float8_ fc; 196 | fc.x = add(a.x, fb.x); 197 | fc.y = add(a.y, fb.y); 198 | fc.z = add(a.z, fb.z); 199 | fc.w = add(a.w, fb.w); 200 | return fc; 201 | } 202 | 203 | // Vector multiplication. 204 | template<> 205 | inline __device__ uint16_t mul(uint16_t a, uint16_t b) { 206 | uint16_t c; 207 | #ifndef USE_ROCM 208 | asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); 209 | #else 210 | asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); 211 | #endif 212 | return c; 213 | } 214 | 215 | template<> 216 | inline __device__ uint32_t mul(uint32_t a, uint32_t b) { 217 | uint32_t c; 218 | #ifndef USE_ROCM 219 | asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); 220 | #else 221 | asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); 222 | #endif 223 | return c; 224 | } 225 | 226 | template<> 227 | inline __device__ uint32_t mul(uint16_t a, uint32_t b) { 228 | return mul(h0_h0(a), b); 229 | } 230 | 231 | template<> 232 | inline __device__ uint2 mul(uint2 a, uint2 b) { 233 | uint2 c; 234 | c.x = mul(a.x, b.x); 235 | c.y = mul(a.y, b.y); 236 | return c; 237 | } 238 | 239 | template<> 240 | inline __device__ uint2 mul(uint16_t a, uint2 b) { 241 | uint32_t s = h0_h0(a); 242 | uint2 c; 243 | c.x = mul(s, b.x); 244 | c.y = mul(s, b.y); 245 | return c; 246 | } 247 | 248 | template<> 249 | inline __device__ uint4 mul(uint4 a, uint4 b) { 250 | uint4 c; 251 | c.x = mul(a.x, b.x); 252 | c.y = mul(a.y, b.y); 253 | c.z = mul(a.z, b.z); 254 | c.w = mul(a.w, b.w); 255 | return c; 256 | } 257 | 258 | template<> 259 | inline __device__ uint4 mul(uint16_t a, uint4 b) { 260 | uint32_t s = h0_h0(a); 261 | uint4 c; 262 | c.x = mul(s, b.x); 263 | c.y = mul(s, b.y); 264 | c.z = mul(s, b.z); 265 | c.w = mul(s, b.w); 266 | return c; 267 | } 268 | 269 | template<> 270 | inline __device__ float mul(uint16_t a, uint16_t b) { 271 | float fa = half_to_float(a); 272 | float fb = half_to_float(b); 273 | return fa * fb; 274 | } 275 | 276 | template<> 277 | inline __device__ float2 mul(uint32_t a, uint32_t b) { 278 | float2 fa = half2_to_float2(a); 279 | float2 fb = half2_to_float2(b); 280 | return mul(fa, fb); 281 | } 282 | 283 | template<> 284 | inline __device__ float2 mul(uint16_t a, uint32_t b) { 285 | return mul(h0_h0(a), b); 286 | } 287 | 288 | template<> 289 | inline __device__ Float4_ mul(uint2 a, uint2 b) { 290 | Float4_ fc; 291 | fc.x = mul(a.x, b.x); 292 | fc.y = mul(a.y, b.y); 293 | return fc; 294 | } 295 | 296 | template<> 297 | inline __device__ Float4_ mul(uint16_t a, uint2 b) { 298 | uint32_t s = h0_h0(a); 299 | Float4_ fc; 300 | fc.x = mul(s, b.x); 301 | fc.y = mul(s, b.y); 302 | return fc; 303 | } 304 | 305 | template<> 306 | inline __device__ Float8_ mul(uint4 a, uint4 b) { 307 | Float8_ fc; 308 | fc.x = mul(a.x, b.x); 309 | fc.y = mul(a.y, b.y); 310 | fc.z = mul(a.z, b.z); 311 | fc.w = mul(a.w, b.w); 312 | return fc; 313 | } 314 | 315 | template<> 316 | inline __device__ Float8_ mul(uint16_t a, uint4 b) { 317 | uint32_t s = h0_h0(a); 318 | Float8_ fc; 319 | fc.x = mul(s, b.x); 320 | fc.y = mul(s, b.y); 321 | fc.z = mul(s, b.z); 322 | fc.w = mul(s, b.w); 323 | return fc; 324 | } 325 | 326 | // Vector fused multiply-add. 327 | inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { 328 | uint32_t d; 329 | #ifndef USE_ROCM 330 | asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); 331 | #else 332 | asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); 333 | #endif 334 | return d; 335 | } 336 | 337 | inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { 338 | return fma(h0_h0(a), b, c); 339 | } 340 | 341 | inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { 342 | uint2 d; 343 | d.x = fma(a.x, b.x, c.x); 344 | d.y = fma(a.y, b.y, c.y); 345 | return d; 346 | } 347 | 348 | inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { 349 | uint32_t s = h0_h0(a); 350 | uint2 d; 351 | d.x = fma(s, b.x, c.x); 352 | d.y = fma(s, b.y, c.y); 353 | return d; 354 | } 355 | 356 | inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { 357 | uint4 d; 358 | d.x = fma(a.x, b.x, c.x); 359 | d.y = fma(a.y, b.y, c.y); 360 | d.z = fma(a.z, b.z, c.z); 361 | d.w = fma(a.w, b.w, c.w); 362 | return d; 363 | } 364 | 365 | inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { 366 | uint32_t s = h0_h0(a); 367 | uint4 d; 368 | d.x = fma(s, b.x, c.x); 369 | d.y = fma(s, b.y, c.y); 370 | d.z = fma(s, b.z, c.z); 371 | d.w = fma(s, b.w, c.w); 372 | return d; 373 | } 374 | 375 | inline __device__ float fma(uint16_t a, uint16_t b, float fc) { 376 | float fa = half_to_float(a); 377 | float fb = half_to_float(b); 378 | return fa * fb + fc; 379 | } 380 | 381 | inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { 382 | float2 fa = half2_to_float2(a); 383 | float2 fb = half2_to_float2(b); 384 | return fma(fa, fb, fc); 385 | } 386 | 387 | inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { 388 | return fma(h0_h0(a), b, fc); 389 | } 390 | 391 | inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { 392 | Float4_ fd; 393 | fd.x = fma(a.x, b.x, fc.x); 394 | fd.y = fma(a.y, b.y, fc.y); 395 | return fd; 396 | } 397 | 398 | inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { 399 | uint32_t s = h0_h0(a); 400 | Float4_ fd; 401 | fd.x = fma(s, b.x, fc.x); 402 | fd.y = fma(s, b.y, fc.y); 403 | return fd; 404 | } 405 | 406 | inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { 407 | Float8_ fd; 408 | fd.x = fma(a.x, b.x, fc.x); 409 | fd.y = fma(a.y, b.y, fc.y); 410 | fd.z = fma(a.z, b.z, fc.z); 411 | fd.w = fma(a.w, b.w, fc.w); 412 | return fd; 413 | } 414 | 415 | inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { 416 | uint32_t s = h0_h0(a); 417 | Float8_ fd; 418 | fd.x = fma(s, b.x, fc.x); 419 | fd.y = fma(s, b.y, fc.y); 420 | fd.z = fma(s, b.z, fc.z); 421 | fd.w = fma(s, b.w, fc.w); 422 | return fd; 423 | } 424 | 425 | // Vector sum. 426 | template<> 427 | inline __device__ float sum(uint16_t v) { 428 | return half_to_float(v); 429 | } 430 | 431 | template<> 432 | inline __device__ float sum(uint32_t v) { 433 | float2 tmp = half2_to_float2(v); 434 | return tmp.x + tmp.y; 435 | } 436 | 437 | template<> 438 | inline __device__ float sum(uint2 v) { 439 | uint32_t c = add(v.x, v.y); 440 | return sum(c); 441 | } 442 | 443 | template<> 444 | inline __device__ float sum(uint4 v) { 445 | uint32_t c = add(v.x, v.y); 446 | c = add(c, v.z); 447 | c = add(c, v.w); 448 | return sum(c); 449 | } 450 | 451 | // From float32 to float16. 452 | inline __device__ void from_float(uint16_t& dst, float src) { 453 | dst = float_to_half(src); 454 | } 455 | 456 | inline __device__ void from_float(uint32_t& dst, float2 src) { 457 | dst = float2_to_half2(src); 458 | } 459 | 460 | inline __device__ void from_float(uint2& dst, Float4_ src) { 461 | dst.x = float2_to_half2(src.x); 462 | dst.y = float2_to_half2(src.y); 463 | } 464 | 465 | inline __device__ void from_float(uint4& dst, Float8_ src) { 466 | dst.x = float2_to_half2(src.x); 467 | dst.y = float2_to_half2(src.y); 468 | dst.z = float2_to_half2(src.z); 469 | dst.w = float2_to_half2(src.w); 470 | } 471 | 472 | // From float16 to float32. 473 | inline __device__ float to_float(uint16_t u) { 474 | return half_to_float(u); 475 | } 476 | 477 | inline __device__ float2 to_float(uint32_t u) { 478 | return half2_to_float2(u); 479 | } 480 | 481 | inline __device__ Float4_ to_float(uint2 u) { 482 | Float4_ tmp; 483 | tmp.x = half2_to_float2(u.x); 484 | tmp.y = half2_to_float2(u.y); 485 | return tmp; 486 | } 487 | 488 | inline __device__ Float8_ to_float(uint4 u) { 489 | Float8_ tmp; 490 | tmp.x = half2_to_float2(u.x); 491 | tmp.y = half2_to_float2(u.y); 492 | tmp.z = half2_to_float2(u.z); 493 | tmp.w = half2_to_float2(u.w); 494 | return tmp; 495 | } 496 | 497 | // Zero-out a variable. 498 | inline __device__ void zero(uint16_t& dst) { 499 | dst = uint16_t(0); 500 | } 501 | 502 | } // namespace vllm -------------------------------------------------------------------------------- /csrc/paged_attention/attention/dtype_float32.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp 3 | * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h 4 | * Copyright (c) 2023, The vLLM team. 5 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 6 | * 7 | * Licensed under the Apache License, Version 2.0 (the "License"); 8 | * you may not use this file except in compliance with the License. 9 | * You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | #pragma once 20 | 21 | #include "attention_generic.cuh" 22 | 23 | #include 24 | 25 | namespace vllm { 26 | 27 | // Define custom FP32 vector data types. 28 | struct Float4_ { 29 | float2 x; 30 | float2 y; 31 | }; 32 | 33 | struct Float8_ { 34 | float2 x; 35 | float2 y; 36 | float2 z; 37 | float2 w; 38 | }; 39 | 40 | // FP32 vector types for Q, K, V. 41 | template<> 42 | struct Vec { 43 | using Type = float; 44 | }; 45 | template<> 46 | struct Vec { 47 | using Type = float2; 48 | }; 49 | template<> 50 | struct Vec { 51 | using Type = float4; 52 | }; 53 | 54 | // FP32 accumulator vector types corresponding to Vec. 55 | template<> 56 | struct FloatVec { 57 | using Type = float; 58 | }; 59 | template<> 60 | struct FloatVec { 61 | using Type = float2; 62 | }; 63 | template<> 64 | struct FloatVec { 65 | using Type = float4; 66 | }; 67 | 68 | // Vector addition. 69 | inline __device__ float add(float a, float b) { 70 | return a + b; 71 | } 72 | 73 | inline __device__ float2 add(float2 a, float2 b) { 74 | float2 c; 75 | c.x = add(a.x, b.x); 76 | c.y = add(a.y, b.y); 77 | return c; 78 | } 79 | 80 | inline __device__ float4 add(float4 a, float4 b) { 81 | float4 c; 82 | c.x = add(a.x, b.x); 83 | c.y = add(a.y, b.y); 84 | c.z = add(a.z, b.z); 85 | c.w = add(a.w, b.w); 86 | return c; 87 | } 88 | 89 | // Vector multiplication. 90 | template<> 91 | inline __device__ float mul(float a, float b) { 92 | return a * b; 93 | } 94 | 95 | template<> 96 | inline __device__ float2 mul(float2 a, float2 b) { 97 | float2 c; 98 | c.x = a.x * b.x; 99 | c.y = a.y * b.y; 100 | return c; 101 | } 102 | 103 | template<> 104 | inline __device__ float2 mul(float a, float2 b) { 105 | float2 c; 106 | c.x = a * b.x; 107 | c.y = a * b.y; 108 | return c; 109 | } 110 | 111 | template<> 112 | inline __device__ float4 mul(float4 a, float4 b) { 113 | float4 c; 114 | c.x = a.x * b.x; 115 | c.y = a.y * b.y; 116 | c.z = a.z * b.z; 117 | c.w = a.w * b.w; 118 | return c; 119 | } 120 | 121 | template<> 122 | inline __device__ float4 mul(float a, float4 b) { 123 | float4 c; 124 | c.x = a * b.x; 125 | c.y = a * b.y; 126 | c.z = a * b.z; 127 | c.w = a * b.w; 128 | return c; 129 | } 130 | 131 | // Vector fused multiply-add. 132 | inline __device__ float fma(float a, float b, float c) { 133 | return a * b + c; 134 | } 135 | 136 | inline __device__ float2 fma(float2 a, float2 b, float2 c) { 137 | float2 d; 138 | d.x = fma(a.x, b.x, c.x); 139 | d.y = fma(a.y, b.y, c.y); 140 | return d; 141 | } 142 | 143 | inline __device__ float2 fma(float a, float2 b, float2 c) { 144 | float2 d; 145 | d.x = fma(a, b.x, c.x); 146 | d.y = fma(a, b.y, c.y); 147 | return d; 148 | } 149 | 150 | inline __device__ float4 fma(float4 a, float4 b, float4 c) { 151 | float4 d; 152 | d.x = fma(a.x, b.x, c.x); 153 | d.y = fma(a.y, b.y, c.y); 154 | d.z = fma(a.z, b.z, c.z); 155 | d.w = fma(a.w, b.w, c.w); 156 | return d; 157 | } 158 | 159 | inline __device__ float4 fma(float a, float4 b, float4 c) { 160 | float4 d; 161 | d.x = fma(a, b.x, c.x); 162 | d.y = fma(a, b.y, c.y); 163 | d.z = fma(a, b.z, c.z); 164 | d.w = fma(a, b.w, c.w); 165 | return d; 166 | } 167 | 168 | inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { 169 | Float4_ d; 170 | d.x = fma(a, b.x, c.x); 171 | d.y = fma(a, b.y, c.y); 172 | return d; 173 | } 174 | 175 | inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { 176 | Float8_ d; 177 | d.x = fma(a, b.x, c.x); 178 | d.y = fma(a, b.y, c.y); 179 | d.z = fma(a, b.z, c.z); 180 | d.w = fma(a, b.w, c.w); 181 | return d; 182 | } 183 | 184 | // Vector sum. 185 | template<> 186 | inline __device__ float sum(float v) { 187 | return v; 188 | } 189 | 190 | template<> 191 | inline __device__ float sum(float2 v) { 192 | return v.x + v.y; 193 | } 194 | 195 | template<> 196 | inline __device__ float sum(float4 v) { 197 | return v.x + v.y + v.z + v.w; 198 | } 199 | 200 | template<> 201 | inline __device__ float sum(Float4_ v) { 202 | return v.x.x + v.x.y + v.y.x + v.y.y; 203 | } 204 | 205 | template<> 206 | inline __device__ float sum(Float8_ v) { 207 | return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; 208 | } 209 | 210 | // Vector dot product. 211 | inline __device__ float dot(float a, float b) { 212 | return a * b; 213 | } 214 | 215 | inline __device__ float dot(float2 a, float2 b) { 216 | float2 c = mul(a, b); 217 | return c.x + c.y; 218 | } 219 | 220 | inline __device__ float dot(Float4_ a, Float4_ b) { 221 | float2 acc = mul(a.x, b.x); 222 | acc = fma(a.y, b.y, acc); 223 | return acc.x + acc.y; 224 | } 225 | 226 | inline __device__ float dot(Float8_ a, Float8_ b) { 227 | float2 acc = mul(a.x, b.x); 228 | acc = fma(a.y, b.y, acc); 229 | acc = fma(a.z, b.z, acc); 230 | acc = fma(a.w, b.w, acc); 231 | return acc.x + acc.y; 232 | } 233 | 234 | // From float to float. 235 | inline __device__ void from_float(float& dst, float src) { 236 | dst = src; 237 | } 238 | 239 | inline __device__ void from_float(float2& dst, float2 src) { 240 | dst = src; 241 | } 242 | 243 | inline __device__ void from_float(float4& dst, float4 src) { 244 | dst = src; 245 | } 246 | 247 | // From float to float. 248 | inline __device__ float to_float(float u) { 249 | return u; 250 | } 251 | 252 | inline __device__ float2 to_float(float2 u) { 253 | return u; 254 | } 255 | 256 | inline __device__ float4 to_float(float4 u) { 257 | return u; 258 | } 259 | 260 | inline __device__ Float4_ to_float(Float4_ u) { 261 | return u; 262 | } 263 | 264 | inline __device__ Float8_ to_float(Float8_ u) { 265 | return u; 266 | } 267 | 268 | // Zero-out a variable. 269 | inline __device__ void zero(float& dst) { 270 | dst = 0.f; 271 | } 272 | 273 | } // namespace vllm 274 | -------------------------------------------------------------------------------- /csrc/paged_attention/cache.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | void copy_blocks( 9 | std::vector& key_caches, 10 | std::vector& value_caches, 11 | const std::map>& block_mapping); 12 | 13 | void reshape_and_cache( 14 | torch::Tensor& key, 15 | torch::Tensor& value, 16 | torch::Tensor& key_cache, 17 | torch::Tensor& value_cache, 18 | torch::Tensor& slot_mapping); 19 | -------------------------------------------------------------------------------- /csrc/paged_attention/cache_kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_compat.h" 6 | #include "dispatch_utils.h" 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | namespace vllm { 14 | 15 | // Grid: (num_layers, num_pairs) 16 | template 17 | __global__ void copy_blocks_kernel( 18 | int64_t* key_cache_ptrs, 19 | int64_t* value_cache_ptrs, 20 | const int64_t* __restrict__ block_mapping, 21 | const int numel_per_block) { 22 | const int layer_idx = blockIdx.x; 23 | const int pair_idx = blockIdx.y; 24 | 25 | scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); 26 | scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); 27 | int64_t src_block_number = block_mapping[2 * pair_idx]; 28 | int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; 29 | 30 | const int64_t src_block_offset = src_block_number * numel_per_block; 31 | const int64_t dst_block_offset = dst_block_number * numel_per_block; 32 | for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { 33 | int64_t src_offset = src_block_offset + i; 34 | int64_t dst_offset = dst_block_offset + i; 35 | key_cache[dst_offset] = key_cache[src_offset]; 36 | } 37 | for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { 38 | int64_t src_offset = src_block_offset + i; 39 | int64_t dst_offset = dst_block_offset + i; 40 | value_cache[dst_offset] = value_cache[src_offset]; 41 | } 42 | } 43 | 44 | } // namespace vllm 45 | 46 | void copy_blocks( 47 | std::vector& key_caches, 48 | std::vector& value_caches, 49 | const std::map>& block_mapping) { 50 | int num_layers = key_caches.size(); 51 | TORCH_CHECK(num_layers == value_caches.size()); 52 | if (num_layers == 0) { 53 | return; 54 | } 55 | torch::Device cache_device = key_caches[0].device(); 56 | TORCH_CHECK(cache_device.is_cuda()); 57 | 58 | // Create data structures for the kernel. 59 | // Create an array of pointers to the key and value caches. 60 | int64_t key_cache_ptrs[num_layers]; 61 | int64_t value_cache_ptrs[num_layers]; 62 | for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { 63 | key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); 64 | value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); 65 | } 66 | // Create block mapping array. 67 | std::vector block_mapping_vec; 68 | for (const auto& pair : block_mapping) { 69 | int64_t src_block_number = pair.first; 70 | for (int64_t dst_block_number : pair.second) { 71 | block_mapping_vec.push_back(src_block_number); 72 | block_mapping_vec.push_back(dst_block_number); 73 | } 74 | } 75 | int64_t* block_mapping_array = block_mapping_vec.data(); 76 | int num_pairs = block_mapping_vec.size() / 2; 77 | 78 | // Move the data structures to the GPU. 79 | // NOTE: This synchronizes the CPU and GPU. 80 | torch::Tensor key_cache_ptrs_tensor = torch::from_blob( 81 | key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); 82 | torch::Tensor value_cache_ptrs_tensor = torch::from_blob( 83 | value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); 84 | torch::Tensor block_mapping_tensor = torch::from_blob( 85 | block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); 86 | 87 | // Launch the kernel. 88 | const int numel_per_block = key_caches[0][0].numel(); 89 | dim3 grid(num_layers, num_pairs); 90 | dim3 block(std::min(1024, numel_per_block)); 91 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 92 | VLLM_DISPATCH_FLOATING_TYPES( 93 | key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { 94 | vllm::copy_blocks_kernel<<>>( 95 | key_cache_ptrs_tensor.data_ptr(), 96 | value_cache_ptrs_tensor.data_ptr(), 97 | block_mapping_tensor.data_ptr(), 98 | numel_per_block); 99 | })); 100 | } 101 | 102 | namespace vllm { 103 | 104 | template 105 | __global__ void reshape_and_cache_kernel( 106 | const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] 107 | const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] 108 | scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 109 | scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] 110 | const int64_t* __restrict__ slot_mapping, // [num_tokens] 111 | const int key_stride, 112 | const int value_stride, 113 | const int num_heads, 114 | const int head_size, 115 | const int block_size, 116 | const int x) { 117 | const int64_t token_idx = blockIdx.x; 118 | const int64_t slot_idx = slot_mapping[token_idx]; 119 | if (slot_idx < 0) { 120 | // Padding token that should be ignored. 121 | return; 122 | } 123 | 124 | const int64_t block_idx = slot_idx / block_size; 125 | const int64_t block_offset = slot_idx % block_size; 126 | 127 | const int n = num_heads * head_size; 128 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 129 | const int64_t src_key_idx = token_idx * key_stride + i; 130 | const int64_t src_value_idx = token_idx * value_stride + i; 131 | 132 | const int head_idx = i / head_size; 133 | const int head_offset = i % head_size; 134 | const int x_idx = head_offset / x; 135 | const int x_offset = head_offset % x; 136 | 137 | const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x 138 | + head_idx * (head_size / x) * block_size * x 139 | + x_idx * block_size * x 140 | + block_offset * x 141 | + x_offset; 142 | const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size 143 | + head_idx * head_size * block_size 144 | + head_offset * block_size 145 | + block_offset; 146 | key_cache[tgt_key_idx] = key[src_key_idx]; 147 | value_cache[tgt_value_idx] = value[src_value_idx]; 148 | } 149 | } 150 | 151 | } // namespace vllm 152 | 153 | void reshape_and_cache( 154 | torch::Tensor& key, // [num_tokens, num_heads, head_size] 155 | torch::Tensor& value, // [num_tokens, num_heads, head_size] 156 | torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] 157 | torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] 158 | torch::Tensor& slot_mapping) // [num_tokens] 159 | { 160 | int num_tokens = key.size(0); 161 | int num_heads = key.size(1); 162 | int head_size = key.size(2); 163 | int block_size = key_cache.size(3); 164 | int x = key_cache.size(4); 165 | 166 | int key_stride = key.stride(0); 167 | int value_stride = value.stride(0); 168 | 169 | dim3 grid(num_tokens); 170 | dim3 block(std::min(num_heads * head_size, 512)); 171 | const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); 172 | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 173 | VLLM_DISPATCH_FLOATING_TYPES( 174 | key.scalar_type(), 175 | "reshape_and_cache_kernel", 176 | [&] { 177 | vllm::reshape_and_cache_kernel<<>>( 178 | key.data_ptr(), 179 | value.data_ptr(), 180 | key_cache.data_ptr(), 181 | value_cache.data_ptr(), 182 | slot_mapping.data_ptr(), 183 | key_stride, 184 | value_stride, 185 | num_heads, 186 | head_size, 187 | block_size, 188 | x); 189 | }); 190 | } -------------------------------------------------------------------------------- /csrc/paged_attention/cuda_compat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef USE_ROCM 4 | #define VLLM_LDG(arg) __ldg(arg) 5 | #else 6 | #define VLLM_LDG(arg) *(arg) 7 | #endif 8 | 9 | #ifndef USE_ROCM 10 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) 11 | #else 12 | #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) 13 | #endif 14 | 15 | #ifndef USE_ROCM 16 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) 17 | #else 18 | #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) 19 | #endif 20 | 21 | #ifndef USE_ROCM 22 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 23 | cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) 24 | #else 25 | #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ 26 | hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) 27 | #endif -------------------------------------------------------------------------------- /csrc/paged_attention/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | int get_device_attribute( 6 | int attribute, 7 | int device_id); -------------------------------------------------------------------------------- /csrc/paged_attention/cuda_utils_kernels.cu: -------------------------------------------------------------------------------- 1 | #ifdef USE_ROCM 2 | #include 3 | #endif 4 | int get_device_attribute( 5 | int attribute, 6 | int device_id) 7 | { 8 | int device, value; 9 | if (device_id < 0) { 10 | cudaGetDevice(&device); 11 | } 12 | else { 13 | device = device_id; 14 | } 15 | cudaDeviceGetAttribute(&value, static_cast(attribute), device); 16 | return value; 17 | } -------------------------------------------------------------------------------- /csrc/paged_attention/dispatch_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from 3 | * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h 4 | */ 5 | #pragma once 6 | 7 | #include 8 | 9 | #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ 10 | AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ 11 | AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ 12 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) 13 | 14 | #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ 15 | AT_DISPATCH_SWITCH( \ 16 | TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -------------------------------------------------------------------------------- /csrc/paged_attention/ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void paged_attention_v1( 6 | torch::Tensor& out, 7 | torch::Tensor& query, 8 | torch::Tensor& key_cache, 9 | torch::Tensor& value_cache, 10 | int num_kv_heads, 11 | float scale, 12 | torch::Tensor& block_tables, 13 | torch::Tensor& context_lens, 14 | int block_size, 15 | int max_context_len, 16 | const c10::optional& alibi_slopes); 17 | 18 | void paged_attention_v2( 19 | torch::Tensor& out, 20 | torch::Tensor& exp_sums, 21 | torch::Tensor& max_logits, 22 | torch::Tensor& tmp_out, 23 | torch::Tensor& query, 24 | torch::Tensor& key_cache, 25 | torch::Tensor& value_cache, 26 | int num_kv_heads, 27 | float scale, 28 | torch::Tensor& block_tables, 29 | torch::Tensor& context_lens, 30 | int block_size, 31 | int max_context_len, 32 | const c10::optional& alibi_slopes); 33 | -------------------------------------------------------------------------------- /csrc/paged_attention/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include "cache.h" 2 | #include "cuda_utils.h" 3 | #include "ops.h" 4 | #include 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // vLLM custom ops 8 | pybind11::module ops = m.def_submodule("attn_ops", "vLLM attn operators"); 9 | 10 | // Attention ops 11 | ops.def( 12 | "paged_attention_v1", 13 | &paged_attention_v1, 14 | "Compute the attention between an input query and the cached keys/values using PagedAttention."); 15 | ops.def( 16 | "paged_attention_v2", 17 | &paged_attention_v2, 18 | "PagedAttention V2."); 19 | 20 | // Cache ops 21 | pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); 22 | cache_ops.def( 23 | "reshape_and_cache", 24 | &reshape_and_cache, 25 | "Reshape the key and value tensors and cache them"); 26 | 27 | cache_ops.def( 28 | "copy_blocks", 29 | ©_blocks, 30 | "Copy the cache blocks from src to dst"); 31 | 32 | // Cuda utils 33 | pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); 34 | cuda_utils.def( 35 | "get_device_attribute", 36 | &get_device_attribute, 37 | "Gets the specified device attribute."); 38 | } 39 | -------------------------------------------------------------------------------- /csrc/paged_attention/reduction_utils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh 3 | * Copyright (c) 2023, The vLLM team. 4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | #pragma once 19 | 20 | namespace vllm { 21 | 22 | template 23 | __inline__ __device__ T warpReduceSum(T val) { 24 | #pragma unroll 25 | for (int mask = 16; mask > 0; mask >>= 1) 26 | val += __shfl_xor_sync(0xffffffff, val, mask, 32); 27 | return val; 28 | } 29 | 30 | /* Calculate the sum of all elements in a block */ 31 | template 32 | __inline__ __device__ T blockReduceSum(T val) { 33 | static __shared__ T shared[32]; 34 | int lane = threadIdx.x & 0x1f; 35 | int wid = threadIdx.x >> 5; 36 | 37 | val = warpReduceSum(val); 38 | 39 | if (lane == 0) 40 | shared[wid] = val; 41 | 42 | __syncthreads(); 43 | 44 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 45 | // blockDim.x is not divided by 32 46 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); 47 | val = warpReduceSum(val); 48 | return val; 49 | } 50 | 51 | } // namespace vllm 52 | -------------------------------------------------------------------------------- /fms-extras-requirements.txt: -------------------------------------------------------------------------------- 1 | # This requirement files can be used for development purposes, as it points to an 2 | # unreleased version (the main branch) for the ibm-fms package 3 | # It pins all other fms-extra dependencies to a specific version 4 | # for repeatable installs 5 | 6 | accelerate==0.26.1 7 | 8 | ibm-fms @ git+https://github.com/foundation-model-stack/foundation-model-stack@main -------------------------------------------------------------------------------- /fms_extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/fms_extras/__init__.py -------------------------------------------------------------------------------- /fms_extras/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/fms_extras/models/__init__.py -------------------------------------------------------------------------------- /fms_extras/models/calico.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import re 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Any, Mapping, Optional, OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | from fms import models 12 | from fms.distributed.strategy import ( 13 | DistributedStrategy, 14 | NoOpStrategy, 15 | TensorParallelStrategy, 16 | UniformModelParallelStrategy, 17 | ) 18 | from fms.modules.attention import MultiHeadAttention 19 | from fms.modules.embedding import WordEmbedding 20 | from fms.modules.feedforward import GatedLinearUnit 21 | from fms.modules.layernorm import LayerNormParameterized 22 | from fms.modules.positions import RotaryEmbedding 23 | from fms.utils import serialization 24 | from fms.utils.activation import str_to_activation 25 | from fms.utils.config import ModelConfig 26 | from fms.utils.serialization import ( 27 | _legacy_attn_unfused_to_fused_adapter, 28 | _legacy_mlp_glu_unfused_to_fused_adapter, 29 | ) 30 | from fms.utils.tokenizers import _has_hf, get_tokenizer 31 | 32 | 33 | # params emb_dim heads layers lr 34 | # 7B 4096 32 32 3.0E-04 35 | # 13B 5120 40 40 3.0E-04 36 | # 33B 6656 52 60 1.5.E-04 37 | # 65B 8192 64 80 1.5.E-04 38 | 39 | 40 | @dataclass 41 | class CalicoConfig(ModelConfig): 42 | src_vocab_size: int = 65024 # can be set by tokenizer 43 | emb_dim: int = 4608 44 | norm_eps: float = 1e-5 45 | nheads: int = 36 46 | kvheads: int = 4 47 | nlayers: int = 46 48 | pad_id: int = 2 49 | hidden_grow_factor: float = 8 / 3 50 | multiple_of: int = 1 51 | activation_fn: str = "swish" 52 | p_dropout: float = 0.1 53 | max_expected_seq_len: int = 4096 54 | ntk_scaling: bool = False 55 | 56 | 57 | class CalicoBlock(nn.Module): 58 | def __init__(self, config: CalicoConfig, rotary_emb: RotaryEmbedding): 59 | super(CalicoBlock, self).__init__() 60 | self.config = config 61 | emb_kq = self.config.emb_dim // self.config.nheads 62 | emb_v = self.config.emb_dim // self.config.nheads 63 | 64 | self.ln = LayerNormParameterized( 65 | self.config.emb_dim, 66 | elementwise_scale=True, 67 | elementwise_shift=False, 68 | use_mean=False, 69 | eps=self.config.norm_eps, 70 | use_high_precision_pow=True, 71 | ) 72 | self.ff_ln = LayerNormParameterized( 73 | self.config.emb_dim, 74 | elementwise_scale=True, 75 | elementwise_shift=False, 76 | use_mean=False, 77 | eps=self.config.norm_eps, 78 | use_high_precision_pow=True, 79 | ) 80 | 81 | if self.config.kvheads == 0: 82 | kvheads = self.config.nheads 83 | else: 84 | kvheads = self.config.kvheads 85 | assert self.config.nheads % self.config.kvheads == 0 86 | 87 | self.attn = MultiHeadAttention( 88 | self.config.emb_dim, 89 | emb_kq, 90 | emb_v, 91 | self.config.nheads, 92 | kvheads, 93 | p_dropout=self.config.p_dropout, 94 | use_bias=True, 95 | position_encoder=rotary_emb, 96 | ) 97 | self.ff_sub_layer = GatedLinearUnit( 98 | self.config.emb_dim, 99 | hidden_grow_factor=self.config.hidden_grow_factor, 100 | multiple_of=self.config.multiple_of, 101 | activation_fn=str_to_activation(self.config.activation_fn), 102 | p_dropout=self.config.p_dropout, 103 | use_bias=True, 104 | ) 105 | 106 | if self.config.p_dropout != 0: 107 | self.dropout = nn.Dropout(self.config.p_dropout) 108 | 109 | def forward( 110 | self, 111 | x, 112 | *, 113 | mask=None, 114 | position_ids=None, 115 | past_key_value_state=None, 116 | use_cache=False, 117 | is_causal_mask=False, 118 | attn_algorithm=None, 119 | ): 120 | # if the cache is not empty, we need to get the kv cache for self and cross attention 121 | self_attn_past_key_value = past_key_value_state 122 | # if past_key_value_state is not None: 123 | # self_attn_past_key_value = past_key_value_state[:2] 124 | # else: 125 | # self_attn_past_key_value = None 126 | 127 | # first we do MHA and Add&Norm 128 | residual = x 129 | x = self.ln(x) 130 | x = self.attn( 131 | q=x, 132 | k=x, 133 | v=x, 134 | mask=mask, 135 | position_ids=position_ids, 136 | attn_algorithm=attn_algorithm, 137 | past_key_value_state=self_attn_past_key_value, 138 | use_cache=use_cache, 139 | is_self=True, 140 | is_causal_mask=is_causal_mask, 141 | ) 142 | cache = None 143 | if use_cache: 144 | x, cache = x 145 | if self.config.p_dropout != 0: 146 | x = self.dropout(x) 147 | # residual connection 148 | x = x + residual 149 | 150 | # then we do FF and Add&Norm 151 | residual = x 152 | x = self.ff_ln(x) 153 | x = self.ff_sub_layer(x) 154 | if self.config.p_dropout != 0: 155 | x = self.dropout(x) 156 | # another residual 157 | x = x + residual 158 | 159 | if use_cache: 160 | return (x, cache) 161 | else: 162 | return x 163 | 164 | 165 | class Calico(nn.Module): 166 | """ 167 | This is an IBM model similar to LLaMA with a few key differences: 168 | 169 | - Calico ties the weights of the input/output embeddings 170 | - Calico adds a bias to attention and mlp 171 | """ 172 | 173 | def __init__( 174 | self, 175 | config: Optional[CalicoConfig] = None, 176 | distributed_strategy: DistributedStrategy = NoOpStrategy, 177 | **kwargs, 178 | ): 179 | super(Calico, self).__init__() 180 | if config is not None: 181 | self.config = config 182 | else: 183 | self.config = CalicoConfig() 184 | self.config = self.config.updated(**kwargs) 185 | self.distributed_strategy = distributed_strategy 186 | 187 | self.width = self.config.emb_dim 188 | self.pad_id = self.config.pad_id 189 | self.max_expected_seq_len = self.config.max_expected_seq_len 190 | 191 | shared = WordEmbedding( 192 | self.config.src_vocab_size, 193 | self.config.emb_dim, 194 | padding_idx=self.config.pad_id, 195 | abs_pos=False, 196 | reversible=True, 197 | tie_weights=True, 198 | bias=False, 199 | ) 200 | self.shared = self.distributed_strategy.distribute_module(shared) 201 | 202 | self.rot_emb = RotaryEmbedding( 203 | dim=self.config.emb_dim // self.config.nheads, 204 | ntk_scaling=self.config.ntk_scaling, 205 | max_seq_len=self.config.max_expected_seq_len, 206 | ) 207 | if isinstance(self.distributed_strategy, UniformModelParallelStrategy): 208 | for dev_idx in set(self.distributed_strategy.layer_to_device): 209 | self.rot_emb.compute_freqs_cis( 210 | torch.device("cuda", dev_idx), self.config.max_expected_seq_len 211 | ) 212 | else: 213 | self.rot_emb.compute_freqs_cis( 214 | self.shared.emb.weight.device, self.config.max_expected_seq_len 215 | ) 216 | 217 | layers = [] 218 | for i in range(self.config.nlayers): 219 | block = CalicoBlock(self.config, self.rot_emb) 220 | block_module = self.distributed_strategy.distribute_layer(block, i) 221 | layers.append(block_module) 222 | self.layers = nn.ModuleList(layers) 223 | 224 | dec_norm = LayerNormParameterized( 225 | self.config.emb_dim, 226 | elementwise_scale=True, 227 | elementwise_shift=False, 228 | use_mean=False, 229 | eps=self.config.norm_eps, 230 | use_high_precision_pow=True, 231 | ) 232 | self.dec_norm = self.distributed_strategy.distribute_module( 233 | dec_norm, final_layers=True 234 | ) 235 | 236 | if self.config.p_dropout: 237 | self.dropout = nn.Dropout(self.config.p_dropout) 238 | 239 | self.reset_params() 240 | 241 | def get_config(self) -> CalicoConfig: 242 | return self.config 243 | 244 | @classmethod 245 | def from_config(cls, config: CalicoConfig) -> "Calico": 246 | return cls(config) 247 | 248 | def reset_params(self): 249 | # Modules are self-initializing, we're just going to down-scale the final prediction head to be 250 | # mixed-fan (inputs and gradients scale to the same inverse factors) if it isn't tied 251 | self.shared.head.weight.data.normal_( 252 | 0, 1 / math.sqrt(math.sqrt(self.width * self.shared.vocab_size)) 253 | ) 254 | 255 | def _helper( 256 | self, 257 | x_in, 258 | mask=None, 259 | position_ids=None, 260 | past_key_value_states=None, 261 | use_cache=False, 262 | attn_algorithm=None, 263 | ): 264 | # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified 265 | # x_in: batch_size x seq_len 266 | # mask: batch_size x seq_len x seq_len 267 | # bias: nheads x seq_len x seq_len 268 | if past_key_value_states is None or len(past_key_value_states) == 0: 269 | past_key_value_states = [None for _ in range(len(self.layers))] 270 | 271 | qlen = x_in.size(1) 272 | klen = x_in.size(1) 273 | 274 | # if we are using the cache, the key length needs to be extended with the past keys length 275 | if use_cache and past_key_value_states[0] is not None: 276 | klen += past_key_value_states[0][0].size(-2) 277 | 278 | # if mask is none, we need to specify causal mask 279 | if mask is None: 280 | # we are caching and can assume all 1s in the mask 281 | if use_cache and klen != 1 and qlen == 1: 282 | # b x h x qlen x kvlen 283 | is_causal_mask = False 284 | else: 285 | is_causal_mask = True 286 | else: 287 | is_causal_mask = False 288 | 289 | x_in = self.shared(x_in) 290 | 291 | # this is the output cache for all the decoder layers 292 | present_key_value_states = [] 293 | 294 | for i, layer in enumerate(self.layers): 295 | output = layer( 296 | x=x_in, 297 | mask=mask, 298 | position_ids=position_ids, 299 | past_key_value_state=past_key_value_states[i], 300 | use_cache=use_cache, 301 | is_causal_mask=is_causal_mask, 302 | attn_algorithm=attn_algorithm, 303 | ) 304 | 305 | if use_cache: 306 | x_in, present_key_value_state = output 307 | present_key_value_states.append(present_key_value_state) 308 | 309 | else: 310 | x_in = output 311 | 312 | dec_out = x_in 313 | dec_out = self.dec_norm(dec_out) 314 | if self.config.p_dropout: 315 | dec_out = self.dropout(dec_out) 316 | 317 | return dec_out, present_key_value_states 318 | 319 | def forward( 320 | self, 321 | x, 322 | mask=None, 323 | position_ids=None, 324 | past_key_value_states=None, 325 | use_cache=False, 326 | only_last_token=False, 327 | attn_algorithm=None, 328 | ): 329 | output, cache = self._helper( 330 | x, mask, position_ids, past_key_value_states, use_cache, attn_algorithm 331 | ) 332 | 333 | if only_last_token: 334 | output = output[:, -1, :] 335 | preds = self.shared(output, reverse=True) 336 | 337 | if use_cache: 338 | return preds, cache 339 | else: 340 | return preds 341 | 342 | 343 | _1b_config = CalicoConfig( 344 | src_vocab_size=50304, 345 | emb_dim=2048, 346 | nheads=16, 347 | kvheads=4, 348 | nlayers=24, 349 | pad_id=0, 350 | hidden_grow_factor=5464 / 2048, 351 | multiple_of=1, 352 | max_expected_seq_len=2048, 353 | ) 354 | 355 | _8b_config = CalicoConfig( 356 | src_vocab_size=65024, 357 | emb_dim=4608, 358 | nheads=36, 359 | kvheads=4, 360 | nlayers=36, 361 | pad_id=2, 362 | hidden_grow_factor=12288 / 4608, 363 | multiple_of=1, 364 | max_expected_seq_len=4096, 365 | ) 366 | 367 | _13b_config = CalicoConfig( 368 | src_vocab_size=65024, 369 | emb_dim=5120, 370 | nheads=40, 371 | kvheads=4, 372 | nlayers=48, 373 | pad_id=2, 374 | hidden_grow_factor=13696 / 5120, 375 | multiple_of=1, 376 | max_expected_seq_len=4096, 377 | ) 378 | 379 | _architecture_name = "calico" 380 | 381 | 382 | def _calico_factory_factory(config): 383 | def factory(**kwargs): 384 | return Calico(config, **kwargs) 385 | 386 | return factory 387 | 388 | 389 | models.register_model(_architecture_name, "1b", _calico_factory_factory(_1b_config)) 390 | 391 | models.register_model(_architecture_name, "8b", _calico_factory_factory(_8b_config)) 392 | 393 | models.register_model(_architecture_name, "13b", _calico_factory_factory(_13b_config)) 394 | 395 | 396 | def _megatron_sd_to_fms_sd(hf_sd: Mapping[Any, Any]) -> Mapping[Any, Any]: 397 | replacements = [ 398 | # embedding 399 | (r"^transformer\.wte\.weight", "shared.emb.weight"), 400 | # layers 401 | (r"^transformer\.h", "layers"), 402 | # attn 403 | (r"attn\.c_proj", "attn.dense"), 404 | # mlp 405 | (r"mlp\.c_proj", "ff_sub_layer.w2"), 406 | # block ln 407 | (r"ln_1\.weight", "ln.weight"), 408 | (r"ln_2\.weight", "ff_ln.weight"), 409 | # model ln 410 | (r"^transformer\.ln_f\.weight", "dec_norm.weight"), 411 | # model head 412 | (r"^lm_head\.weight", "shared.head.weight"), 413 | ] 414 | 415 | qkv_weight_pattern = re.compile("transformer.h.[0-9]+.attn.c_attn.weight") 416 | qkv_bias_pattern = re.compile("transformer.h.[0-9]+.attn.c_attn.bias") 417 | mlp_weight_pattern = re.compile("transformer.h.[0-9]+.mlp.c_fc.weight") 418 | mlp_bias_pattern = re.compile("transformer.h.[0-9]+.mlp.c_fc.bias") 419 | new_sd = {} 420 | for name, param in hf_sd.items(): 421 | new_name = name 422 | for pattern, repl in replacements: 423 | new_name = re.sub(pattern, repl, new_name) 424 | new_sd[new_name] = param 425 | 426 | # qkv fused 427 | if bool(qkv_weight_pattern.match(name)): 428 | new_sd.pop(new_name) 429 | 430 | emb_dim = param.size(1) 431 | num_heads = emb_dim // 128 432 | num_key_value_heads = (param.size(0) // 128 - num_heads) // 2 433 | attn_splits = [ 434 | (num_heads * 128) // num_key_value_heads, 435 | (num_key_value_heads * 128) // num_key_value_heads, 436 | (num_key_value_heads * 128) // num_key_value_heads, 437 | ] 438 | 439 | prefix = new_name.replace("c_attn.weight", "") 440 | q, k, v = param.view(num_key_value_heads, -1, emb_dim).split( 441 | attn_splits, dim=1 442 | ) 443 | q = q.reshape(-1, q.size(2)) 444 | k = k.reshape(-1, k.size(2)) 445 | v = v.reshape(-1, v.size(2)) 446 | q = q.view(num_heads, 2, -1, q.size(1)).transpose(1, 2).reshape(*q.size()) 447 | k = ( 448 | k.view(num_key_value_heads, 2, -1, k.size(1)) 449 | .transpose(1, 2) 450 | .reshape(*k.size()) 451 | ) 452 | 453 | new_sd[f"{prefix}query.weight"] = q 454 | new_sd[f"{prefix}key.weight"] = k 455 | new_sd[f"{prefix}value.weight"] = v 456 | elif bool(qkv_bias_pattern.match(name)): 457 | weight_name = name.replace("bias", "weight") 458 | new_sd.pop(new_name) 459 | 460 | emb_dim = hf_sd[weight_name].size(1) 461 | num_heads = emb_dim // 128 462 | num_key_value_heads = (param.size(0) // 128 - num_heads) // 2 463 | attn_splits = [ 464 | (num_heads * 128) // num_key_value_heads, 465 | (num_key_value_heads * 128) // num_key_value_heads, 466 | (num_key_value_heads * 128) // num_key_value_heads, 467 | ] 468 | 469 | prefix = new_name.replace("c_attn.bias", "") 470 | q, k, v = param.view(num_key_value_heads, -1).split(attn_splits, dim=1) 471 | q = q.reshape(-1) 472 | k = k.reshape(-1) 473 | v = v.reshape(-1) 474 | q = q.view(num_heads, 2, -1).transpose(1, 2).reshape(*q.size()) 475 | k = k.view(num_key_value_heads, 2, -1).transpose(1, 2).reshape(*k.size()) 476 | 477 | new_sd[f"{prefix}query.bias"] = q 478 | new_sd[f"{prefix}key.bias"] = k 479 | new_sd[f"{prefix}value.bias"] = v 480 | elif bool(mlp_weight_pattern.match(name)): 481 | new_sd.pop(new_name) 482 | prefix = new_name.replace("mlp.c_fc.weight", "") 483 | w1, wg = param.chunk(2) 484 | new_sd[f"{prefix}ff_sub_layer.w1.weight"] = w1 485 | new_sd[f"{prefix}ff_sub_layer.wg.weight"] = wg 486 | elif bool(mlp_bias_pattern.match(name)): 487 | new_sd.pop(new_name) 488 | prefix = new_name.replace("mlp.c_fc.bias", "") 489 | w1, wg = param.chunk(2) 490 | new_sd[f"{prefix}ff_sub_layer.w1.bias"] = w1 491 | new_sd[f"{prefix}ff_sub_layer.wg.bias"] = wg 492 | 493 | new_sd = _legacy_mlp_glu_unfused_to_fused_adapter( 494 | _legacy_attn_unfused_to_fused_adapter(new_sd) 495 | ) 496 | return new_sd 497 | 498 | 499 | serialization.register_adapter("calico", "megatron", _megatron_sd_to_fms_sd) 500 | -------------------------------------------------------------------------------- /fms_extras/models/hf/__init__.py: -------------------------------------------------------------------------------- 1 | def register_fms_models(): 2 | from fms.models.hf import _causal_lm_models, _fms_to_hf_adapt_map, _headless_models 3 | from fms.models.hf.utils import register_fms_models 4 | 5 | from fms_extras.models.calico import Calico 6 | from fms_extras.models.hf.modeling_calico_hf import ( 7 | HFAdaptedCalicoForCausalLM, 8 | HFAdaptedCalicoHeadless, 9 | ) 10 | 11 | # todo: should have a better registration method than this 12 | if HFAdaptedCalicoHeadless not in _headless_models: 13 | _headless_models.append(HFAdaptedCalicoHeadless) 14 | 15 | if HFAdaptedCalicoForCausalLM not in _causal_lm_models: 16 | _causal_lm_models.append(HFAdaptedCalicoForCausalLM) 17 | 18 | if Calico not in _fms_to_hf_adapt_map: 19 | _fms_to_hf_adapt_map[Calico] = HFAdaptedCalicoForCausalLM 20 | 21 | register_fms_models() 22 | -------------------------------------------------------------------------------- /fms_extras/models/hf/modeling_calico_hf.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from fms.models.hf.lm_head_mixins import LMHeadModelLMHeadMixin 6 | from fms.models.hf.modeling_hf_adapter import HFDecoder, HFDecoderModelArchitecture 7 | from transformers import PretrainedConfig 8 | from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions 9 | 10 | from fms_extras.models.calico import Calico, CalicoConfig 11 | 12 | 13 | class HFAdaptedCalicoConfig(PretrainedConfig): 14 | model_type = "hf_adapted_calico" 15 | attribute_map = { 16 | "vocab_size": "src_vocab_size", 17 | "hidden_size": "emb_dim", 18 | "num_attention_heads": "nheads", 19 | "num_hidden_layers": "nlayers", 20 | } 21 | 22 | def __init__( 23 | self, 24 | src_vocab_size: Optional[int] = 32000, 25 | emb_dim: Optional[int] = 4096, 26 | norm_eps: float = 1e-6, 27 | nheads: int = 32, 28 | kvheads: int = 0, 29 | nlayers: int = 32, 30 | # note this is different from the non-hf config (which is -1), hf keeps a different default 31 | pad_token_id: int = 0, 32 | hidden_grow_factor: float = 8 / 3, 33 | multiple_of: int = 256, 34 | activation_fn: str = "swish", 35 | p_dropout: float = 0.0, 36 | max_expected_seq_len: int = 2048, 37 | use_cache: bool = True, 38 | eos_token_id: int = 2, 39 | bos_token_id: int = 1, 40 | is_decoder: bool = True, 41 | **kwargs, 42 | ): 43 | self.src_vocab_size = src_vocab_size 44 | self.emb_dim = emb_dim 45 | self.norm_eps = norm_eps 46 | self.nheads = nheads 47 | self.kvheads = kvheads 48 | self.nlayers = nlayers 49 | self.hidden_grow_factor = hidden_grow_factor 50 | self.multiple_of = multiple_of 51 | self.activation_fn = activation_fn 52 | self.p_dropout = p_dropout 53 | self.max_expected_seq_len = max_expected_seq_len 54 | self.use_cache = use_cache 55 | super().__init__( 56 | pad_token_id=pad_token_id, 57 | eos_token_id=eos_token_id, 58 | bos_token_id=bos_token_id, 59 | is_decoder=is_decoder, 60 | tie_word_embeddings=kwargs.pop( 61 | "tie_word_embeddings", False 62 | ), # note: This was added here as we handle tying of heads with our underlying model, we may want to revisit this in future 63 | **kwargs, 64 | ) 65 | 66 | @classmethod 67 | def from_pretrained( 68 | cls, pretrained_model_name_or_path, **kwargs 69 | ) -> "PretrainedConfig": 70 | config_dict, kwargs = cls.get_config_dict( 71 | pretrained_model_name_or_path, **kwargs 72 | ) 73 | 74 | return cls.from_dict(config_dict, **kwargs) 75 | 76 | @classmethod 77 | def from_fms_config(cls, config: CalicoConfig, **hf_kwargs): 78 | config_dict = config.as_dict() 79 | config_dict["pad_token_id"] = config_dict.pop("pad_id") 80 | return cls.from_dict(config_dict, **hf_kwargs) 81 | 82 | 83 | class HFAdaptedCalicoDecoder(HFDecoder): 84 | """Adapter for the Calico decoder""" 85 | 86 | def __init__(self, model: Calico, config: PretrainedConfig): 87 | super().__init__(model, config, attention_mask_dim=3) 88 | 89 | def _adapt( 90 | self, 91 | input_ids: Optional[torch.LongTensor] = None, 92 | attention_mask: Optional[torch.Tensor] = None, 93 | position_ids: Optional[torch.LongTensor] = None, 94 | past_key_values: Optional[Tuple[torch.Tensor]] = None, 95 | use_cache: Optional[bool] = None, 96 | attn_algorithm: Optional[ 97 | str 98 | ] = None, # this can be passed in from top most forward 99 | *args, 100 | **kwargs, 101 | ) -> BaseModelOutputWithPastAndCrossAttentions: 102 | output = self.model._helper( 103 | x_in=input_ids, 104 | mask=attention_mask, 105 | position_ids=position_ids, 106 | past_key_value_states=past_key_values, 107 | use_cache=use_cache, 108 | attn_algorithm=attn_algorithm, 109 | ) 110 | 111 | present_key_values = None 112 | if isinstance(output, tuple): 113 | output, present_key_values = output 114 | return BaseModelOutputWithPastAndCrossAttentions( 115 | last_hidden_state=output, past_key_values=present_key_values 116 | ) 117 | 118 | 119 | class HFAdaptedCalicoHeadless(HFDecoderModelArchitecture): 120 | """This is the Adapter for the base Calico architecture""" 121 | 122 | # attributes required by HF 123 | config_class = HFAdaptedCalicoConfig 124 | base_model_prefix = "hf_adapted_calico" 125 | 126 | def __init__( 127 | self, 128 | config: PretrainedConfig, 129 | decoder: Optional[nn.Module] = None, 130 | embedding: Optional[nn.Module] = None, 131 | *args, 132 | **kwargs, 133 | ): 134 | # in the case we have not yet received the encoder/decoder/embedding, initialize it here 135 | if decoder is None or embedding is None: 136 | params = config.to_dict() 137 | model = Calico(pad_id=params.pop("pad_token_id"), **params) 138 | decoder = model if decoder is None else decoder 139 | embedding = model.shared.emb if embedding is None else embedding 140 | 141 | # these are now huggingface compatible 142 | decoder = HFAdaptedCalicoDecoder(decoder, config) 143 | super().__init__(decoder, embedding, config, *args, **kwargs) 144 | 145 | def _prepare_inputs_for_generation( 146 | self, 147 | input_ids: torch.Tensor, 148 | attention_mask: Optional[torch.Tensor] = None, 149 | past_key_values: Optional[Tuple[torch.Tensor]] = None, 150 | use_cache: Optional[bool] = None, 151 | **model_kwargs, 152 | ) -> dict: 153 | """ 154 | Overriding _prepare_inputs_for_generation to include position_ids requirements for Calico batch processing 155 | """ 156 | position_ids = model_kwargs.pop("position_ids", None) 157 | 158 | if position_ids is None and attention_mask is not None: 159 | position_ids = attention_mask.long().cumsum(-1) 160 | 161 | # Add more cached rope freqs if over cached number 162 | max_expected_len = input_ids.shape[1] + torch.max(position_ids) 163 | if max_expected_len > self.decoder.model.rot_emb.max_seq_len: 164 | self.decoder.model.rot_emb.compute_freqs_cis( 165 | input_ids.device, max_expected_len 166 | ) 167 | 168 | return { 169 | "input_ids": input_ids, 170 | "attention_mask": attention_mask, 171 | "past_key_values": past_key_values, 172 | "use_cache": use_cache, 173 | "position_ids": position_ids, 174 | **model_kwargs, 175 | } 176 | 177 | 178 | class HFAdaptedCalicoForCausalLM(LMHeadModelLMHeadMixin, HFAdaptedCalicoHeadless): 179 | _keys_to_ignore_on_load_missing = [r"lm_head.weight"] 180 | _tied_weights_keys = ["embedding.weight", "lm_head.weight"] 181 | 182 | def __init__(self, config: HFAdaptedCalicoConfig, *args, **kwargs): 183 | super().__init__(config=config, bias=False, *args, **kwargs) 184 | 185 | @classmethod 186 | def _hf_model_from_fms( 187 | cls, model: Calico, config: HFAdaptedCalicoConfig 188 | ) -> "HFAdaptedCalicoForCausalLM": 189 | return cls( 190 | config=config, 191 | decoder=model, 192 | embedding=model.shared.emb, 193 | lm_head=model.shared.head, 194 | ) 195 | 196 | # overriding this to enable tensor-parallel since it requires a WordEmbedding forward 197 | # in the future WordEmbedding should be split up 198 | def _lm_head(self, input_ids, *args, **kwargs): 199 | return self.decoder.model.shared(input_ids, reverse=True) 200 | -------------------------------------------------------------------------------- /fms_extras/models/hf/modeling_mlp_speculator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from fms.models.hf import _fms_to_hf_adapt_map 5 | from transformers import PretrainedConfig, PreTrainedModel 6 | 7 | from fms_extras.models.speculator import MLPSpeculator 8 | 9 | 10 | class MLPSpeculatorConfig(PretrainedConfig): 11 | model_type = "mlp_speculator" 12 | 13 | attribute_map = { 14 | "hidden_size": "emb_dim", 15 | } 16 | 17 | def __init__( 18 | self, 19 | vocab_size: int = 32000, 20 | emb_dim: int = 4096, 21 | inner_dim: int = 0, 22 | n_predict: int = 3, 23 | top_k_tokens_per_head: List[int] = [5, 4, 3], 24 | n_candidates: int = 5, 25 | tie_weights: bool = False, 26 | scale_input: bool = False, 27 | **kwargs 28 | ): 29 | """ 30 | Initialize an MLPSpeculatorConfig 31 | 32 | Args: 33 | vocab_size: int 34 | the model vocab size 35 | emb_dim: int 36 | the model embedding dimension 37 | inner_dim: int 38 | the inner dimension of the model. If 0, will be the emb_dim. 39 | n_predict: int 40 | the number of lookaheads for the speculator 41 | top_k_tokens_per_head: List[int] 42 | Number of tokens to consider from each head when forming the candidate tree. 43 | For each candidate branch in the tree, head n produces topk[n] additional sub-branches. 44 | n_candidates: int 45 | number of child candidates to create per sequence 46 | tie_weights : bool 47 | If true, use a single set of weights for every model head/stage after the first. 48 | The initial projection from the base model may have a different size, so that stays separate. 49 | scale_input: bool 50 | If true, apply an extra layernorm to the initial state vector input. 51 | Helps training dynamics, particularly when base model output has unusual scale. 52 | """ 53 | assert len(top_k_tokens_per_head) == n_predict 54 | self.vocab_size = vocab_size 55 | self.emb_dim = emb_dim 56 | self.inner_dim = inner_dim 57 | self.n_predict = n_predict 58 | self.top_k_tokens_per_head = top_k_tokens_per_head 59 | self.n_candidates = n_candidates 60 | self.tie_weights = tie_weights 61 | self.scale_input = scale_input 62 | super().__init__(**kwargs) 63 | 64 | 65 | class MLPSpeculatorPreTrainedModel(PreTrainedModel): 66 | """ 67 | Huggingface MLPSpeculator which provides loading/saving in huggingface 68 | """ 69 | 70 | config_class = MLPSpeculatorConfig 71 | 72 | def __init__( 73 | self, config: MLPSpeculatorConfig, speculator: Optional[MLPSpeculator] = None 74 | ): 75 | super().__init__( 76 | config=config, 77 | emb_dim=config.emb_dim, 78 | inner_dim=config.inner_dim, 79 | vocab_size=config.vocab_size, 80 | n_predict=config.n_predict, 81 | tie_weights=config.tie_weights, 82 | scale_input=config.scale_input, 83 | ) 84 | if speculator is None: 85 | self.speculator = MLPSpeculator( 86 | config.emb_dim, 87 | config.inner_dim, 88 | config.vocab_size, 89 | config.n_predict, 90 | tie_weights=config.tie_weights, 91 | scale_input=config.scale_input, 92 | ) 93 | self.speculator.reset_parameters() 94 | else: 95 | self.speculator = speculator 96 | 97 | @classmethod 98 | def from_fms_model( 99 | cls, 100 | model: MLPSpeculator, 101 | top_k_tokens_per_head: List[int], 102 | n_candidates: int, 103 | tie_weights: bool = False, 104 | scale_input: bool = False, 105 | *args, 106 | **kwargs 107 | ): 108 | config = MLPSpeculatorConfig( 109 | vocab_size=model.vsize, 110 | emb_dim=model.emb_dim, 111 | inner_dim=model.inner_dim, 112 | n_predict=model.n_predict, 113 | top_k_tokens_per_head=top_k_tokens_per_head, 114 | n_candidates=n_candidates, 115 | tie_weights=tie_weights, 116 | scale_input=scale_input, 117 | ) 118 | return cls(config, model) 119 | 120 | def generate_suffixes( 121 | self, 122 | state: torch.Tensor, 123 | ind: torch.Tensor, 124 | topk: List[int] = [5, 4, 3], 125 | n: int = 5, 126 | ) -> torch.Tensor: 127 | """ 128 | FOR INFERENCE 129 | Generate tree of candidate sequences. 130 | ... 131 | Args 132 | ---- 133 | state : torch.Tensor 134 | Most recent embedding vector from the base model (pre-classification head). 135 | Expects size [b 1 d] where b is batch size and d is model width. 136 | ind : torch.Tensor 137 | Token indices of the base model's most recent predicted token(s). 138 | Expects size [b 1] where b is batch size. 139 | topk : List(int) 140 | Number of tokens to consider from each head when forming the candidate tree. 141 | For each candidate branch in the tree, head n produces topk[n] additional sub-branches. 142 | n : int 143 | Given the final tree of prod(topk) candidates, return only the top n most confident. 144 | ... 145 | Output : torch.Tensor 146 | The tensor of most likely candidate sequences. 147 | Has size [b n self.n_predict], where b is batch size and n is provided above. 148 | """ 149 | return self.speculator.generate_suffixes(state, ind, topk, n) 150 | 151 | def forward( 152 | self, 153 | state: torch.Tensor, 154 | inds: torch.Tensor, 155 | ) -> torch.Tensor: 156 | """ 157 | FOR TRAINING 158 | A parallel forward pass on pre-existing ground-truth tokens in pretraining contexts. 159 | Produces self.n_predict predicted tokens for each token embedding in state. 160 | Inds requires self.n_predict extra tokens on the right to "simulate" recursive 161 | behavior for end positions. 162 | ... 163 | Args 164 | ---- 165 | state : torch.Tensor 166 | Embedding vectors from the base model for a given sequence. 167 | Expects size [b n d] where b is batch size, n is seq len, and d is model width. 168 | inds : torch.Tensor 169 | Ground-truth token indices. inds[:,i] is the prediction coming from state[:,i] 170 | (or the legal fiction ground truth corresponding to that prediction). 171 | Expects size [b n+self.n_predict]. 172 | ... 173 | Output : torch.Tensor 174 | Prediction logits at each position, for each head of the speculator. 175 | Has size [self.n_predict b n v] where v is vocab size. 176 | """ 177 | return self.speculator(state, inds) 178 | 179 | def reset_parameters(self): 180 | self.speculator.reset_parameters() 181 | 182 | 183 | _fms_to_hf_adapt_map[MLPSpeculator] = MLPSpeculatorPreTrainedModel 184 | -------------------------------------------------------------------------------- /fms_extras/models/paged_gpt_bigcode.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Mapping, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from fms import models 8 | from fms.distributed.strategy import DistributedStrategy, NoOpStrategy 9 | from fms.modules.feedforward import FeedForwardBlock 10 | from fms.utils import serialization 11 | from fms.utils.activation import str_to_activation 12 | from fms.utils.config import ModelConfig 13 | 14 | from fms_extras.modules.attention import PagedMultiHeadAttention 15 | from fms_extras.utils.cache.paged import ( 16 | PagedAttentionCacheData, 17 | PagedAttentionCacheDataLayer, 18 | ) 19 | 20 | 21 | @dataclass 22 | class PagedGPTBigCodeConfig(ModelConfig): 23 | src_vocab_size: int = 49157 # This param default is based on https://huggingface.co/bigcode/gpt_bigcode-santacoder 24 | emb_dim: int = 2048 # This param default is based on https://huggingface.co/bigcode/gpt_bigcode-santacoder 25 | nheads: int = 12 26 | nlayers: int = 12 27 | pad_id: int = 0 28 | max_pos: int = 512 29 | hidden_grow_factor: float = 4.0 30 | activation_fn: str = "gelu-tanh" 31 | p_dropout: float = 0.0 32 | emb_dropout: float = 0.0 33 | multiquery_attn: bool = True 34 | ln_eps: float = 1e-5 35 | 36 | 37 | class PagedGPTBigCodeBlock(nn.Module): 38 | def __init__(self, config: PagedGPTBigCodeConfig): 39 | super().__init__() 40 | self.config = config 41 | 42 | self.ln = nn.LayerNorm(self.config.emb_dim, self.config.ln_eps) 43 | self.ff_ln = nn.LayerNorm(self.config.emb_dim, self.config.ln_eps) 44 | 45 | self.attn = PagedMultiHeadAttention( 46 | self.config.emb_dim, 47 | self.config.emb_dim // self.config.nheads, 48 | self.config.emb_dim // self.config.nheads, 49 | self.config.nheads, 50 | kvheads=1 if self.config.multiquery_attn else self.config.nheads, 51 | p_dropout=self.config.p_dropout, 52 | use_bias=True, 53 | ) 54 | 55 | self.ff_sub_layer = FeedForwardBlock( 56 | self.config.emb_dim, 57 | hidden_grow_factor=self.config.hidden_grow_factor, 58 | activation_fn=str_to_activation(self.config.activation_fn), 59 | p_dropout=self.config.p_dropout, 60 | use_bias=True, 61 | ) 62 | 63 | if self.config.p_dropout != 0: 64 | self.dropout = nn.Dropout(self.config.p_dropout) 65 | 66 | def forward( 67 | self, 68 | x: torch.Tensor, 69 | *, 70 | mask: Optional[torch.Tensor] = None, 71 | cache_data_layer: Optional[PagedAttentionCacheDataLayer] = None, 72 | use_cache: bool = False, 73 | is_causal_mask: bool = False, 74 | attn_algorithm: Optional[str] = None, 75 | ): 76 | # first we do MHA and Add&Norm 77 | residual = x 78 | x = self.ln(x) 79 | # self attention 80 | x = self.attn( 81 | q=x, 82 | k=x, 83 | v=x, 84 | mask=mask, 85 | attn_algorithm=attn_algorithm, 86 | cache_data_layer=cache_data_layer, 87 | use_cache=use_cache, 88 | is_self=True, 89 | is_causal_mask=is_causal_mask, 90 | ) 91 | 92 | cache = None 93 | if use_cache: 94 | x, cache = x 95 | if self.config.p_dropout != 0: 96 | x = self.dropout(x) 97 | # residual connection 98 | x = x + residual 99 | 100 | # then we do FF and Add&Norm 101 | residual = x 102 | x = self.ff_ln(x) 103 | x = self.ff_sub_layer(x) 104 | if self.config.p_dropout != 0: 105 | x = self.dropout(x) 106 | # another residual 107 | x = x + residual 108 | 109 | if use_cache: 110 | return x, cache 111 | else: 112 | return x 113 | 114 | 115 | class PagedGPTBigCodeHeadless(nn.Module): 116 | def __init__( 117 | self, config: PagedGPTBigCodeConfig, distributed_strategy: DistributedStrategy 118 | ): 119 | super().__init__() 120 | self.config = config 121 | self.distributed_strategy = distributed_strategy 122 | 123 | layers = [] 124 | for i in range(self.config.nlayers): 125 | block = PagedGPTBigCodeBlock(self.config) 126 | block_module = self.distributed_strategy.distribute_layer(block, i) 127 | layers.append(block_module) 128 | self.layers = nn.ModuleList(layers) 129 | 130 | self.embedding = nn.Embedding(self.config.src_vocab_size, self.config.emb_dim) 131 | self.position_embedding = nn.Embedding(self.config.max_pos, self.config.emb_dim) 132 | 133 | self.dec_norm = self.distributed_strategy.distribute_module( 134 | nn.LayerNorm(self.config.emb_dim, eps=self.config.ln_eps), final_layers=True 135 | ) 136 | 137 | if self.config.emb_dropout: 138 | self.emb_dropout = nn.Dropout(self.config.emb_dropout) 139 | 140 | if self.config.p_dropout: 141 | self.dropout = nn.Dropout(self.config.p_dropout) 142 | 143 | def forward( 144 | self, 145 | x: torch.LongTensor, 146 | mask: Optional[torch.Tensor] = None, 147 | cache_data: Optional[PagedAttentionCacheData] = None, 148 | use_cache: bool = False, 149 | attn_algorithm: Optional[str] = None, 150 | ): 151 | # Embed the given vocabulary indices using the given attention mask, with pre-/post-norm and dropout as specified 152 | # x_in: batch_size x seq_len 153 | # mask: batch_size x seq_len x seq_len 154 | # bias: nheads x seq_len x seq_len 155 | 156 | qlen = x.size(1) 157 | filled_cache = False 158 | 159 | # if we are using the cache, the key length needs to be extended with the past keys length 160 | if use_cache: 161 | if cache_data: 162 | filled_cache = cache_data.is_filled() 163 | 164 | # if mask is none, we need to specify causal mask 165 | if mask is None: 166 | # we are caching and can assume all 1s in the mask 167 | if use_cache and filled_cache and qlen == 1: 168 | # b x h x qlen x kvlen 169 | is_causal_mask = False 170 | else: 171 | is_causal_mask = True 172 | else: 173 | is_causal_mask = False 174 | 175 | x_emb = self.embedding(x) 176 | 177 | # if pad_id exists 178 | # is_pad will be a BoolTensor 179 | # otherwise pad_id will not be taken into account 180 | if self.config.pad_id is None: 181 | is_pad = torch.zeros_like(x, dtype=bool, device=x.device) 182 | else: 183 | is_pad = x == self.config.pad_id 184 | 185 | if cache_data is None or cache_data.position_ids is None: 186 | position_ids = ((~is_pad).cumsum(1) - 1).clamp(min=0) 187 | 188 | if cache_data is not None: 189 | cache_data.position_ids = position_ids 190 | else: 191 | position_ids = cache_data.position_ids 192 | 193 | # look up position embeddings 194 | position_out = self.position_embedding(position_ids) 195 | 196 | # zero out the associated position embeddings 197 | if self.config.pad_id is not None: 198 | position_out = position_out.mul(~is_pad.unsqueeze(-1)) 199 | 200 | # perform absolute position embedding 201 | x = x_emb + position_out 202 | 203 | # apply dropout to embeddings 204 | if self.config.emb_dropout: 205 | x = self.emb_dropout(x) 206 | 207 | # this is the output cache for all the decoder layers 208 | present_key_value_states = [] 209 | 210 | for i, layer in enumerate(self.layers): 211 | output = layer( 212 | x=x, 213 | mask=mask, 214 | cache_data_layer=None 215 | if cache_data is None 216 | else cache_data.get_layer(i), 217 | use_cache=use_cache, 218 | is_causal_mask=is_causal_mask, 219 | attn_algorithm=attn_algorithm, 220 | ) 221 | 222 | if use_cache: 223 | x, present_key_value_state = output 224 | present_key_value_states.append(present_key_value_state) 225 | 226 | else: 227 | x = output 228 | 229 | dec_out = self.dec_norm(x) 230 | if self.config.p_dropout: 231 | dec_out = self.dropout(dec_out) 232 | 233 | return dec_out, present_key_value_states 234 | 235 | 236 | # Implements the decoder-only PagedGPTBigCodeModel 237 | class PagedGPTBigCode(nn.Module): 238 | def __init__( 239 | self, 240 | config: Optional[PagedGPTBigCodeConfig] = None, 241 | distributed_strategy: DistributedStrategy = NoOpStrategy, 242 | **kwargs, 243 | ): 244 | super(PagedGPTBigCode, self).__init__() 245 | if config is not None: 246 | self.config = config 247 | else: 248 | self.config = PagedGPTBigCodeConfig() 249 | self.config = self.config.updated(**kwargs) 250 | self.distributed_strategy = distributed_strategy 251 | 252 | self.base_model = PagedGPTBigCodeHeadless( 253 | self.config, self.distributed_strategy 254 | ) 255 | self.head = nn.Linear( 256 | self.config.emb_dim, self.config.src_vocab_size, bias=False 257 | ) 258 | 259 | # this model ties weights, so we tie here 260 | self.head.weight = self.base_model.embedding.weight 261 | 262 | self.reset_parameters() 263 | 264 | @classmethod 265 | def from_config(cls, config: PagedGPTBigCodeConfig) -> "PagedGPTBigCode": 266 | return cls(config) 267 | 268 | def get_config(self) -> PagedGPTBigCodeConfig: 269 | return self.config 270 | 271 | def reset_parameters(self): 272 | # Call reset_parameters for relevant sub-layers 273 | for m in self.modules(): 274 | if isinstance(m, PagedMultiHeadAttention) or isinstance( 275 | m, FeedForwardBlock 276 | ): 277 | m.reset_parameters() 278 | 279 | def forward( 280 | self, 281 | x: torch.LongTensor, 282 | mask: Optional[torch.Tensor] = None, 283 | cache_data: Optional[PagedAttentionCacheData] = None, 284 | use_cache: bool = False, 285 | attn_algorithm: Optional[str] = None, 286 | return_embeds: bool = False, 287 | ): 288 | embeds, cache = self.base_model( 289 | x, 290 | mask, 291 | cache_data=cache_data, 292 | use_cache=use_cache, 293 | attn_algorithm=attn_algorithm, 294 | ) 295 | 296 | preds = self.head(embeds) 297 | 298 | out = [preds] 299 | if use_cache: 300 | out.append(cache) 301 | if return_embeds: 302 | out.append(embeds) 303 | 304 | if len(out) == 1: 305 | return out[0] 306 | else: 307 | return tuple(out) 308 | 309 | 310 | _santacoder_config = PagedGPTBigCodeConfig( 311 | src_vocab_size=49280, 312 | emb_dim=2048, 313 | nheads=16, 314 | nlayers=24, 315 | pad_id=-1, 316 | max_pos=2048, 317 | p_dropout=0.1, 318 | emb_dropout=0.1, 319 | ) 320 | 321 | _13b_config = PagedGPTBigCodeConfig( 322 | src_vocab_size=50304, 323 | emb_dim=5632, 324 | nheads=44, 325 | nlayers=40, 326 | pad_id=50280, 327 | max_pos=8192, 328 | hidden_grow_factor=4.0, 329 | p_dropout=0.1, 330 | emb_dropout=0.1, 331 | ln_eps=1e-5, 332 | ) 333 | _20b_config = PagedGPTBigCodeConfig( 334 | src_vocab_size=49152, 335 | emb_dim=6144, 336 | nheads=48, 337 | nlayers=52, 338 | pad_id=0, 339 | max_pos=8192, 340 | hidden_grow_factor=4.0, 341 | p_dropout=0.1, 342 | emb_dropout=0.1, 343 | ln_eps=1e-5, 344 | ) 345 | 346 | 347 | _architecture_name = "paged_gpt_bigcode" 348 | 349 | 350 | def _gpt_bigcode_factory_factory(config): 351 | def factory(**kwargs): 352 | return PagedGPTBigCode(config, **kwargs) 353 | 354 | return factory 355 | 356 | 357 | models.register_model( 358 | _architecture_name, "santacoder", _gpt_bigcode_factory_factory(_santacoder_config) 359 | ) 360 | models.register_model( 361 | _architecture_name, "ibm.13b", _gpt_bigcode_factory_factory(_13b_config) 362 | ) 363 | models.register_model( 364 | _architecture_name, "ibm.20b", _gpt_bigcode_factory_factory(_20b_config) 365 | ) 366 | 367 | 368 | def _hf_sd_to_fms_sd(hf_sd: Mapping) -> Mapping: 369 | import re 370 | 371 | replacements = [ 372 | ("lm_head.weight", "head.weight"), 373 | (r"^transformer.wte.weight", "base_model.embedding.weight"), 374 | (r"^transformer.wpe.weight", "base_model.position_embedding.weight"), 375 | (r"^transformer.ln_f", "base_model.dec_norm"), 376 | (r"^transformer.h", "base_model.layers"), 377 | (r"attn\.c_attn", "attn.qkv_fused"), 378 | (r"attn\.c_proj", "attn.dense"), 379 | (r"mlp\.c_fc", "ff_sub_layer.w1"), 380 | (r"mlp\.c_proj", "ff_sub_layer.w2"), 381 | (r"ln_1", "ln"), 382 | (r"ln_2", "ff_ln"), 383 | ] 384 | 385 | new_sd = {} 386 | for name, param in hf_sd.items(): 387 | new_name = name 388 | for pattern, repl in replacements: 389 | new_name = re.sub(pattern, repl, new_name) 390 | 391 | new_sd[new_name] = param 392 | 393 | return new_sd 394 | 395 | 396 | serialization.register_adapter(_architecture_name, "hf", _hf_sd_to_fms_sd) 397 | -------------------------------------------------------------------------------- /fms_extras/models/speculator.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from fms import models 8 | from fms.modules.layernorm import LayerNormParameterized 9 | from fms.utils import serialization 10 | 11 | 12 | class MLPSpeculator(nn.Module): 13 | """ 14 | This is a simple MLP-based speculator that functions similarly to Medusa 15 | (https://arxiv.org/abs/2401.10774), ingesting context via the final embedding 16 | vector from the base model. However, this model also conditions on previously 17 | predicted tokens, similarly to an RNN, allowing it to generate better-quality n-grams. 18 | 19 | The architecture is as flat and simple as possible: for each prediction head, 20 | the current state vector is projected into a new latent space and added to the 21 | previous token's embedding. This sum goes through layernorm and activation, forming 22 | the new state vector. This state predicts the next token (or set of candidate tokens) 23 | for the current head, and then is passed on to the next. 24 | ... 25 | Args 26 | ---- 27 | emb_dim : int 28 | Dimensionality of the input vector from the base model. 29 | inner_dim : int 30 | Latent dimensionality of the speculator model. 31 | vocab_size : int 32 | Number of entries in the tokenizer associated with the base model. 33 | n_predict : int 34 | Number of heads / number of tokens to guess ahead. Model size and speed scale with this value. 35 | tie_weights : bool 36 | If true, use a single set of weights for every model head/stage after the first. 37 | The initial projection from the base model may have a different size, so that stays separate. 38 | scale_input: bool 39 | If true, apply an extra layernorm to the initial state vector input. 40 | Helps training dynamics, particularly when base model output has unusual scale. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | emb_dim=4096, 46 | inner_dim=0, 47 | vocab_size=32000, 48 | n_predict=3, 49 | tie_weights=False, 50 | scale_input=False, 51 | ): 52 | super().__init__() 53 | self.n_predict = n_predict 54 | self.emb_dim = emb_dim 55 | inner_dim = inner_dim if inner_dim != 0 else emb_dim 56 | self.inner_dim = inner_dim 57 | self.vsize = vocab_size 58 | self.scale_input = scale_input 59 | self.emb = nn.ModuleList( 60 | [nn.Embedding(vocab_size, inner_dim) for _ in range(n_predict)] 61 | ) 62 | self.proj = nn.ModuleList( 63 | [ 64 | nn.Linear((emb_dim if i == 0 else inner_dim), inner_dim, bias=False) 65 | for i in range(n_predict) 66 | ] 67 | ) 68 | self.head = nn.ModuleList( 69 | [nn.Linear(inner_dim, vocab_size, bias=False) for _ in range(n_predict)] 70 | ) 71 | self.ln = nn.ModuleList( 72 | [ 73 | LayerNormParameterized( 74 | inner_dim, elementwise_shift=True, elementwise_scale=True 75 | ) 76 | for _ in range(n_predict) 77 | ] 78 | ) 79 | if self.scale_input: 80 | self.ln0 = LayerNormParameterized( 81 | emb_dim, elementwise_shift=False, elementwise_scale=False 82 | ) 83 | # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation 84 | self.state_weight = 0.5 ** (0.5 / n_predict) 85 | self.emb_weight = math.sqrt((1 - self.state_weight**2) * (self.inner_dim / 2)) 86 | self.activation = nn.GELU() 87 | 88 | # Handle weight tying as specified 89 | if tie_weights: 90 | assert ( 91 | n_predict > 1 92 | ), "You cannot tie weights between stages when only 1 exists" 93 | 94 | for emb in self.emb: 95 | emb.weight = self.emb[0].weight 96 | 97 | for head in self.head: 98 | head.weight = self.head[0].weight 99 | 100 | for ln in self.ln: 101 | ln.weight = self.ln[0].weight 102 | ln.bias = self.ln[0].bias 103 | 104 | # Since first proj has different size, allow different initial proj from base into model 105 | for i in range(2, n_predict): 106 | self.proj[i].weight = self.proj[1].weight 107 | 108 | def reset_parameters(self): 109 | for m in self.modules(): 110 | if isinstance(m, nn.Embedding) or isinstance(m, nn.Linear): 111 | nn.init.normal_(m.weight, 0, 1 / math.sqrt(self.inner_dim)) 112 | elif isinstance(m, LayerNormParameterized) and hasattr(m, "weight"): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def generate_suffixes( 117 | self, 118 | state: torch.Tensor, 119 | ind: torch.Tensor, 120 | topk: List[int] = [5, 4, 3], 121 | n: int = 5, 122 | ) -> torch.Tensor: 123 | """ 124 | FOR INFERENCE 125 | Generate tree of candidate sequences. 126 | ... 127 | Args 128 | ---- 129 | state : torch.Tensor 130 | Most recent embedding vector from the base model (pre-classification head). 131 | Expects size [b 1 d] where b is batch size and d is model width. 132 | ind : torch.Tensor 133 | Token indices of the base model's most recent predicted token(s). 134 | Expects size [b 1] where b is batch size. 135 | topk : List(int) 136 | Number of tokens to consider from each head when forming the candidate tree. 137 | For each candidate branch in the tree, head n produces topk[n] additional sub-branches. 138 | n : int 139 | Given the final tree of prod(topk) candidates, return only the top n most confident. 140 | ... 141 | Output : torch.Tensor 142 | The tensor of most likely candidate sequences. 143 | Has size [b n self.n_predict], where b is batch size and n is provided above. 144 | """ 145 | # k indicates # of candidates 146 | # h indicates # of generated tokens 147 | b = state.size(0) 148 | k = math.prod(topk) 149 | out = torch.empty( 150 | b, 1, k, self.n_predict, device=state.device 151 | ).int() # b 1 k h -> b k 1 h 152 | log_probs = torch.zeros(b, 1, k, device=state.device) # b 1 k -> b k 1 153 | assert ( 154 | len(topk) == self.n_predict 155 | ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(topk)} provided)" 156 | if self.scale_input: 157 | state = self.ln0(state) / (2**0.5) 158 | for i in range(self.n_predict): 159 | # Project and predict 160 | z = self.emb[i](ind) # b k d 161 | state = self.proj[i](state) 162 | # Weighted add of state_weight*state and emb_weight*z 163 | # Let subsequent LN take care of denominator 164 | # state_weight is close to 1, so shouldn't be any precision issues 165 | state = torch.add(state, z, alpha=self.emb_weight / self.state_weight) 166 | state = self.activation(self.ln[i](state)) # b k d 167 | probs = F.log_softmax(self.head[i](state), dim=2) # b k v 168 | probs, preds = probs.topk(topk[i], dim=2) # b k k' 169 | 170 | # Update candidate set with new predictions, repeating shared prefixes as needed 171 | out = out.view(b, preds.size(1) * preds.size(2), -1, self.n_predict) 172 | out[:, :, :, i] = preds.view(b, -1, 1) 173 | 174 | # Update state, log_probs and ind for new predictions 175 | state = state.unsqueeze(2).expand(-1, -1, topk[i], -1) # b k k' d 176 | state = state.reshape(b, -1, state.size(3)) # b kk' d 177 | ind = preds.view(b, -1) # b kk' 178 | log_probs = log_probs.view(b, probs.size(1) * probs.size(2), -1) 179 | log_probs = log_probs.add(probs.view(b, -1, 1)) 180 | 181 | # Take only top n best guesses 182 | out = out.view(b, k, self.n_predict) 183 | log_probs = log_probs.view(b, k) 184 | best_guesses = log_probs.topk(n, dim=1)[1] # b k 185 | return out.gather( 186 | 1, best_guesses.unsqueeze(2).expand(-1, -1, self.n_predict) 187 | ) # b n h 188 | 189 | def forward( 190 | self, 191 | state: torch.Tensor, 192 | inds: torch.Tensor, 193 | ) -> torch.Tensor: 194 | """ 195 | FOR TRAINING 196 | A parallel forward pass on pre-existing ground-truth tokens in pretraining contexts. 197 | Produces self.n_predict predicted tokens for each token embedding in state. 198 | Inds requires self.n_predict extra tokens on the right to "simulate" recursive 199 | behavior for end positions. 200 | ... 201 | Args 202 | ---- 203 | state : torch.Tensor 204 | Embedding vectors from the base model for a given sequence. 205 | Expects size [b n d] where b is batch size, n is seq len, and d is model width. 206 | inds : torch.Tensor 207 | Ground-truth token indices. inds[:,i] is the prediction coming from state[:,i] 208 | (or the legal fiction ground truth corresponding to that prediction). 209 | Expects size [b n+self.n_predict]. 210 | ... 211 | Output : torch.Tensor 212 | Prediction logits at each position, for each head of the speculator. 213 | Has size [self.n_predict b n v] where v is vocab size. 214 | """ 215 | out = [] 216 | if self.scale_input: 217 | state = self.ln0(state) / (2**0.5) 218 | for i in range(self.n_predict): 219 | z = self.emb[i](inds[:, i : i + state.size(1)]) # b n d 220 | state = self.proj[i](state) 221 | # Weighted add of state_weight*state and emb_weight*z 222 | # Let subsequent LN take care of denominator 223 | # state_weight is close to 1, so shouldn't be any precision issues 224 | state = torch.add(state, z, alpha=self.emb_weight / self.state_weight) 225 | state = self.activation(self.ln[i](state)) # b n d 226 | out.append(self.head[i](state)) # b n v 227 | return torch.stack(out, dim=0) # h b n v 228 | 229 | 230 | def apply_index_map( 231 | inp: torch.Tensor, inds: torch.Tensor, dim: int = 0 232 | ) -> torch.Tensor: 233 | """ 234 | Applies index map to specified dimension of input tensor. Used for batch flattening/unflattening. 235 | 236 | More precisely, takes input of size ([...], n, [...]), with n in the dim-th dimension, 237 | and tensor of indices of size (a, ..., z). Using those indices we draw from the input 238 | on dimension dim, to create output tensor with size ([...], (a, ..., z), [...]). 239 | 240 | i.e. if dim=0, inp has size (6,3,2), and inds has size (8,4), then: 241 | 1) max(inds) < 6 242 | 2) output has size (8,4,3,2) 243 | 3) the output contains repeated values (8*4 > 6) 244 | 245 | Args: 246 | inp: torch.Tensor 247 | tensor of inputs 248 | inds: torch.Tensor 249 | tensor of indices 250 | dim: int 251 | dimension to index on 252 | 253 | Returns: 254 | torch.Tensor 255 | output tensor with new size ([...], (a, ..., z), [...]) 256 | """ 257 | inds_shape = inds.size() 258 | inp_shape = inp.size() 259 | out = inp.index_select(dim, inds.view(-1)) 260 | return out.view(*inp_shape[:dim], *inds_shape, *inp_shape[dim + 1 :]) 261 | 262 | 263 | def flatten_batch(inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 264 | """ 265 | Takes a speculator suffix tree: a bsize x n_candidates x candidate_len rectangular batch 266 | of token indices, and flattens it while removing redundant tokens. 267 | 268 | For example, given: 269 | 270 | a b c 271 | a b d 272 | a e f 273 | 274 | Tokens 'a b' in line 2 and token 'a' in line 3 are functionally equivalent to 'a b' in 275 | line 1, so the flattened batch returns `a b c d e f` 276 | 277 | Args: 278 | inp: torch.Tensor 279 | speculator suffix tree 280 | 281 | Returns: 282 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 283 | 1) the flattened, pruned input 284 | 2) a tensor, sized as input, mapping each input token to its slot in output 285 | 3) a tensor, sized as output, mapping each output token to its slot in the flattened input 286 | """ 287 | unflat_map = torch.zeros_like(inp, dtype=torch.int) 288 | inp_list = inp.tolist() 289 | flat_map = [] 290 | batch_offset = 0 291 | # Generate the flatten/unflatten maps 292 | for b, candidate_set in enumerate(inp_list): 293 | lineages: Dict[ 294 | Tuple[List[int]], int 295 | ] = {} # Prefix : n unique prefixes observed so far 296 | for k, candidate in enumerate(candidate_set): 297 | for n in range(len(candidate)): 298 | lineage = tuple(candidate[: n + 1]) 299 | if lineage in lineages: 300 | # Token is redundant 301 | unflat_map[b, k, n] = lineages[lineage] + batch_offset 302 | else: 303 | # Token is not redundant 304 | unflat_map[b, k, n] = len(lineages) + batch_offset 305 | lineages[lineage] = len(lineages) 306 | flat_map.append( 307 | b * len(candidate_set) * len(candidate) + k * len(candidate) + n 308 | ) 309 | batch_offset += len(lineages) 310 | # Generate the flattened batch 311 | flat_map_tensor = torch.tensor(flat_map, device=inp.device, dtype=torch.int32) 312 | out = apply_index_map(inp.view(-1), flat_map_tensor, 0) 313 | return out, unflat_map, flat_map_tensor 314 | 315 | 316 | _llama_7b = {"emb_dim": 4096, "vocab_size": 32000, "n_predict": 3, "inner_dim": 0} 317 | _ibm_llama_7b_instruct_lab = { 318 | "emb_dim": 4096, 319 | "vocab_size": 32008, 320 | "n_predict": 5, 321 | "inner_dim": 0, 322 | } 323 | 324 | _llama_13b = {"emb_dim": 5120, "vocab_size": 32000, "n_predict": 3, "inner_dim": 4096} 325 | 326 | _llama_13b_code = { 327 | "emb_dim": 5120, 328 | "vocab_size": 32016, 329 | "n_predict": 7, 330 | "inner_dim": 4096, 331 | } 332 | 333 | _llama_34b_code = { 334 | "emb_dim": 8192, 335 | "vocab_size": 32000, 336 | "n_predict": 5, 337 | "inner_dim": 8192, 338 | "scale_input": True, 339 | "tie_weights": True, 340 | } 341 | 342 | _llama3_8b_3_2b = { 343 | "emb_dim": 4096, 344 | "vocab_size": 128256, 345 | "n_predict": 4, 346 | "inner_dim": 3072, 347 | } 348 | 349 | _ibm_20b_code_instruct = { 350 | "emb_dim": 6144, 351 | "vocab_size": 49152, 352 | "n_predict": 4, 353 | "inner_dim": 4096, 354 | } 355 | 356 | _ibm_34b_code_instruct = { 357 | "emb_dim": 6144, 358 | "vocab_size": 49152, 359 | "n_predict": 5, 360 | "inner_dim": 6144, 361 | "scale_input": True, 362 | "tie_weights": True, 363 | } 364 | 365 | _llama3_70b_961m = { 366 | "emb_dim": 8192, 367 | "vocab_size": 128256, 368 | "n_predict": 4, 369 | "inner_dim": 3584, 370 | "scale_input": True, 371 | "tie_weights": True, 372 | } 373 | 374 | _calico_8b_test = { 375 | "emb_dim": 4096, 376 | "vocab_size": 49152, 377 | "n_predict": 5, 378 | "inner_dim": 4096, 379 | } 380 | 381 | 382 | _architecture_name = "mlp_speculator" 383 | 384 | 385 | def _mlp_speculator_factory_factory(variant_config_dict): 386 | def factory(**user_kwargs): 387 | return MLPSpeculator(**(variant_config_dict | user_kwargs)) 388 | 389 | return factory 390 | 391 | 392 | models.register_model( 393 | _architecture_name, 394 | "llama.7b.ibm_instruct_lab.1_4b", 395 | _mlp_speculator_factory_factory(_ibm_llama_7b_instruct_lab), 396 | ) 397 | 398 | models.register_model( 399 | _architecture_name, 400 | "llama.7b.840m", 401 | _mlp_speculator_factory_factory(_llama_7b), 402 | ) 403 | models.register_model( 404 | _architecture_name, 405 | "llama.13b.840m", 406 | _mlp_speculator_factory_factory(_llama_13b), 407 | ) 408 | models.register_model( 409 | _architecture_name, 410 | "llama.13b.code.2b", 411 | _mlp_speculator_factory_factory(_llama_13b_code), 412 | ) 413 | models.register_model( 414 | _architecture_name, 415 | "llama.34b.code.658m", 416 | _mlp_speculator_factory_factory(_llama_34b_code), 417 | ) 418 | models.register_model( 419 | _architecture_name, 420 | "llama.llama3.8b.3_2b", 421 | _mlp_speculator_factory_factory(_llama3_8b_3_2b), 422 | ) 423 | models.register_model( 424 | _architecture_name, 425 | "llama.llama3.70b.961m", 426 | _mlp_speculator_factory_factory(_llama3_70b_961m), 427 | ) 428 | 429 | models.register_model( 430 | _architecture_name, 431 | "gpt_bigcode.ibm.20b.1_7b", 432 | _mlp_speculator_factory_factory(_ibm_20b_code_instruct), 433 | ) 434 | 435 | models.register_model( 436 | _architecture_name, 437 | "gpt_bigcode.ibm.34b.680m", 438 | _mlp_speculator_factory_factory(_ibm_34b_code_instruct), 439 | ) 440 | 441 | models.register_model( 442 | _architecture_name, 443 | "llama.calico.8b.code.2_1b", 444 | _mlp_speculator_factory_factory(_calico_8b_test), 445 | ) 446 | 447 | 448 | def _rename_hf_weights_to_fms(orig_sd): 449 | new_sd = {} 450 | for name, param in orig_sd.items(): 451 | new_sd[name.replace("speculator.", "")] = param 452 | 453 | return new_sd 454 | 455 | 456 | serialization.register_adapter(_architecture_name, "hf", _rename_hf_weights_to_fms) 457 | -------------------------------------------------------------------------------- /fms_extras/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/fms_extras/modules/__init__.py -------------------------------------------------------------------------------- /fms_extras/modules/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Set 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from fms import distributed, models 7 | from fms.distributed.tensorparallel import ( 8 | copy_to_tensor_model_parallel_region, 9 | reduce_from_tensor_model_parallel_region, 10 | ) 11 | from fms.modules.positions import PositionEncoder 12 | from fms.modules.tp import TPModule 13 | from torch._C._distributed_c10d import ProcessGroup 14 | 15 | from fms_extras.utils.cache.paged import PagedAttentionCacheDataLayer 16 | 17 | 18 | class PagedMultiHeadAttention(nn.Module): 19 | """ 20 | Performs multi-headed self- or cross-attention, with optional attention masking. 21 | 22 | Note: this class extends MultiHeadAttention to enable tensor parallel support 23 | """ 24 | 25 | def __init__( 26 | self, 27 | emb_dim, 28 | emb_kq, 29 | emb_v, 30 | nheads, 31 | kvheads, 32 | p_dropout=None, 33 | use_bias=False, 34 | position_encoder: Optional[PositionEncoder] = None, 35 | gain=1, 36 | ): 37 | super(PagedMultiHeadAttention, self).__init__() 38 | self.nheads = nheads 39 | self.kvheads = kvheads 40 | self.emb_dim = emb_dim 41 | self.emb_kq_per_head = emb_kq 42 | self.emb_v_per_head = emb_v 43 | self.p_dropout = p_dropout if p_dropout is not None else 0.0 44 | self.use_bias = use_bias 45 | self.dense = nn.Linear( 46 | self.nheads * self.emb_v_per_head, self.emb_dim, bias=use_bias 47 | ) 48 | 49 | self.splits = [ 50 | self.nheads * self.emb_kq_per_head, 51 | self.kvheads * self.emb_kq_per_head, 52 | self.kvheads * self.emb_v_per_head, 53 | ] 54 | 55 | self.qkv_fused = nn.Linear( 56 | self.emb_dim, 57 | sum(self.splits), 58 | bias=use_bias, 59 | ) 60 | 61 | if self.p_dropout: 62 | self.attn_dropout = nn.Dropout(self.p_dropout) 63 | self.position_encoder = position_encoder 64 | # Avoiding graph breaks 65 | self.previous_flash: bool = torch.backends.cuda.flash_sdp_enabled() 66 | self.previous_mem_efficient: bool = ( 67 | torch.backends.cuda.mem_efficient_sdp_enabled() 68 | ) 69 | self.previous_math: bool = torch.backends.cuda.math_sdp_enabled() 70 | 71 | def reset_parameters(self): 72 | for m in self.modules(): 73 | if isinstance(m, nn.Linear): 74 | nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02) 75 | if self.use_bias: 76 | m.bias.data.zero_() 77 | 78 | def to_tp(self, group: ProcessGroup) -> "TPPagedMultiHeadAttention": 79 | return TPPagedMultiHeadAttention.import_module(self, group) 80 | 81 | def forward( 82 | self, 83 | q: torch.Tensor, 84 | k: torch.Tensor, 85 | v: torch.Tensor, 86 | mask: Optional[torch.Tensor] = None, 87 | attn_algorithm: Optional[str] = None, 88 | cache_data_layer: Optional[PagedAttentionCacheDataLayer] = None, 89 | use_cache: bool = False, 90 | is_self: bool = True, 91 | is_causal_mask: bool = False, 92 | ): 93 | """ 94 | cache_data_layer: PagedAttentionCacheDataLayer, optional 95 | A single layer of the cache (default is None) 96 | use_cache: bool 97 | if True, the kv states for self/cross attention will be saved, otherwise they will not be saved 98 | is_self: bool 99 | if True, this will perform self attention, otherwise this will perform cross attention. Note: This will 100 | only be used in the case that use_cache=True. This may be removed in future 101 | 102 | Returns 103 | ------- 104 | tensor or tuple 105 | If use_cache=False, only the hidden state will be returned as a tensor. If use_cache=True, a tuple will be 106 | returned in the form (hidden_state, cache) where hidden_state is a tensor and cache is of the form specified 107 | in past_key_value_state 108 | """ 109 | 110 | # q, k, v: batch_size x seq_len x emb_dim 111 | # mask: batch_size x seq_len x seq_len 112 | batch_size, q_len, _ = q.size() 113 | position_ids = ( 114 | None if cache_data_layer is None else cache_data_layer.position_ids 115 | ) 116 | 117 | queries, keys, values = self.qkv_fused(q).split(self.splits, dim=-1) 118 | 119 | queries = queries.view(batch_size, q_len, self.nheads, self.emb_kq_per_head) 120 | keys = keys.view(batch_size, q_len, self.kvheads, self.emb_kq_per_head) 121 | values = values.view(batch_size, q_len, self.kvheads, self.emb_v_per_head) 122 | 123 | # You want to apply rotary embeddings pre-cache 124 | if self.position_encoder is not None: 125 | queries, keys = self.position_encoder.adjusted_qk( 126 | queries, 127 | keys, 128 | position_ids, # type: ignore 129 | None, 130 | use_cache, 131 | ) 132 | 133 | # store the values in kv-cache 134 | if use_cache and cache_data_layer: 135 | keys, values = cache_data_layer.store(keys, values) 136 | 137 | if use_cache and cache_data_layer and cache_data_layer.is_filled(): 138 | attn = cache_data_layer.attend(queries) 139 | # otherwise we always fall back into SDPA as this is either a prompt or it is a single contiguous cache 140 | else: 141 | queries = queries.transpose(2, 1) 142 | keys = keys.transpose(2, 1) 143 | values = values.transpose(2, 1) 144 | 145 | # Merge rel pos bias and mask into single float mask 146 | if mask is not None: 147 | # Our expected mask format is bs x q_len x k_len, so to make it broadcastable 148 | # we need to create the nheads dimension 149 | while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len 150 | mask = mask.unsqueeze(1) 151 | 152 | if self.position_encoder is not None: 153 | attn_mask = self.position_encoder.adjusted_mask( 154 | mask, queries, keys, position_ids, use_cache # type: ignore 155 | ) 156 | else: 157 | attn_mask = mask 158 | 159 | # Expand kv so black-box attn will work 160 | expansion = self.nheads // self.kvheads 161 | # k/v: b h l d 162 | if expansion != 1: 163 | keys_e = ( 164 | keys.unsqueeze(2).expand(-1, -1, expansion, -1, -1).flatten(1, 2) 165 | ) 166 | values_e = ( 167 | values.unsqueeze(2).expand(-1, -1, expansion, -1, -1).flatten(1, 2) 168 | ) 169 | else: 170 | keys_e = keys 171 | values_e = values 172 | 173 | if attn_algorithm: 174 | # Pick which fused attn kernels will run. 175 | use_flash = attn_algorithm == "flash" 176 | use_mem_efficient = attn_algorithm == "mem" 177 | use_math = attn_algorithm == "math" 178 | 179 | torch.backends.cuda.enable_flash_sdp(use_flash) 180 | torch.backends.cuda.enable_mem_efficient_sdp(use_mem_efficient) 181 | torch.backends.cuda.enable_math_sdp(use_math) 182 | 183 | attn = F.scaled_dot_product_attention( 184 | queries, 185 | keys_e, 186 | values_e, 187 | attn_mask=attn_mask, 188 | dropout_p=self.p_dropout if self.training else 0.0, 189 | is_causal=is_causal_mask, 190 | ) 191 | 192 | if attn_algorithm: 193 | torch.backends.cuda.enable_flash_sdp(self.previous_flash) 194 | torch.backends.cuda.enable_mem_efficient_sdp( 195 | self.previous_mem_efficient 196 | ) 197 | torch.backends.cuda.enable_math_sdp(self.previous_math) 198 | 199 | # attn: bs x seq_len x nheads*emb_v_per_head 200 | # attn: b x h x qlen x ds 201 | # attn after permute: b x qlen x h x ds 202 | # b x qlen x (d) 203 | attn = attn.transpose(2, 1).contiguous() 204 | 205 | attn = attn.view(batch_size, q_len, self.nheads * self.emb_v_per_head) 206 | 207 | out = self.dense(attn) 208 | 209 | # if use_cache=True, we return the hidden_state as well as the kv cache 210 | if use_cache and cache_data_layer: 211 | # note: needed to add this check to return the data_layer as it fails compile otherwise 212 | return out, cache_data_layer.data_layer 213 | else: 214 | return out 215 | 216 | 217 | class TPPagedMultiHeadAttention(PagedMultiHeadAttention, TPModule): 218 | def __init__( 219 | self, 220 | emb_dim, 221 | emb_kq, 222 | emb_v, 223 | nheads, 224 | kvheads, 225 | p_dropout=None, 226 | use_bias=False, 227 | position_encoder: Optional[PositionEncoder] = None, 228 | gain=1, 229 | group: Optional[ProcessGroup] = None, 230 | ): 231 | assert torch.distributed.is_initialized() 232 | 233 | rank, world_size = distributed.rank_and_world(group) 234 | assert ( 235 | nheads % world_size == 0 236 | ), "The number of heads must be divisible by world size" 237 | PagedMultiHeadAttention.__init__( 238 | self, 239 | emb_dim, 240 | emb_kq, 241 | emb_v, 242 | nheads // world_size, 243 | (kvheads // world_size) if kvheads > 1 else kvheads, 244 | p_dropout, 245 | use_bias, 246 | position_encoder, 247 | gain, 248 | ) 249 | self.pre_tp_nheads = nheads 250 | self.pre_tp_kvheads = kvheads 251 | self.setup_tp(rank, world_size) 252 | 253 | def load_weights( 254 | self, 255 | tensor_values: Dict[str, torch.Tensor], 256 | ): 257 | # 1. Grab the weights from tensor_values 258 | used_keys: Set[str] = set() 259 | qkv_weight = self._get_sd_weight( 260 | tensor_values, used_keys, ["qkv_fused", "weight"] 261 | ) 262 | dense_weight = self._get_sd_weight( 263 | tensor_values, used_keys, ["dense", "weight"] 264 | ) 265 | if self.use_bias: 266 | qkv_bias = self._get_sd_weight( 267 | tensor_values, used_keys, ["qkv_fused", "bias"] 268 | ) 269 | dense_bias = self._get_sd_weight( 270 | tensor_values, used_keys, ["dense", "bias"] 271 | ) 272 | 273 | # 2. Raise exceptions 274 | if len(tensor_values) > (4 if self.use_bias else 2): 275 | unused_keys = set(tensor_values.keys()).difference(used_keys) 276 | raise AttributeError(f"Unused weight(s): {', '.join(unused_keys)}") 277 | 278 | # 3. Load and shard the weights 279 | # The number in max_partition_sizes will signify the largest world size 280 | # til we need to duplicate. For instance if we have nheads=16 and 281 | # world_size=32, then first 2 ranks will get first 1/16th of query 282 | self.sharded_copy( 283 | self.qkv_fused.weight, 284 | qkv_weight, 285 | 0, 286 | [self.pre_tp_nheads, self.pre_tp_kvheads, self.pre_tp_kvheads], 287 | ) 288 | self.sharded_copy(self.dense.weight, dense_weight, 1, [self.world_size]) 289 | if self.use_bias: 290 | self.sharded_copy( 291 | self.qkv_fused.bias, 292 | qkv_bias, 293 | 0, 294 | [self.pre_tp_nheads, self.pre_tp_kvheads, self.pre_tp_kvheads], 295 | ) 296 | self.sharded_copy(self.dense.bias, dense_bias, 1, [self.world_size], False) 297 | 298 | @staticmethod 299 | def import_module( 300 | mha: PagedMultiHeadAttention, group: ProcessGroup 301 | ) -> "TPPagedMultiHeadAttention": 302 | tp_mha = TPPagedMultiHeadAttention( 303 | emb_dim=mha.emb_dim, 304 | emb_kq=mha.emb_kq_per_head, 305 | emb_v=mha.emb_v_per_head, 306 | nheads=mha.nheads, 307 | kvheads=mha.kvheads, 308 | p_dropout=mha.p_dropout, 309 | use_bias=mha.use_bias, 310 | position_encoder=mha.position_encoder, 311 | group=group, 312 | ) 313 | return tp_mha 314 | 315 | def forward( 316 | self, 317 | q: torch.Tensor, 318 | k: torch.Tensor, 319 | v: torch.Tensor, 320 | mask: Optional[torch.Tensor] = None, 321 | attn_algorithm: Optional[str] = None, 322 | cache_data_layer: Optional[PagedAttentionCacheDataLayer] = None, 323 | use_cache: bool = False, 324 | is_self: bool = True, 325 | is_causal_mask: bool = False, 326 | ): 327 | """ 328 | Check MultiHeadAttention for up-to-date arguments and docs 329 | """ 330 | 331 | q_par = copy_to_tensor_model_parallel_region(q) 332 | k_par = copy_to_tensor_model_parallel_region(k) 333 | v_par = copy_to_tensor_model_parallel_region(v) 334 | 335 | out_par = PagedMultiHeadAttention.forward( 336 | self, 337 | q_par, 338 | k_par, 339 | v_par, 340 | mask, 341 | attn_algorithm, 342 | cache_data_layer, 343 | use_cache, 344 | is_self, 345 | is_causal_mask, 346 | ) 347 | 348 | # if use_cache=True, we return the hidden_state as well as the kv cache. 349 | # We only reduce the output, and keep the cache thread-local 350 | if use_cache: 351 | out = reduce_from_tensor_model_parallel_region(out_par[0], self.world_size) 352 | return out, out_par[1] 353 | else: 354 | out = reduce_from_tensor_model_parallel_region(out_par, self.world_size) 355 | return out 356 | -------------------------------------------------------------------------------- /fms_extras/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/fms_extras/utils/__init__.py -------------------------------------------------------------------------------- /fms_extras/utils/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/fms_extras/utils/cache/__init__.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Should be mirrored in requirements-build.txt 3 | requires = [ 4 | "ninja", 5 | "packaging", 6 | "setuptools >= 49.4.0", 7 | "torch ~= 2.2.0", 8 | "wheel", 9 | ] 10 | build-backend = "setuptools.build_meta" 11 | -------------------------------------------------------------------------------- /requirements-build.txt: -------------------------------------------------------------------------------- 1 | # Should be mirrored in pyproject.toml 2 | ninja 3 | packaging 4 | setuptools>=49.4.0 5 | torch~=2.2.0 6 | wheel 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Used to install pinned dependencies 2 | # Useful for dev/test jobs caches 3 | # Must be kept in sync with setup.py 4 | torch ~= 2.2.0 # This is what is installed in CI today 5 | ibm-fms >= 0.0.4 6 | transformers >= 4.40.2 7 | accelerate >= 0.30.0 8 | -------------------------------------------------------------------------------- /scripts/paged_speculative_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import os 4 | import time 5 | 6 | import torch 7 | import torch._inductor.config 8 | from fms.models import get_model 9 | from fms.utils import generation, tokenizers 10 | from torch import distributed as dist 11 | 12 | import fms_extras.models.paged_gpt_bigcode 13 | import fms_extras.models.paged_llama 14 | from fms_extras.models.speculator import MLPSpeculator 15 | from fms_extras.utils.generation import paged_generate, speculative_generate 16 | 17 | 18 | # This example script validates the LLaMA implementation by running inference on a couple of prompts. 19 | # torchrun --nproc_per_node=1 scripts/inference.py --variant=7b --model_path=~/models/7B-F --tokenizer=~/models/tokenizer.model --model_source=meta --speculator_path=~/models/speculator_7B_F.pth --compile 20 | 21 | parser = argparse.ArgumentParser( 22 | description="Script to run inference on a causal model" 23 | ) 24 | parser.add_argument("--device_type", type=str, default="cuda") 25 | parser.add_argument( 26 | "--architecture", 27 | type=str, 28 | default="llama", 29 | help="The model architecture to benchmark", 30 | ) 31 | parser.add_argument( 32 | "--variant", 33 | type=str, 34 | default="7b", 35 | help="The model variant (configuration) to benchmark. E.g. 7b, 13b, 70b.", 36 | ) 37 | parser.add_argument( 38 | "--model_path", 39 | type=str, 40 | help="Path to the directory containing LLaMa weights (.pth files sharded by tensor parallel rank, not HF weights)", 41 | ) 42 | parser.add_argument( 43 | "--speculator_path", 44 | type=str, 45 | default=None, 46 | help="Path to the checkpoint containing speculator weights (single .pth file, not HF weights)", 47 | ) 48 | parser.add_argument( 49 | "--speculator_variant", 50 | type=str, 51 | default="840m", 52 | help="The model variant (configuration) to benchmark. E.g. 840m, 1.4b, 2b, etc.", 53 | ) 54 | parser.add_argument( 55 | "--speculator_source", 56 | type=str, 57 | default=None, 58 | choices=["hf"], 59 | help="Source format of speculator weights. Note: If the weights path specified in speculator_path are not local and " 60 | "the source is hf, the weights will be pulled using the normal Huggingface from_pretrained method.", 61 | ) 62 | parser.add_argument( 63 | "--model_source", 64 | type=str, 65 | help="Source of the checkpoint. E.g. 'meta', 'hf', None", 66 | ) 67 | 68 | parser.add_argument( 69 | "--checkpoint_sharding", 70 | type=str, 71 | default=None, 72 | help="type of weight sharding. E.g. tensor-parallel (tp), None", 73 | ) 74 | parser.add_argument( 75 | "--tokenizer", 76 | type=str, 77 | required=True, 78 | help="Path to the tokenizer (e.g. ~/tokenizer.model)", 79 | ) 80 | 81 | parser.add_argument( 82 | "--compile", 83 | action="store_true", 84 | help="Use torch.compile (slow for first inference pass)", 85 | ) 86 | parser.add_argument( 87 | "--compile_mode", 88 | type=str, 89 | help="Mode for compilation", 90 | default="default", 91 | choices=["default", "reduce-overhead"], 92 | ) 93 | parser.add_argument( 94 | "--deterministic", 95 | action="store_true", 96 | help="Set torch.use_deterministic_algorithms? Requires env variable `CUBLAS_WORKSPACE_CONFIG=:4096:8`", 97 | ) 98 | parser.add_argument( 99 | "--distributed", 100 | action="store_true", 101 | help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", 102 | ) 103 | parser.add_argument("--context_file", type=str, default=None, help="File to summarize") 104 | parser.add_argument( 105 | "--batch_input", 106 | action="store_true", 107 | help="use a batch of prompts as input (note this is still wip for reduce-overhead=True)", 108 | ) 109 | # top_k_tokens_per_head 110 | parser.add_argument( 111 | "--top_k_tokens_per_head", 112 | type=lambda s: list(map(int, s.split(","))), 113 | default=[5, 3, 2], 114 | help="Number of tokens to consider from each head when forming the candidate tree. For each candidate branch in the tree, head n produces topk[n] additional sub-branches.", 115 | ) 116 | parser.add_argument( 117 | "--prompt_type", 118 | type=str, 119 | choices=["chat", "code"], 120 | default="chat", 121 | help="type of prompts to be used, either chat or code", 122 | ) 123 | args = parser.parse_args() 124 | 125 | if args.batch_input and args.compile and args.compile_mode == "reduce-overhead": 126 | print( 127 | "setting compile_mode to default as cudagraphs is not yet supported with batches" 128 | ) 129 | compile_mode = "default" 130 | else: 131 | compile_mode = args.compile_mode 132 | 133 | local_rank = int(os.getenv("LOCAL_RANK", 0)) 134 | world_size = int(os.getenv("WORLD_SIZE", 1)) 135 | if args.device_type == "cuda": 136 | device = torch.device(args.device_type, local_rank) 137 | torch.cuda.set_device(device) 138 | else: 139 | device = torch.device(args.device_type) 140 | 141 | torch.set_default_dtype(torch.half) 142 | 143 | # requires setting environment variable: `CUBLAS_WORKSPACE_CONFIG=:4096:8` 144 | if args.deterministic: 145 | torch.use_deterministic_algorithms(True) 146 | 147 | if args.distributed: 148 | dist.init_process_group() 149 | torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) 150 | 151 | print("loading model") 152 | if args.distributed: 153 | distr_param = "tp" 154 | else: 155 | if torch.cuda.device_count() > 1 and world_size == 1: 156 | distr_param = "mp" 157 | else: 158 | distr_param = None 159 | 160 | model = get_model( 161 | f"paged_{args.architecture}", 162 | args.variant, 163 | model_path=args.model_path, 164 | checkpoint_sharding=args.checkpoint_sharding, 165 | device_type=args.device_type, 166 | source=args.model_source, 167 | distributed_strategy=distr_param, 168 | group=dist.group.WORLD, 169 | ) 170 | decode_model = None 171 | 172 | tokenizer = tokenizers.get_tokenizer(args.tokenizer) 173 | model.eval() 174 | torch.set_grad_enabled(False) 175 | speculator = None 176 | if args.speculator_path is not None: 177 | print("loading speculator") 178 | # todo: handling of remote weights in get_model 179 | is_local = os.path.exists(args.speculator_path) or args.speculator_source != "hf" 180 | if is_local: 181 | speculator = get_model( 182 | "mlp_speculator", 183 | f"{args.architecture}.{args.variant}.{args.speculator_variant}", 184 | model_path=args.speculator_path, 185 | source=args.speculator_source, 186 | device_type=args.device_type, 187 | ) 188 | else: 189 | from fms_extras.models.hf.modeling_mlp_speculator import ( 190 | MLPSpeculatorPreTrainedModel, 191 | ) 192 | 193 | speculator = MLPSpeculatorPreTrainedModel.from_pretrained( 194 | args.speculator_path, device_map=args.device_type 195 | ).speculator 196 | speculator = speculator.to(device) 197 | if len(args.top_k_tokens_per_head) != speculator.n_predict: 198 | print( 199 | "length of top_k_tokens_per_head must be equal to the speculator's number of heads (n_predict)" 200 | ) 201 | exit() 202 | print("loading complete on rank", local_rank) 203 | 204 | print("initializing paged cache") 205 | # cache setup 206 | from fms_extras.utils.cache.paged import PagedKVCacheManager 207 | 208 | 209 | use_cache = True 210 | if hasattr(model.config, "kvheads"): 211 | kv_heads = model.config.kvheads 212 | else: 213 | kv_heads = 1 if model.config.multiquery_attn else model.config.nheads 214 | 215 | kv_cache_manager = PagedKVCacheManager( 216 | model.config.nlayers, 217 | model.config.nheads, 218 | model.config.emb_dim, 219 | kv_heads=kv_heads, 220 | tensor_parallel_size=dist.get_world_size() if args.distributed else 1, 221 | dtype=torch.get_default_dtype(), 222 | device=device, 223 | ) 224 | print("cache initialization complete on rank", local_rank) 225 | 226 | add_special_tokens = tokenizer.bos_token_id != tokenizer.eos_token_id 227 | 228 | 229 | def ids_for_prompt(prompt): 230 | tokens = tokenizer.tokenize(prompt) 231 | ids = tokenizer.convert_tokens_to_ids(tokens) 232 | if add_special_tokens: 233 | ids = [tokenizer.bos_token_id] + ids 234 | ids = torch.tensor(ids, dtype=torch.long, device=device) 235 | return ids 236 | 237 | 238 | def print_result(result, inp, n_steps): 239 | if local_rank != 0: 240 | return 241 | # stop at EOS token if present 242 | if add_special_tokens: 243 | result = generation.truncate_after_eos(result, tokenizer.eos_token_id) 244 | print(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(result))) 245 | print(f"{len(result) - len(inp)} tokens in {n_steps} steps") 246 | print() 247 | 248 | 249 | def infer(ids, warmup): 250 | # With greedy generation (do_sample=False) we _should_ always get the same results. 251 | # There is currently a bug in start_pos for batched rotary embeddings that can lead 252 | # varying results for the same prompt. 253 | if local_rank == 0: 254 | print("==================") 255 | 256 | cudagraphs = compile_mode == "reduce-overhead" 257 | max_seq_len = ( 258 | model.config.max_expected_seq_len 259 | if hasattr(model.config, "max_expected_seq_len") 260 | else model.config.max_pos 261 | ) 262 | if speculator: 263 | result, n_steps, ttft, generated_token_time_out = speculative_generate( 264 | model, 265 | ids, 266 | speculator, 267 | kv_cache_manager, 268 | new_tokens=100, 269 | max_seq_len=max_seq_len, 270 | decode_model=decode_model, 271 | # todo: we can only reduce-overhead for now when batch size is 1 272 | flattening=not (args.compile and compile_mode == "reduce-overhead"), 273 | cudagraphs=cudagraphs, 274 | threshes=args.top_k_tokens_per_head, 275 | ) 276 | else: 277 | result, n_steps, ttft, generated_token_time_out = paged_generate( 278 | model, 279 | ids, 280 | kv_cache_manager, 281 | max_new_tokens=100, 282 | max_seq_len=max_seq_len, 283 | do_sample=False, 284 | decode_model=decode_model, 285 | cudagraphs=cudagraphs, 286 | ) 287 | if not warmup: 288 | total_tokens = 0 289 | for i in range(len(result)): 290 | print_result(result[i], ids[i], n_steps) 291 | total_tokens += len(result[i]) - len(ids[i]) 292 | avg_tokens = total_tokens / len(result) 293 | print(f"time to first token: {ttft}") 294 | print(f"time per token (decode): {generated_token_time_out / avg_tokens}") 295 | 296 | 297 | if args.compile: 298 | print("compiling model") 299 | # Bug with kv-cache in PT2.1 300 | torch._inductor.config.joint_graph_constant_folding = False 301 | # compiling can make first inference pass slow 302 | decode_model = model 303 | decode_model = torch.compile(decode_model, mode=compile_mode, fullgraph=True) 304 | model = torch.compile(model, fullgraph=True, dynamic=True) 305 | if speculator: 306 | speculator = torch.compile(speculator, mode=compile_mode) 307 | speculator.generate_suffixes = torch.compile( 308 | speculator.generate_suffixes, mode=compile_mode 309 | ) 310 | 311 | if args.prompt_type == "chat": 312 | template = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:" 313 | 314 | prompt1 = template.format( 315 | "Provide a list of instructions for preparing chicken soup." 316 | ) 317 | prompt2 = template.format("Explain some popular greetings in Spanish.") 318 | prompt3 = template.format("Explain to me why ignorance is bliss.") 319 | prompt4 = template.format( 320 | "I have just come into a very large sum of money. I received the money from my parents who told me I could do whatever I want with it. My first thought was to go to a financial advisor. Provide me a list of things that I can do with my new found wealth." 321 | ) 322 | 323 | elif args.prompt_type == "code": 324 | template = "[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```:\n{}\n[/INST]" 325 | prompt1 = template.format("Write a bubble sort function in python.") 326 | prompt2 = template.format( 327 | "Using the Java streams API, write a simple function which will get the cumulative sum of a list of integers." 328 | ) 329 | prompt3 = template.format( 330 | "In bash, how do I list all directories and sub-directories which contain a .py file." 331 | ) 332 | prompt4 = template.format( 333 | "Write a simple decorator in python which will modify all string inputs to ints if possible." 334 | ) 335 | 336 | else: 337 | print("prompt_type must be one of chat or code") 338 | exit() 339 | 340 | 341 | prompt1 = ids_for_prompt(prompt1) 342 | prompt2 = ids_for_prompt(prompt2) 343 | prompt3 = ids_for_prompt(prompt3) 344 | prompt4 = ids_for_prompt(prompt4) 345 | 346 | if args.batch_input: 347 | ids = [prompt1, prompt2, prompt3, prompt4] 348 | else: 349 | ids = [prompt1] 350 | 351 | infer(ids, warmup=True) 352 | print("generating output", local_rank) 353 | infer(ids, warmup=True) 354 | infer(ids, warmup=False) 355 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | import warnings 5 | from typing import List, Set 6 | 7 | import torch 8 | from packaging.version import Version, parse 9 | from setuptools import find_packages, setup 10 | from torch.utils.cpp_extension import ( 11 | CUDA_HOME, 12 | ROCM_HOME, 13 | BuildExtension, 14 | CUDAExtension, 15 | ) 16 | 17 | 18 | ROOT_DIR = os.path.dirname(__file__) 19 | 20 | ext_modules = [] 21 | cmdclass = {} 22 | 23 | 24 | def _is_hip() -> bool: 25 | return torch.version.hip is not None 26 | 27 | 28 | def _is_cuda() -> bool: 29 | return torch.version.cuda is not None 30 | 31 | 32 | if CUDA_HOME is not None or ROCM_HOME is not None: 33 | # vllm setup for csrc 34 | MAIN_CUDA_VERSION = "12.1" 35 | 36 | # Supported NVIDIA GPU architectures. 37 | NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} 38 | ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} 39 | 40 | # Compiler flags. 41 | CXX_FLAGS = ["-g", "-O2", "-std=c++17"] 42 | NVCC_FLAGS = ["-O2", "-std=c++17"] 43 | 44 | ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 45 | CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] 46 | NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] 47 | 48 | if _is_hip(): 49 | if ROCM_HOME is None: 50 | raise RuntimeError( 51 | "Cannot find ROCM_HOME. ROCm must be available to build the package." 52 | ) 53 | NVCC_FLAGS += ["-DUSE_ROCM"] 54 | 55 | def get_amdgpu_offload_arch(): 56 | command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" 57 | try: 58 | output = subprocess.check_output([command]) 59 | return output.decode("utf-8").strip() 60 | except subprocess.CalledProcessError as e: 61 | error_message = f"Error: {e}" 62 | raise RuntimeError(error_message) from e 63 | except FileNotFoundError as e: 64 | # If the command is not found, print an error message 65 | error_message = f"The command {command} was not found." 66 | raise RuntimeError(error_message) from e 67 | 68 | return None 69 | 70 | def get_hipcc_rocm_version(): 71 | # Run the hipcc --version command 72 | result = subprocess.run( 73 | ["hipcc", "--version"], 74 | stdout=subprocess.PIPE, 75 | stderr=subprocess.STDOUT, 76 | text=True, 77 | ) 78 | 79 | # Check if the command was executed successfully 80 | if result.returncode != 0: 81 | print("Error running 'hipcc --version'") 82 | return None 83 | 84 | # Extract the version using a regular expression 85 | match = re.search(r"HIP version: (\S+)", result.stdout) 86 | if match: 87 | # Return the version string 88 | return match.group(1) 89 | else: 90 | print("Could not find HIP version in the output") 91 | return None 92 | 93 | def get_nvcc_cuda_version(cuda_dir: str) -> Version: 94 | """Get the CUDA version from nvcc. 95 | 96 | Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py 97 | """ 98 | nvcc_output = subprocess.check_output( 99 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 100 | ) 101 | output = nvcc_output.split() 102 | release_idx = output.index("release") + 1 103 | nvcc_cuda_version = parse(output[release_idx].split(",")[0]) 104 | return nvcc_cuda_version 105 | 106 | def get_torch_arch_list() -> Set[str]: 107 | # TORCH_CUDA_ARCH_LIST can have one or more architectures, 108 | # e.g. "8.0" or "7.5,8.0,8.6+PTX". Here, the "8.6+PTX" option asks the 109 | # compiler to additionally include PTX code that can be runtime-compiled 110 | # and executed on the 8.6 or newer architectures. While the PTX code will 111 | # not give the best performance on the newer architectures, it provides 112 | # forward compatibility. 113 | env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) 114 | if env_arch_list is None: 115 | return set() 116 | 117 | # List are separated by ; or space. 118 | torch_arch_list = set(env_arch_list.replace(" ", ";").split(";")) 119 | if not torch_arch_list: 120 | return set() 121 | 122 | # Filter out the invalid architectures and print a warning. 123 | valid_archs = NVIDIA_SUPPORTED_ARCHS.union( 124 | {s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS} 125 | ) 126 | arch_list = torch_arch_list.intersection(valid_archs) 127 | # If none of the specified architectures are valid, raise an error. 128 | if not arch_list: 129 | raise RuntimeError( 130 | "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " 131 | f"variable ({env_arch_list}) is supported. " 132 | f"Supported CUDA architectures are: {valid_archs}." 133 | ) 134 | invalid_arch_list = torch_arch_list - valid_archs 135 | if invalid_arch_list: 136 | warnings.warn( 137 | f"Unsupported CUDA architectures ({invalid_arch_list}) are " 138 | "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " 139 | f"({env_arch_list}). Supported CUDA architectures are: " 140 | f"{valid_archs}.", 141 | stacklevel=2, 142 | ) 143 | return arch_list 144 | 145 | # First, check the TORCH_CUDA_ARCH_LIST environment variable. 146 | compute_capabilities = get_torch_arch_list() 147 | if _is_cuda() and not compute_capabilities: 148 | # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available 149 | # GPUs on the current machine. 150 | device_count = torch.cuda.device_count() 151 | for i in range(device_count): 152 | major, minor = torch.cuda.get_device_capability(i) 153 | if major < 7: 154 | raise RuntimeError( 155 | "GPUs with compute capability below 7.0 are not supported." 156 | ) 157 | compute_capabilities.add(f"{major}.{minor}") 158 | 159 | if _is_cuda(): 160 | nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) 161 | if not compute_capabilities: 162 | # If no GPU is specified nor available, add all supported architectures 163 | # based on the NVCC CUDA version. 164 | compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy() 165 | if nvcc_cuda_version < Version("11.1"): 166 | compute_capabilities.remove("8.6") 167 | if nvcc_cuda_version < Version("11.8"): 168 | compute_capabilities.remove("8.9") 169 | compute_capabilities.remove("9.0") 170 | # Validate the NVCC CUDA version. 171 | if nvcc_cuda_version < Version("11.0"): 172 | raise RuntimeError("CUDA 11.0 or higher is required to build the package.") 173 | if nvcc_cuda_version < Version("11.1") and any( 174 | cc.startswith("8.6") for cc in compute_capabilities 175 | ): 176 | raise RuntimeError( 177 | "CUDA 11.1 or higher is required for compute capability 8.6." 178 | ) 179 | if nvcc_cuda_version < Version("11.8"): 180 | if any(cc.startswith("8.9") for cc in compute_capabilities): 181 | # CUDA 11.8 is required to generate the code targeting compute capability 8.9. 182 | # However, GPUs with compute capability 8.9 can also run the code generated by 183 | # the previous versions of CUDA 11 and targeting compute capability 8.0. 184 | # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 185 | # instead of 8.9. 186 | warnings.warn( 187 | "CUDA 11.8 or higher is required for compute capability 8.9. " 188 | "Targeting compute capability 8.0 instead.", 189 | stacklevel=2, 190 | ) 191 | compute_capabilities = set( 192 | cc for cc in compute_capabilities if not cc.startswith("8.9") 193 | ) 194 | compute_capabilities.add("8.0+PTX") 195 | if any(cc.startswith("9.0") for cc in compute_capabilities): 196 | raise RuntimeError( 197 | "CUDA 11.8 or higher is required for compute capability 9.0." 198 | ) 199 | 200 | # Add target compute capabilities to NVCC flags. 201 | for capability in compute_capabilities: 202 | num = capability[0] + capability[2] 203 | NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] 204 | if capability.endswith("+PTX"): 205 | NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] 206 | 207 | # Use NVCC threads to parallelize the build. 208 | if nvcc_cuda_version >= Version("11.2"): 209 | nvcc_threads = int(os.getenv("NVCC_THREADS", 8)) 210 | num_threads = min(os.cpu_count(), nvcc_threads) 211 | NVCC_FLAGS += ["--threads", str(num_threads)] 212 | 213 | elif _is_hip(): 214 | amd_arch = get_amdgpu_offload_arch() 215 | if amd_arch not in ROCM_SUPPORTED_ARCHS: 216 | raise RuntimeError( 217 | f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" 218 | f"amdgpu_arch_found: {amd_arch}" 219 | ) 220 | 221 | paged_attn_extension = CUDAExtension( 222 | name="fms_extras.paged_c", 223 | sources=[ 224 | "csrc/paged_attention/cache_kernels.cu", 225 | "csrc/paged_attention/attention/attention_kernels.cu", 226 | "csrc/paged_attention/cuda_utils_kernels.cu", 227 | "csrc/paged_attention/pybind.cpp", 228 | ], 229 | extra_compile_args={ 230 | "cxx": CXX_FLAGS, 231 | "nvcc": NVCC_FLAGS, 232 | }, 233 | ) 234 | ext_modules.append(paged_attn_extension) 235 | cmdclass["build_ext"] = BuildExtension 236 | 237 | 238 | def get_path(*filepath) -> str: 239 | return os.path.join(ROOT_DIR, *filepath) 240 | 241 | 242 | def get_requirements() -> List[str]: 243 | """Get Python package dependencies from requirements.txt.""" 244 | with open(get_path("requirements.txt")) as f: 245 | requirements = f.read().strip().split("\n") 246 | return requirements 247 | 248 | 249 | setup( 250 | name="fms_extras", 251 | version="0.0.1", 252 | author="Brian Vaughan, Joshua Rosenkranz, Antoni Viros i Martin, Davis Wertheimer, Supriyo Chakraborty, Raghu Kiran Ganti", 253 | author_email="bvaughan@ibm.com, jmrosenk@us.ibm.com, aviros@ibm.com, Davis.Wertheimer@ibm.com, supriyo@us.ibm.com, rganti@us.ibm.com", 254 | description="IBM Foundation Model Stack Extras", 255 | packages=find_packages(exclude=("csrc",)), 256 | install_requires=get_requirements(), 257 | ext_modules=ext_modules, 258 | cmdclass=cmdclass, 259 | url="https://github.com/foundation-model-stack/fms-extras", 260 | license="Apache License 2.0", 261 | classifiers=[ 262 | "Programming Language :: Python :: 3", 263 | "License :: OSI Approved :: Apache Software License", 264 | ], 265 | ) 266 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | # This requirements file is for test jobs. 2 | # It pulls in general dependencies from fms-extras-requirements.txt 3 | # as well as test-only dependencies 4 | 5 | -r fms-extras-requirements.txt 6 | 7 | # Test tools 8 | mypy==1.8.0 9 | mypy-extensions==1.0.0 10 | pytest==8.0.0 11 | 12 | # Types packages 13 | pyarrow-stubs==10.0.1.7 14 | types-requests==2.31.0.20240125 15 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # content of conftest.py 2 | 3 | import pytest 4 | 5 | 6 | def pytest_addoption(parser): 7 | parser.addoption( 8 | "--runslow", action="store_true", default=False, help="run slow tests" 9 | ) 10 | parser.addoption( 11 | "--capture_expectation", 12 | action="store_true", 13 | default=False, 14 | help="capture the output expectation for a given test", 15 | ) 16 | 17 | 18 | def pytest_configure(config): 19 | config.addinivalue_line("markers", "slow: mark test as slow to run") 20 | config.addinivalue_line("markers", "capture expectation: expectation was captured") 21 | 22 | 23 | def pytest_generate_tests(metafunc): 24 | option_value = metafunc.config.option.capture_expectation 25 | if "capture_expectation" in metafunc.fixturenames and option_value is not None: 26 | metafunc.parametrize("capture_expectation", [option_value]) 27 | 28 | 29 | def pytest_collection_modifyitems(config, items): 30 | if config.getoption("--runslow"): 31 | # --runslow given in cli: do not skip slow tests 32 | return 33 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 34 | for item in items: 35 | if "slow" in item.keywords: 36 | item.add_marker(skip_slow) 37 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/tests/models/hf/__init__.py -------------------------------------------------------------------------------- /tests/models/hf/test_mlp_speculator.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import torch 4 | from fms.models import get_model 5 | from fms.models.hf import to_hf_api 6 | from transformers import PreTrainedModel 7 | 8 | from fms_extras.models.hf.modeling_mlp_speculator import ( 9 | MLPSpeculatorConfig, 10 | MLPSpeculatorPreTrainedModel, 11 | ) 12 | from fms_extras.models.speculator import MLPSpeculator 13 | 14 | 15 | def __test_speculator_equivalence( 16 | speculator_1, speculator_2, top_k_tokens_per_head, n_candidates 17 | ): 18 | sd1_is_hf = isinstance(speculator_1, PreTrainedModel) 19 | sd2_is_hf = isinstance(speculator_2, PreTrainedModel) 20 | emb_dim = speculator_1.config.emb_dim if sd1_is_hf else speculator_1.emb_dim 21 | sd1 = speculator_1.state_dict() 22 | sd2 = speculator_2.state_dict() 23 | # make sure state dicts are same 24 | assert len(sd1) == len(sd2) 25 | for k in sd1.keys(): 26 | sd1_key = k 27 | if sd1_is_hf and not sd2_is_hf: 28 | sd2_key = sd1_key.replace("speculator.", "") 29 | elif sd2_is_hf and not sd1_is_hf: 30 | sd2_key = f"speculator.{sd1_key}" 31 | else: 32 | sd2_key = sd1_key 33 | torch.testing.assert_close(sd1[sd1_key], sd2[sd2_key]) 34 | 35 | # make sure generate_suffixes produce same output 36 | state = torch.rand(4, 1, emb_dim) 37 | ind = torch.randint(low=0, high=10, size=(4, 1)) 38 | 39 | speculator1_out = speculator_1.generate_suffixes( 40 | state, ind, top_k_tokens_per_head, n_candidates 41 | ) 42 | speculator2_out = speculator_2.generate_suffixes( 43 | state, ind, top_k_tokens_per_head, n_candidates 44 | ) 45 | torch.testing.assert_close(speculator1_out, speculator2_out) 46 | 47 | 48 | def test_get_model_from_hf(): 49 | config = MLPSpeculatorConfig( 50 | vocab_size=256, 51 | emb_dim=64, 52 | inner_dim=32, 53 | n_predict=4, 54 | top_k_tokens_per_head=[5, 3, 2, 2], 55 | n_candidates=5, 56 | ) 57 | 58 | hf_model = MLPSpeculatorPreTrainedModel(config) 59 | hf_model.reset_parameters() 60 | 61 | with tempfile.TemporaryDirectory() as workdir: 62 | path = f"{workdir}/model_out" 63 | hf_model.save_pretrained(path) 64 | 65 | model = get_model( 66 | "mlp_speculator", 67 | "llama.7b.840m", 68 | model_path=path, 69 | source="hf", 70 | vocab_size=config.vocab_size, 71 | emb_dim=config.emb_dim, 72 | inner_dim=config.inner_dim, 73 | n_predict=config.n_predict, 74 | ) 75 | 76 | model.eval() 77 | 78 | __test_speculator_equivalence( 79 | model, hf_model, config.top_k_tokens_per_head, config.n_candidates 80 | ) 81 | 82 | 83 | def test_saved_hf_model_produces_same_output_as_original_fms(): 84 | vocab_size = 256 85 | emb_dim = 64 86 | inner_dim = 32 87 | n_predict = 4 88 | speculator = MLPSpeculator( 89 | emb_dim=emb_dim, vocab_size=vocab_size, inner_dim=inner_dim, n_predict=n_predict 90 | ) 91 | speculator.reset_parameters() 92 | speculator.eval() 93 | 94 | top_k_tokens_per_head = [5, 3, 2, 2] 95 | n_candidates = 5 96 | hf_speculator = to_hf_api( 97 | speculator, 98 | top_k_tokens_per_head=top_k_tokens_per_head, 99 | n_candidates=n_candidates, 100 | ) 101 | hf_speculator.eval() 102 | 103 | with tempfile.TemporaryDirectory() as workdir: 104 | hf_path = f"{workdir}/hf_speculator_out.pth" 105 | hf_speculator.save_pretrained(hf_path) 106 | 107 | loaded_hf_speculator = MLPSpeculatorPreTrainedModel.from_pretrained(hf_path) 108 | loaded_hf_speculator.eval() 109 | 110 | __test_speculator_equivalence( 111 | speculator, loaded_hf_speculator, top_k_tokens_per_head, n_candidates 112 | ) 113 | 114 | 115 | def test_to_hf_api(): 116 | vocab_size = 256 117 | emb_dim = 64 118 | inner_dim = 32 119 | n_predict = 4 120 | speculator = MLPSpeculator( 121 | emb_dim=emb_dim, vocab_size=vocab_size, inner_dim=inner_dim, n_predict=n_predict 122 | ) 123 | speculator.reset_parameters() 124 | speculator.eval() 125 | 126 | top_k_tokens_per_head = [5, 3, 2, 2] 127 | n_candidates = 5 128 | hf_speculator = to_hf_api( 129 | speculator, 130 | top_k_tokens_per_head=top_k_tokens_per_head, 131 | n_candidates=n_candidates, 132 | ) 133 | hf_speculator.eval() 134 | 135 | __test_speculator_equivalence( 136 | speculator, hf_speculator, top_k_tokens_per_head, n_candidates 137 | ) 138 | 139 | 140 | def test_round_trip(): 141 | vocab_size = 256 142 | emb_dim = 64 143 | inner_dim = 32 144 | n_predict = 4 145 | top_k_tokens_per_head = [5, 3, 2, 2] 146 | n_candidates = 5 147 | config = MLPSpeculatorConfig( 148 | emb_dim=emb_dim, 149 | vocab_size=vocab_size, 150 | inner_dim=inner_dim, 151 | n_predict=n_predict, 152 | top_k_tokens_per_head=top_k_tokens_per_head, 153 | n_candidates=n_candidates, 154 | ) 155 | original_model = MLPSpeculatorPreTrainedModel(config) 156 | original_model.reset_parameters() 157 | original_model.eval() 158 | 159 | with tempfile.TemporaryDirectory() as workdir: 160 | hf_path = f"{workdir}/hf_speculator_out.pth" 161 | original_model.save_pretrained(hf_path) 162 | 163 | loaded_model = MLPSpeculatorPreTrainedModel.from_pretrained(hf_path) 164 | loaded_model.eval() 165 | 166 | __test_speculator_equivalence( 167 | original_model, loaded_model, top_k_tokens_per_head, n_candidates 168 | ) 169 | -------------------------------------------------------------------------------- /tests/models/hf_equivalence/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/tests/models/hf_equivalence/__init__.py -------------------------------------------------------------------------------- /tests/models/hf_equivalence/test_calico.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from fms.models import get_model 3 | from fms.models.hf import to_hf_api 4 | 5 | from fms_extras.models.hf import register_fms_models 6 | 7 | 8 | @pytest.mark.slow 9 | def test_calico_equivalence(): 10 | pytest.importorskip("megatron_models") 11 | # TODO: model_gpt_bigcode required a change on line 999 past_length = past_key_values[0][0].shape[2] for generate to work 12 | import torch 13 | from fms.testing.comparison import ( 14 | HFModelSignatureParams, 15 | ModelSignatureParams, 16 | compare_model_signatures, 17 | ) 18 | from megatron_models import GPTMegatronForCausalLM 19 | from transformers import AutoTokenizer, pipeline 20 | 21 | register_fms_models() 22 | 23 | device = "cpu" 24 | # path to GPTMegatronForCausalLM weights 25 | path = "" 26 | # 1b, 8b, or 13b 27 | variant = "1b" 28 | 29 | gpt_megatron_model = GPTMegatronForCausalLM.from_pretrained(path, device_map=device) 30 | calico_model = get_model( 31 | "calico", variant, path, source="megatron", device_type=device 32 | ) 33 | 34 | count_parameters = lambda m: sum(p.numel() for p in m.parameters()) 35 | assert count_parameters(gpt_megatron_model) == count_parameters(calico_model) 36 | 37 | inp = torch.arange(5, 15).unsqueeze(0) 38 | params_mega = HFModelSignatureParams( 39 | model=gpt_megatron_model, params=["input_ids"], inp=inp 40 | ) 41 | params_fms = ModelSignatureParams(model=calico_model, params=1, inp=inp) 42 | 43 | compare_model_signatures(params_mega, params_fms) 44 | 45 | # huggingface model backed by fms internals 46 | calico_hf_model = to_hf_api( 47 | calico_model, 48 | pad_token_id=gpt_megatron_model.config.pad_token_id, 49 | bos_token_id=gpt_megatron_model.config.bos_token_id, 50 | eos_token_id=gpt_megatron_model.config.eos_token_id, 51 | ) 52 | 53 | # generate some text -- the first time will be slow since the model needs to be compiled, but subsequent generations should be faster. 54 | tokenizer = AutoTokenizer.from_pretrained(path) 55 | calico_generator = pipeline( 56 | task="text-generation", model=calico_hf_model, tokenizer=tokenizer 57 | ) 58 | calico_out = calico_generator( 59 | """q: how are you? a: I am good. How about you? q: What is the weather like today? a:""", 60 | max_new_tokens=25, 61 | ) 62 | print(calico_out) 63 | 64 | gpt_megatron_generator = pipeline( 65 | task="text-generation", model=gpt_megatron_model, tokenizer=tokenizer 66 | ) 67 | gpt_megatron_out = gpt_megatron_generator( 68 | """q: how are you? a: I am good. How about you? q: What is the weather like today? a:""", 69 | max_new_tokens=25, 70 | ) 71 | print(gpt_megatron_out) 72 | assert calico_out == gpt_megatron_out 73 | -------------------------------------------------------------------------------- /tests/models/test_calico.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from fms.testing._internal.model_test_suite import ( 4 | ConfigFixtureMixin, 5 | ModelCompileTestSuite, 6 | ModelConfigTestSuite, 7 | ModelConsistencyTestSuite, 8 | ModelFixtureMixin, 9 | ) 10 | 11 | from fms_extras.models.calico import Calico, CalicoConfig 12 | 13 | 14 | class CalicoFixtures(ConfigFixtureMixin, ModelFixtureMixin): 15 | """ 16 | Base Calico Fixtures that can be re-used for other purposes 17 | 18 | This will include the config and model signatures 19 | """ 20 | 21 | @pytest.fixture(scope="class", autouse=True) 22 | def uninitialized_model(self, config: CalicoConfig): 23 | return Calico(config) 24 | 25 | @pytest.fixture(scope="class", autouse=True) 26 | def config(self) -> CalicoConfig: 27 | return CalicoConfig( 28 | src_vocab_size=384, 29 | emb_dim=16, 30 | norm_eps=1e-5, 31 | nheads=8, 32 | nlayers=2, 33 | hidden_grow_factor=2.0, 34 | multiple_of=1, 35 | kvheads=4, 36 | activation_fn="swish", 37 | max_expected_seq_len=512, 38 | pad_id=0, 39 | ) 40 | 41 | 42 | class TestCalico( 43 | ModelConfigTestSuite, 44 | ModelConsistencyTestSuite, 45 | ModelCompileTestSuite, 46 | CalicoFixtures, 47 | ): 48 | """ 49 | Model Test Suite for Calico 50 | 51 | This suite will include tests for: 52 | - model configuration 53 | - basic load/save model 54 | - consistency of model output 55 | """ 56 | 57 | # x is the main parameter for this model which is the input tensor 58 | _get_signature_params = ["x"] 59 | 60 | def test_config_passed_to_model_and_updated(self, model, config): 61 | """test model constructor appropriately merges any passed kwargs into the config without mutating the original config""" 62 | model = type(model)(config=config, pad_id=config.pad_id + 1) 63 | # check not same reference 64 | assert model.get_config() is not config 65 | 66 | # modify pad_id to the new value expected and check equivalence 67 | config.pad_id = config.pad_id + 1 68 | assert model.get_config().as_dict() == config.as_dict() 69 | -------------------------------------------------------------------------------- /tests/models/test_paged_llama.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tempfile 3 | 4 | import pytest 5 | import torch 6 | from fms.models import get_model 7 | 8 | 9 | @pytest.mark.skipif( 10 | not torch.cuda.is_available(), 11 | reason="must have cuda to run paged llama equivalency test", 12 | ) 13 | def test_llama_and_paged_llama_equivalency(): 14 | from fms_extras.models import paged_llama 15 | from fms_extras.utils.cache.paged import PagedKVCacheManager 16 | 17 | # note: changed micro to have nheads=2 to increase head size for paged attention kernels 18 | llama = get_model("llama", "micro", device_type="cuda", nheads=2) 19 | 20 | with tempfile.TemporaryDirectory() as workdir: 21 | sd_path = f"{workdir}/model.pth" 22 | torch.save(llama.state_dict(), sd_path) 23 | 24 | paged_llama = get_model( 25 | "paged_llama", 26 | "micro", 27 | model_path=sd_path, 28 | source="fms_llama", 29 | device_type="cuda", 30 | nheads=2, 31 | ) 32 | torch.set_grad_enabled(False) 33 | llama.eval() 34 | paged_llama.eval() 35 | 36 | kv_cache_manager = PagedKVCacheManager( 37 | paged_llama.config.nlayers, 38 | paged_llama.config.nheads, 39 | paged_llama.config.emb_dim, 40 | kv_heads=paged_llama.config.kvheads, 41 | dtype=torch.get_default_dtype(), 42 | total_num_gpu_blocks=100, 43 | ) 44 | input_ids = torch.arange(0, 16, device="cuda").unsqueeze(0) 45 | cache_data = kv_cache_manager.allocate_tokens([input_ids.size(1)]) 46 | 47 | prefill_llama, prefill_cache = llama.forward( 48 | input_ids, position_ids=cache_data.position_ids, use_cache=True 49 | ) 50 | prefill_paged_llama, _ = paged_llama.forward( 51 | input_ids, use_cache=True, cache_data=cache_data 52 | ) 53 | 54 | torch.testing.assert_close(prefill_llama, prefill_paged_llama) 55 | 56 | input_ids = torch.argmax(prefill_llama[:, -1, :], dim=-1).unsqueeze(0).t() 57 | cache_data = kv_cache_manager.allocate_tokens([1], cache_data.sequence_ids) 58 | 59 | decode_llama, _ = llama.forward( 60 | input_ids, 61 | position_ids=cache_data.position_ids, 62 | use_cache=True, 63 | past_key_value_states=prefill_cache, 64 | ) 65 | decode_paged_llama, _ = paged_llama.forward( 66 | input_ids, use_cache=True, cache_data=cache_data 67 | ) 68 | 69 | torch.testing.assert_close(decode_llama, decode_paged_llama) 70 | -------------------------------------------------------------------------------- /tests/models/test_speculator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from fms_extras.models.speculator import apply_index_map, flatten_batch 5 | 6 | 7 | def _get_test_inp(): 8 | return list(torch.randn(100, 4, 20, 4).sign()) 9 | 10 | 11 | def test_flatten_correctness(): 12 | # Verify that compression is occurring, and flattened batch corresponds to the flattening map 13 | inps = _get_test_inp() 14 | for inp in inps: 15 | inp_flat, _, ind_map = flatten_batch(inp) 16 | # There are only 8 possible unique candidates, from a set of 20, so compression is at least 2x 17 | assert inp_flat.numel() < inp.numel() // 2 18 | torch.testing.assert_close(inp_flat, apply_index_map(inp.view(-1), ind_map)) 19 | 20 | 21 | def test_flatten_unflatten(): 22 | # Verify that unflat(flat(x)) == x 23 | inps = _get_test_inp() 24 | for inp in inps: 25 | inp_flat, unflat_map, flat_map = flatten_batch(inp) 26 | new_flat = apply_index_map(inp.view(-1), flat_map) 27 | torch.testing.assert_close(inp_flat, new_flat) 28 | new_inp = apply_index_map(new_flat, unflat_map) 29 | torch.testing.assert_close(inp, new_inp) 30 | 31 | 32 | def test_unflatten_flatten(): 33 | # Verify that flat(unflat(x)) == x 34 | inps = _get_test_inp() 35 | for inp in inps: 36 | inp_flat, unflat_map, flat_map = flatten_batch(inp) 37 | new_unflat = apply_index_map(inp_flat, unflat_map) 38 | torch.testing.assert_close(inp, new_unflat) 39 | new_inp_flat = apply_index_map(new_unflat.view(-1), flat_map) 40 | torch.testing.assert_close(inp_flat, new_inp_flat) 41 | -------------------------------------------------------------------------------- /tests/resources/expectations/models.test_calico.TestCalico.test_model_output: -------------------------------------------------------------------------------- 1 | 0.010602295398712158,0.021518290042877197,0.01032547652721405,0.022708401083946228,0.01275646686553955,0.005698159337043762,0.0005097836256027222,0.018675729632377625,0.017626173794269562,0.02315041422843933,0.02719421684741974,3.999471664428711e-05,0.014802753925323486,0.005737707018852234,0.0,0.024082735180854797 -------------------------------------------------------------------------------- /tests/resources/expectations/models.test_calico.TestCalico.test_model_weight_keys: -------------------------------------------------------------------------------- 1 | dec_norm.weight,layers.0.attn.dense.bias,layers.0.attn.dense.weight,layers.0.attn.in_proj.qkv_fused.bias,layers.0.attn.in_proj.qkv_fused.weight,layers.0.ff_ln.weight,layers.0.ff_sub_layer.w2.bias,layers.0.ff_sub_layer.w2.weight,layers.0.ff_sub_layer.wg1_fused.bias,layers.0.ff_sub_layer.wg1_fused.weight,layers.0.ln.weight,layers.1.attn.dense.bias,layers.1.attn.dense.weight,layers.1.attn.in_proj.qkv_fused.bias,layers.1.attn.in_proj.qkv_fused.weight,layers.1.ff_ln.weight,layers.1.ff_sub_layer.w2.bias,layers.1.ff_sub_layer.w2.weight,layers.1.ff_sub_layer.wg1_fused.bias,layers.1.ff_sub_layer.wg1_fused.weight,layers.1.ln.weight,shared.emb.weight,shared.head.weight -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-model-stack/fms-extras/16339f7c82255983d20dabf4807d124133be1d1c/tests/utils/cache/__init__.py -------------------------------------------------------------------------------- /tests/utils/cache/test_paged.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | 5 | construction_test_data = [ 6 | (10, 8, 0, 1), 7 | (2, 4, 2, 1), 8 | (10, 8, 0, 2), 9 | (10, 8, 4, 2), 10 | ] 11 | 12 | 13 | @pytest.mark.skipif( 14 | not torch.cuda.is_available(), reason="must have cuda to run paged attention tests" 15 | ) 16 | @pytest.mark.parametrize( 17 | "num_layers,num_heads,kv_heads,tensor_parallel_size", construction_test_data 18 | ) 19 | def test_construction(num_layers, num_heads, kv_heads, tensor_parallel_size): 20 | from fms_extras.utils.cache.paged import PagedKVCacheManager 21 | 22 | total_num_gpu_blocks = 100 23 | emb_dim = 100 24 | block_size = 16 25 | head_size = emb_dim // num_heads 26 | kv_cache_manager = PagedKVCacheManager( 27 | num_layers, 28 | num_heads, 29 | emb_dim, 30 | kv_heads, 31 | block_size=block_size, 32 | total_num_gpu_blocks=total_num_gpu_blocks, 33 | tensor_parallel_size=tensor_parallel_size, 34 | ) 35 | 36 | assert len(kv_cache_manager.cache) == num_layers 37 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks 38 | assert len(kv_cache_manager.cbg_map) == 0 39 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks 40 | 41 | assert kv_cache_manager.head_size == head_size 42 | if kv_heads == 0: 43 | kv_heads = num_heads 44 | kv_heads_final = kv_heads // tensor_parallel_size if kv_heads > 1 else kv_heads 45 | num_heads_final = num_heads // tensor_parallel_size if num_heads > 1 else num_heads 46 | element_size = torch.tensor([], dtype=kv_cache_manager.dtype).element_size() 47 | x = block_size // element_size 48 | 49 | assert kv_cache_manager.kv_heads == kv_heads_final 50 | assert kv_cache_manager.num_heads == num_heads_final 51 | assert kv_cache_manager.cache[0][0].shape == ( 52 | total_num_gpu_blocks, 53 | kv_heads_final, 54 | head_size // x, 55 | block_size, 56 | x, 57 | ) 58 | assert kv_cache_manager.cache[0][1].shape == ( 59 | total_num_gpu_blocks, 60 | kv_heads_final, 61 | head_size, 62 | block_size, 63 | ) 64 | 65 | 66 | @pytest.mark.skipif( 67 | not torch.cuda.is_available(), reason="must have cuda to run paged attention tests" 68 | ) 69 | def test_allocate_tokens(): 70 | from fms_extras.utils.cache.paged import PagedKVCacheManager 71 | 72 | total_num_gpu_blocks = 100 73 | kv_cache_manager = PagedKVCacheManager( 74 | 4, 4, 16, 0, total_num_gpu_blocks=total_num_gpu_blocks 75 | ) 76 | 77 | # test prompt 78 | # 5 - 1 block 79 | # 18 - 2 blocks 80 | # 40 - 3 blocks 81 | sequence_lengths = [5, 18, 40] 82 | cache_data = kv_cache_manager.allocate_tokens(sequence_lengths) 83 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks - 6 84 | assert not cache_data.is_filled() # this is the prompt so not yet filled 85 | assert cache_data.context_lengths is None 86 | assert cache_data.max_sequence_length == 40 87 | assert cache_data.sequence_ids == [0, 1, 2] 88 | block_mapping = torch.tensor( 89 | [[99, 0, 0], [98, 97, 0], [96, 95, 94]], dtype=torch.int32, device="cuda" 90 | ) 91 | torch.testing.assert_allclose(cache_data.block_mapping, block_mapping) 92 | slot_mapping = torch.tensor( 93 | [ 94 | [ 95 | -1, 96 | -1, 97 | -1, 98 | -1, 99 | -1, 100 | -1, 101 | -1, 102 | -1, 103 | -1, 104 | -1, 105 | -1, 106 | -1, 107 | -1, 108 | -1, 109 | -1, 110 | -1, 111 | -1, 112 | -1, 113 | -1, 114 | -1, 115 | -1, 116 | -1, 117 | -1, 118 | -1, 119 | -1, 120 | -1, 121 | -1, 122 | -1, 123 | -1, 124 | -1, 125 | -1, 126 | -1, 127 | -1, 128 | -1, 129 | -1, 130 | 1584, 131 | 1585, 132 | 1586, 133 | 1587, 134 | 1588, 135 | ], 136 | [ 137 | -1, 138 | -1, 139 | -1, 140 | -1, 141 | -1, 142 | -1, 143 | -1, 144 | -1, 145 | -1, 146 | -1, 147 | -1, 148 | -1, 149 | -1, 150 | -1, 151 | -1, 152 | -1, 153 | -1, 154 | -1, 155 | -1, 156 | -1, 157 | -1, 158 | -1, 159 | 1568, 160 | 1569, 161 | 1570, 162 | 1571, 163 | 1572, 164 | 1573, 165 | 1574, 166 | 1575, 167 | 1576, 168 | 1577, 169 | 1578, 170 | 1579, 171 | 1580, 172 | 1581, 173 | 1582, 174 | 1583, 175 | 1552, 176 | 1553, 177 | ], 178 | [ 179 | 1536, 180 | 1537, 181 | 1538, 182 | 1539, 183 | 1540, 184 | 1541, 185 | 1542, 186 | 1543, 187 | 1544, 188 | 1545, 189 | 1546, 190 | 1547, 191 | 1548, 192 | 1549, 193 | 1550, 194 | 1551, 195 | 1520, 196 | 1521, 197 | 1522, 198 | 1523, 199 | 1524, 200 | 1525, 201 | 1526, 202 | 1527, 203 | 1528, 204 | 1529, 205 | 1530, 206 | 1531, 207 | 1532, 208 | 1533, 209 | 1534, 210 | 1535, 211 | 1504, 212 | 1505, 213 | 1506, 214 | 1507, 215 | 1508, 216 | 1509, 217 | 1510, 218 | 1511, 219 | ], 220 | ], 221 | dtype=torch.int32, 222 | device="cuda", 223 | ) 224 | torch.testing.assert_allclose(cache_data.slot_mapping, slot_mapping) 225 | position_ids = [] 226 | for sequence_length in sequence_lengths: 227 | positions = [0 for _ in range(max(sequence_lengths) - sequence_length)] + [ 228 | i for i in range(sequence_length) 229 | ] 230 | position_ids.append(positions) 231 | torch.testing.assert_allclose( 232 | cache_data.position_ids, 233 | torch.tensor(position_ids, dtype=torch.long, device="cuda"), 234 | ) 235 | 236 | # test generated tokens 237 | num_tokens_per_sequence = [12, 1, 1] 238 | cache_data = kv_cache_manager.allocate_tokens( 239 | num_tokens_per_sequence, sequence_ids=cache_data.sequence_ids 240 | ) 241 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks - 7 242 | assert cache_data.is_filled() # this is the prompt so not yet filled 243 | context_lengths = [l + r for l, r in zip(sequence_lengths, num_tokens_per_sequence)] 244 | assert torch.allclose( 245 | cache_data.context_lengths, 246 | torch.tensor(context_lengths, dtype=torch.int32, device="cuda"), 247 | ) 248 | assert cache_data.max_sequence_length == 41 249 | assert cache_data.sequence_ids == [0, 1, 2] 250 | block_mapping = torch.tensor( 251 | [[99, 93, 0], [98, 97, 0], [96, 95, 94]], dtype=torch.int32, device="cuda" 252 | ) 253 | torch.testing.assert_allclose(cache_data.block_mapping, block_mapping) 254 | slot_mapping = torch.tensor( 255 | [ 256 | [1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1488], 257 | [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1554], 258 | [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1512], 259 | ], 260 | dtype=torch.int32, 261 | device="cuda:0", 262 | ) 263 | torch.testing.assert_allclose(cache_data.slot_mapping, slot_mapping) 264 | position_ids = [] 265 | for n_i, num_tokens in enumerate(num_tokens_per_sequence): 266 | positions = [0 for _ in range(max(num_tokens_per_sequence) - num_tokens)] + [ 267 | sequence_lengths[n_i] + i for i in range(num_tokens) 268 | ] 269 | position_ids.append(positions) 270 | torch.testing.assert_allclose( 271 | cache_data.position_ids, 272 | torch.tensor(position_ids, dtype=torch.long, device="cuda"), 273 | ) 274 | 275 | 276 | @pytest.mark.skipif( 277 | not torch.cuda.is_available(), reason="must have cuda to run paged attention tests" 278 | ) 279 | def test_free_sequences(): 280 | from fms_extras.utils.cache.paged import PagedKVCacheManager 281 | 282 | total_num_gpu_blocks = 100 283 | sequence_lengths = [5, 18, 40] 284 | kv_cache_manager = PagedKVCacheManager( 285 | 4, 4, 16, 0, total_num_gpu_blocks=total_num_gpu_blocks 286 | ) 287 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks 288 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks 289 | assert len(kv_cache_manager.cbg_map) == 0 290 | 291 | # test prompt 292 | # 5 - 1 block 293 | # 18 - 2 blocks 294 | # 40 - 3 blocks 295 | cache_data = kv_cache_manager.allocate_tokens(sequence_lengths) 296 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks - 6 297 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks - len( 298 | sequence_lengths 299 | ) 300 | assert len(kv_cache_manager.cbg_map) == len(sequence_lengths) 301 | 302 | kv_cache_manager.free_sequences(cache_data.sequence_ids) 303 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks 304 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks 305 | assert len(kv_cache_manager.cbg_map) == 0 306 | 307 | 308 | @pytest.mark.skipif( 309 | not torch.cuda.is_available(), reason="must have cuda to run paged attention tests" 310 | ) 311 | def test_free_sequences_recursive(): 312 | from fms_extras.utils.cache.paged import PagedKVCacheManager 313 | 314 | total_num_gpu_blocks = 100 315 | sequence_lengths = [5, 32] 316 | kv_cache_manager = PagedKVCacheManager( 317 | 4, 4, 16, 0, total_num_gpu_blocks=total_num_gpu_blocks 318 | ) 319 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks 320 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks 321 | assert len(kv_cache_manager.cbg_map) == 0 322 | 323 | # test prompt 324 | # 5 - 1 block 325 | # 32 - 2 blocks 326 | cache_data = kv_cache_manager.allocate_tokens(sequence_lengths) 327 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks - 3 328 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks - len( 329 | sequence_lengths 330 | ) 331 | assert len(kv_cache_manager.cbg_map) == len(sequence_lengths) 332 | 333 | child_sequence_ids = kv_cache_manager.add_child_sequences( 334 | cache_data.sequence_ids[0], 3 335 | ) 336 | all_leaf_sequence_ids = [cache_data.sequence_ids[1]] + child_sequence_ids 337 | kv_cache_manager.allocate_tokens( 338 | [1 for _ in all_leaf_sequence_ids], all_leaf_sequence_ids 339 | ) 340 | assert ( 341 | len(kv_cache_manager.free_blocks) == total_num_gpu_blocks - 7 342 | ) # +3 for copied child sequences, +1 for new block 343 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks - ( 344 | len(all_leaf_sequence_ids) + 1 345 | ) # +1 for the original parent of the children 346 | assert len(kv_cache_manager.cbg_map) == ( 347 | len(sequence_lengths) + 3 348 | ) # +3 for the new child sequences 349 | 350 | kv_cache_manager.free_sequences(all_leaf_sequence_ids, recursive=True) 351 | assert len(kv_cache_manager.free_blocks) == total_num_gpu_blocks 352 | assert kv_cache_manager.unused_keys.qsize() == total_num_gpu_blocks 353 | assert len(kv_cache_manager.cbg_map) == 0 354 | -------------------------------------------------------------------------------- /tests/utils/test_generation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import tempfile 4 | from typing import List 5 | 6 | import pytest 7 | import torch 8 | from fms.models import get_model 9 | from fms.utils import serialization 10 | from fms.utils.generation import generate 11 | 12 | from fms_extras.models.speculator import MLPSpeculator 13 | 14 | 15 | @pytest.mark.skipif( 16 | not torch.cuda.is_available(), 17 | reason="must have cuda to run paged llama generation test", 18 | ) 19 | def test_paged_generate(): 20 | from fms_extras.models import paged_llama 21 | from fms_extras.utils.cache.paged import PagedKVCacheManager 22 | from fms_extras.utils.generation import paged_generate 23 | 24 | torch.set_grad_enabled(False) 25 | 26 | llama = get_model("llama", "micro", device_type="cuda", nheads=2) 27 | 28 | with tempfile.TemporaryDirectory() as workdir: 29 | sd_path = f"{workdir}/model.pth" 30 | torch.save(llama.state_dict(), sd_path) 31 | 32 | paged_llama = get_model( 33 | "paged_llama", 34 | "micro", 35 | model_path=sd_path, 36 | source="fms_llama", 37 | device_type="cuda", 38 | nheads=2, 39 | ) 40 | 41 | kv_cache_manager = PagedKVCacheManager( 42 | paged_llama.config.nlayers, 43 | paged_llama.config.nheads, 44 | paged_llama.config.emb_dim, 45 | kv_heads=paged_llama.config.kvheads, 46 | dtype=torch.get_default_dtype(), 47 | total_num_gpu_blocks=100, 48 | ) 49 | 50 | input_ids = torch.tensor( 51 | [1] + [i for i in range(5, 25)], dtype=torch.long, device="cuda" 52 | ) 53 | 54 | paged_result, _, _, _ = paged_generate( 55 | paged_llama, [input_ids], kv_cache_manager, do_sample=False 56 | ) 57 | 58 | result = generate(llama, input_ids.unsqueeze(0), do_sample=False) 59 | 60 | torch.testing.assert_close(paged_result, result) 61 | 62 | 63 | class MockSpeculator(MLPSpeculator): 64 | def __init__(self, candidates_per_step: List[List[List[List[int]]]]): 65 | # candidates_per_step: decode_steps x batch x num_candidates x num_predictions 66 | super().__init__() 67 | self.n_predict = len(candidates_per_step[0][0][0]) 68 | self.step = 0 69 | self.guesses = candidates_per_step 70 | 71 | def generate_suffixes( 72 | self, 73 | state: torch.Tensor, 74 | ind: torch.Tensor, 75 | topk: List[int] = [5, 4, 3], 76 | n: int = 5, 77 | ) -> torch.Tensor: 78 | guess = self.guesses[self.step] 79 | self.step += 1 80 | return torch.tensor(guess, device="cuda").int() 81 | 82 | 83 | @pytest.mark.skipif( 84 | not torch.cuda.is_available(), 85 | reason="must have cuda to run paged llama generation test", 86 | ) 87 | def test_speculative_generate(): 88 | from fms_extras.models import paged_llama 89 | from fms_extras.utils.cache.paged import PagedKVCacheManager 90 | from fms_extras.utils.generation import paged_generate, speculative_generate 91 | 92 | torch.set_grad_enabled(False) 93 | 94 | paged_llama = get_model( 95 | "paged_llama", 96 | "micro", 97 | device_type="cuda", 98 | nheads=2, 99 | ) 100 | 101 | kv_cache_manager = PagedKVCacheManager( 102 | paged_llama.config.nlayers, 103 | paged_llama.config.nheads, 104 | paged_llama.config.emb_dim, 105 | kv_heads=paged_llama.config.kvheads, 106 | dtype=torch.get_default_dtype(), 107 | total_num_gpu_blocks=100, 108 | ) 109 | 110 | input_ids1 = torch.tensor( 111 | [1] + [i for i in range(5, 25)], dtype=torch.long, device="cuda" 112 | ) 113 | 114 | input_ids2 = torch.tensor( 115 | [1] + [i for i in range(30, 35)], dtype=torch.long, device="cuda" 116 | ) 117 | input_ids_list = [input_ids1, input_ids2] 118 | max_prompt = max([input_ids.size(0) for input_ids in input_ids_list]) 119 | max_new_tokens = 20 120 | 121 | paged_result, paged_n_steps, _, _ = paged_generate( 122 | paged_llama, 123 | input_ids_list, 124 | kv_cache_manager, 125 | do_sample=False, 126 | max_new_tokens=max_new_tokens, 127 | ) 128 | 129 | # running tests for different prediction lengths and number of candidates 130 | n_predict_list = [2, 3, 4] 131 | num_candidates_list = [1, 3, 5] 132 | 133 | for n_predict in n_predict_list: 134 | for num_candidates in num_candidates_list: 135 | # randomly generate the correct number of guesses per step to mock the speculator 136 | candidates_per_step = [] 137 | # starting needle at 1 since first token is free 138 | needles = [1, 1] 139 | # need to continue til both sequences have completed max_new_tokens 140 | while any(needle <= max_new_tokens for needle in needles): 141 | candidates = [] 142 | 143 | # get candidates for each sequence 144 | for i, result_i in enumerate(paged_result): 145 | # offsetting by max_prompt to only include generated tokens 146 | tokens = result_i.tolist()[max_prompt:] 147 | candidates_per_sequence = [] 148 | max_num_correct = -1 149 | # adding a max so we reduce chance of all correct 150 | n_correct_max = random.randint(0, n_predict) 151 | 152 | # get each candidate of variable correctness 153 | for _ in range(num_candidates): 154 | n_correct = random.randint(0, n_correct_max) 155 | candidate = tokens[needles[i] : needles[i] + n_predict] 156 | 157 | # inject a wrong token if needed 158 | if n_correct < len(candidate): 159 | candidate[n_correct] = ( 160 | candidate[n_correct] - 1 161 | ) % paged_llama.config.src_vocab_size 162 | 163 | # pad if not enough ground truth tokens left 164 | if len(candidate) < n_predict: 165 | candidate = candidate + ([0] * (n_predict - len(candidate))) 166 | 167 | candidates_per_sequence.append(candidate) 168 | max_num_correct = max(max_num_correct, n_correct) 169 | candidates.append(candidates_per_sequence) 170 | 171 | # +1 for one free token 172 | needles[i] += max_num_correct + 1 173 | candidates_per_step.append(candidates) 174 | 175 | speculator = MockSpeculator(candidates_per_step) 176 | speculative_result, speculative_n_steps, _, _ = speculative_generate( 177 | paged_llama, 178 | input_ids_list, 179 | speculator, 180 | kv_cache_manager, 181 | new_tokens=max_new_tokens, 182 | n_candidates=num_candidates, 183 | ) 184 | 185 | # test that we actually were able to perform speculative decoding in the correct number of steps 186 | assert speculative_n_steps == len(candidates_per_step) 187 | 188 | # test for correctness of output 189 | for paged_single, speculative_single, prompt_ids in zip( 190 | paged_result, speculative_result, input_ids_list 191 | ): 192 | paged_single = paged_single[max_prompt:] 193 | num_pads = max_prompt - prompt_ids.size(0) 194 | speculative_single = speculative_single[ 195 | max_prompt - num_pads : max_prompt - num_pads + paged_single.size(0) 196 | ] 197 | torch.testing.assert_close(paged_single, speculative_single) 198 | --------------------------------------------------------------------------------