├── .gitignore ├── .vscode └── c_cpp_properties.json ├── LICENSE ├── README.md ├── cpm_kernels ├── __init__.py ├── device │ └── __init__.py ├── kernels │ ├── __init__.py │ ├── arith.py │ ├── base.py │ ├── embedding.py │ ├── gelu.py │ ├── gemm.py │ ├── gemv.py │ ├── layernorm.py │ ├── mask.py │ ├── position_bucket.py │ ├── softmax.py │ ├── transpose.py │ └── utils.py ├── library │ ├── __init__.py │ ├── base.py │ ├── cublaslt.py │ ├── cuda.py │ ├── cudart.py │ └── nvrtc.py └── torch │ ├── __init__.py │ ├── arith.py │ ├── embedding.py │ ├── gelu.py │ ├── gemm.py │ ├── layernorm.py │ ├── mask.py │ ├── position_embedding.py │ ├── softmax.py │ ├── transpose.py │ └── utils.py ├── cuda ├── Makefile ├── arith.cu ├── embedding.cu ├── gelu.cu ├── gemm.cu ├── gemv.cu ├── includes │ ├── common.h │ └── reduce.cuh ├── layernorm.cu ├── mask.cu ├── position_bucket.cu ├── softmax.cu ├── transpose.cu └── utils.cu ├── setup.py └── tests ├── run_test.py ├── test_arith.py ├── test_embedding.py ├── test_gelu.py ├── test_gemm.py ├── test_gemv.py ├── test_layernorm.py ├── test_position_embedding.py ├── test_softmax.py ├── test_transpose.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | *.pt 141 | 142 | *.npy 143 | 144 | bminference/version.py 145 | 146 | .DS_Store 147 | *.fatbin -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "cuda/includes", 7 | ], 8 | "defines": [ 9 | ], 10 | "compilerPath": "/usr/local/cuda/bin/nvcc", 11 | "cStandard": "gnu17", 12 | "cppStandard": "gnu++14", 13 | "intelliSenseMode": "linux-gcc-x64" 14 | } 15 | ], 16 | "version": 4 17 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPM kernels 2 | 3 | CUDA Kernels for cpm. -------------------------------------------------------------------------------- /cpm_kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from . import library 2 | from .kernels import * -------------------------------------------------------------------------------- /cpm_kernels/device/__init__.py: -------------------------------------------------------------------------------- 1 | from ..library import cuda, cudart, cublaslt 2 | 3 | ATTRIBUTES = { 4 | "cudaDevAttrMaxThreadsPerBlock": 1, 5 | "cudaDevAttrMaxBlockDimX": 2, 6 | "cudaDevAttrMaxBlockDimY": 3, 7 | "cudaDevAttrMaxBlockDimZ": 4, 8 | "cudaDevAttrMaxGridDimX": 5, 9 | "cudaDevAttrMaxGridDimY": 6, 10 | "cudaDevAttrMaxGridDimZ": 7, 11 | "cudaDevAttrMaxSharedMemoryPerBlock": 8, 12 | "cudaDevAttrTotalConstantMemory": 9, 13 | "cudaDevAttrWarpSize": 10, 14 | "cudaDevAttrMaxPitch": 11, 15 | "cudaDevAttrMaxRegistersPerBlock": 12, 16 | "cudaDevAttrClockRate": 13, 17 | "cudaDevAttrTextureAlignment": 14, 18 | "cudaDevAttrGpuOverlap": 15, 19 | "cudaDevAttrMultiProcessorCount": 16, 20 | "cudaDevAttrKernelExecTimeout": 17, 21 | "cudaDevAttrIntegrated": 18, 22 | "cudaDevAttrCanMapHostMemory": 19, 23 | "cudaDevAttrComputeMode": 20, 24 | "cudaDevAttrMaxTexture1DWidth": 21, 25 | "cudaDevAttrMaxTexture2DWidth": 22, 26 | "cudaDevAttrMaxTexture2DHeight": 23, 27 | "cudaDevAttrMaxTexture3DWidth": 24, 28 | "cudaDevAttrMaxTexture3DHeight": 25, 29 | "cudaDevAttrMaxTexture3DDepth": 26, 30 | "cudaDevAttrMaxTexture2DLayeredWidth": 27, 31 | "cudaDevAttrMaxTexture2DLayeredHeight": 28, 32 | "cudaDevAttrMaxTexture2DLayeredLayers": 29, 33 | "cudaDevAttrSurfaceAlignment": 30, 34 | "cudaDevAttrConcurrentKernels": 31, 35 | "cudaDevAttrEccEnabled": 32, 36 | "cudaDevAttrPciBusId": 33, 37 | "cudaDevAttrPciDeviceId": 34, 38 | "cudaDevAttrTccDriver": 35, 39 | "cudaDevAttrMemoryClockRate": 36, 40 | "cudaDevAttrGlobalMemoryBusWidth": 37, 41 | "cudaDevAttrL2CacheSize": 38, 42 | "cudaDevAttrMaxThreadsPerMultiProcessor": 39, 43 | "cudaDevAttrAsyncEngineCount": 40, 44 | "cudaDevAttrUnifiedAddressing": 41, 45 | "cudaDevAttrMaxTexture1DLayeredWidth": 42, 46 | "cudaDevAttrMaxTexture1DLayeredLayers": 43, 47 | "cudaDevAttrMaxTexture2DGatherWidth": 45, 48 | "cudaDevAttrMaxTexture2DGatherHeight": 46, 49 | "cudaDevAttrMaxTexture3DWidthAlt": 47, 50 | "cudaDevAttrMaxTexture3DHeightAlt": 48, 51 | "cudaDevAttrMaxTexture3DDepthAlt": 49, 52 | "cudaDevAttrPciDomainId": 50, 53 | "cudaDevAttrTexturePitchAlignment": 51, 54 | "cudaDevAttrMaxTextureCubemapWidth": 52, 55 | "cudaDevAttrMaxTextureCubemapLayeredWidth": 53, 56 | "cudaDevAttrMaxTextureCubemapLayeredLayers": 54, 57 | "cudaDevAttrMaxSurface1DWidth": 55, 58 | "cudaDevAttrMaxSurface2DWidth": 56, 59 | "cudaDevAttrMaxSurface2DHeight": 57, 60 | "cudaDevAttrMaxSurface3DWidth": 58, 61 | "cudaDevAttrMaxSurface3DHeight": 59, 62 | "cudaDevAttrMaxSurface3DDepth": 60, 63 | "cudaDevAttrMaxSurface1DLayeredWidth": 61, 64 | "cudaDevAttrMaxSurface1DLayeredLayers": 62, 65 | "cudaDevAttrMaxSurface2DLayeredWidth": 63, 66 | "cudaDevAttrMaxSurface2DLayeredHeight": 64, 67 | "cudaDevAttrMaxSurface2DLayeredLayers": 65, 68 | "cudaDevAttrMaxSurfaceCubemapWidth": 66, 69 | "cudaDevAttrMaxSurfaceCubemapLayeredWidth": 67, 70 | "cudaDevAttrMaxSurfaceCubemapLayeredLayers": 68, 71 | "cudaDevAttrMaxTexture1DLinearWidth": 69, 72 | "cudaDevAttrMaxTexture2DLinearWidth": 70, 73 | "cudaDevAttrMaxTexture2DLinearHeight": 71, 74 | "cudaDevAttrMaxTexture2DLinearPitch": 72, 75 | "cudaDevAttrMaxTexture2DMipmappedWidth": 73, 76 | "cudaDevAttrMaxTexture2DMipmappedHeight": 74, 77 | "cudaDevAttrComputeCapabilityMajor": 75, 78 | "cudaDevAttrComputeCapabilityMinor": 76, 79 | "cudaDevAttrMaxTexture1DMipmappedWidth": 77, 80 | "cudaDevAttrStreamPrioritiesSupported": 78, 81 | "cudaDevAttrGlobalL1CacheSupported": 79, 82 | "cudaDevAttrLocalL1CacheSupported": 80, 83 | "cudaDevAttrMaxSharedMemoryPerMultiprocessor": 81, 84 | "cudaDevAttrMaxRegistersPerMultiprocessor": 82, 85 | "cudaDevAttrManagedMemory": 83, 86 | "cudaDevAttrIsMultiGpuBoard": 84, 87 | "cudaDevAttrMultiGpuBoardGroupID": 85, 88 | "cudaDevAttrHostNativeAtomicSupported": 86, 89 | "cudaDevAttrSingleToDoublePrecisionPerfRatio": 87, 90 | "cudaDevAttrPageableMemoryAccess": 88, 91 | "cudaDevAttrConcurrentManagedAccess": 89, 92 | "cudaDevAttrComputePreemptionSupported": 90, 93 | "cudaDevAttrCanUseHostPointerForRegisteredMem": 91, 94 | "cudaDevAttrReserved92": 92, 95 | "cudaDevAttrReserved93": 93, 96 | "cudaDevAttrReserved94": 94, 97 | "cudaDevAttrCooperativeLaunch": 95, 98 | "cudaDevAttrCooperativeMultiDeviceLaunch": 96, 99 | "cudaDevAttrMaxSharedMemoryPerBlockOptin": 97, 100 | "cudaDevAttrCanFlushRemoteWrites": 98, 101 | "cudaDevAttrHostRegisterSupported": 99, 102 | "cudaDevAttrPageableMemoryAccessUsesHostPageTables": 100, 103 | "cudaDevAttrDirectManagedMemAccessFromHost": 101, 104 | } 105 | 106 | class _Device: 107 | def __init__(self, index): 108 | self._index = index 109 | self.attributes = {} 110 | self._initialized = False 111 | 112 | for kw, idx in ATTRIBUTES.items(): 113 | self.attributes[kw] = cudart.cudaDeviceGetAttribute(idx, self._index) 114 | 115 | def use(self): 116 | cudart.cudaSetDevice(self._index) 117 | if not self._initialized: 118 | cudart.cudaFree( None ) # lazy initialze 119 | self._initialized = True 120 | self.cublasLtHandle = cublaslt.cublasLtCreate() 121 | 122 | 123 | 124 | if cudart.version > 0: 125 | _DEVICES = [ 126 | _Device(i) for i in range(cudart.cudaGetDeviceCount()) 127 | ] 128 | else: 129 | _DEVICES = [] 130 | 131 | class Device: 132 | def __init__(self, index) -> None: 133 | if index > len(_DEVICES): 134 | raise ValueError("Device index out of range (%d >= %d)" % (index, len(_DEVICES))) 135 | 136 | self._device = _DEVICES[index] 137 | for kw, value in self._device.attributes.items(): 138 | setattr(self, kw[len("cuda"):], value) 139 | 140 | def attr(self, name : str) -> int: 141 | return self._device.attr(name) 142 | 143 | @property 144 | def architecture(self) -> int: 145 | return self.DevAttrComputeCapabilityMajor * 10 + self.DevAttrComputeCapabilityMinor 146 | 147 | @property 148 | def cublasLtHandle(self) -> cublaslt.cublasLtHandle_t: 149 | return self._device.cublasLtHandle 150 | 151 | def use(self): 152 | self._device.use() 153 | 154 | def num_devices(): 155 | return len(_DEVICES) 156 | 157 | def current_device() -> Device: 158 | return Device(cudart.cudaGetDevice()) -------------------------------------------------------------------------------- /cpm_kernels/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import embedding_forward, embedding_backward_stage1, embedding_backward_stage2, embedding_step 2 | from .gelu import gelu_forward, gelu_backward, gelu_inplace_forward 3 | from .gemm import gemm_calc_scale, gemm_calc_scale_transpose, gemm_round, gemm_round_transpose, gemm_scale, gemm_fp16, gemm_int8, gemm_backward_round_scale, gemm_backward_scale_round, gemm_scale_x, gemm_scale_y 4 | from .gemv import gemv_fp16, gemv_broadcast_mat_int8, gemv_fp16_transpose, gemv_broadcast_mat_fp16, gemv_calc_scale, gemv_round, \ 5 | gemv_broadcast_mat_fp16_light, gemv_fp16_light, gemv_fp16_transpose_light 6 | from .mask import mask 7 | from .arith import arith_batch_add_backward, arith_batch_add_forward, arith_element_add, arith_ln_add_backward, arith_ln_div, arith_ln_mul, \ 8 | arith_ln_mul_add, arith_ln_mul_backward, arith_ln_sub_div, arith_element_mul, arith_ln_add, \ 9 | arith_batch_mul, arith_batch_mul_add, arith_global_scale 10 | 11 | from .layernorm import layernorm_forward, layernorm_inplace_forward, layernorm_forward_v, layernorm_forward_mv, layernorm_backward_v, layernorm_backward_mv, layernorm_step, layernorm_step_inplace 12 | from .position_bucket import position_embedding_init, position_embedding_forward, position_embedding_backward, position_embedding_step 13 | from .softmax import softmax_forward, softmax_backward, softmax_inplace_forward, softmax_step_inplace 14 | from .transpose import transpose 15 | from .utils import copy_data_to_kv, has_nan_inf -------------------------------------------------------------------------------- /cpm_kernels/kernels/arith.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | arith_kernel = Kernel( 5 | "arith", 6 | [ 7 | "cu_arith_global_scale", 8 | "cu_arith_element_add", 9 | "cu_arith_element_mul", 10 | "cu_arith_batch_add_forward", 11 | "cu_arith_batch_add_backward", 12 | "cu_arith_ln_mul_add", 13 | "cu_arith_ln_add", 14 | "cu_arith_ln_mul", 15 | "cu_arith_ln_div", 16 | "cu_arith_ln_sub_div", 17 | "cu_arith_ln_mul_backward", 18 | "cu_arith_ln_add_backward", 19 | "cu_arith_batch_mul_add", 20 | "cu_arith_batch_mul" 21 | ] 22 | ) 23 | 24 | def arith_global_scale( 25 | n : int, 26 | inp : DevicePointer, # (n,) fp16 27 | scale : float, 28 | out : DevicePointer, # (n,) fp16 29 | stream : CUDAStream 30 | ): 31 | threads = min(round_up(n, 32), 1024) 32 | gridDim = (round_up(n, threads) // threads, 1, 1) 33 | blockDim = (threads, 1, 1) 34 | arith_kernel.cu_arith_global_scale( 35 | gridDim, blockDim, 0, stream, [ 36 | ctypes.c_int64(n), 37 | ctypes.c_void_p(inp), 38 | ctypes.c_float(scale), 39 | ctypes.c_void_p(out) 40 | ] 41 | ) 42 | 43 | def arith_element_add( 44 | batch : int, n : int, 45 | x : DevicePointer, # (batch, n) fp16 46 | y : DevicePointer, # (batch, n) fp16 47 | out : DevicePointer, # (batch, n) fp16 48 | stream : CUDAStream 49 | ): 50 | """ 51 | out = x + y 52 | """ 53 | assert n % 2 == 0 54 | n = n // 2 55 | threads = min(round_up(n, 32), 1024) 56 | gridDim = (batch, round_up(n, threads) // threads, 1) 57 | blockDim = (threads, 1, 1) 58 | arith_kernel.cu_arith_element_add( 59 | gridDim, blockDim, 0, stream, [ 60 | ctypes.c_int64(batch), 61 | ctypes.c_int64(n), 62 | ctypes.c_void_p(x), 63 | ctypes.c_void_p(y), 64 | ctypes.c_void_p(out) 65 | ] 66 | ) 67 | 68 | def arith_element_mul( 69 | batch : int, n : int, 70 | x : DevicePointer, # (batch, n) fp16 71 | y : DevicePointer, # (batch, n) fp16 72 | out : DevicePointer, # (batch, n) fp16 73 | stream : CUDAStream 74 | ): 75 | """ 76 | out = x * y 77 | """ 78 | assert n % 2 == 0 79 | n = n // 2 80 | threads = min(round_up(n, 32), 1024) 81 | gridDim = (batch, round_up(n, threads) // threads, 1) 82 | blockDim = (threads, 1, 1) 83 | arith_kernel.cu_arith_element_mul( 84 | gridDim, blockDim, 0, stream, [ 85 | ctypes.c_int64(batch), 86 | ctypes.c_int64(n), 87 | ctypes.c_void_p(x), 88 | ctypes.c_void_p(y), 89 | ctypes.c_void_p(out) 90 | ] 91 | ) 92 | 93 | def arith_batch_add_forward( 94 | batch : int, n : int, 95 | x : DevicePointer, # (batch, n) fp16 96 | y : DevicePointer, # (n) fp16 97 | out : DevicePointer, # (batch, n) fp16 98 | stream : CUDAStream 99 | ): 100 | """ 101 | out = x + y[None, :] 102 | """ 103 | assert n % 2 == 0 104 | n = n // 2 105 | threads = min(round_up(n, 32), 1024) 106 | gridDim = (batch, round_up(n, threads) // threads, 1) 107 | blockDim = (threads, 1, 1) 108 | arith_kernel.cu_arith_batch_add_forward( 109 | gridDim, blockDim, 0, stream, [ 110 | ctypes.c_int64(batch), 111 | ctypes.c_int64(n), 112 | ctypes.c_void_p(x), 113 | ctypes.c_void_p(y), 114 | ctypes.c_void_p(out) 115 | ] 116 | ) 117 | 118 | def arith_batch_add_backward( 119 | batch : int, n : int, 120 | grad_out : DevicePointer, # (batch, n) fp16 121 | grad : DevicePointer, # (n) fp16 122 | stream : CUDAStream 123 | ): 124 | gridDim = ( round_up(n, 32) // 32, 1, 1 ) 125 | blockDim = (32, 32, 1) 126 | arith_kernel.cu_arith_batch_add_backward( 127 | gridDim, blockDim, 0, stream, [ 128 | ctypes.c_int64(batch), 129 | ctypes.c_int64(n), 130 | ctypes.c_void_p(grad_out), 131 | ctypes.c_void_p(grad) 132 | ] 133 | ) 134 | 135 | def arith_ln_mul_add( 136 | batch : int, n : int, m : int, 137 | inp : DevicePointer, # (batch, n, m) fp16 138 | alpha : DevicePointer, # (n) fp16 139 | beta : DevicePointer, # (n) fp16 140 | out : DevicePointer, # (batch, n, m) fp16 141 | stream : CUDAStream 142 | ): 143 | """ 144 | out = x * alpha[None, :, None] + beta[None, :, None] 145 | """ 146 | assert m % 2 == 0 147 | m = m // 2 148 | threads = min(round_up(m, 32), 1024) 149 | gridDim = (batch, n, round_up(m, threads) // threads) 150 | blockDim = (threads, 1, 1) 151 | arith_kernel.cu_arith_ln_mul_add( 152 | gridDim, blockDim, 0, stream, [ 153 | ctypes.c_int64(batch), 154 | ctypes.c_int64(n), 155 | ctypes.c_int64(m), 156 | ctypes.c_void_p(inp), 157 | ctypes.c_void_p(alpha), 158 | ctypes.c_void_p(beta), 159 | ctypes.c_void_p(out) 160 | ] 161 | ) 162 | 163 | def arith_ln_add( 164 | batch : int, n : int, m : int, 165 | inp : DevicePointer, # (batch, n, m) fp16 166 | beta : DevicePointer, # (n) fp16 167 | out : DevicePointer, # (batch, n, m) fp16 168 | stream : CUDAStream 169 | ): 170 | """ 171 | out = x + beta[None, :, None] 172 | """ 173 | assert m % 2 == 0 174 | m = m // 2 175 | threads = min(round_up(m, 32), 1024) 176 | gridDim = (batch, n, round_up(m, threads) // threads) 177 | blockDim = (threads, 1, 1) 178 | arith_kernel.cu_arith_ln_add( 179 | gridDim, blockDim, 0, stream, [ 180 | ctypes.c_int64(batch), 181 | ctypes.c_int64(n), 182 | ctypes.c_int64(m), 183 | ctypes.c_void_p(inp), 184 | ctypes.c_void_p(beta), 185 | ctypes.c_void_p(out) 186 | ] 187 | ) 188 | 189 | 190 | def arith_ln_mul( 191 | batch : int, n : int, m : int, 192 | inp : DevicePointer, # (batch, n, m) fp16 193 | alpha : DevicePointer, # (n) fp16 194 | out : DevicePointer, # (batch, n, m) fp16 195 | stream : CUDAStream 196 | ): 197 | """ 198 | out = x * alpha[None, :, None] 199 | """ 200 | assert m % 2 == 0 201 | m = m // 2 202 | threads = min(round_up(m, 32), 1024) 203 | gridDim = (batch, n, round_up(m, threads) // threads) 204 | blockDim = (threads, 1, 1) 205 | arith_kernel.cu_arith_ln_mul( 206 | gridDim, blockDim, 0, stream, [ 207 | ctypes.c_int64(batch), 208 | ctypes.c_int64(n), 209 | ctypes.c_int64(m), 210 | ctypes.c_void_p(inp), 211 | ctypes.c_void_p(alpha), 212 | ctypes.c_void_p(out) 213 | ] 214 | ) 215 | 216 | def arith_ln_div( 217 | batch : int, n : int, m : int, 218 | inp : DevicePointer, # (batch, n, m) fp16 219 | alpha : DevicePointer, # (n) fp16 220 | out : DevicePointer, # (batch, n, m) fp16 221 | stream : CUDAStream 222 | ): 223 | """ 224 | out = x / alpha[None, :, None] 225 | """ 226 | assert m % 2 == 0 227 | m = m // 2 228 | threads = min(round_up(m, 32), 1024) 229 | gridDim = (batch, n, round_up(m, threads) // threads) 230 | blockDim = (threads, 1, 1) 231 | arith_kernel.cu_arith_ln_div( 232 | gridDim, blockDim, 0, stream, [ 233 | ctypes.c_int64(batch), 234 | ctypes.c_int64(n), 235 | ctypes.c_int64(m), 236 | ctypes.c_void_p(inp), 237 | ctypes.c_void_p(alpha), 238 | ctypes.c_void_p(out) 239 | ] 240 | ) 241 | 242 | def arith_ln_sub_div( 243 | batch : int, n : int, m : int, 244 | inp : DevicePointer, # (batch, n, m) fp16 245 | alpha : DevicePointer, # (n) fp16 246 | beta : DevicePointer, # (n) fp16 247 | out : DevicePointer, # (batch, n, m) fp16 248 | stream : CUDAStream 249 | ): 250 | """ 251 | out = (x - beta[None, :, None]) / alpha[None, :, None] 252 | """ 253 | assert m % 2 == 0 254 | m = m // 2 255 | threads = min(round_up(m, 32), 1024) 256 | gridDim = (batch, n, round_up(m, threads) // threads) 257 | blockDim = (threads, 1, 1) 258 | arith_kernel.cu_arith_ln_sub_div( 259 | gridDim, blockDim, 0, stream, [ 260 | ctypes.c_int64(batch), 261 | ctypes.c_int64(n), 262 | ctypes.c_int64(m), 263 | ctypes.c_void_p(inp), 264 | ctypes.c_void_p(alpha), 265 | ctypes.c_void_p(beta), 266 | ctypes.c_void_p(out) 267 | ] 268 | ) 269 | 270 | 271 | def arith_ln_mul_backward( 272 | batch : int, n : int, m : int, 273 | inp : DevicePointer, # (batch, n, m) fp16 274 | grad_out : DevicePointer, # (batch, n, m) fp16 275 | grad : DevicePointer, # (n) fp16 276 | stream : CUDAStream 277 | ): 278 | gridDim = (n, 1, 1) 279 | blockDim = (32, 32, 1) 280 | arith_kernel.cu_arith_ln_mul_backward( 281 | gridDim, blockDim, 0, stream, [ 282 | ctypes.c_int64(batch), 283 | ctypes.c_int64(n), 284 | ctypes.c_int64(m), 285 | ctypes.c_void_p(inp), 286 | ctypes.c_void_p(grad_out), 287 | ctypes.c_void_p(grad) 288 | ] 289 | ) 290 | 291 | def arith_ln_add_backward( 292 | batch : int, n : int, m : int, 293 | grad_out : DevicePointer, # (batch, n, m) fp16 294 | grad : DevicePointer, # (n) fp16 295 | stream : CUDAStream 296 | ): 297 | gridDim = (n, 1, 1) 298 | blockDim = (32, 32, 1) 299 | arith_kernel.cu_arith_ln_add_backward( 300 | gridDim, blockDim, 0, stream, [ 301 | ctypes.c_int64(batch), 302 | ctypes.c_int64(n), 303 | ctypes.c_int64(m), 304 | ctypes.c_void_p(grad_out), 305 | ctypes.c_void_p(grad) 306 | ] 307 | ) 308 | 309 | def arith_batch_mul_add( 310 | batch : int, n : int, 311 | x : DevicePointer, # (batch, n) 312 | alpha : DevicePointer, # (n) 313 | beta : DevicePointer, # (n) 314 | out : DevicePointer, # (batch, n) 315 | stream : CUDAStream 316 | ): 317 | assert n % 2 == 0 318 | n = n // 2 319 | threads = min(round_up(n, 32), 1024) 320 | gridDim = (batch, round_up(n, threads) // threads, 1) 321 | blockDim = (threads, 1, 1) 322 | arith_kernel.cu_arith_batch_mul_add( 323 | gridDim, blockDim, 0, stream, [ 324 | ctypes.c_int64(batch), 325 | ctypes.c_int64(n), 326 | ctypes.c_void_p(x), 327 | ctypes.c_void_p(alpha), 328 | ctypes.c_void_p(beta), 329 | ctypes.c_void_p(out) 330 | ] 331 | ) 332 | 333 | def arith_batch_mul( 334 | batch : int, n : int, 335 | x : DevicePointer, # (batch, n) 336 | alpha : DevicePointer, # (n) 337 | out : DevicePointer, # (batch, n) 338 | stream : CUDAStream 339 | ): 340 | assert n % 2 == 0 341 | n = n // 2 342 | threads = min(round_up(n, 32), 1024) 343 | gridDim = (batch, round_up(n, threads) // threads, 1) 344 | blockDim = (threads, 1, 1) 345 | arith_kernel.cu_arith_batch_mul( 346 | gridDim, blockDim, 0, stream, [ 347 | ctypes.c_int64(batch), 348 | ctypes.c_int64(n), 349 | ctypes.c_void_p(x), 350 | ctypes.c_void_p(alpha), 351 | ctypes.c_void_p(out) 352 | ] 353 | ) 354 | 355 | -------------------------------------------------------------------------------- /cpm_kernels/kernels/base.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | from typing import List, Any, Tuple 4 | from ..library import cuda, cudart 5 | from ..device import Device 6 | import pkg_resources 7 | DevicePointer = int 8 | CUDAStream = cudart.cudaStream_t 9 | 10 | RESOURCE_PACKAGE_NAME = __name__ 11 | 12 | def round_up(x : int, m : int) -> int: 13 | return (x + m - 1) // m * m 14 | 15 | class LazyKernelCModule: 16 | def __init__(self, code): 17 | self._code = code 18 | self._module = {} 19 | 20 | def get_module(self): 21 | curr_device = cudart.cudaGetDevice() 22 | if curr_device not in self._module: 23 | Device(curr_device).use() # force initialize context 24 | self._module[curr_device] = cuda.cuModuleLoadData(self._code) 25 | return self._module[curr_device] 26 | 27 | 28 | 29 | class KernelFunction: 30 | def __init__(self, cmodule : LazyKernelCModule, func_name : str) -> None: 31 | self._module = cmodule 32 | self._funcs = {} 33 | self._func_name = func_name 34 | 35 | def _prepare_func(self): 36 | curr_device = cudart.cudaGetDevice() 37 | cudart.cudaSetDevice(curr_device) # ensure cudart context 38 | if curr_device not in self._funcs: 39 | self._funcs[curr_device] = cuda.cuModuleGetFunction( 40 | self._module.get_module(), self._func_name 41 | ) 42 | return self._funcs[curr_device] 43 | 44 | def __call__(self, gridDim : Tuple[int, int, int], blockDim : Tuple[int, int, int], 45 | sharedMemBytes : int, stream : cudart.cudaStream_t, params : List[Any] ) -> None: 46 | assert len(gridDim) == 3 47 | assert len(blockDim) == 3 48 | func = self._prepare_func() 49 | 50 | cuda.cuLaunchKernel(func, 51 | gridDim[0], gridDim[1], gridDim[2], 52 | blockDim[0], blockDim[1], blockDim[2], 53 | sharedMemBytes, stream, [ 54 | ctypes.addressof(p) for p in params 55 | ] 56 | ) 57 | 58 | 59 | class Kernel: 60 | def __init__(self, filename : str, function_names : List[str]): 61 | filename = filename + ".fatbin" 62 | filename = os.path.join("cuda", filename) 63 | if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename): 64 | raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME)) 65 | self.filename = filename 66 | self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename) 67 | self._function_names = function_names 68 | self._cmodule = LazyKernelCModule(self.code) 69 | 70 | for name in self._function_names: 71 | setattr(self, name, KernelFunction(self._cmodule, name)) 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /cpm_kernels/kernels/embedding.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | embedding_kernel = Kernel( 5 | "embedding", 6 | [ 7 | "cu_embedding_forward", 8 | "cu_embedding_backward_stage1", 9 | "cu_embedding_backward_stage2", 10 | "cu_embedding_step" 11 | ] 12 | ) 13 | 14 | def embedding_forward( 15 | batch : int, 16 | hidden_size : int, # hidden size 17 | seq_len : int, # sequence length 18 | ids : DevicePointer, # (batch, m) 19 | weights : DevicePointer, # (vocab_size, n) 20 | out : DevicePointer, # (batch, n, m) 21 | stream : CUDAStream 22 | ): 23 | gridDim = (batch, round_up(seq_len, 32) // 32, round_up(hidden_size, 32) // 32) 24 | blockDim = (32, 32, 1) 25 | embedding_kernel.cu_embedding_forward( 26 | gridDim, blockDim, 0, stream, [ 27 | ctypes.c_int32(batch), 28 | ctypes.c_int32(hidden_size), 29 | ctypes.c_int32(seq_len), 30 | ctypes.c_void_p(ids), 31 | ctypes.c_void_p(weights), 32 | ctypes.c_void_p(out) 33 | ] 34 | ) 35 | 36 | def embedding_backward_stage1( 37 | batch : int, 38 | seq_len : int, 39 | hidden_size : int, 40 | grad_out : DevicePointer, # (batch * n, m) 41 | argsort_ids : DevicePointer, # (batch, n) 42 | sorted_ids : DevicePointer, # (batch, n) 43 | grad : DevicePointer, # (vocab_size, m) 44 | aux_grad : DevicePointer, # (batch, m) 45 | aux_grad_idx : DevicePointer, # (batch) 46 | stream : CUDAStream 47 | ): 48 | """ 49 | Sort idx and calc grad stage1 50 | """ 51 | gridDim = (batch, round_up(hidden_size, 1024) // 1024, 1) 52 | blockDim = (1024, 1, 1) 53 | embedding_kernel.cu_embedding_backward_stage1( 54 | gridDim, blockDim, 0, stream, [ 55 | ctypes.c_int32(batch), 56 | ctypes.c_int32(seq_len), 57 | ctypes.c_int32(hidden_size), 58 | ctypes.c_void_p(grad_out), 59 | ctypes.c_void_p(argsort_ids), 60 | ctypes.c_void_p(sorted_ids), 61 | ctypes.c_void_p(grad), 62 | ctypes.c_void_p(aux_grad), 63 | ctypes.c_void_p(aux_grad_idx) 64 | ] 65 | ) 66 | 67 | def embedding_backward_stage2( 68 | batch : int, 69 | hidden_size : int, 70 | aux_grad : DevicePointer, # (batch, m) 71 | aux_grad_idx : DevicePointer, # (batch) 72 | grad : DevicePointer, # (vocab_size, m) 73 | stream : CUDAStream 74 | ): 75 | 76 | gridDim = (round_up(hidden_size, 1024) // 1024, 1, 1) 77 | blockDim = (1024, 1, 1) 78 | embedding_kernel.cu_embedding_backward_stage2( 79 | gridDim, blockDim, 0, stream, [ 80 | ctypes.c_int32(batch), 81 | ctypes.c_int32(hidden_size), 82 | ctypes.c_void_p(aux_grad), 83 | ctypes.c_void_p(aux_grad_idx), 84 | ctypes.c_void_p(grad) 85 | ] 86 | ) 87 | 88 | def embedding_step( 89 | batch : int, embedding_size : int, 90 | ids : DevicePointer, # (batch,) int32 91 | weights : DevicePointer, # (vocab_size, embedding_size) fp16 92 | out : DevicePointer, # (batch, embedding_size) fp16 93 | stream : CUDAStream 94 | ): 95 | gridDim = (batch, 1, 1) 96 | blockDim = (min(1024, embedding_size), 1, 1) 97 | embedding_kernel.cu_embedding_step( 98 | gridDim, blockDim, 0, stream, [ 99 | ctypes.c_int32(batch), 100 | ctypes.c_int32(embedding_size), 101 | ctypes.c_void_p(ids), 102 | ctypes.c_void_p(weights), 103 | ctypes.c_void_p(out) 104 | ] 105 | ) -------------------------------------------------------------------------------- /cpm_kernels/kernels/gelu.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | gelu_kernel = Kernel( 5 | "gelu", 6 | [ 7 | "cu_gelu_forward", 8 | "cu_gelu_backward", 9 | ] 10 | ) 11 | 12 | 13 | def gelu_forward( 14 | batch : int, 15 | n : int, 16 | mat : DevicePointer, 17 | out : DevicePointer, 18 | stream : CUDAStream 19 | ): 20 | threads = min(round_up(n, 32), 1024) 21 | gridDim = (batch, round_up(n, threads) // threads, 1) 22 | blockDim = (threads, 1, 1) 23 | gelu_kernel.cu_gelu_forward( 24 | gridDim, blockDim, 0, stream, [ 25 | ctypes.c_int32(batch), 26 | ctypes.c_int32(n), 27 | ctypes.c_void_p(mat), 28 | ctypes.c_void_p(out) 29 | ] 30 | ) 31 | 32 | def gelu_inplace_forward( 33 | batch : int, 34 | n : int, 35 | mat : DevicePointer, 36 | stream : CUDAStream 37 | ): 38 | threads = min(round_up(n, 32), 1024) 39 | gridDim = (batch, round_up(n, threads) // threads, 1) 40 | blockDim = (threads, 1, 1) 41 | gelu_kernel.cu_gelu_forward( 42 | gridDim, blockDim, 0, stream, [ 43 | ctypes.c_int32(batch), 44 | ctypes.c_int32(n), 45 | ctypes.c_void_p(mat), 46 | ctypes.c_void_p(mat) 47 | ] 48 | ) 49 | 50 | def gelu_backward( 51 | batch : int, 52 | n : int, 53 | grad_out : DevicePointer, 54 | mat : DevicePointer, 55 | grad : DevicePointer, 56 | stream : CUDAStream 57 | ): 58 | threads = min(round_up(n, 32), 1024) 59 | gridDim = (batch, round_up(n, threads) // threads, 1) 60 | blockDim = (threads, 1, 1) 61 | gelu_kernel.cu_gelu_backward( 62 | gridDim, blockDim, 0, stream, [ 63 | ctypes.c_int32(batch), 64 | ctypes.c_int32(n), 65 | ctypes.c_void_p(grad_out), 66 | ctypes.c_void_p(mat), 67 | ctypes.c_void_p(grad) 68 | ] 69 | ) 70 | -------------------------------------------------------------------------------- /cpm_kernels/kernels/gemv.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | from ..library import cublaslt 3 | from ..device import current_device 4 | import ctypes 5 | 6 | gemv_kernel = Kernel( 7 | "gemv", 8 | [ 9 | "cu_gemv_calc_scale", 10 | "cu_gemv_round", 11 | "cu_gemv_broadcast_mat_int8", 12 | "cu_gemv_fp16", 13 | "cu_gemv_fp16_transpose", 14 | "cu_gemv_broadcast_mat_fp16" 15 | ] 16 | ) 17 | 18 | def gemv_calc_scale( 19 | batch : int, n : int, 20 | vec : DevicePointer, # (batch, n) fp16 21 | out : DevicePointer, # (batch,) fp16 22 | stream : CUDAStream 23 | ): 24 | gridDim = (batch, 1, 1) 25 | blockDim = (min(1024, round_up(n, 32)), 1, 1) 26 | gemv_kernel.cu_gemv_calc_scale( 27 | gridDim, blockDim, 0, stream, [ 28 | ctypes.c_int32(batch), 29 | ctypes.c_int32(n), 30 | ctypes.c_void_p(vec), 31 | ctypes.c_void_p(out) 32 | ] 33 | ) 34 | 35 | def gemv_round( 36 | batch : int, n : int, 37 | vec : DevicePointer, # (batch, n) fp16 38 | scale : DevicePointer, # (batch) fp16 39 | out : DevicePointer, # (batch, n) int8 40 | stream : CUDAStream 41 | ): 42 | threads = min(1024, round_up(n, 32)) 43 | gridDim = (batch, round_up(n, threads) // threads, 1) 44 | blockDim = (threads, 1, 1) 45 | gemv_kernel.cu_gemv_round( 46 | gridDim, blockDim, 0, stream, [ 47 | ctypes.c_int32(batch), 48 | ctypes.c_int32(n), 49 | ctypes.c_void_p(vec), 50 | ctypes.c_void_p(scale), 51 | ctypes.c_void_p(out) 52 | ] 53 | ) 54 | 55 | 56 | def gemv_broadcast_mat_int8( 57 | batch : int, dim_out : int, dim_in : int, 58 | scale_mat : DevicePointer, # (dim_out,) fp16 59 | mat : DevicePointer, # (dim_out, dim_in) int8 60 | scale_vec : DevicePointer, # (batch,) fp16 61 | vec : DevicePointer, # (batch, dim_in) int8 62 | out : DevicePointer, # (batch, dim_out) fp16 63 | stream : CUDAStream 64 | ): 65 | assert dim_in % 4 == 0 66 | gridDim = (batch, dim_out, 1) 67 | blockDim = (min(1024, round_up(dim_in // 4, 32)), 1, 1) 68 | gemv_kernel.cu_gemv_broadcast_mat_int8( 69 | gridDim, blockDim, 0, stream, [ 70 | ctypes.c_int32(batch), 71 | ctypes.c_int32(dim_out), 72 | ctypes.c_int32(dim_in), 73 | ctypes.c_void_p(scale_mat), 74 | ctypes.c_void_p(mat), 75 | ctypes.c_void_p(scale_vec), 76 | ctypes.c_void_p(vec), 77 | ctypes.c_void_p(out) 78 | ] 79 | ) 80 | 81 | 82 | def gemv_fp16( 83 | batch : int, dim_out : int, dim_in : int, 84 | mat : DevicePointer, # (batch, dim_out, dim_in) fp16 85 | vec : DevicePointer, # (batch, dim_in) fp16 86 | out : DevicePointer, # (batch, dim_out) fp16 87 | stream : CUDAStream 88 | ): 89 | device = current_device() 90 | device.use() 91 | layoutA = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_in, dim_out, dim_in) 92 | layoutB = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_in, 1, dim_in) 93 | layoutC = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_out, 1, dim_out) 94 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutA, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 95 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutA, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_in * dim_out)) 96 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutB, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 97 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutB, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_in)) 98 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutC, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 99 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutC, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_out)) 100 | fallback_32f = device.architecture < 62 101 | 102 | if cublaslt.version >= 11000: 103 | if fallback_32f: 104 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUBLAS_COMPUTE_32F, cublaslt.CUDA_R_32F) 105 | else: 106 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUBLAS_COMPUTE_16F, cublaslt.CUDA_R_16F) 107 | else: 108 | if fallback_32f: 109 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUDA_R_32F) 110 | else: 111 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUDA_R_16F) 112 | cublaslt.cublasLtMatmulDescSetAttribute(matmulHandle, cublaslt.CUBLASLT_MATMUL_DESC_TRANSA, ctypes.c_int32(cublaslt.CUBLAS_OP_T)) 113 | cublaslt.cublasLtMatmul( 114 | device.cublasLtHandle, 115 | matmulHandle, 116 | ctypes.c_float(1.0) if fallback_32f else ctypes.c_short(15360), # half(1) 117 | mat, layoutA, 118 | vec, layoutB, 119 | ctypes.c_float(0) if fallback_32f else ctypes.c_short(0), # half(0) 120 | out, layoutC, 121 | out, layoutC, 122 | stream 123 | ) 124 | cublaslt.cublasLtMatmulDescDestroy(matmulHandle) 125 | cublaslt.cublasLtMatrixLayoutDestroy(layoutA) 126 | cublaslt.cublasLtMatrixLayoutDestroy(layoutB) 127 | cublaslt.cublasLtMatrixLayoutDestroy(layoutC) 128 | 129 | def gemv_fp16_light( 130 | batch : int, dim_out : int, dim_in : int, 131 | mat : DevicePointer, # (batch, dim_out, dim_in) fp16 132 | vec : DevicePointer, # (batch, dim_in) fp16 133 | out : DevicePointer, # (batch, dim_out) fp16 134 | stream : CUDAStream 135 | ): 136 | assert dim_in % 2 == 0 137 | gridDim = (batch, dim_out, 1) 138 | blockDim = (min(1024, round_up(dim_in // 2, 32)), 1, 1) 139 | gemv_kernel.cu_gemv_fp16( 140 | gridDim, blockDim, 0, stream, [ 141 | ctypes.c_int32(batch), 142 | ctypes.c_int32(dim_out), 143 | ctypes.c_int32(dim_in), 144 | ctypes.c_void_p(mat), 145 | ctypes.c_void_p(vec), 146 | ctypes.c_void_p(out) 147 | ] 148 | ) 149 | 150 | 151 | def gemv_fp16_transpose( 152 | batch : int, dim_out : int, dim_in : int, 153 | mat : DevicePointer, # (batch, dim_in, dim_out) fp16 154 | vec : DevicePointer, # (batch, dim_in) fp16 155 | out : DevicePointer, # (batch, dim_out) fp16 156 | stream : CUDAStream 157 | ): 158 | device = current_device() 159 | device.use() 160 | layoutA = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_out, dim_in, dim_out) 161 | layoutB = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_in, 1, dim_in) 162 | layoutC = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_out, 1, dim_out) 163 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutA, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 164 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutA, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_in * dim_out)) 165 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutB, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 166 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutB, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_in)) 167 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutC, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 168 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutC, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_out)) 169 | 170 | fallback_32f = device.architecture < 62 171 | 172 | if cublaslt.version >= 11000: 173 | if fallback_32f: 174 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUBLAS_COMPUTE_32F, cublaslt.CUDA_R_32F) 175 | else: 176 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUBLAS_COMPUTE_16F, cublaslt.CUDA_R_16F) 177 | else: 178 | if fallback_32f: 179 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUDA_R_32F) 180 | else: 181 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUDA_R_16F) 182 | cublaslt.cublasLtMatmul( 183 | device.cublasLtHandle, 184 | matmulHandle, 185 | ctypes.c_float(1) if fallback_32f else ctypes.c_short(15360), # half(1) 186 | mat, layoutA, 187 | vec, layoutB, 188 | ctypes.c_float(0) if fallback_32f else ctypes.c_short(0), # half(0) 189 | out, layoutC, 190 | out, layoutC, 191 | stream 192 | ) 193 | cublaslt.cublasLtMatmulDescDestroy(matmulHandle) 194 | cublaslt.cublasLtMatrixLayoutDestroy(layoutA) 195 | cublaslt.cublasLtMatrixLayoutDestroy(layoutB) 196 | cublaslt.cublasLtMatrixLayoutDestroy(layoutC) 197 | 198 | def gemv_fp16_transpose_light( 199 | batch : int, dim_out : int, dim_in : int, 200 | mat : DevicePointer, # (batch, dim_in, dim_out) fp16 201 | vec : DevicePointer, # (batch, dim_in) fp16 202 | out : DevicePointer, # (batch, dim_out) fp16 203 | stream : CUDAStream 204 | ): 205 | gridDim = (batch, round_up(dim_out, 32) // 32, 1) 206 | blockDim = (32, 32, 1) 207 | gemv_kernel.cu_gemv_fp16_transpose( 208 | gridDim, blockDim, 0, stream, [ 209 | ctypes.c_int32(batch), 210 | ctypes.c_int32(dim_out), 211 | ctypes.c_int32(dim_in), 212 | ctypes.c_void_p(mat), 213 | ctypes.c_void_p(vec), 214 | ctypes.c_void_p(out) 215 | ] 216 | ) 217 | 218 | def gemv_broadcast_mat_fp16( 219 | batch : int, dim_out : int, dim_in : int, 220 | mat : DevicePointer, # (dim_out, dim_in) fp16 221 | vec : DevicePointer, # (batch, dim_in) 222 | out : DevicePointer, # (batch, dim_out) 223 | stream : CUDAStream 224 | ): 225 | device = current_device() 226 | device.use() 227 | layoutA = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_in, dim_out, dim_in) 228 | layoutB = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_in, 1, dim_in) 229 | layoutC = cublaslt.cublasLtMatrixLayoutCreate(cublaslt.CUDA_R_16F, dim_out, 1, dim_out) 230 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutA, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 231 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutA, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(0)) 232 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutB, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 233 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutB, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_in)) 234 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutC, cublaslt.CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, ctypes.c_int32(batch)) 235 | cublaslt.cublasLtMatrixLayoutSetAttribute(layoutC, cublaslt.CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, ctypes.c_int64(dim_out)) 236 | fallback_32f = device.architecture < 62 237 | if cublaslt.version >= 11000: 238 | if fallback_32f: 239 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUBLAS_COMPUTE_32F, cublaslt.CUDA_R_32F) 240 | else: 241 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUBLAS_COMPUTE_16F, cublaslt.CUDA_R_16F) 242 | else: 243 | if fallback_32f: 244 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUDA_R_32F) 245 | else: 246 | matmulHandle = cublaslt.cublasLtMatmulDescCreate(cublaslt.CUDA_R_16F) 247 | cublaslt.cublasLtMatmulDescSetAttribute(matmulHandle, cublaslt.CUBLASLT_MATMUL_DESC_TRANSA, ctypes.c_int32(cublaslt.CUBLAS_OP_T)) 248 | cublaslt.cublasLtMatmul( 249 | device.cublasLtHandle, 250 | matmulHandle, 251 | ctypes.c_float(1) if fallback_32f else ctypes.c_short(15360), # half(1) 252 | mat, layoutA, 253 | vec, layoutB, 254 | ctypes.c_float(0) if fallback_32f else ctypes.c_short(0), # half(0) 255 | out, layoutC, 256 | out, layoutC, 257 | stream 258 | ) 259 | cublaslt.cublasLtMatmulDescDestroy(matmulHandle) 260 | cublaslt.cublasLtMatrixLayoutDestroy(layoutA) 261 | cublaslt.cublasLtMatrixLayoutDestroy(layoutB) 262 | cublaslt.cublasLtMatrixLayoutDestroy(layoutC) 263 | 264 | def gemv_broadcast_mat_fp16_light( 265 | batch : int, dim_out : int, dim_in : int, 266 | mat : DevicePointer, # (dim_out, dim_in) fp16 267 | vec : DevicePointer, # (batch, dim_in) 268 | out : DevicePointer, # (batch, dim_out) 269 | stream : CUDAStream 270 | ): 271 | assert dim_in % 2 == 0 272 | gridDim = (batch, dim_out, 1) 273 | blockDim = (min(1024, round_up(dim_in // 2, 32)), 1, 1) 274 | gemv_kernel.cu_gemv_broadcast_mat_fp16( 275 | gridDim, blockDim, 0, stream, [ 276 | ctypes.c_int32(batch), 277 | ctypes.c_int32(dim_out), 278 | ctypes.c_int32(dim_in), 279 | ctypes.c_void_p(mat), 280 | ctypes.c_void_p(vec), 281 | ctypes.c_void_p(out) 282 | ] 283 | ) 284 | -------------------------------------------------------------------------------- /cpm_kernels/kernels/layernorm.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | layernorm_kernel = Kernel( 5 | "layernorm", 6 | [ 7 | "cu_layernorm_forward", 8 | "cu_layernorm_inplace_forward", 9 | "cu_layernorm_forward_v", 10 | "cu_layernorm_forward_mv", 11 | "cu_layernorm_backward_v", 12 | "cu_layernorm_backward_mv", 13 | "cu_layernorm_step", 14 | "cu_layernorm_step_inplace" 15 | ] 16 | ) 17 | 18 | def layernorm_forward( 19 | batch : int, n : int, m : int, 20 | mat : DevicePointer, # (batch, n, m) 21 | out : DevicePointer, # (batch, n, m) 22 | eps : float, 23 | rd_mean : bool, 24 | stream : CUDAStream 25 | ): 26 | gridDim = (batch, round_up(m, 32) // 32, 1) 27 | blockDim = (32, 32, 1) 28 | layernorm_kernel.cu_layernorm_forward( 29 | gridDim, blockDim, 0, stream, [ 30 | ctypes.c_int32(batch), 31 | ctypes.c_int32(n), 32 | ctypes.c_int32(m), 33 | ctypes.c_void_p(mat), 34 | ctypes.c_void_p(out), 35 | ctypes.c_float(eps), 36 | ctypes.c_bool(rd_mean) 37 | ] 38 | ) 39 | 40 | def layernorm_inplace_forward( 41 | batch : int, n : int, m : int, 42 | mat : DevicePointer, # (batch, n, m) 43 | eps : float, 44 | rd_mean : bool, 45 | stream : CUDAStream 46 | ): 47 | gridDim = (batch, round_up(m, 32) // 32, 1) 48 | blockDim = (32, 32, 1) 49 | layernorm_kernel.cu_layernorm_inplace_forward( 50 | gridDim, blockDim, 0, stream, [ 51 | ctypes.c_int32(batch), 52 | ctypes.c_int32(n), 53 | ctypes.c_int32(m), 54 | ctypes.c_void_p(mat), 55 | ctypes.c_float(eps), 56 | ctypes.c_bool(rd_mean) 57 | ] 58 | ) 59 | 60 | def layernorm_forward_v( 61 | batch : int, n : int, m : int, 62 | mat : DevicePointer, # (batch, n, m) 63 | out : DevicePointer, # (batch, n, m) 64 | out_var : DevicePointer, # (batch, m) 65 | eps: float, 66 | stream : CUDAStream 67 | ): 68 | gridDim = (batch, round_up(m, 32) // 32, 1) 69 | blockDim = (32, 32, 1) 70 | layernorm_kernel.cu_layernorm_forward_v( 71 | gridDim, blockDim, 0, stream, [ 72 | ctypes.c_int32(batch), 73 | ctypes.c_int32(n), 74 | ctypes.c_int32(m), 75 | ctypes.c_void_p(mat), 76 | ctypes.c_void_p(out), 77 | ctypes.c_void_p(out_var), 78 | ctypes.c_float(eps) 79 | ] 80 | ) 81 | 82 | def layernorm_forward_mv( 83 | batch : int, n : int, m : int, 84 | mat : DevicePointer, # (batch, n, m) 85 | out : DevicePointer, # (batch, n, m) 86 | out_mean : DevicePointer, # (batch, m) 87 | out_var : DevicePointer, # (batch, m) 88 | eps : float, 89 | stream : CUDAStream 90 | ): 91 | gridDim = (batch, round_up(m, 32) // 32, 1) 92 | blockDim = (32, 32, 1) 93 | layernorm_kernel.cu_layernorm_forward_mv( 94 | gridDim, blockDim, 0, stream, [ 95 | ctypes.c_int32(batch), 96 | ctypes.c_int32(n), 97 | ctypes.c_int32(m), 98 | ctypes.c_void_p(mat), 99 | ctypes.c_void_p(out), 100 | ctypes.c_void_p(out_mean), 101 | ctypes.c_void_p(out_var), 102 | ctypes.c_float(eps) 103 | ] 104 | ) 105 | 106 | def layernorm_backward_v( 107 | batch : int, n : int, m : int, 108 | mat : DevicePointer, # (batch, n, m) 109 | grad_out : DevicePointer, # (batch, n, m) 110 | var : DevicePointer, # (batch, m) 111 | grad : DevicePointer, # (batch, n, m) 112 | stream : CUDAStream 113 | ): 114 | gridDim = (batch, round_up(m, 32) // 32, 1) 115 | blockDim = (32, 32, 1) 116 | layernorm_kernel.cu_layernorm_backward_v( 117 | gridDim, blockDim, 0, stream, [ 118 | ctypes.c_int32(batch), 119 | ctypes.c_int32(n), 120 | ctypes.c_int32(m), 121 | ctypes.c_void_p(mat), 122 | ctypes.c_void_p(grad_out), 123 | ctypes.c_void_p(var), 124 | ctypes.c_void_p(grad) 125 | ] 126 | ) 127 | 128 | def layernorm_backward_mv( 129 | batch : int, n : int, m : int, 130 | mat : DevicePointer, # (batch, n, m) 131 | grad_out : DevicePointer, # (batch, n, m) 132 | mean : DevicePointer, # (batch, m) 133 | var : DevicePointer, # (batch, m) 134 | grad : DevicePointer, # (batch, n, m) 135 | stream : CUDAStream 136 | ): 137 | gridDim = (batch, round_up(m, 32) // 32, 1) 138 | blockDim = (32, 32, 1) 139 | layernorm_kernel.cu_layernorm_backward_mv( 140 | gridDim, blockDim, 0, stream, [ 141 | ctypes.c_int32(batch), 142 | ctypes.c_int32(n), 143 | ctypes.c_int32(m), 144 | ctypes.c_void_p(mat), 145 | ctypes.c_void_p(grad_out), 146 | ctypes.c_void_p(mean), 147 | ctypes.c_void_p(var), 148 | ctypes.c_void_p(grad) 149 | ] 150 | ) 151 | 152 | def layernorm_step( 153 | batch : int, n : int, 154 | mat : DevicePointer, # (batch, n) fp16 155 | out : DevicePointer, # (batch, n) fp16 156 | eps : float, 157 | rd_mean : bool, 158 | stream : CUDAStream 159 | ): 160 | gridDim = (batch, 1, 1) 161 | blockDim = (min(1024, round_up(n, 32)), 1, 1) 162 | layernorm_kernel.cu_layernorm_step( 163 | gridDim, blockDim, 0, stream, [ 164 | ctypes.c_int32(batch), 165 | ctypes.c_int32(n), 166 | ctypes.c_void_p(mat), 167 | ctypes.c_void_p(out), 168 | ctypes.c_float(eps), 169 | ctypes.c_bool(rd_mean) 170 | ] 171 | ) 172 | 173 | def layernorm_step_inplace( 174 | batch : int, n : int, 175 | mat : DevicePointer, # (batch, n) fp16 176 | eps : float, 177 | rd_mean : bool, 178 | stream : CUDAStream 179 | ): 180 | gridDim = (batch, 1, 1) 181 | blockDim = (min(1024, round_up(n, 32)), 1, 1) 182 | layernorm_kernel.cu_layernorm_step_inplace( 183 | gridDim, blockDim, 0, stream, [ 184 | ctypes.c_int32(batch), 185 | ctypes.c_int32(n), 186 | ctypes.c_void_p(mat), 187 | ctypes.c_float(eps), 188 | ctypes.c_bool(rd_mean) 189 | ] 190 | ) -------------------------------------------------------------------------------- /cpm_kernels/kernels/mask.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | 5 | mask_kernel = Kernel( 6 | "mask", 7 | [ 8 | "cu_mask" 9 | ] 10 | ) 11 | 12 | def mask( 13 | batch : int, n : int, m : int, 14 | inp : DevicePointer, # (batch, n, m) 15 | mask : DevicePointer, # (batch, m) 16 | value : float, 17 | out : DevicePointer, # (batch, n, m) 18 | stream : CUDAStream 19 | ): 20 | """ 21 | mask 22 | """ 23 | gridDim = (batch, round_up(m, 1024) // 1024, 1) 24 | blockDim = (min(m, 1024), 1, 1) 25 | mask_kernel.cu_mask( 26 | gridDim, blockDim, 0, stream, [ 27 | ctypes.c_int32(batch), 28 | ctypes.c_int32(n), 29 | ctypes.c_int32(m), 30 | ctypes.c_void_p(inp), 31 | ctypes.c_void_p(mask), 32 | ctypes.c_float(value), 33 | ctypes.c_void_p(out) 34 | ] 35 | ) 36 | -------------------------------------------------------------------------------- /cpm_kernels/kernels/position_bucket.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream 2 | import ctypes 3 | 4 | embedding_kernel = Kernel( 5 | "position_bucket", 6 | [ 7 | "cu_init_position_mapping", 8 | "cu_position_embedding_forward", 9 | "cu_position_embedding_backward", 10 | "cu_position_embedding_step" 11 | ] 12 | ) 13 | 14 | def position_embedding_init( 15 | num_buckets : int, 16 | max_distance : int, 17 | out : DevicePointer, # (max_distance) int32 18 | bidirectional : bool, 19 | stream : CUDAStream 20 | ): 21 | gridDim = (1, 1, 1) 22 | blockDim = (min(max_distance, 1024), 1, 1) 23 | embedding_kernel.cu_init_position_mapping( 24 | gridDim, blockDim, 0, stream, [ 25 | ctypes.c_int32(num_buckets), 26 | ctypes.c_int32(max_distance), 27 | ctypes.c_void_p(out), 28 | ctypes.c_bool(bidirectional) 29 | ] 30 | ) 31 | 32 | def position_embedding_forward( 33 | query_len : int, 34 | key_len : int, 35 | num_buckets : int, 36 | max_distance : int, 37 | num_heads : int, 38 | position_mapping : DevicePointer, # (max_distance) 39 | weight : DevicePointer, # (num_heads, num_bucket) 40 | out : DevicePointer, # (num_heads, key_len, query_len) 41 | bidirectional : bool, 42 | stream : CUDAStream 43 | ): 44 | gridDim = (key_len, 1, 1) 45 | blockDim = (min(query_len, 1024), 1, 1) 46 | embedding_kernel.cu_position_embedding_forward( 47 | gridDim, blockDim, 0, stream, [ 48 | ctypes.c_int32(query_len), 49 | ctypes.c_int32(key_len), 50 | ctypes.c_int32(num_buckets), 51 | ctypes.c_int32(max_distance), 52 | ctypes.c_int32(num_heads), 53 | ctypes.c_void_p(position_mapping), 54 | ctypes.c_void_p(weight), 55 | ctypes.c_void_p(out), 56 | ctypes.c_bool(bidirectional) 57 | ] 58 | ) 59 | 60 | def position_embedding_backward( 61 | query_len : int, 62 | key_len : int, 63 | num_buckets : int, 64 | max_distance : int, 65 | num_heads : int, # no more than 1024 66 | position_mapping : DevicePointer, # (max_distance) 67 | grad_out : DevicePointer, # (num_heads, key_len, query_len) 68 | grad : DevicePointer, # (num_heads, num_bucket) 69 | bidirectional : bool, 70 | stream : CUDAStream 71 | ): 72 | gridDim = (num_buckets, 1, 1) 73 | blockDim = (1024, 1, 1) 74 | embedding_kernel.cu_position_embedding_backward( 75 | gridDim, blockDim, 0, stream, [ 76 | ctypes.c_int32(query_len), 77 | ctypes.c_int32(key_len), 78 | ctypes.c_int32(num_buckets), 79 | ctypes.c_int32(max_distance), 80 | ctypes.c_int32(num_heads), 81 | ctypes.c_void_p(position_mapping), 82 | ctypes.c_void_p(grad_out), 83 | ctypes.c_void_p(grad), 84 | ctypes.c_bool(bidirectional) 85 | ] 86 | ) 87 | 88 | def position_embedding_step( 89 | query_pos : int, 90 | key_len : int, 91 | num_buckets : int, 92 | max_distance : int, 93 | num_heads : int, 94 | position_mapping : DevicePointer, # (max_distance) 95 | weight : DevicePointer, # (num_heads, num_bucket) 96 | out : DevicePointer, # (num_heads, key_len) 97 | bidirectional : bool, 98 | stream : CUDAStream 99 | ): 100 | gridDim = (1, 1, 1) 101 | blockDim = (min(key_len, 1024), 1, 1) 102 | embedding_kernel.cu_position_embedding_step( 103 | gridDim, blockDim, 0, stream, [ 104 | ctypes.c_int32(query_pos), 105 | ctypes.c_int32(key_len), 106 | ctypes.c_int32(num_buckets), 107 | ctypes.c_int32(max_distance), 108 | ctypes.c_int32(num_heads), 109 | ctypes.c_void_p(position_mapping), 110 | ctypes.c_void_p(weight), 111 | ctypes.c_void_p(out), 112 | ctypes.c_bool(bidirectional) 113 | ] 114 | ) -------------------------------------------------------------------------------- /cpm_kernels/kernels/softmax.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | softmax_kernel = Kernel( 5 | "softmax", 6 | [ 7 | "cu_softmax_forward", 8 | "cu_softmax_inplace_forward", 9 | "cu_softmax_backward", 10 | "cu_softmax_step_inplace" 11 | ] 12 | ) 13 | 14 | def softmax_forward( 15 | batch : int, n : int, m : int, 16 | inp : DevicePointer, # (batch, n, m) 17 | out : DevicePointer, # (batch, n, m) 18 | stream : CUDAStream 19 | ): 20 | gridDim = (batch, round_up(m, 32) // 32, 1) 21 | blockDim = (32, 32, 1) 22 | softmax_kernel.cu_softmax_forward( 23 | gridDim, blockDim, 0, stream, [ 24 | ctypes.c_int32(batch), 25 | ctypes.c_int32(n), 26 | ctypes.c_int32(m), 27 | ctypes.c_void_p(inp), 28 | ctypes.c_void_p(out) 29 | ] 30 | ) 31 | 32 | def softmax_inplace_forward( 33 | batch : int, n : int, m : int, 34 | inp : DevicePointer, # (batch, n, m) 35 | stream : CUDAStream 36 | ): 37 | gridDim = (batch, round_up(m, 32) // 32, 1) 38 | blockDim = (32, 32, 1) 39 | softmax_kernel.cu_softmax_inplace_forward( 40 | gridDim, blockDim, 0, stream, [ 41 | ctypes.c_int32(batch), 42 | ctypes.c_int32(n), 43 | ctypes.c_int32(m), 44 | ctypes.c_void_p(inp) 45 | ] 46 | ) 47 | 48 | def softmax_backward( 49 | batch : int, n : int, m : int, 50 | out : DevicePointer, # (batch, n, m) 51 | grad_out : DevicePointer, # (batch, n, m) 52 | grad : DevicePointer, # (batch, n, m) 53 | stream : CUDAStream 54 | ): 55 | gridDim = (batch, round_up(m, 32) // 32, 1) 56 | blockDim = (32, 32, 1) 57 | softmax_kernel.cu_softmax_backward( 58 | gridDim, blockDim, 0, stream, [ 59 | ctypes.c_int32(batch), 60 | ctypes.c_int32(n), 61 | ctypes.c_int32(m), 62 | ctypes.c_void_p(out), 63 | ctypes.c_void_p(grad_out), 64 | ctypes.c_void_p(grad) 65 | ] 66 | ) 67 | 68 | def softmax_step_inplace( 69 | batch : int, n : int, 70 | x : DevicePointer, 71 | stream : CUDAStream 72 | ): 73 | gridDim = (batch, 1, 1) 74 | blockDim = (min(1024, round_up(n, 32)), 1, 1) 75 | softmax_kernel.cu_softmax_step_inplace( 76 | gridDim, blockDim, 0, stream, [ 77 | ctypes.c_int32(batch), 78 | ctypes.c_int32(n), 79 | ctypes.c_void_p(x) 80 | ] 81 | ) -------------------------------------------------------------------------------- /cpm_kernels/kernels/transpose.py: -------------------------------------------------------------------------------- 1 | from .base import Kernel, DevicePointer, CUDAStream, round_up 2 | import ctypes 3 | 4 | transpose_kernel = Kernel( 5 | "transpose", 6 | [ 7 | "cu_transpose" 8 | ] 9 | ) 10 | 11 | def transpose( 12 | batch : int, n : int, m : int, 13 | inp : DevicePointer, 14 | out : DevicePointer, 15 | stream : CUDAStream 16 | ): 17 | gridDim = (batch, round_up(n, 32) // 32, round_up(m, 32) // 32) 18 | blockDim = (32, 32, 1) 19 | transpose_kernel.cu_transpose ( 20 | gridDim, blockDim, 0, stream, [ 21 | ctypes.c_int32(batch), 22 | ctypes.c_int32(n), 23 | ctypes.c_int32(m), 24 | ctypes.c_void_p(inp), 25 | ctypes.c_void_p(out) 26 | ] 27 | ) -------------------------------------------------------------------------------- /cpm_kernels/kernels/utils.py: -------------------------------------------------------------------------------- 1 | from ..library import cublaslt 2 | from .base import Kernel, DevicePointer, CUDAStream, round_up 3 | import ctypes 4 | 5 | utils_kernel = Kernel( 6 | "utils", 7 | [ 8 | "copy_data_to_kv", 9 | "cu_array_add", 10 | "cu_adjustify_logits", 11 | "cu_copy_extend_buffer", 12 | "cu_has_nan_inf", 13 | "cu_copy_pos_hidden" 14 | ] 15 | ) 16 | 17 | 18 | def copy_data_to_kv( 19 | batch : int, buffer_len : int, n : int, 20 | inp : DevicePointer, # (batch, n) 21 | out : DevicePointer, # (batch, buffer_len, n) 22 | pos : int, 23 | stream : CUDAStream 24 | ): 25 | assert n % 2 == 0 26 | gridDim = (batch, 1, 1) 27 | blockDim = (min(1024, n // 2), 1, 1) 28 | utils_kernel.copy_data_to_kv( 29 | gridDim, blockDim, 0, stream, [ 30 | ctypes.c_int32(batch), 31 | ctypes.c_int32(buffer_len), 32 | ctypes.c_int32(n), 33 | ctypes.c_void_p(inp), 34 | ctypes.c_void_p(out), 35 | ctypes.c_int32(pos) 36 | ] 37 | ) 38 | 39 | def array_add( 40 | array : DevicePointer, 41 | pos : int, 42 | val : int, 43 | stream : CUDAStream 44 | ): 45 | gridDim = (1, 1, 1) 46 | blockDim = (1, 1, 1) 47 | utils_kernel.cu_array_add( 48 | gridDim, blockDim, 0, stream, [ 49 | ctypes.c_void_p(array), 50 | ctypes.c_int32(pos), 51 | ctypes.c_int32(val) 52 | ] 53 | ) 54 | 55 | def adjustify_logits( 56 | batch : int, n : int, 57 | logits : DevicePointer, 58 | temperature : float, 59 | frequency_penalty : float, 60 | presence_penalty : float, 61 | frequency : DevicePointer, 62 | stream : CUDAStream 63 | ): 64 | threads = min(1024, round_up(n, 32)) 65 | gridDim = (batch, round_up(n, threads) // threads, 1) 66 | blockDim = (threads, 1, 1) 67 | utils_kernel.cu_adjustify_logits( 68 | gridDim, blockDim, 0, stream, [ 69 | ctypes.c_int32(batch), 70 | ctypes.c_int32(n), 71 | ctypes.c_void_p(logits), 72 | ctypes.c_float(temperature), 73 | ctypes.c_float(frequency_penalty), 74 | ctypes.c_float(presence_penalty), 75 | ctypes.c_void_p(frequency) 76 | ] 77 | ) 78 | 79 | def copy_extend_buffer( 80 | batch : int, old_size : int, nw_size : int, 81 | old_buffer : DevicePointer, 82 | new_buffer : DevicePointer, 83 | stream : CUDAStream 84 | ): 85 | threads = min(1024, round_up(old_size, 32)) 86 | gridDim = (batch, round_up(old_size, threads) // threads, 1) 87 | blockDim = (threads, 1, 1) 88 | utils_kernel.cu_copy_extend_buffer( 89 | gridDim, blockDim, 0, stream, [ 90 | ctypes.c_int32(batch), 91 | ctypes.c_int32(old_size), 92 | ctypes.c_int32(nw_size), 93 | ctypes.c_void_p(old_buffer), 94 | ctypes.c_void_p(new_buffer) 95 | ] 96 | ) 97 | 98 | def has_nan_inf( 99 | n : int, 100 | inp : DevicePointer, # (n,) half 101 | out : DevicePointer, # (1,) bool 102 | stream : CUDAStream 103 | ): 104 | gridDim = (1, 1, 1) 105 | blockDim = (min(round_up(n, 32), 1024), 1, 1) 106 | utils_kernel.cu_has_nan_inf( 107 | gridDim, blockDim, 0, stream, [ 108 | ctypes.c_int32(n), 109 | ctypes.c_void_p(inp), 110 | ctypes.c_void_p(out) 111 | ] 112 | ) 113 | 114 | def copy_pos_hidden( 115 | batch : int, hidden_size : int, seq_len : int, 116 | pos : int, 117 | inp : DevicePointer, # (batch, hidden_size, seq_len) 118 | out : DevicePointer, # (batch, hidden_size) 119 | stream : CUDAStream 120 | ): 121 | threads = min(1024, round_up(hidden_size, 32)) 122 | gridDim = (batch, round_up(hidden_size, threads) // threads, 1) 123 | blockDim = (threads, 1, 1) 124 | utils_kernel.cu_copy_pos_hidden( 125 | gridDim, blockDim, 0, stream, [ 126 | ctypes.c_int32(batch), 127 | ctypes.c_int32(hidden_size), 128 | ctypes.c_int32(seq_len), 129 | ctypes.c_int32(pos), 130 | ctypes.c_void_p(inp), 131 | ctypes.c_void_p(out) 132 | ] 133 | ) -------------------------------------------------------------------------------- /cpm_kernels/library/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nvrtc 2 | from . import cuda 3 | from . import cudart 4 | from . import cublaslt -------------------------------------------------------------------------------- /cpm_kernels/library/base.py: -------------------------------------------------------------------------------- 1 | import os, sys, struct 2 | import ctypes 3 | import ctypes.util 4 | from functools import wraps 5 | from typing import Callable, TypeVar 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | LibCall = TypeVar("LibCall") 10 | 11 | def lookup_dll(prefix): 12 | paths = os.environ.get("PATH", "").split(os.pathsep) 13 | for path in paths: 14 | if not os.path.exists(path): 15 | continue 16 | for name in os.listdir(path): 17 | if name.startswith(prefix) and name.lower().endswith(".dll"): 18 | return os.path.join(path, name) 19 | return None 20 | 21 | def unix_find_lib(name): 22 | cuda_path = os.environ.get("CUDA_PATH", None) 23 | if cuda_path is not None: 24 | lib_name = os.path.join(cuda_path, "lib64", "lib%s.so" % name) 25 | if os.path.exists(lib_name): 26 | return lib_name 27 | 28 | cuda_path = "/usr/local/cuda" 29 | if cuda_path is not None: 30 | lib_name = os.path.join(cuda_path, "lib64", "lib%s.so" % name) 31 | if os.path.exists(lib_name): 32 | return lib_name 33 | 34 | lib_name = ctypes.util.find_library(name) 35 | return lib_name 36 | 37 | def windows_find_lib(name): 38 | lib_name = "%s%d_" % (name, struct.calcsize("P") * 8) 39 | return lookup_dll(lib_name) 40 | 41 | class Lib: 42 | def __init__(self, name): 43 | self.__name = name 44 | if sys.platform.startswith("win"): 45 | lib_path = windows_find_lib(self.__name) 46 | self.__lib_path = lib_path 47 | if lib_path is not None: 48 | self.__lib = ctypes.WinDLL(lib_path) 49 | else: 50 | self.__lib = None 51 | elif sys.platform.startswith("linux"): 52 | lib_path = unix_find_lib(self.__name) 53 | self.__lib_path = lib_path 54 | if lib_path is not None: 55 | self.__lib = ctypes.cdll.LoadLibrary(lib_path) 56 | else: 57 | self.__lib = None 58 | else: 59 | raise RuntimeError("Unknown platform: %s" % sys.platform) 60 | 61 | @staticmethod 62 | def from_lib(name, lib): 63 | ret = Lib(name) 64 | ret.__lib = lib 65 | return ret 66 | 67 | def bind(self, name, arg_types, ret_type) -> Callable[[LibCall], LibCall]: 68 | if self.__lib is None: 69 | def decorator(f): 70 | @wraps(f) 71 | def wrapper(*args, **kwargs): 72 | raise RuntimeError("Library %s is not initialized" % self.__name) 73 | return wrapper 74 | return decorator 75 | else: 76 | try: 77 | func = getattr(self.__lib, name) 78 | except AttributeError: 79 | # Name not found in library 80 | def decorator(f): 81 | @wraps(f) 82 | def wrapper(*args, **kwargs): 83 | raise AttributeError("%s: undefined symbol: %s" % (self.__lib_path, name)) 84 | return wrapper 85 | logger.warning("Symbol %s not found in %s", name, self.__lib_path) 86 | return decorator 87 | func.argtypes = arg_types 88 | func.restype = ret_type 89 | setattr(self, name, func) 90 | 91 | def decorator(f): 92 | @wraps(f) 93 | def wrapper(*args, **kwargs): 94 | return f(*args, **kwargs) 95 | return wrapper 96 | return decorator -------------------------------------------------------------------------------- /cpm_kernels/library/cublaslt.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from typing import Any, List, Tuple 3 | from .base import Lib 4 | 5 | cublasLt = Lib("cublasLt") 6 | 7 | CUDA_R_8I = 3 8 | CUDA_R_32I = 10 9 | CUDA_R_16F = 2 10 | CUDA_R_32F = 0 11 | 12 | CUBLAS_OP_N = 0 13 | CUBLAS_OP_T = 1 14 | 15 | CUBLASLT_ORDER_COL = 0 16 | CUBLASLT_ORDER_ROW = 1 17 | CUBLASLT_ORDER_COL32 = 2 18 | CUBLASLT_ORDER_COL4_4R2_8C = 3 19 | CUBLASLT_ORDER_COL32_2R_4R4 = 4 20 | 21 | 22 | CUBLASLT_MATRIX_LAYOUT_TYPE = 0 23 | CUBLASLT_MATRIX_LAYOUT_ORDER = 1 24 | CUBLASLT_MATRIX_LAYOUT_ROWS = 2 25 | CUBLASLT_MATRIX_LAYOUT_COLS = 3 26 | CUBLASLT_MATRIX_LAYOUT_LD = 4 27 | CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5 28 | CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6 29 | CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7 30 | 31 | CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0 32 | CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1 33 | CUBLASLT_MATMUL_DESC_POINTER_MODE = 2 34 | CUBLASLT_MATMUL_DESC_TRANSA = 3 35 | CUBLASLT_MATMUL_DESC_TRANSB = 4 36 | CUBLASLT_MATMUL_DESC_TRANSC = 5 37 | CUBLASLT_MATMUL_DESC_FILL_MODE = 6 38 | CUBLASLT_MATMUL_DESC_EPILOGUE = 7 39 | CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8 40 | 41 | CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE = 0 42 | CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE = 1 43 | CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA = 2 44 | CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB = 3 45 | 46 | CUBLAS_COMPUTE_16F = 64 47 | CUBLAS_COMPUTE_16F_PEDANTIC = 65 48 | CUBLAS_COMPUTE_32F = 68 49 | CUBLAS_COMPUTE_32F_PEDANTIC = 69 50 | CUBLAS_COMPUTE_32F_FAST_16F = 74 51 | CUBLAS_COMPUTE_32F_FAST_16BF = 75 52 | CUBLAS_COMPUTE_32F_FAST_TF32 = 77 53 | CUBLAS_COMPUTE_64F = 70 54 | CUBLAS_COMPUTE_64F_PEDANTIC = 71 55 | CUBLAS_COMPUTE_32I = 72 56 | CUBLAS_COMPUTE_32I_PEDANTIC = 73 57 | 58 | cublasLtHandle_t = ctypes.c_void_p 59 | cublasStatus_t = ctypes.c_int 60 | cublasLtMatrixTransformDesc_t = ctypes.c_void_p 61 | cudaStream_t = ctypes.c_void_p 62 | cublasLtMatmulDesc_t = ctypes.c_void_p 63 | cublasLtMatrixLayout_t = ctypes.c_void_p 64 | cudaDataType = ctypes.c_int 65 | cublasComputeType_t = ctypes.c_int 66 | cublasLtMatmulDescAttributes_t = ctypes.c_int 67 | cublasLtMatrixLayoutAttribute_t = ctypes.c_int 68 | cublasLtMatrixTransformDescAttributes_t = ctypes.c_int 69 | 70 | @cublasLt.bind("cublasLtGetVersion", [], ctypes.c_size_t) 71 | def cublasLtGetVersion() -> int: 72 | return cublasLt.cublasLtGetVersion() 73 | 74 | try: 75 | version = cublasLtGetVersion() 76 | except RuntimeError: 77 | version = 0 78 | 79 | def cublasGetStatusString(status : int) -> str: 80 | cublas_errors = { 81 | 0: "CUBLAS_STATUS_SUCCESS", 82 | 1: "CUBLAS_STATUS_NOT_INITIALIZED", 83 | 3: "CUBLAS_STATUS_ALLOC_FAILED", 84 | 7: "CUBLAS_STATUS_INVALID_VALUE", 85 | 8: "CUBLAS_STATUS_ARCH_MISMATCH", 86 | 11: "CUBLAS_STATUS_MAPPING_ERROR", 87 | 13: "CUBLAS_STATUS_EXECUTION_FAILED", 88 | 14: "CUBLAS_STATUS_INTERNAL_ERROR", 89 | 15: "CUBLAS_STATUS_NOT_SUPPORTED", 90 | 16: "CUBLAS_STATUS_LICENSE_ERROR" 91 | } 92 | if status not in cublas_errors: 93 | raise RuntimeError("Unknown cublasLt status: %d" % status) 94 | return cublas_errors[status] 95 | 96 | def checkCublasStatus(status: int) -> None: 97 | if status != 0: 98 | raise RuntimeError("CUBLAS error: {}".format( 99 | cublasGetStatusString(status) 100 | )) 101 | 102 | @cublasLt.bind("cublasLtCreate", [ctypes.POINTER(cublasLtHandle_t)], cublasStatus_t) 103 | def cublasLtCreate() -> cublasLtHandle_t: 104 | handle = cublasLtHandle_t() 105 | checkCublasStatus(cublasLt.cublasLtCreate(ctypes.byref(handle))) 106 | return handle 107 | 108 | @cublasLt.bind("cublasLtDestroy", [cublasLtHandle_t], cublasStatus_t) 109 | def cublasLtDestroy(handle: cublasLtHandle_t) -> None: 110 | checkCublasStatus(cublasLt.cublasLtDestroy(handle)) 111 | 112 | 113 | @cublasLt.bind("cublasLtMatmul", [ 114 | cublasLtHandle_t, cublasLtMatmulDesc_t, 115 | ctypes.c_void_p, 116 | ctypes.c_void_p, cublasLtMatrixLayout_t, 117 | ctypes.c_void_p, cublasLtMatrixLayout_t, 118 | ctypes.c_void_p, 119 | ctypes.c_void_p, cublasLtMatrixLayout_t, 120 | ctypes.c_void_p, cublasLtMatrixLayout_t, 121 | ctypes.c_void_p, 122 | ctypes.c_void_p, 123 | ctypes.c_size_t, 124 | cudaStream_t 125 | ], cublasStatus_t) 126 | def cublasLtMatmul( 127 | lightHandle : cublasLtHandle_t, 128 | computeDesc : cublasLtMatmulDesc_t, 129 | alpha : Any, 130 | A : ctypes.c_void_p, A_layout : cublasLtMatrixLayout_t, 131 | B : ctypes.c_void_p, B_layout : cublasLtMatrixLayout_t, 132 | beta : Any, 133 | C : ctypes.c_void_p, C_layout : cublasLtMatrixLayout_t, 134 | D : ctypes.c_void_p, D_layout : cublasLtMatrixLayout_t, 135 | stream : cudaStream_t 136 | ) -> None: 137 | checkCublasStatus(cublasLt.cublasLtMatmul( 138 | lightHandle, 139 | computeDesc, 140 | ctypes.byref(alpha), 141 | A, A_layout, 142 | B, B_layout, 143 | ctypes.byref(beta), 144 | C, C_layout, 145 | D, D_layout, 146 | 0, 147 | 0, 148 | 0, 149 | stream 150 | )) 151 | 152 | 153 | if version >= 11000: 154 | @cublasLt.bind("cublasLtMatmulDescCreate", [ctypes.POINTER(cublasLtMatmulDesc_t), cublasComputeType_t, cudaDataType], cublasStatus_t) 155 | def cublasLtMatmulDescCreate(computeType : cublasComputeType_t, dataType : cudaDataType) -> cublasLtMatmulDesc_t: 156 | desc = cublasLtMatmulDesc_t() 157 | checkCublasStatus(cublasLt.cublasLtMatmulDescCreate(ctypes.byref(desc), computeType, dataType)) 158 | return desc 159 | 160 | else: 161 | @cublasLt.bind("cublasLtMatmulDescCreate", [ctypes.POINTER(cublasLtMatmulDesc_t), cudaDataType], cublasStatus_t) 162 | def cublasLtMatmulDescCreate(computeType : cudaDataType) -> cublasLtMatmulDesc_t: 163 | desc = cublasLtMatmulDesc_t() 164 | checkCublasStatus(cublasLt.cublasLtMatmulDescCreate(ctypes.byref(desc), computeType)) 165 | return desc 166 | 167 | @cublasLt.bind("cublasLtMatmulDescSetAttribute", [cublasLtMatmulDesc_t, cublasLtMatmulDescAttributes_t, ctypes.c_void_p, ctypes.c_size_t], cublasStatus_t) 168 | def cublasLtMatmulDescSetAttribute(desc : cublasLtMatmulDesc_t, attr : cublasLtMatmulDescAttributes_t, value : Any) -> None: 169 | checkCublasStatus(cublasLt.cublasLtMatmulDescSetAttribute(desc, attr, ctypes.byref(value), ctypes.sizeof(value))) 170 | 171 | @cublasLt.bind("cublasLtMatmulDescDestroy", [cublasLtMatmulDesc_t], cublasStatus_t) 172 | def cublasLtMatmulDescDestroy(desc : cublasLtMatmulDesc_t) -> None: 173 | checkCublasStatus(cublasLt.cublasLtMatmulDescDestroy(desc)) 174 | 175 | @cublasLt.bind("cublasLtMatrixLayoutCreate", [ctypes.POINTER(cublasLtMatrixLayout_t), cudaDataType, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int64], cublasStatus_t) 176 | def cublasLtMatrixLayoutCreate(dataType : cudaDataType, rows : int, cols : int, ld : int) -> cublasLtMatrixLayout_t: 177 | layout = cublasLtMatrixLayout_t() 178 | checkCublasStatus(cublasLt.cublasLtMatrixLayoutCreate(ctypes.byref(layout), dataType, rows, cols, ld)) 179 | return layout 180 | 181 | @cublasLt.bind("cublasLtMatrixLayoutDestroy", [cublasLtMatrixLayout_t], cublasStatus_t) 182 | def cublasLtMatrixLayoutDestroy(layout : cublasLtMatrixLayout_t) -> None: 183 | checkCublasStatus(cublasLt.cublasLtMatrixLayoutDestroy(layout)) 184 | 185 | @cublasLt.bind("cublasLtMatrixLayoutSetAttribute", [cublasLtMatrixLayout_t, cublasLtMatrixLayoutAttribute_t, ctypes.c_void_p, ctypes.c_size_t], cublasStatus_t) 186 | def cublasLtMatrixLayoutSetAttribute(layout : cublasLtMatrixLayout_t, attr : cublasLtMatrixLayoutAttribute_t, value : Any) -> None: 187 | checkCublasStatus(cublasLt.cublasLtMatrixLayoutSetAttribute(layout, attr, ctypes.byref(value), ctypes.sizeof(value))) 188 | 189 | @cublasLt.bind("cublasLtMatrixTransform", [ 190 | cublasLtHandle_t, cublasLtMatrixTransformDesc_t, 191 | ctypes.c_void_p, 192 | ctypes.c_void_p, cublasLtMatrixLayout_t, 193 | ctypes.c_void_p, 194 | ctypes.c_void_p, cublasLtMatrixLayout_t, 195 | ctypes.c_void_p, cublasLtMatrixLayout_t, 196 | cudaStream_t 197 | ], cublasStatus_t) 198 | def cublasLtMatrixTransform( 199 | lightHandle : cublasLtHandle_t, 200 | transformDesc : cublasLtMatrixTransformDesc_t, 201 | alpha : Any, 202 | A : ctypes.c_void_p, A_layout : cublasLtMatrixLayout_t, 203 | beta : Any, 204 | B : ctypes.c_void_p, B_layout : cublasLtMatrixLayout_t, 205 | C : ctypes.c_void_p, C_layout : cublasLtMatrixLayout_t, 206 | stream : cudaStream_t 207 | ) -> None: 208 | checkCublasStatus(cublasLt.cublasLtMatrixTransform( 209 | lightHandle, 210 | transformDesc, 211 | ctypes.byref(alpha), 212 | A, A_layout, 213 | ctypes.byref(beta), 214 | B, B_layout, 215 | C, C_layout, 216 | stream 217 | )) 218 | 219 | @cublasLt.bind("cublasLtMatrixTransformDescCreate", [ctypes.POINTER(cublasLtMatrixTransformDesc_t), cudaDataType], cublasStatus_t) 220 | def cublasLtMatrixTransformDescCreate(dataType : cudaDataType) -> cublasLtMatrixTransformDesc_t: 221 | desc = cublasLtMatrixTransformDesc_t() 222 | checkCublasStatus(cublasLt.cublasLtMatrixTransformDescCreate(ctypes.byref(desc), dataType)) 223 | return desc 224 | 225 | @cublasLt.bind("cublasLtMatrixTransformDescDestroy", [cublasLtMatrixTransformDesc_t], cublasStatus_t) 226 | def cublasLtMatrixTransformDescDestroy(desc : cublasLtMatrixTransformDesc_t) -> None: 227 | checkCublasStatus(cublasLt.cublasLtMatrixTransformDescDestroy(desc)) 228 | 229 | @cublasLt.bind("cublasLtMatrixTransformDescSetAttribute", [cublasLtMatrixTransformDesc_t, cublasLtMatrixTransformDescAttributes_t, ctypes.c_void_p, ctypes.c_size_t], cublasStatus_t) 230 | def cublasLtMatrixTransformDescSetAttribute(desc : cublasLtMatrixTransformDesc_t, attr : cublasLtMatrixTransformDescAttributes_t, value : Any) -> None: 231 | checkCublasStatus(cublasLt.cublasLtMatrixTransformDescSetAttribute(desc, attr, ctypes.byref(value), ctypes.sizeof(value))) 232 | -------------------------------------------------------------------------------- /cpm_kernels/library/nvrtc.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from typing import List, Tuple 3 | from .base import Lib 4 | 5 | nvrtc = Lib("nvrtc") 6 | 7 | nvrtcResult = ctypes.c_int 8 | NVRTC_SUCCESS = 0 9 | NVRTC_ERROR_OUT_OF_MEMORY = 1 10 | NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2 11 | NVRTC_ERROR_INVALID_INPUT = 3 12 | NVRTC_ERROR_INVALID_PROGRAM = 4 13 | NVRTC_ERROR_INVALID_OPTION = 5 14 | NVRTC_ERROR_COMPILATION = 6 15 | NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7 16 | NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8 17 | NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9 18 | NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10 19 | NVRTC_ERROR_INTERNAL_ERROR = 11 20 | 21 | nvrtcProgram = ctypes.c_void_p 22 | 23 | @nvrtc.bind("nvrtcGetErrorString", [nvrtcResult], ctypes.c_char_p) 24 | def nvrtcGetErrorString(status : int) -> str: 25 | return nvrtc.nvrtcGetErrorString(status).decode() 26 | 27 | def checkNVRTCStatus(status : int): 28 | if status == 0: 29 | return 30 | raise RuntimeError("NVRTC Error: %s" % nvrtcGetErrorString(status)) 31 | 32 | @nvrtc.bind("nvrtcVersion", [ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)], nvrtcResult) 33 | def nvrtcVersion() -> Tuple[int, int]: 34 | major = ctypes.c_int() 35 | minor = ctypes.c_int() 36 | checkNVRTCStatus( nvrtc.nvrtcVersion(ctypes.byref(major), ctypes.byref(minor)) ) 37 | return (major.value, minor.value) 38 | 39 | try: 40 | version = nvrtcVersion() 41 | except RuntimeError: 42 | version = (0, 0) 43 | 44 | @nvrtc.bind("nvrtcCompileProgram", [nvrtcProgram, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p)], nvrtcResult) 45 | def nvrtcCompileProgram(prog : nvrtcProgram, numOptions : int, options : List[str]): 46 | lstType = ctypes.c_char_p * numOptions 47 | lst = [ 48 | ctypes.c_char_p(opt.encode()) for opt in options 49 | ] 50 | options = lstType(*lst) 51 | status = nvrtc.nvrtcCompileProgram(prog, numOptions, options) 52 | if status == NVRTC_ERROR_COMPILATION: 53 | psize = nvrtcGetProgramLogSize(prog) 54 | log = ctypes.create_string_buffer(psize) 55 | nvrtcGetProgramLog(prog, log) 56 | raise RuntimeError( 57 | "NVRTC Error: NVRTC ERROR COMPILATION\n%s" % log.value.decode() 58 | ) 59 | else: 60 | checkNVRTCStatus( status ) 61 | 62 | @nvrtc.bind("nvrtcCreateProgram", [ ctypes.POINTER(nvrtcProgram), ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.POINTER(ctypes.c_char_p) ], nvrtcResult) 63 | def nvrtcCreateProgram(prog : nvrtcProgram, src : str, name : str, numHeaders : int, headers : List[str], includeNames : List[str]): 64 | headers = [ 65 | ctypes.c_char_p(header.encode()) for header in headers 66 | ] 67 | headers = (ctypes.c_char_p * numHeaders)(*headers) 68 | 69 | includeNames = [ 70 | ctypes.c_char_p(includeName.encode()) for includeName in includeNames 71 | ] 72 | includeNames = (ctypes.c_char_p * numHeaders)(*includeNames) 73 | checkNVRTCStatus( nvrtc.nvrtcCreateProgram(ctypes.byref(prog), src.encode(), name.encode(), numHeaders, headers, includeNames) ) 74 | 75 | 76 | @nvrtc.bind("nvrtcDestroyProgram", [ ctypes.POINTER(nvrtcProgram) ], nvrtcResult) 77 | def nvrtcDestroyProgram(prog : nvrtcProgram): 78 | checkNVRTCStatus( nvrtc.nvrtcDestroyProgram( ctypes.byref(prog)) ) 79 | 80 | if version[0] >= 11: 81 | @nvrtc.bind("nvrtcGetCUBIN", [nvrtcProgram, ctypes.c_char_p], nvrtcResult) 82 | def nvrtcGetCUBIN(prog : nvrtcProgram, buf : ctypes.c_char_p): 83 | checkNVRTCStatus( nvrtc.nvrtcGetCUBIN(prog, buf) ) 84 | 85 | @nvrtc.bind("nvrtcGetCUBINSize", [nvrtcProgram, ctypes.POINTER(ctypes.c_size_t)], nvrtcResult) 86 | def nvrtcGetCUBINSize(prog) -> int: 87 | size = ctypes.c_size_t() 88 | checkNVRTCStatus( nvrtc.nvrtcGetCUBINSize(prog, ctypes.byref(size)) ) 89 | return size.value 90 | 91 | @nvrtc.bind("nvrtcGetPTX", [nvrtcProgram, ctypes.c_char_p], nvrtcResult) 92 | def nvrtcGetPTX(prog, buf): 93 | checkNVRTCStatus( nvrtc.nvrtcGetPTX(prog, buf) ) 94 | 95 | @nvrtc.bind("nvrtcGetPTXSize", [nvrtcProgram, ctypes.POINTER(ctypes.c_size_t)], nvrtcResult) 96 | def nvrtcGetPTXSize(prog) -> int: 97 | size = ctypes.c_size_t() 98 | checkNVRTCStatus( nvrtc.nvrtcGetPTXSize(prog, ctypes.byref(size)) ) 99 | return size.value 100 | 101 | @nvrtc.bind("nvrtcGetProgramLog", [nvrtcProgram, ctypes.c_char_p], nvrtcResult) 102 | def nvrtcGetProgramLog(prog, buf): 103 | checkNVRTCStatus( nvrtc.nvrtcGetProgramLog(prog, buf) ) 104 | 105 | @nvrtc.bind("nvrtcGetProgramLogSize", [nvrtcProgram, ctypes.POINTER(ctypes.c_size_t)], nvrtcResult) 106 | def nvrtcGetProgramLogSize(prog) -> int: 107 | size = ctypes.c_size_t() 108 | checkNVRTCStatus( nvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(size)) ) 109 | return size.value 110 | -------------------------------------------------------------------------------- /cpm_kernels/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import OpEmbedding, Embedding, EmbeddingTH 2 | from .gelu import gelu, geluTH, gelu_inplace 3 | from .gemm import bmm 4 | from .mask import mask, mask_inplace, maskTH 5 | from .arith import ln_div, ln_div_inplace, ln_divTH, ln_mul_add, ln_mul_add_inplace, \ 6 | ln_mul_addTH, ln_mul, ln_mul_inplace, ln_mulTH, ln_sub_div, \ 7 | ln_sub_divTH, ln_sub_div_inplace, element_add, element_add_inplace, \ 8 | element_addTH, batched_add, batched_add_inplace, batched_addTH, \ 9 | element_mul, element_mul_inplace, element_mulTH, \ 10 | ln_add, ln_add_inplace, ln_addTH, global_scale, global_scale_inplace, global_scaleTH 11 | from .layernorm import LayerNorm, LayerNormTH, normalize_inplace, normalizeTH 12 | from .position_embedding import PositionEmbedding, PositionEmbeddingTH 13 | from .softmax import softmax, softmaxTH, softmax_inplace 14 | from .transpose import transpose, transposeTH 15 | from .utils import has_nan_inf -------------------------------------------------------------------------------- /cpm_kernels/torch/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..kernels import embedding_forward, embedding_backward_stage1, embedding_backward_stage2, transpose 3 | 4 | class OpEmbedding(torch.autograd.Function): 5 | """ 6 | Embedding function for the cpm_kernels. 7 | Input: 8 | - ids: (batch_size, seq_len) 9 | - weight: (vocab_size, embedding_size) 10 | Output: 11 | - embeddings: (batch_size, embedding_size, seq_len) 12 | """ 13 | @staticmethod 14 | def forward(ctx, ids : torch.Tensor, weight : torch.Tensor): 15 | assert ids.is_cuda and weight.is_cuda 16 | assert ids.device == weight.device 17 | assert ids.ndim == 2 18 | assert weight.ndim == 2 19 | assert ids.dtype == torch.int32 20 | assert weight.dtype == torch.half 21 | if not ids.is_contiguous(): 22 | ids = ids.contiguous() 23 | assert weight.is_contiguous() 24 | 25 | ctx.save_for_backward(ids, weight) 26 | 27 | out = torch.empty((ids.size(0), weight.size(1), ids.size(1)), device=ids.device, dtype=torch.half) 28 | assert out.is_contiguous() 29 | 30 | embedding_forward(ids.size(0), weight.size(1), ids.size(1), ids.data_ptr(), weight.data_ptr(), out.data_ptr(), torch.cuda.current_stream().cuda_stream) 31 | return out 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output : torch.Tensor): 35 | ids, weight = ctx.saved_tensors 36 | batch, n, m = grad_output.size() 37 | 38 | assert grad_output.device == ids.device 39 | assert m == ids.size(1) 40 | assert n == weight.size(1) 41 | if not grad_output.is_contiguous(): 42 | grad_output = grad_output.contiguous() 43 | 44 | sort_result = ids.view(-1).sort() 45 | indices = sort_result.indices.to(torch.int32) 46 | values = sort_result.values 47 | 48 | grad_transpose = torch.empty((batch, m, n), device=grad_output.device, dtype=torch.half) 49 | 50 | assert grad_output.is_contiguous() and grad_transpose.is_contiguous() 51 | transpose(batch, n, m, grad_output.data_ptr(), grad_transpose.data_ptr(), torch.cuda.current_stream().cuda_stream) 52 | 53 | buf = torch.empty((batch, n), device=grad_output.device, dtype=torch.half) 54 | buf_indices = torch.empty((batch,), device=grad_output.device, dtype=torch.int32) 55 | 56 | ret = torch.zeros((weight.size(0), n), device=grad_output.device, dtype=torch.half) 57 | assert grad_transpose.is_contiguous() and indices.is_contiguous() and values.is_contiguous() and ret.is_contiguous() and buf.is_contiguous() and buf_indices.is_contiguous() 58 | embedding_backward_stage1( 59 | batch, m, n, 60 | grad_transpose.data_ptr(), 61 | indices.data_ptr(), 62 | values.data_ptr(), 63 | ret.data_ptr(), 64 | buf.data_ptr(), 65 | buf_indices.data_ptr(), 66 | torch.cuda.current_stream().cuda_stream 67 | ) 68 | 69 | embedding_backward_stage2( 70 | batch, n, 71 | buf.data_ptr(), 72 | buf_indices.data_ptr(), 73 | ret.data_ptr(), 74 | torch.cuda.current_stream().cuda_stream 75 | ) 76 | 77 | return None, ret 78 | 79 | class Embedding(torch.nn.Module): 80 | def __init__(self, vocab_size : int, embedding_size : int): 81 | super().__init__() 82 | self.weight = torch.nn.Parameter(torch.empty((vocab_size, embedding_size), dtype=torch.half)) 83 | 84 | def forward(self, ids : torch.Tensor): 85 | return OpEmbedding.apply(ids, self.weight) 86 | 87 | class EmbeddingTH(torch.nn.Module): 88 | def __init__(self, vocab_size : int, embedding_size : int): 89 | super().__init__() 90 | self.weight = torch.nn.Parameter(torch.empty((vocab_size, embedding_size), dtype=torch.half)) 91 | 92 | def forward(self, ids : torch.Tensor): 93 | assert ids.ndim == 2 94 | v = torch.embedding( 95 | self.weight, 96 | ids 97 | ) 98 | assert v.ndim == 3 99 | return v.transpose(1, 2) 100 | -------------------------------------------------------------------------------- /cpm_kernels/torch/gelu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..kernels import gelu_forward, gelu_backward, gelu_inplace_forward 3 | 4 | class OpGeLU(torch.autograd.Function): 5 | """ 6 | Element wised GeLU function. 7 | Input: 8 | - x (batch, *) 9 | Output: 10 | - y (batch, *) 11 | """ 12 | 13 | @staticmethod 14 | def forward(ctx, x : torch.Tensor): 15 | assert x.is_contiguous() and x.is_cuda and x.dtype == torch.half 16 | ctx.save_for_backward(x) 17 | ret = torch.empty_like(x) 18 | gelu_forward( 19 | x.size(0), 20 | x.stride(0), 21 | x.data_ptr(), 22 | ret.data_ptr(), 23 | torch.cuda.current_stream().cuda_stream 24 | ) 25 | return ret 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output : torch.Tensor): 29 | x = ctx.saved_tensors[0] 30 | grad = torch.empty_like(grad_output) 31 | 32 | assert grad_output.is_contiguous() and grad_output.is_cuda and grad_output.dtype == torch.half 33 | gelu_backward( 34 | x.size(0), 35 | x.stride(0), 36 | grad_output.data_ptr(), 37 | x.data_ptr(), 38 | grad.data_ptr(), 39 | torch.cuda.current_stream().cuda_stream 40 | ) 41 | return grad 42 | 43 | def gelu(x : torch.Tensor) -> torch.Tensor: 44 | return OpGeLU.apply(x) 45 | 46 | 47 | @torch.jit.script 48 | def gelu_impl(x): 49 | """OpenAI's gelu implementation.""" 50 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * 51 | (1.0 + 0.044715 * x * x))) 52 | 53 | def geluTH(x : torch.Tensor): 54 | return gelu_impl(x) 55 | 56 | def gelu_inplace(x : torch.Tensor) -> None: 57 | assert x.is_contiguous() and x.is_cuda and x.dtype == torch.half 58 | gelu_inplace_forward( 59 | x.size(0), 60 | x.stride(0), 61 | x.data_ptr(), 62 | torch.cuda.current_stream().cuda_stream 63 | ) -------------------------------------------------------------------------------- /cpm_kernels/torch/gemm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from ..kernels import gemm_calc_scale, gemm_calc_scale_transpose, gemm_round, gemm_round_transpose, gemm_scale, gemm_fp16, gemm_int8, gemm_backward_round_scale, gemm_backward_scale_round, gemm_scale_x, gemm_scale_y 4 | 5 | def calc_scale(mat : torch.Tensor, transpose : bool): 6 | assert mat.is_contiguous() and mat.is_cuda 7 | if transpose: 8 | out = torch.empty((mat.size(0), mat.size(2)), dtype=torch.half, device=mat.device) 9 | gemm_calc_scale_transpose( 10 | mat.size(0), mat.size(1), mat.size(2), 11 | mat.data_ptr(), out.data_ptr(), torch.cuda.current_stream().cuda_stream 12 | ) 13 | else: 14 | out = torch.empty((mat.size(0), mat.size(1)), dtype=torch.half, device=mat.device) 15 | gemm_calc_scale( 16 | mat.size(0), mat.size(1), mat.size(2), 17 | mat.data_ptr(), out.data_ptr(), torch.cuda.current_stream().cuda_stream 18 | ) 19 | return out 20 | 21 | def round_i8(mat : torch.Tensor, scale : torch.Tensor, transpose : bool): 22 | assert mat.is_contiguous() and mat.is_cuda 23 | assert scale.is_contiguous() and scale.is_cuda 24 | if transpose: 25 | out = torch.empty(mat.size(), dtype=torch.int8, device=mat.device) 26 | gemm_round_transpose( 27 | mat.size(0), mat.size(1), mat.size(2), 28 | mat.data_ptr(), scale.data_ptr(), out.data_ptr(), 29 | torch.cuda.current_stream().cuda_stream 30 | ) 31 | else: 32 | out = torch.empty(mat.size(), dtype=torch.int8, device=mat.device) 33 | gemm_round( 34 | mat.size(0), mat.size(1), mat.size(2), 35 | mat.data_ptr(), scale.data_ptr(), out.data_ptr(), 36 | torch.cuda.current_stream().cuda_stream 37 | ) 38 | return out 39 | 40 | def gemm_and_scale(quantA : torch.Tensor, scaleA : Optional[torch.Tensor], quantB : torch.Tensor, scaleB : Optional[torch.Tensor], aT, bT) -> torch.Tensor: 41 | M = quantA.size(2) if aT else quantA.size(1) 42 | K = quantA.size(1) if aT else quantA.size(2) 43 | N = quantB.size(1) if bT else quantB.size(2) 44 | result_i32 = torch.empty((max(quantA.size(0), quantB.size(0)), M, N), dtype=torch.int32, device=quantA.device) 45 | gemm_int8 ( 46 | N, K, M, 47 | quantB.size(0), quantA.size(0), 48 | bT, aT, 49 | quantB.data_ptr(), quantA.data_ptr(), 50 | result_i32.data_ptr(), 51 | torch.cuda.current_stream().cuda_stream 52 | ) 53 | result_fp = torch.empty((max(quantA.size(0), quantB.size(0)), M, N), dtype=torch.float16, device=quantA.device) 54 | 55 | if scaleA is not None and scaleB is not None: 56 | gemm_scale( 57 | result_i32.size(0), M, N, 58 | result_i32.data_ptr(), 59 | scaleA.data_ptr(), scaleB.data_ptr(), 60 | result_fp.data_ptr(), 61 | quantA.size(0) == 1, 62 | quantB.size(0) == 1, 63 | torch.cuda.current_stream().cuda_stream 64 | ) 65 | elif scaleA is not None: 66 | gemm_scale_x( 67 | result_i32.size(0), M, N, 68 | result_i32.data_ptr(), 69 | scaleA.data_ptr(), 70 | result_fp.data_ptr(), 71 | torch.cuda.current_stream().cuda_stream 72 | ) 73 | else: 74 | assert scaleB is not None 75 | gemm_scale_y( 76 | result_i32.size(0), M, N, 77 | result_i32.data_ptr(), 78 | scaleB.data_ptr(), 79 | result_fp.data_ptr(), 80 | torch.cuda.current_stream().cuda_stream 81 | ) 82 | return result_fp 83 | 84 | class GEMMInt8(torch.autograd.Function): 85 | @staticmethod 86 | def forward(ctx, A : torch.Tensor, aT : bool, B : torch.Tensor, bT : bool): 87 | """ 88 | Input: 89 | - A: (batchA, M, K) 90 | - B: (batchB, K, N) 91 | Output: 92 | - C: (batch, M, N) 93 | """ 94 | assert A.is_cuda and B.is_cuda and A.device == B.device 95 | assert A.is_contiguous() and B.is_contiguous() 96 | assert A.dtype == torch.half and B.dtype == torch.half 97 | 98 | scale_A = calc_scale(A, aT) 99 | scale_B = calc_scale(B, not bT) 100 | 101 | quantized_A = round_i8(A, scale_A, aT) 102 | quantized_B = round_i8(B, scale_B, not bT) 103 | 104 | result = gemm_and_scale(quantized_A, scale_A, quantized_B, scale_B, aT, bT) 105 | 106 | # save backward 107 | ctx.save_for_backward( 108 | scale_A, quantized_A, 109 | scale_B, quantized_B 110 | ) 111 | ctx.aT = aT 112 | ctx.bT = bT 113 | return result 114 | 115 | @staticmethod 116 | def backward(ctx, grad_f : torch.Tensor): 117 | assert grad_f.is_contiguous() and grad_f.is_cuda and grad_f.dtype == torch.float16 118 | scale_A, quantized_A, scale_B, quantized_B = ctx.saved_tensors 119 | aT, bT = ctx.aT, ctx.bT 120 | 121 | batch, m, n = grad_f.size() 122 | 123 | scale_G_a = torch.empty((batch, m), dtype=torch.half, device=grad_f.device) 124 | quant_G_a = torch.empty((batch, m, n), dtype=torch.int8, device=grad_f.device) 125 | gemm_backward_round_scale( 126 | batch, m, n, 127 | grad_f.data_ptr(), 128 | scale_B.data_ptr(), 129 | quant_G_a.data_ptr(), 130 | scale_G_a.data_ptr(), 131 | scale_B.size(0) == 1, 132 | torch.cuda.current_stream().cuda_stream 133 | ) 134 | 135 | if aT: 136 | grad_A = gemm_and_scale( 137 | quantized_B, None, 138 | quant_G_a, scale_G_a, 139 | bT, True 140 | ) 141 | else: 142 | grad_A = gemm_and_scale( 143 | quant_G_a, scale_G_a, 144 | quantized_B, None, 145 | False, not bT 146 | ) 147 | del scale_G_a 148 | del quant_G_a 149 | 150 | scale_G_b = torch.empty((batch, n), dtype=torch.half, device=grad_f.device) 151 | quant_G_b = torch.empty((batch, m, n), dtype=torch.int8, device=grad_f.device) 152 | gemm_backward_scale_round( 153 | batch, m, n, 154 | grad_f.data_ptr(), 155 | scale_A.data_ptr(), 156 | quant_G_b.data_ptr(), 157 | scale_G_b.data_ptr(), 158 | scale_A.size(0) == 1, 159 | torch.cuda.current_stream().cuda_stream 160 | ) 161 | if bT: 162 | grad_B = gemm_and_scale( 163 | quant_G_b, scale_G_b, 164 | quantized_A, None, 165 | True, aT 166 | ) 167 | else: 168 | grad_B = gemm_and_scale( 169 | quantized_A, None, 170 | quant_G_b, scale_G_b, 171 | not aT, False 172 | ) 173 | 174 | if scale_A.size(0) == 1 and grad_A.size(0) > 1: 175 | grad_A = grad_A.sum(dim=0, keepdim=True) 176 | if scale_B.size(0) == 1 and grad_B.size(0) > 1: 177 | grad_B = grad_B.sum(dim=0, keepdim=True) 178 | 179 | return grad_A, None, grad_B, None 180 | 181 | 182 | def gemm_pth_fp16(A : torch.Tensor, aT : bool, B : torch.Tensor,bT : bool) -> torch.Tensor: 183 | M = A.size(2) if aT else A.size(1) 184 | K = A.size(1) if aT else A.size(2) 185 | N = B.size(1) if bT else B.size(2) 186 | out = torch.empty((max(A.size(0), B.size(0)), M, N), dtype=torch.float16, device=A.device) 187 | gemm_fp16( 188 | N, K, M, 189 | B.size(0), A.size(0), 190 | bT, aT, 191 | B.data_ptr(), A.data_ptr(), 192 | out.data_ptr(), 193 | torch.cuda.current_stream().cuda_stream 194 | ) 195 | return out 196 | 197 | class GEMMFloat(torch.autograd.Function): 198 | @staticmethod 199 | def forward(ctx, A : torch.Tensor, aT, B : torch.Tensor, bT): 200 | assert A.is_cuda and A.is_contiguous() and A.dtype == torch.half 201 | assert B.is_cuda and B.is_contiguous() and B.dtype == torch.half 202 | assert A.device == B.device 203 | 204 | ctx.save_for_backward(A, B) 205 | ctx.aT = aT 206 | ctx.bT = bT 207 | 208 | return gemm_pth_fp16(A, aT, B, bT) 209 | 210 | @staticmethod 211 | def backward(ctx, grad_f): 212 | assert grad_f.is_cuda and grad_f.is_contiguous() and grad_f.dtype == torch.float16 213 | aT = ctx.aT 214 | bT = ctx.bT 215 | A, B = ctx.saved_tensors 216 | if aT: 217 | grad_A = gemm_pth_fp16(B, bT, grad_f, True) 218 | else: 219 | grad_A = gemm_pth_fp16(grad_f, False, B, not bT) 220 | 221 | if bT: 222 | grad_B = gemm_pth_fp16(grad_f, True, A, aT) 223 | else: 224 | grad_B = gemm_pth_fp16(A, not aT, grad_f, False) 225 | 226 | if A.size(0) == 1 and grad_A.size(0) > 1: 227 | grad_A = grad_A.sum(dim=0, keepdim=True) 228 | if B.size(0) == 1 and grad_B.size(0) > 1: 229 | grad_B = grad_B.sum(dim=0, keepdim=True) 230 | 231 | return grad_A, None, grad_B, None 232 | 233 | def bmm(A : torch.Tensor, aT : bool, B : torch.Tensor, bT : bool, int8 : bool =False) -> torch.Tensor: 234 | assert A.ndim == 3 235 | assert B.ndim == 3 236 | if int8: 237 | return GEMMInt8.apply(A, aT, B, bT) 238 | else: 239 | return GEMMFloat.apply(A, aT, B, bT) 240 | -------------------------------------------------------------------------------- /cpm_kernels/torch/layernorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..kernels import layernorm_forward_mv, layernorm_forward_v, layernorm_backward_mv, layernorm_backward_v, layernorm_forward, layernorm_inplace_forward, \ 4 | arith_ln_mul_add, arith_ln_add_backward, arith_ln_mul_backward, arith_ln_mul 5 | 6 | 7 | class OpLayerNormMean(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x : torch.Tensor, eps : float, weight : torch.Tensor, bias : torch.Tensor): 10 | assert x.is_cuda and x.is_contiguous() and x.ndim == 3 and x.dtype == torch.float16 11 | assert weight.is_cuda and weight.is_contiguous() and weight.ndim == 1 and weight.dtype == torch.float16 12 | out = torch.empty((x.size(0), x.size(1), x.size(2)), device=x.device, dtype=torch.float16) 13 | layernorm_forward( 14 | x.size(0), x.size(1), x.size(2), 15 | x.data_ptr(), 16 | out.data_ptr(), 17 | eps, 18 | True, 19 | torch.cuda.current_stream().cuda_stream 20 | ) 21 | ctx.save_for_backward(x, weight) 22 | arith_ln_mul_add( 23 | out.size(0), out.size(1), out.size(2), 24 | out.data_ptr(), weight.data_ptr(), bias.data_ptr(), 25 | out.data_ptr(), 26 | torch.cuda.current_stream().cuda_stream 27 | ) 28 | ctx.eps = eps 29 | return out 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output : torch.Tensor): 33 | assert grad_output.is_cuda and grad_output.is_contiguous() and grad_output.ndim == 3 and grad_output.dtype == torch.float16 34 | x, weight = ctx.saved_tensors 35 | 36 | mean = torch.empty((x.size(0), x.size(2)), device=x.device, dtype=torch.float16) 37 | var = torch.empty((x.size(0), x.size(2)), device=x.device, dtype=torch.float16) 38 | layer_out = torch.empty((x.size(0), x.size(1), x.size(2)), device=x.device, dtype=torch.float16) 39 | layernorm_forward_mv( 40 | x.size(0), x.size(1), x.size(2), 41 | x.data_ptr(), 42 | layer_out.data_ptr(), 43 | mean.data_ptr(), 44 | var.data_ptr(), 45 | ctx.eps, 46 | torch.cuda.current_stream().cuda_stream 47 | ) 48 | 49 | grad_bias = torch.empty((x.size(1),), device=x.device, dtype=torch.float16) 50 | 51 | arith_ln_add_backward( 52 | x.size(0), x.size(1), x.size(2), 53 | grad_output.data_ptr(), 54 | grad_bias.data_ptr(), 55 | torch.cuda.current_stream().cuda_stream 56 | ) 57 | 58 | grad_weight = torch.empty((x.size(1),), device=x.device, dtype=torch.float16) 59 | arith_ln_mul_backward( 60 | x.size(0), x.size(1), x.size(2), 61 | layer_out.data_ptr(), 62 | grad_output.data_ptr(), 63 | grad_weight.data_ptr(), 64 | torch.cuda.current_stream().cuda_stream 65 | ) 66 | 67 | grad_x = torch.empty(x.size(), device=x.device, dtype=torch.float16) 68 | arith_ln_mul( 69 | x.size(0), x.size(1), x.size(2), 70 | grad_output.data_ptr(), 71 | weight.data_ptr(), 72 | grad_x.data_ptr(), 73 | torch.cuda.current_stream().cuda_stream 74 | ) 75 | 76 | grad = torch.empty_like(x) 77 | layernorm_backward_mv( 78 | x.size(0), x.size(1), x.size(2), 79 | x.data_ptr(), 80 | grad_x.data_ptr(), 81 | mean.data_ptr(), 82 | var.data_ptr(), 83 | grad.data_ptr(), 84 | torch.cuda.current_stream().cuda_stream 85 | ) 86 | return grad, None, grad_weight, grad_bias 87 | 88 | class OpLayerNormNoMean(torch.autograd.Function): 89 | @staticmethod 90 | def forward(ctx, x : torch.Tensor, eps : float, weight : torch.Tensor): 91 | assert x.is_cuda and x.is_contiguous() and x.ndim == 3 and x.dtype == torch.float16 92 | assert weight.is_cuda and weight.is_contiguous() and weight.ndim == 1 and weight.dtype == torch.float16 93 | out = torch.empty((x.size(0), x.size(1), x.size(2)), device=x.device, dtype=torch.float16) 94 | layernorm_forward( 95 | x.size(0), x.size(1), x.size(2), 96 | x.data_ptr(), 97 | out.data_ptr(), 98 | eps, 99 | False, 100 | torch.cuda.current_stream().cuda_stream 101 | ) 102 | ctx.save_for_backward(x, weight) 103 | arith_ln_mul( 104 | out.size(0), out.size(1), out.size(2), 105 | out.data_ptr(), 106 | weight.data_ptr(), 107 | out.data_ptr(), 108 | torch.cuda.current_stream().cuda_stream 109 | ) 110 | ctx.eps = eps 111 | return out 112 | 113 | @staticmethod 114 | def backward(ctx, grad_output : torch.Tensor): 115 | assert grad_output.is_cuda and grad_output.is_contiguous() and grad_output.ndim == 3 and grad_output.dtype == torch.float16 116 | x, weight = ctx.saved_tensors 117 | 118 | layer_out = torch.empty((x.size(0), x.size(1), x.size(2)), device=x.device, dtype=torch.float16) 119 | var = torch.empty((x.size(0), x.size(2)), device=x.device, dtype=torch.float16) 120 | layernorm_forward_v( 121 | x.size(0), x.size(1), x.size(2), 122 | x.data_ptr(), 123 | layer_out.data_ptr(), 124 | var.data_ptr(), 125 | ctx.eps, 126 | torch.cuda.current_stream().cuda_stream 127 | ) 128 | 129 | grad_weight = torch.empty((x.size(1),), device=x.device, dtype=torch.float16) 130 | arith_ln_mul_backward( 131 | x.size(0), x.size(1), x.size(2), 132 | layer_out.data_ptr(), 133 | grad_output.data_ptr(), 134 | grad_weight.data_ptr(), 135 | torch.cuda.current_stream().cuda_stream 136 | ) 137 | 138 | grad_x = torch.empty(x.size(), device=x.device, dtype=torch.float16) 139 | arith_ln_mul( 140 | x.size(0), x.size(1), x.size(2), 141 | grad_output.data_ptr(), 142 | weight.data_ptr(), 143 | grad_x.data_ptr(), 144 | torch.cuda.current_stream().cuda_stream 145 | ) 146 | 147 | grad = torch.empty_like(x) 148 | layernorm_backward_v( 149 | x.size(0), x.size(1), x.size(2), 150 | x.data_ptr(), 151 | grad_x.data_ptr(), 152 | var.data_ptr(), 153 | grad.data_ptr(), 154 | torch.cuda.current_stream().cuda_stream 155 | ) 156 | return grad, None, grad_weight 157 | 158 | class LayerNorm(torch.nn.Module): 159 | def __init__(self, hidden_size : int, eps : float = 1e-5, bias=True): 160 | super(LayerNorm, self).__init__() 161 | self.eps = eps 162 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 163 | self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) if bias else None 164 | 165 | def forward(self, x : torch.Tensor): 166 | assert x.is_cuda and x.is_contiguous() and x.ndim == 3 and x.dtype == torch.float16 167 | assert x.size(1) == self.weight.size(0) 168 | 169 | if self.bias is not None: 170 | return OpLayerNormMean.apply(x, self.eps, self.weight, self.bias) 171 | else: 172 | return OpLayerNormNoMean.apply(x, self.eps, self.weight) 173 | 174 | class LayerNormTH(torch.nn.Module): 175 | def __init__(self, hidden_size : int, eps : float = 1e-5, bias=True): 176 | super(LayerNormTH, self).__init__() 177 | self.eps = eps 178 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 179 | self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) if bias else None 180 | 181 | def forward(self, x : torch.Tensor): 182 | old_dtype = x.dtype 183 | x = x.to(torch.float32) 184 | var = (x**2).mean(axis=1, keepdim=True) 185 | if self.bias is not None: 186 | mean = x.mean(axis=1, keepdim=True) 187 | var = var - (mean**2) # var = E(x^2) - E(x)^2 188 | x = (x - mean) * torch.rsqrt(var + self.eps) 189 | else: 190 | x = x * torch.rsqrt(var + self.eps) 191 | if self.bias is not None: 192 | x = x * self.weight[None, :, None] + self.bias[None, :, None] 193 | else: 194 | x = x * self.weight[None, :, None] 195 | x = x.to(old_dtype) 196 | return x 197 | 198 | def normalize_inplace(x : torch.Tensor, eps : float, rd_mean : bool): 199 | assert x.is_cuda and x.is_contiguous() and x.ndim == 3 and x.dtype == torch.float16 200 | layernorm_inplace_forward( 201 | x.size(0), x.size(1), x.size(2), 202 | x.data_ptr(), 203 | eps, 204 | rd_mean, 205 | torch.cuda.current_stream().cuda_stream 206 | ) 207 | 208 | def normalizeTH(x : torch.Tensor, eps : float, rd_mean : bool) -> torch.Tensor: 209 | old_dtype = x.dtype 210 | x = x.to(torch.float32) 211 | var = (x**2).mean(axis=1, keepdim=True) 212 | if rd_mean: 213 | mean = x.mean(axis=1, keepdim=True) 214 | var = var - (mean**2) 215 | x = (x - mean) * torch.rsqrt(var + eps) 216 | else: 217 | x = x * torch.rsqrt(var + eps) 218 | return x.to(old_dtype) -------------------------------------------------------------------------------- /cpm_kernels/torch/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..kernels import mask as mask_cuda 3 | 4 | class OpMask(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, x : torch.Tensor, mask : torch.Tensor, value : float) -> torch.Tensor: 7 | assert x.is_contiguous() and x.is_cuda and x.dtype == torch.float16 and x.ndim == 3 8 | assert mask.is_contiguous() and mask.is_cuda and mask.dtype == torch.bool and mask.ndim == 2 9 | assert x.device == mask.device 10 | batch, n, m = x.size() 11 | assert mask.size() == (batch, m) 12 | 13 | out = torch.empty(x.size(), dtype=torch.float16, device=x.device) 14 | mask_cuda( 15 | batch, n, m, 16 | x.data_ptr(), 17 | mask.data_ptr(), 18 | value, 19 | out.data_ptr(), 20 | torch.cuda.current_stream().cuda_stream 21 | ) 22 | ctx.save_for_backward(mask) 23 | return out 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output : torch.Tensor) -> torch.Tensor: 27 | mask = ctx.saved_tensors[0] 28 | batch, n, m = grad_output.size() 29 | assert grad_output.is_cuda and grad_output.is_contiguous() and grad_output.dtype == torch.float16 30 | 31 | grad = torch.empty(grad_output.size(), dtype=torch.float16, device=grad_output.device) 32 | mask_cuda( 33 | batch, n, m, 34 | grad_output.data_ptr(), 35 | mask.data_ptr(), 36 | 0.0, 37 | grad.data_ptr(), 38 | torch.cuda.current_stream().cuda_stream 39 | ) 40 | return grad, None, None 41 | 42 | 43 | def mask(x : torch.Tensor, mask : torch.Tensor, value : float) -> torch.Tensor: 44 | return OpMask.apply(x, mask, value) 45 | 46 | def mask_inplace(x : torch.Tensor, mask : torch.Tensor, value : float) -> None: 47 | assert x.is_contiguous() and x.is_cuda and x.dtype == torch.float16 and x.ndim == 3 48 | assert mask.is_contiguous() and mask.is_cuda and mask.dtype == torch.bool and mask.ndim == 2 49 | assert x.device == mask.device 50 | batch, n, m = x.size() 51 | assert mask.size() == (batch, m) 52 | 53 | mask_cuda( 54 | batch, n, m, 55 | x.data_ptr(), 56 | mask.data_ptr(), 57 | value, 58 | x.data_ptr(), 59 | torch.cuda.current_stream().cuda_stream 60 | ) 61 | 62 | def maskTH(x : torch.Tensor, mask : torch.Tensor, value : float) -> torch.Tensor: 63 | return torch.where( 64 | mask[:, None, :], 65 | x, 66 | torch.scalar_tensor(value, device=x.device, dtype=x.dtype), 67 | ) 68 | -------------------------------------------------------------------------------- /cpm_kernels/torch/position_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..kernels import position_embedding_init, position_embedding_forward, position_embedding_backward 3 | import math 4 | import torch.nn.functional as F 5 | 6 | class OpPositionEmbedding(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, query_len, key_len, num_buckets, max_distance, num_heads, weight : torch.Tensor, bidirectional : bool): 9 | assert weight.is_cuda and weight.is_contiguous() and weight.dtype == torch.float16 10 | device = weight.device 11 | 12 | mapping = torch.empty( (max_distance,), dtype=torch.int32, device=device ) 13 | position_embedding_init( 14 | num_buckets, 15 | max_distance, 16 | mapping.data_ptr(), 17 | bidirectional, 18 | torch.cuda.current_stream().cuda_stream 19 | ) 20 | out = torch.empty((num_heads, key_len, query_len), device=device, dtype=torch.float16) 21 | position_embedding_forward( 22 | query_len, 23 | key_len, 24 | num_buckets, 25 | max_distance, 26 | num_heads, 27 | mapping.data_ptr(), 28 | weight.data_ptr(), 29 | out.data_ptr(), 30 | bidirectional, 31 | torch.cuda.current_stream().cuda_stream 32 | ) 33 | ctx.save_for_backward(mapping) 34 | ctx.input_args = (query_len, key_len, num_buckets, max_distance, num_heads, bidirectional) 35 | return out 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output : torch.Tensor): 39 | assert grad_output.is_cuda and grad_output.is_contiguous() and grad_output.dtype == torch.float16 40 | query_len, key_len, num_buckets, max_distance, num_heads, bidirectional = ctx.input_args 41 | mapping = ctx.saved_tensors[0] 42 | grad = torch.empty((num_heads, num_buckets), device=grad_output.device, dtype=torch.float16) 43 | position_embedding_backward( 44 | query_len, 45 | key_len, 46 | num_buckets, 47 | max_distance, 48 | num_heads, 49 | mapping.data_ptr(), 50 | grad_output.data_ptr(), 51 | grad.data_ptr(), 52 | bidirectional, 53 | torch.cuda.current_stream().cuda_stream 54 | ) 55 | return None, None, None, None, None, grad, None 56 | 57 | class PositionEmbedding(torch.nn.Module): 58 | def __init__(self, num_heads, num_buckets, max_distance, bidirectional=True): 59 | super(PositionEmbedding, self).__init__() 60 | self.weight = torch.nn.Parameter(torch.randn(num_heads, num_buckets)) 61 | 62 | self.num_heads = num_heads 63 | self.num_buckets = num_buckets 64 | self.max_distance = max_distance 65 | self.bidirectional = bidirectional 66 | 67 | 68 | def forward(self, key_len, query_len): 69 | return OpPositionEmbedding.apply(query_len, key_len, self.num_buckets, self.max_distance, self.num_heads, self.weight, self.bidirectional) 70 | 71 | 72 | 73 | class PositionEmbeddingTH(torch.nn.Module): 74 | def __init__(self, num_heads, num_buckets, max_distance, bidirectional=True) -> None: 75 | super(PositionEmbeddingTH, self).__init__() 76 | self.num_buckets = num_buckets 77 | self.num_heads = num_heads 78 | self.max_distance = max_distance 79 | self.bidirectional = bidirectional 80 | 81 | # self.embedding = weight(self.num_buckets, self.num_heads) 82 | self.weight = torch.nn.Parameter(torch.randn(num_heads, num_buckets)) 83 | 84 | def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128): 85 | relative_buckets = 0 86 | if bidirectional: 87 | num_buckets //= 2 88 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 89 | relative_position = torch.abs(relative_position) 90 | else: 91 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 92 | # now relative_position is in the range [0, inf) 93 | 94 | # half of the buckets are for exact increments in positions 95 | max_exact = num_buckets // 2 96 | is_small = relative_position < max_exact 97 | 98 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 99 | relative_postion_if_large = max_exact + ( 100 | torch.log(relative_position.float() / max_exact) 101 | / math.log(max_distance / max_exact) 102 | * (num_buckets - max_exact) 103 | ).to(torch.long) 104 | relative_postion_if_large = torch.min( 105 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) 106 | ) 107 | 108 | relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) 109 | return relative_buckets 110 | 111 | def compute_bias(self, query_length, key_length): 112 | """ Compute binned relative position bias """ 113 | device = self.weight.device 114 | context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] 115 | memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] 116 | relative_position = memory_position - context_position # shape (query_length, key_length) 117 | relative_position_bucket = self.relative_position_bucket( 118 | relative_position, 119 | bidirectional=self.bidirectional, 120 | num_buckets=self.num_buckets, 121 | max_distance=self.max_distance, 122 | ) 123 | values = F.embedding(relative_position_bucket, self.weight.transpose(0, 1)) 124 | values = values.permute([2, 1, 0]) # shape (num_heads, key_length, query_length) 125 | return values 126 | 127 | def forward(self, key_length, query_length): 128 | return self.compute_bias(query_length, key_length) 129 | -------------------------------------------------------------------------------- /cpm_kernels/torch/softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..kernels import softmax_forward, softmax_backward, softmax_inplace_forward 3 | 4 | class OpSoftmax(torch.autograd.Function): 5 | """ 6 | Softmax dim=1 7 | """ 8 | @staticmethod 9 | def forward(ctx, x : torch.Tensor): 10 | assert x.is_cuda and x.is_contiguous() and x.dtype == torch.half 11 | assert x.ndim == 3 12 | out = torch.empty(x.size(), device=x.device, dtype=torch.half) 13 | softmax_forward( 14 | x.size(0), x.size(1), x.size(2), 15 | x.data_ptr(), out.data_ptr(), 16 | torch.cuda.current_stream().cuda_stream 17 | ) 18 | ctx.save_for_backward(out) 19 | return out 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output : torch.Tensor): 23 | assert grad_output.is_cuda and grad_output.is_contiguous() and grad_output.dtype == torch.half 24 | assert grad_output.ndim == 3 25 | out = ctx.saved_tensors[0] 26 | grad = torch.empty(grad_output.size(), device=grad_output.device, dtype=torch.half) 27 | softmax_backward( 28 | grad_output.size(0), grad_output.size(1), grad_output.size(2), 29 | out.data_ptr(), 30 | grad_output.data_ptr(), 31 | grad.data_ptr(), 32 | torch.cuda.current_stream().cuda_stream 33 | ) 34 | return grad 35 | 36 | def softmax(x : torch.Tensor) -> torch.Tensor: 37 | return OpSoftmax.apply(x) 38 | 39 | def softmaxTH(x : torch.Tensor) -> torch.Tensor: 40 | return torch.nn.functional.softmax(x, dim=1) 41 | 42 | def softmax_inplace(x : torch.Tensor) -> None: 43 | assert x.is_cuda and x.ndim == 3 and x.is_contiguous() and x.dtype == torch.half 44 | softmax_inplace_forward( 45 | x.size(0), x.size(1), x.size(2), 46 | x.data_ptr(), 47 | torch.cuda.current_stream().cuda_stream 48 | ) -------------------------------------------------------------------------------- /cpm_kernels/torch/transpose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..kernels import transpose as trans_func 3 | 4 | 5 | class OpTranspose(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, x : torch.Tensor): 8 | assert x.is_contiguous() and x.is_cuda and x.dtype == torch.half and x.ndim == 3 9 | out = torch.empty((x.size(0), x.size(2), x.size(1)), dtype=torch.half, device=x.device) 10 | trans_func( 11 | x.size(0), x.size(1), x.size(2), 12 | x.data_ptr(), 13 | out.data_ptr(), 14 | torch.cuda.current_stream().cuda_stream 15 | ) 16 | return out 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output : torch.Tensor): 20 | assert grad_output.is_contiguous() and grad_output.is_cuda and grad_output.dtype == torch.half and grad_output.ndim == 3 21 | grad = torch.empty((grad_output.size(0), grad_output.size(2), grad_output.size(1)), dtype=torch.half, device=grad_output.device) 22 | trans_func( 23 | grad_output.size(0), grad_output.size(1), grad_output.size(2), 24 | grad_output.data_ptr(), 25 | grad.data_ptr(), 26 | torch.cuda.current_stream().cuda_stream 27 | ) 28 | return grad 29 | 30 | def transpose(x : torch.Tensor) -> torch.Tensor: 31 | return OpTranspose.apply(x) 32 | 33 | def transposeTH(x : torch.Tensor) -> torch.Tensor: 34 | assert x.ndim == 3 35 | return x.transpose(1, 2) -------------------------------------------------------------------------------- /cpm_kernels/torch/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from ..kernels import has_nan_inf as kn_has_nan_inf 4 | 5 | def has_nan_inf(x : torch.Tensor, out : Optional[torch.Tensor] = None) -> torch.Tensor: 6 | assert x.is_cuda and x.is_contiguous() and x.dtype == torch.half 7 | if out is None: 8 | out = torch.empty(1, dtype=torch.bool, device=x.device)[0] 9 | kn_has_nan_inf( 10 | x.numel(), x.data_ptr(), 11 | out.data_ptr(), 12 | torch.cuda.current_stream().cuda_stream 13 | ) 14 | return out -------------------------------------------------------------------------------- /cuda/Makefile: -------------------------------------------------------------------------------- 1 | NVCC=nvcc 2 | OPTIONS=-Iincludes \ 3 | -gencode arch=compute_61,code=sm_61 \ 4 | -gencode arch=compute_62,code=sm_62 \ 5 | -gencode arch=compute_70,code=sm_70 \ 6 | -gencode arch=compute_72,code=sm_72 \ 7 | -gencode arch=compute_75,code=sm_75 \ 8 | -gencode arch=compute_80,code=sm_80 \ 9 | -gencode arch=compute_86,code=sm_86 10 | 11 | TARGETS=$(patsubst %.cu, %.fatbin, $(wildcard *.cu)) 12 | 13 | all: $(TARGETS) 14 | 15 | %.fatbin: %.cu 16 | $(NVCC) -fatbin $^ $(OPTIONS) -o $@ 17 | 18 | .PHONY : clean, copy 19 | clean: 20 | rm $(TARGETS) 21 | 22 | copy: 23 | mkdir -p ../cpm_kernels/kernels/cuda 24 | cp $(TARGETS) ../cpm_kernels/kernels/cuda 25 | -------------------------------------------------------------------------------- /cuda/arith.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include "common.h" 3 | #include 4 | 5 | // block , thread 6 | CPM_KERNEL_EXPORT void cu_arith_global_scale( 7 | int64_t n, 8 | const half *inp, // (n,) 9 | float scale, 10 | half *out // (n,) 11 | ) { 12 | int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; 13 | if (idx < n) { 14 | out[idx] = __float2half(__half2float(inp[idx]) * scale); 15 | } 16 | } 17 | 18 | // block , thread, half n 19 | CPM_KERNEL_EXPORT void cu_arith_element_add ( 20 | int64_t batch, int64_t n, 21 | const half2 *x, // (batch, n) 22 | const half2 *y, // (batch, n) 23 | half2 *out 24 | ) { 25 | int64_t col = threadIdx.x + blockIdx.y * blockDim.x; 26 | int64_t pos = blockIdx.x * n + col; 27 | if (col < n) { 28 | out[pos] = __hadd2(x[pos], y[pos]); 29 | } 30 | } 31 | 32 | // block , thread, half n 33 | CPM_KERNEL_EXPORT void cu_arith_element_mul ( 34 | int64_t batch, int64_t n, 35 | const half2 *x, // (batch, n) 36 | const half2 *y, // (batch, n) 37 | half2 *out 38 | ) { 39 | int64_t col = threadIdx.x + blockIdx.y * blockDim.x; 40 | int64_t pos = blockIdx.x * n + col; 41 | if (col < n) { 42 | out[pos] = __hmul2(x[pos], y[pos]); 43 | } 44 | } 45 | 46 | // block , thread, half n 47 | CPM_KERNEL_EXPORT void cu_arith_batch_add_forward( 48 | int64_t batch, int64_t n, 49 | const half2 *x, // (batch, n) 50 | const half2 *y, // (n) 51 | half2 *out // (batch, n) 52 | ) { 53 | int64_t col = threadIdx.x + blockIdx.y * blockDim.x; 54 | int64_t pos = blockIdx.x * n + col; 55 | if (col < n) { 56 | out[pos] = __hadd2(x[pos], __ldg(y + col)); 57 | } 58 | } 59 | 60 | // block , thread 61 | CPM_KERNEL_EXPORT void cu_arith_batch_add_backward( 62 | int64_t batch, int64_t n, 63 | const half *grad_out, // (batch, n) 64 | half *grad // (n) 65 | ) { 66 | int64_t col = blockIdx.x * blockDim.x + threadIdx.x; 67 | float sum = 0; 68 | for (int i = 0; i < batch; i += blockDim.y) { 69 | if (i + threadIdx.y < batch && col < n) { 70 | sum += __half2float(grad_out[(i + threadIdx.y) * n + col]); 71 | } 72 | } 73 | sum = transposeReduceSum(sum); // does not support half2 74 | if (threadIdx.y == 0) { 75 | grad[col] = __float2half(sum); 76 | } 77 | } 78 | 79 | // block thread, half m 80 | CPM_KERNEL_EXPORT void cu_arith_ln_add( 81 | int64_t batch, int64_t n, int64_t m, 82 | const half2 *x, // (batch, n, m) 83 | const half *beta, // (n) 84 | half2 *out // (batch, n, m) 85 | ) { 86 | int64_t col = threadIdx.x + blockIdx.z * blockDim.x; 87 | int64_t base_x_idx = (blockIdx.x * n + blockIdx.y) * m + col; 88 | half2 beta_v = __half2half2(__ldg(beta + blockIdx.y)); 89 | 90 | if (col < m) { 91 | out[base_x_idx] = __hadd2(x[base_x_idx], beta_v); 92 | } 93 | } 94 | 95 | 96 | // block thread, half m 97 | CPM_KERNEL_EXPORT void cu_arith_ln_mul_add( 98 | int64_t batch, int64_t n, int64_t m, 99 | const half2 *x, // (batch, n, m) 100 | const half *alpha, // (n) 101 | const half *beta, // (n) 102 | half2 *out // (batch, n, m) 103 | ) { 104 | int64_t col = threadIdx.x + blockIdx.z * blockDim.x; 105 | int64_t base_x_idx = (blockIdx.x * n + blockIdx.y) * m + col; 106 | half2 alpha_v = __half2half2(__ldg(alpha + blockIdx.y)); 107 | half2 beta_v = __half2half2(__ldg(beta + blockIdx.y)); 108 | 109 | if (col < m) { 110 | out[base_x_idx] = __hfma2(x[base_x_idx], alpha_v, beta_v); 111 | } 112 | } 113 | 114 | // block thread, half m 115 | CPM_KERNEL_EXPORT void cu_arith_ln_mul( 116 | int64_t batch, int64_t n, int64_t m, 117 | const half2 *x, // (batch, n, m) 118 | const half *alpha, // (n) 119 | half2 *out 120 | ) { 121 | int64_t col = threadIdx.x + blockIdx.z * blockDim.x; 122 | int64_t base_x_idx = (blockIdx.x * n + blockIdx.y) * m + col; 123 | half2 alpha_v = __half2half2(__ldg(alpha + blockIdx.y)); 124 | if (col < m) { 125 | out[base_x_idx] = __hmul2(x[base_x_idx], alpha_v); 126 | } 127 | } 128 | 129 | 130 | // block thread, half m 131 | CPM_KERNEL_EXPORT void cu_arith_ln_div( 132 | int64_t batch, int64_t n, int64_t m, 133 | const half2 *x, // (batch, n, m) 134 | const half *alpha, // (n) 135 | half2 *out 136 | ) { 137 | int64_t col = threadIdx.x + blockIdx.z * blockDim.x; 138 | int64_t base_x_idx = (blockIdx.x * n + blockIdx.y) * m + col; 139 | half2 alpha_v = __half2half2(__hdiv(__float2half(1.0), __ldg(alpha + blockIdx.y))); 140 | if (col < m) { 141 | out[base_x_idx] = __hmul2(x[base_x_idx], alpha_v); 142 | } 143 | } 144 | 145 | // block thread, half m 146 | CPM_KERNEL_EXPORT void cu_arith_ln_sub_div( 147 | int64_t batch, int64_t n, int64_t m, 148 | const half2 *x, // (batch, n, m) 149 | const half *alpha, // (n) 150 | const half *beta, // (n) 151 | half2* out 152 | ) { 153 | int64_t col = threadIdx.x + blockIdx.z * blockDim.x; 154 | int64_t base_x_idx = (blockIdx.x * n + blockIdx.y) * m + col; 155 | float rev_alpha = 1.0 / (float)(__ldg(alpha + blockIdx.y)); 156 | float neg_beta = - (float)(__ldg(beta + blockIdx.y)) * rev_alpha; 157 | 158 | half2 alpha_v = __float2half2_rn(rev_alpha); // 1 / alpha 159 | half2 beta_v = __float2half2_rn(neg_beta); // - beta / alpha 160 | if (col < m) { 161 | out[base_x_idx] = __hfma2(x[base_x_idx], alpha_v, beta_v); 162 | } 163 | } 164 | 165 | 166 | 167 | // block thread<32, 32> 168 | CPM_KERNEL_EXPORT void cu_arith_ln_mul_backward( 169 | int32_t batch, int32_t n, int32_t m, 170 | const half *x, // (batch, n, m) 171 | const half *grad_out, // (batch, n, m) 172 | half *grad // (n) 173 | ) { 174 | /* 175 | reduce_sum(x * grad_out) 176 | */ 177 | float local_sum = 0; 178 | for (int b = 0; b < batch; b += WARP_SZ * WARP_SZ) { 179 | float inner_sum = 0; 180 | for (int inner_b = 0; inner_b < WARP_SZ * WARP_SZ && inner_b + b < batch; inner_b += WARP_SZ) { 181 | int batch_idx = b + inner_b + threadIdx.y; 182 | int base_idx = batch_idx * n * m + blockIdx.x * m + threadIdx.x; 183 | 184 | float batch_sum = 0; 185 | for (int i = 0; i < m; i += WARP_SZ * WARP_SZ) { 186 | float inner_v = 0; 187 | for (int j = 0; j < WARP_SZ * WARP_SZ && i + j < m; j += WARP_SZ) { 188 | float v = 0; 189 | if (batch_idx < batch && i + j + threadIdx.x < m) { 190 | v = (float)grad_out[base_idx + i + j] * (float)x[base_idx + i + j]; 191 | } 192 | v = warpReduceSum(v); // sum of 32 elements 193 | v = __shfl_sync(0xFFFFFFFF, v, 0); // broadcast to all threads in warp 194 | if (threadIdx.x * WARP_SZ == j) inner_v = v; 195 | } 196 | inner_v = warpReduceSum(inner_v); // sum of 1024 elements 197 | 198 | // stores the sum of batch (b + inner_b + threadIdx.y) in (0, threadIdx.y) 199 | batch_sum += inner_v; // sum of all elements in batch 200 | } 201 | 202 | batch_sum = transposeReduceSum(batch_sum); // sum of 32 batches 203 | if (threadIdx.y * WARP_SZ == inner_b) inner_sum = batch_sum; 204 | } 205 | inner_sum = transposeReduceSum(inner_sum); // sum of 1024 batches 206 | local_sum += inner_sum; // sum of all batches 207 | } 208 | 209 | 210 | if (threadIdx.x == 0 && threadIdx.y == 0) { 211 | grad[blockIdx.x] = __float2half(local_sum); 212 | } 213 | } 214 | 215 | 216 | // block thread<32, 32> 217 | CPM_KERNEL_EXPORT void cu_arith_ln_add_backward( 218 | int64_t batch, int64_t n, int64_t m, 219 | const half *grad_out, // (batch, n, m) 220 | half *grad // (n) 221 | ) { 222 | 223 | float local_sum = 0; 224 | for (int b = 0; b < batch; b += WARP_SZ * WARP_SZ) { 225 | float inner_sum = 0; 226 | for (int inner_b = 0; inner_b < WARP_SZ * WARP_SZ && inner_b + b < batch; inner_b += WARP_SZ) { 227 | int batch_idx = b + inner_b + threadIdx.y; 228 | int base_idx = batch_idx * n * m + blockIdx.x * m + threadIdx.x; 229 | 230 | float batch_sum = 0; 231 | for (int i = 0; i < m; i += WARP_SZ * WARP_SZ) { 232 | float inner_v = 0; 233 | for (int j = 0; j < WARP_SZ * WARP_SZ && i + j < m; j += WARP_SZ) { 234 | float v = 0; 235 | if (batch_idx < batch && i + j + threadIdx.x < m) { 236 | v = (float)grad_out[base_idx + i + j]; 237 | } 238 | v = warpReduceSum(v); // sum of 32 elements 239 | v = __shfl_sync(0xFFFFFFFF, v, 0); // broadcast to all threads in warp 240 | if (threadIdx.x * WARP_SZ == j) inner_v = v; 241 | } 242 | inner_v = warpReduceSum(inner_v); // sum of 1024 elements 243 | 244 | // stores the sum of batch (b + inner_b + threadIdx.y) in (0, threadIdx.y) 245 | batch_sum += inner_v; // sum of all elements in batch 246 | } 247 | 248 | batch_sum = transposeReduceSum(batch_sum); // sum of 32 batches 249 | if (threadIdx.y * WARP_SZ == inner_b) inner_sum = batch_sum; 250 | } 251 | inner_sum = transposeReduceSum(inner_sum); // sum of 1024 batches 252 | local_sum += inner_sum; // sum of all batches 253 | } 254 | 255 | 256 | if (threadIdx.x == 0 && threadIdx.y == 0) { 257 | grad[blockIdx.x] = __float2half(local_sum); 258 | } 259 | } 260 | 261 | 262 | 263 | // block thread, half n 264 | CPM_KERNEL_EXPORT void cu_arith_batch_mul_add( 265 | int64_t batch, int64_t n, 266 | const half2 *x, // (batch, n) 267 | const half2 *alpha, // (n) 268 | const half2 *beta, // (n) 269 | half2 *out // (batch, n) 270 | ) { 271 | int64_t col = threadIdx.x + blockIdx.y * blockDim.x; 272 | if (col < n) { 273 | out[blockIdx.x * n + col] = __hfma2(x[blockIdx.x * n + col], __ldg(alpha + col), __ldg(beta + col)); 274 | } 275 | } 276 | 277 | // block thread, half n 278 | CPM_KERNEL_EXPORT void cu_arith_batch_mul( 279 | int64_t batch, int64_t n, 280 | const half2 *x, // (batch, n) 281 | const half2 *alpha, // (n) 282 | half2 *out // (batch, n) 283 | ) { 284 | int64_t col = threadIdx.x + blockIdx.y * blockDim.x; 285 | if (col < n) { 286 | out[blockIdx.x * n + col] = __hmul2(x[blockIdx.x * n + col], __ldg(alpha + col)); 287 | } 288 | } -------------------------------------------------------------------------------- /cuda/embedding.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include "common.h" 3 | #include 4 | 5 | // block , thread <32, 32> 6 | CPM_KERNEL_EXPORT void cu_embedding_forward( 7 | int32_t batch, int32_t n, int32_t m, 8 | const int32_t *ids, // (batch, m) 9 | const half *weights, // (vocab_size, n) 10 | half *out // (batch, n, m) 11 | ) { 12 | __shared__ half shared[WARP_SZ][WARP_SZ + 1]; 13 | 14 | int32_t col_in_idx = blockIdx.y * WARP_SZ + threadIdx.y; 15 | int32_t col_out_idx = blockIdx.y * WARP_SZ + threadIdx.x; 16 | int32_t offset_n = blockIdx.z * WARP_SZ; 17 | const half *base_weight = weights + (col_in_idx < m ? (ids[blockIdx.x * m + col_in_idx] * n) : 0) + threadIdx.x; 18 | 19 | int32_t base_out_idx = blockIdx.x * n * m + threadIdx.y * m + col_out_idx; 20 | 21 | if (offset_n + threadIdx.x < n) { 22 | shared[threadIdx.y][threadIdx.x] = __ldg(base_weight + offset_n); 23 | } else { 24 | shared[threadIdx.y][threadIdx.x] = __float2half(0); 25 | } 26 | // load multiple data from weights 27 | __syncthreads(); 28 | // write multiple data to out (blockIdx.x, i + threadIdx.y, col_idx) 29 | if (offset_n + threadIdx.y < n && col_out_idx < m) { 30 | out[ base_out_idx + offset_n * m ] = shared[threadIdx.x][threadIdx.y]; 31 | } 32 | 33 | } 34 | 35 | // block thread<1024> 36 | CPM_KERNEL_EXPORT void cu_embedding_backward_stage1( 37 | int32_t batch, int32_t n, int32_t m, 38 | const half *grad_out, // (batch, n, m) 39 | const int32_t *argsort_ids, // (batch, n) 40 | const int32_t *sorted_ids, // (batch, n) 41 | half *grad, // (vocab_size, m) 42 | half *aux_grad, // (batch, m) 43 | int32_t *aux_grad_idx // (batch) 44 | ) { 45 | float sum = 0; 46 | int32_t baes_n_idx = blockIdx.x * n; 47 | int32_t col = blockIdx.y * WARP_SZ * WARP_SZ + threadIdx.x; 48 | 49 | if (col < m) { 50 | for (int i = 0; i < n; ++ i) { 51 | float v = (float)(grad_out[ argsort_ids[baes_n_idx + i] * m + col ]); 52 | sum += v; 53 | if (i + 1 == n) { 54 | aux_grad[blockIdx.x * m + col] = __float2half(sum); 55 | if (col == 0) aux_grad_idx[blockIdx.x] = __ldg(sorted_ids + baes_n_idx + i); 56 | } 57 | else if ( __ldg(sorted_ids + baes_n_idx + i) != __ldg(sorted_ids + baes_n_idx + i + 1)) { 58 | grad[ __ldg(sorted_ids + baes_n_idx + i) * m + col ] = __float2half(sum); 59 | sum = 0; 60 | } 61 | } 62 | } 63 | } 64 | 65 | 66 | // block thread<1024> 67 | CPM_KERNEL_EXPORT void cu_embedding_backward_stage2( 68 | int32_t batch, int32_t m, 69 | const half *aux_grad, // (batch, m) 70 | const int32_t *aux_grad_idx, // (batch) 71 | half *grad // (vocab_size, m) 72 | ) { 73 | float sum = 0; 74 | int32_t col = blockIdx.x * WARP_SZ * WARP_SZ + threadIdx.x; 75 | if (col < m) { 76 | for (int i = 0; i < batch; ++ i) { 77 | float v = (float)(aux_grad[i * m + col]); 78 | sum += v; 79 | if (i + 1 == batch || __ldg(aux_grad_idx + i) != __ldg(aux_grad_idx + i + 1)) { 80 | float v2 = (float)(grad[ __ldg(aux_grad_idx + i) * m + col ]); 81 | grad[ __ldg(aux_grad_idx + i) * m + col ] = __float2half(v2 + sum); 82 | sum = 0; 83 | } 84 | } 85 | } 86 | } 87 | 88 | 89 | // block , thread 90 | CPM_KERNEL_EXPORT void cu_embedding_step( 91 | int32_t batch, int32_t n, 92 | const int32_t *ids, // (batch) 93 | const half *weights, // (vocab_size, n) 94 | half *out // (batch, n) 95 | ) { 96 | int32_t id = ids[blockIdx.x]; 97 | const half *base_weight = weights + id * n; 98 | half *base_out = out + blockIdx.x * n; 99 | 100 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 101 | base_out[i] = __ldg(base_weight + i); 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /cuda/gelu.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include "common.h" 3 | #include 4 | 5 | // block thread<1024> 6 | CPM_KERNEL_EXPORT void cu_gelu_forward( 7 | int32_t batch, int32_t n, 8 | const half *mat, // (batch, n) 9 | half *out // (batch, n) 10 | ) { 11 | int32_t col_idx = blockIdx.y * blockDim.x + threadIdx.x; 12 | if (col_idx < n) { 13 | float x = __half2float(mat[blockIdx.x * n + col_idx]); 14 | x = 0.5 * x * (1.0 + tanhf(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))); 15 | out[blockIdx.x * n + col_idx] = __float2half(x); 16 | } 17 | } 18 | 19 | // block thread<1024> 20 | CPM_KERNEL_EXPORT void cu_gelu_backward( 21 | int32_t batch, int32_t n, 22 | const half *grad_out, // (batch, n) 23 | const half *mat, // (batch, n) 24 | half *grad // (batch, n) 25 | ) { 26 | int32_t col_idx = blockIdx.y * blockDim.x + threadIdx.x; 27 | int32_t offset = blockIdx.x * n + col_idx; 28 | if (col_idx < n) { 29 | float v = __half2float( grad_out[offset] ); 30 | float x = __half2float( mat[offset] ); 31 | float gelu_grad; 32 | 33 | if (-5 < x && x < 5) { 34 | float x3 = x * x * x; 35 | float sech2 = 1.0 / coshf(0.797885 * x + 0.0356774 * x3); 36 | sech2 = sech2 * sech2; 37 | 38 | gelu_grad = 0.5 + (0.398942 * x + 0.0535161 * x3) * sech2 + 0.5 * tanhf(0.797885 * x + 0.0356774 * x3); 39 | } 40 | else { 41 | gelu_grad = x < 0 ? 0 : 1; 42 | } 43 | grad[offset] = __float2half(gelu_grad * v); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /cuda/gemm.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include 3 | #include "common.h" 4 | 5 | // block thread 6 | CPM_KERNEL_EXPORT void cu_gemm_round( 7 | int32_t batch, int32_t n, int32_t m, 8 | const half *mat, // b, n, m 9 | const half *scale, // b, n 10 | int8_t *out 11 | ) { 12 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m; // mat[batch][n][m], scale[batch][n] 13 | half local_scale = scale[blockIdx.x * n + blockIdx.y]; 14 | 15 | for (int32_t i = threadIdx.x; i < m; i += blockDim.x) { 16 | out[base_idx + i] = (int8_t)nearbyintf((float)__ldg(mat + base_idx + i) / (float)local_scale); 17 | } 18 | } 19 | 20 | 21 | // block thread 22 | CPM_KERNEL_EXPORT void cu_gemm_round_transpose( 23 | int32_t batch, int32_t n, int32_t m, 24 | const half *mat, // b, n, m 25 | const half *scale, // b, m 26 | int8_t *out 27 | ) { 28 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m; // mat[batch][n][m], scale[batch][m] 29 | 30 | for (int32_t i = threadIdx.x; i < m; i += blockDim.x) { 31 | out[base_idx + i] = (int8_t)nearbyintf((float)mat[base_idx + i] / (float)__ldg(scale + blockIdx.x * m + i)); 32 | } 33 | } 34 | 35 | 36 | // grid thread 37 | CPM_KERNEL_EXPORT void cu_gemm_scale( 38 | int32_t batch, int32_t n, int32_t m, 39 | const int32_t *mat, // b, n, m 40 | const half *scale_x, // b, n 41 | const half *scale_y, // b, m 42 | half *out, 43 | bool broad_cast_x, bool broad_cast_y 44 | ) { 45 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m; 46 | float scale_x_value = 0; 47 | if (broad_cast_x) { 48 | scale_x_value = (float)__ldg(scale_x + blockIdx.y); 49 | } else { 50 | scale_x_value = (float)__ldg(scale_x + blockIdx.x * n + blockIdx.y); 51 | } 52 | const half* base_scale_y = broad_cast_y ? scale_y : (scale_y + blockIdx.x * m); 53 | 54 | for (int32_t i = threadIdx.x; i < m; i += blockDim.x){ 55 | out[base_idx + i] = __float2half((float)mat[base_idx + i] * scale_x_value * (float)__ldg(base_scale_y + i)); 56 | } 57 | } 58 | 59 | // grid thread 60 | CPM_KERNEL_EXPORT void cu_gemm_calc_scale( 61 | int32_t batch, int32_t n, int32_t m, 62 | const half *mat, // b, n, m 63 | half *out // b, n 64 | ) { 65 | float local_max = 0; 66 | 67 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m; 68 | for (int32_t i = threadIdx.x; i < m; i += blockDim.x){ 69 | local_max = fmaxf(fabsf((float)(mat[base_idx + i])), local_max); 70 | } 71 | local_max = blockReduceMax(local_max); 72 | 73 | if (threadIdx.x == 0) { 74 | out[ blockIdx.x * n + blockIdx.y ] = __float2half(local_max / 127.0); 75 | } 76 | } 77 | 78 | // grid thread 79 | CPM_KERNEL_EXPORT void cu_gemm_calc_scale_transpose( 80 | int32_t batch, int32_t n, int32_t m, 81 | const half *in, // b, n, m 82 | half *out // b, m 83 | ) { 84 | int32_t col_idx = blockIdx.y * WARP_SZ + threadIdx.x; 85 | int32_t base_idx = (blockIdx.x * n + threadIdx.y) * m + col_idx; 86 | 87 | float local_max = 0.0; 88 | for (int32_t i = 0; i < n; i += WARP_SZ) { 89 | // put & transpose 90 | if (i + threadIdx.y < n && col_idx < m) { 91 | local_max = fmaxf(fabsf((float)(in[base_idx + i * m])), local_max); 92 | } 93 | } 94 | local_max = transposeReduceMax(local_max); 95 | if (threadIdx.y == 0 && col_idx < m) { 96 | out[blockIdx.x * m + col_idx] = __float2half(local_max / 127.0); 97 | } 98 | } 99 | 100 | // Backward 101 | 102 | // grid , thread 103 | CPM_KERNEL_EXPORT void cu_gemm_backward_round_scale( 104 | int32_t batch, int32_t n, int32_t m, 105 | const half *mat, // (batch, n, m) 106 | const half *scale_y, // (batch, m) or (1, m) if broadcast_y 107 | int8_t *out, // (batch, n, m) 108 | half *scale_x, // (batch, n) 109 | bool broad_cast_y 110 | ) { 111 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m + threadIdx.x; 112 | int32_t base_m_idx = blockIdx.x * m + threadIdx.x; 113 | if (broad_cast_y) base_m_idx = threadIdx.x; 114 | 115 | float local_max = 0; 116 | __shared__ float global_max; 117 | for (int i = 0; i < m; i += blockDim.x) { 118 | if (i + threadIdx.x < m) { 119 | local_max = fmaxf(fabsf((float)(mat[base_idx + i]) * (float)__ldg(scale_y + base_m_idx + i)), local_max); 120 | } 121 | } 122 | local_max = blockReduceMax(local_max) / 127.0; 123 | if (threadIdx.x == 0) { 124 | global_max = local_max; 125 | scale_x[blockIdx.x * n + blockIdx.y] = __float2half(local_max); 126 | } 127 | __syncthreads(); 128 | local_max = global_max; 129 | for (int i = 0; i < m; i += blockDim.x) { 130 | if (i + threadIdx.x < m) { 131 | out[base_idx + i] = (int8_t)nearbyintf((float)mat[base_idx + i] * (float)__ldg(scale_y + base_m_idx + i) / local_max); 132 | } 133 | } 134 | } 135 | 136 | // grid , thread 137 | CPM_KERNEL_EXPORT void cu_gemm_backward_scale_round( 138 | int32_t batch, int32_t n, int32_t m, 139 | const half *mat, // (batch, n, m) 140 | const half *scale_x, // (batch, n) or (1, n) if broad_cast_x 141 | int8_t *out, // (batch, n, m) 142 | half *scale_y, // (batch, m) 143 | bool broad_cast_x 144 | ) { 145 | int32_t col = blockIdx.y * WARP_SZ + threadIdx.x; 146 | int32_t base_idx = (blockIdx.x * n + threadIdx.y) * m + col; 147 | int32_t base_n_idx = blockIdx.x * n + threadIdx.y; 148 | if (broad_cast_x) base_n_idx = threadIdx.y; 149 | 150 | float local_max = 0; 151 | 152 | if (col < m) { 153 | for (int i = 0; i < n; i += blockDim.y) { 154 | if (i + threadIdx.y < n) { 155 | local_max = fmaxf(fabsf( (float)mat[base_idx + i * m] * (float)__ldg(scale_x + base_n_idx + i) ), local_max); 156 | } 157 | } 158 | } 159 | local_max = transposeReduceMax(local_max); // reduce max along y 160 | local_max = local_max / 127.0; 161 | if (threadIdx.y == 0 && col < m) { 162 | scale_y[blockIdx.x * m + col] = __float2half(local_max); 163 | } 164 | 165 | if (col < m) { 166 | for (int i = 0; i < n; i += blockDim.y) { 167 | if (i + threadIdx.y < n) { 168 | out[base_idx + i * m] = (int8_t)nearbyintf((float)mat[base_idx + i * m] * (float)__ldg(scale_x + base_n_idx + i) / local_max); 169 | } 170 | } 171 | } 172 | } 173 | 174 | 175 | // block , thread 176 | CPM_KERNEL_EXPORT void cu_gemm_scale_x ( 177 | int32_t batch, int32_t n, int32_t m, 178 | const int32_t *mat, // (batch, n, m) 179 | const half *scale_x, // (batch, n) 180 | half *out // (batch, n, m) 181 | ) { 182 | float scale = scale_x[blockIdx.x * n + blockIdx.y]; 183 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m + threadIdx.x; 184 | for (int i = 0; i < m; i += blockDim.x) { 185 | if (i + threadIdx.x < m) { 186 | out[base_idx + i] = __float2half(scale * (float)mat[base_idx + i]); 187 | } 188 | } 189 | } 190 | 191 | // block , thread 192 | CPM_KERNEL_EXPORT void cu_gemm_scale_y ( 193 | int32_t batch, int32_t n, int32_t m, 194 | const int32_t *mat, // (batch, n, m) 195 | const half *scale_y, // (batch, m) 196 | half *out // (batch, n, m) 197 | ) { 198 | int32_t base_idx = (blockIdx.x * n + blockIdx.y) * m + threadIdx.x; 199 | int32_t base_m_idx = blockIdx.x * m + threadIdx.x; 200 | for (int i = 0; i < m; i += blockDim.x) { 201 | if (i + threadIdx.x < m) { 202 | out[base_idx + i] = __float2half((float)mat[base_idx + i] * (float)__ldg(scale_y + base_m_idx + i)); 203 | } 204 | } 205 | } -------------------------------------------------------------------------------- /cuda/gemv.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include 3 | #include "common.h" 4 | #include 5 | 6 | 7 | __inline__ __device__ int32_t warpReduceSumInt32(int32_t x) { 8 | for (int offset = warpSize/2; offset > 0; offset /= 2) 9 | x += __shfl_down_sync(0xFFFFFFFF, x, offset); 10 | return x; 11 | } 12 | 13 | __inline__ __device__ int32_t blockReduceSumInt32(int32_t x) { 14 | static __shared__ int32_t shared[WARP_SZ]; // blockDim.x / warpSize 15 | int lane = threadIdx.x % warpSize; 16 | int wid = threadIdx.x / warpSize; 17 | x = warpReduceSumInt32(x); 18 | if (lane == 0) shared[wid] = x; 19 | __syncthreads(); 20 | x = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; 21 | if (wid == 0) x = warpReduceSumInt32(x); 22 | return x; 23 | } 24 | 25 | // block , thread 26 | CPM_KERNEL_EXPORT void cu_gemv_calc_scale( 27 | int32_t batch, int32_t n, 28 | const half *vec, // 29 | half *out // 30 | ) { 31 | int32_t base_vec = blockIdx.x * n; 32 | float local_max = 0; 33 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 34 | local_max = fmaxf( local_max, fabsf(vec[base_vec + i]) ); 35 | } 36 | local_max = blockReduceMax(local_max); 37 | if (threadIdx.x == 0) { 38 | out[blockIdx.x] = __float2half(local_max / 127.0); 39 | } 40 | } 41 | 42 | // block thread 43 | CPM_KERNEL_EXPORT void cu_gemv_round( 44 | int32_t batch, int32_t n, 45 | const half *vec, // (batch, n) 46 | const half* scale, // (batch,) 47 | int8_t* out // (batch, n) 48 | ) { 49 | int32_t col_idx = blockIdx.y * blockDim.x + threadIdx.x; 50 | half v_scale = __ldg(scale + blockIdx.x); 51 | if (col_idx < n) { 52 | out[blockIdx.x * n + col_idx] = (int8_t) nearbyintf((float)vec[blockIdx.x * n + col_idx] / (float)v_scale); 53 | } 54 | } 55 | 56 | // block thread dim_in % 4 == 0 57 | CPM_KERNEL_EXPORT void cu_gemv_broadcast_mat_int8( 58 | int32_t batch, int32_t dim_out, int32_t dim_in, 59 | const half *scale_mat, // 60 | const char4 *mat, // 61 | const half *scale_vec, // 62 | const char4 *vec, // 63 | half *out // 64 | ) { 65 | int32_t quarter_dim_in = dim_in >> 2; 66 | const char4* base_mat = mat + blockIdx.y * quarter_dim_in; 67 | const char4* base_vec = vec + blockIdx.x * quarter_dim_in; 68 | int32_t local_sum = 0; 69 | for (int32_t i = threadIdx.x; i < quarter_dim_in; i += blockDim.x) { 70 | local_sum = __dp4a(base_mat[i], base_vec[i], local_sum); 71 | } 72 | local_sum = blockReduceSumInt32(local_sum); 73 | if (threadIdx.x == 0) { 74 | out[blockIdx.x * dim_out + blockIdx.y] = __float2half((float)local_sum * (float)__ldg(scale_vec + blockIdx.x) * (float)__ldg(scale_mat + blockIdx.y)); 75 | } 76 | } 77 | 78 | // block , thread 79 | CPM_KERNEL_EXPORT void cu_gemv_fp16( 80 | int32_t batch, int32_t dim_out, int32_t dim_in, 81 | const half2 *mat, // 82 | const half2 *vec, // 83 | half *out // 84 | ) { 85 | int32_t half_dim_in = dim_in >> 1; 86 | int32_t base_v = blockIdx.x * half_dim_in; 87 | int32_t base_mat = (blockIdx.x * dim_out + blockIdx.y) * half_dim_in; 88 | 89 | #if __CUDA_ARCH__ >= 620 || !defined(__CUDA_ARCH__) 90 | half2 sum = __float2half2_rn(0); 91 | for (int i = threadIdx.x; i < half_dim_in; i += blockDim.x) { 92 | sum = __hfma2(vec[base_v + i], mat[base_mat + i], sum); 93 | } 94 | float v = (float)sum.x + (float)sum.y; 95 | #else 96 | // fallback to fp32 97 | float v = 0; 98 | for (int i = threadIdx.x; i < half_dim_in; i += blockDim.x) { 99 | v += (float)vec[base_v + i].x * (float)mat[base_mat + i].x + (float)vec[base_v + i].y * (float)mat[base_mat + i].y; 100 | } 101 | #endif 102 | v = blockReduceSum(v); 103 | if (threadIdx.x == 0) { 104 | out[blockIdx.x * dim_out + blockIdx.y] = __float2half(v); 105 | } 106 | } 107 | 108 | // block , thread 109 | CPM_KERNEL_EXPORT void cu_gemv_fp16_transpose( 110 | int32_t batch, int32_t dim_out, int32_t dim_in, 111 | const half *mat, // 112 | const half *vec, // 113 | half *out // 114 | ) { 115 | int32_t col = blockIdx.y * WARP_SZ + threadIdx.x; 116 | int32_t base_idx = blockIdx.x * dim_in + threadIdx.y; 117 | float sum = 0; 118 | for (int i = 0; i < dim_in; i += WARP_SZ * WARP_SZ) { // warp * warp blocks 119 | float local_sum = 0; 120 | for (int j = 0; j < WARP_SZ * WARP_SZ && i + j < dim_in; j += WARP_SZ) { // warp block 121 | float v = 0; 122 | if (i + j + threadIdx.y < dim_in && col < dim_out) v = (float)vec[base_idx + i + j] * (float)mat[(base_idx + i + j) * dim_out + col]; 123 | v = transposeReduceSum(v); 124 | if (threadIdx.y * WARP_SZ == j) { 125 | local_sum = v; 126 | } 127 | } 128 | local_sum = transposeReduceSum(local_sum); 129 | sum += local_sum; 130 | } 131 | 132 | if (threadIdx.y == 0 && col < dim_out) { 133 | out[blockIdx.x * dim_out + col] = sum; 134 | } 135 | } 136 | 137 | // block , thread 138 | CPM_KERNEL_EXPORT void cu_gemv_broadcast_mat_fp16( 139 | int32_t batch, int32_t dim_out, int32_t dim_in, 140 | const half2 *mat, // 141 | const half2 *vec, // 142 | half *out // 143 | ) { 144 | int32_t half_dim_in = dim_in >> 1; 145 | int32_t base_vec_idx = blockIdx.x * half_dim_in; 146 | int32_t base_mat = blockIdx.y * half_dim_in; 147 | 148 | #if __CUDA_ARCH__ >= 620 || !defined(__CUDA_ARCH__) 149 | half2 sum = __float2half2_rn(0); 150 | for (int i = threadIdx.x; i < half_dim_in; i += blockDim.x) { 151 | sum = __hfma2(vec[base_vec_idx + i], mat[base_mat + i], sum); 152 | } 153 | float v = (float)sum.x + (float)sum.y; 154 | #else 155 | float v = 0; 156 | for (int i = threadIdx.x; i < half_dim_in; i += blockDim.x) { 157 | v += (float)vec[base_vec_idx + i].x * (float)mat[base_mat + i].x + (float)vec[base_vec_idx + i].y * (float)mat[base_mat + i].y; 158 | } 159 | #endif 160 | v = blockReduceSum(v); 161 | if (threadIdx.x == 0) { 162 | out[blockIdx.x * dim_out + blockIdx.y] = __float2half(v); 163 | } 164 | } -------------------------------------------------------------------------------- /cuda/includes/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define CPM_KERNEL_EXPORT extern "C" __global__ -------------------------------------------------------------------------------- /cuda/includes/reduce.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | const int WARP_SZ = 32; 4 | namespace { 5 | 6 | __inline__ __device__ float warpReduceSum(float x) { 7 | for (int offset = warpSize/2; offset > 0; offset /= 2) 8 | x += __shfl_down_sync(0xFFFFFFFF, x, offset); 9 | return x; 10 | } 11 | 12 | __inline__ __device__ float blockReduceSum(float x) { 13 | static __shared__ float shared[WARP_SZ]; // blockDim.x / warpSize 14 | int lane = threadIdx.x % warpSize; 15 | int wid = threadIdx.x / warpSize; 16 | x = warpReduceSum(x); 17 | if (lane == 0) shared[wid] = x; 18 | __syncthreads(); 19 | x = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; 20 | if (wid == 0) x = warpReduceSum(x); 21 | return x; 22 | } 23 | 24 | __inline__ __device__ float warpReduceMax(float x) { 25 | for (int offset = warpSize/2; offset > 0; offset /= 2) 26 | x = fmaxf(x, __shfl_down_sync(0xFFFFFFFF, x, offset)); 27 | return x; 28 | } 29 | 30 | __inline__ __device__ float blockReduceMax(float x) { 31 | static __shared__ float shared[WARP_SZ]; // blockDim.x / warpSize 32 | int lane = threadIdx.x % warpSize; 33 | int wid = threadIdx.x / warpSize; 34 | x = warpReduceMax(x); 35 | if (lane == 0) shared[wid] = x; 36 | __syncthreads(); 37 | x = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : -INFINITY; 38 | if (wid == 0) x = warpReduceMax(x); 39 | return x; 40 | } 41 | 42 | __inline__ __device__ float transposeReduceSum(float x) { 43 | static __shared__ float shared[WARP_SZ][WARP_SZ + 1]; 44 | shared[threadIdx.x][threadIdx.y] = x; 45 | __syncthreads(); 46 | x = warpReduceSum(shared[threadIdx.y][threadIdx.x]); 47 | if (threadIdx.x == 0) { 48 | shared[threadIdx.y][WARP_SZ] = x; 49 | } 50 | __syncthreads(); 51 | return shared[threadIdx.x][WARP_SZ]; 52 | } 53 | 54 | __inline__ __device__ float transposeReduceMax(float x) { 55 | static __shared__ float shared[WARP_SZ][WARP_SZ + 1]; 56 | shared[threadIdx.x][threadIdx.y] = x; 57 | __syncthreads(); 58 | x = warpReduceMax(shared[threadIdx.y][threadIdx.x]); 59 | if (threadIdx.x == 0) { 60 | shared[threadIdx.y][WARP_SZ] = x; 61 | } 62 | __syncthreads(); 63 | return shared[threadIdx.x][WARP_SZ]; 64 | } 65 | 66 | 67 | } -------------------------------------------------------------------------------- /cuda/mask.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include 3 | #include "common.h" 4 | 5 | // block , thread 6 | CPM_KERNEL_EXPORT void cu_mask( 7 | int32_t batch, int32_t n, int32_t m, 8 | const half *x, // (batch, n, m) 9 | const int8_t *mask, // (batch, m) 10 | float value, 11 | half *out // (batch, n, m) 12 | ) { 13 | int32_t col_idx = threadIdx.x + blockIdx.y * blockDim.x; 14 | int32_t base_x_idx = blockIdx.x * n * m + col_idx; 15 | half half_value = __float2half(value); 16 | 17 | if (col_idx < m) { 18 | int8_t mask_val = mask[blockIdx.x * m + col_idx]; 19 | for (int i = 0; i < n; i ++) { 20 | out[base_x_idx + i * m] = (mask_val == 0) ? half_value : x[base_x_idx + i * m]; 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /cuda/position_bucket.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include 3 | #include "common.h" 4 | 5 | // block <1>, thread 6 | CPM_KERNEL_EXPORT void cu_init_position_mapping( 7 | int32_t num_buckets, 8 | int32_t max_distance, 9 | int32_t *out, // (max_distance) 10 | bool bidirectional 11 | ) { 12 | int32_t part_buckets = num_buckets / (bidirectional ? 2 : 1); 13 | int32_t exact_buckets = part_buckets / 2; 14 | int32_t log_buckets = part_buckets - exact_buckets; 15 | 16 | float v = logf(max_distance / exact_buckets); 17 | for (int i = threadIdx.x; i < max_distance; i++) { 18 | if (i < exact_buckets) out[i] = i; 19 | else out[i] = (int32_t)(logf((float)i / (float)exact_buckets) / v * log_buckets) + exact_buckets; 20 | } 21 | } 22 | 23 | 24 | // block 25 | CPM_KERNEL_EXPORT void cu_position_embedding_forward( 26 | int32_t query_len, 27 | int32_t key_len, 28 | int32_t num_buckets, 29 | int32_t max_distance, 30 | int32_t num_head, 31 | const int32_t *position_mapping, // (max_distance) 32 | const half *weight, // (num_head, num_bucket) 33 | half *out, // (num_head, key_len, query_len) 34 | bool bidirectional 35 | ) { 36 | int32_t total_len = key_len * query_len; 37 | for (int i = threadIdx.x; i < query_len; i += blockDim.x) { 38 | int32_t relative_position = i - blockIdx.x; 39 | int32_t bucket_offset = 0; 40 | if (relative_position < 0) { 41 | if (bidirectional) { 42 | relative_position = -relative_position; 43 | bucket_offset = num_buckets / 2; 44 | } else { 45 | relative_position = 0; 46 | } 47 | } 48 | if (relative_position >= max_distance) relative_position = max_distance - 1; 49 | int32_t bucket = __ldg(position_mapping + relative_position) + bucket_offset; 50 | for (int j = 0; j < num_head; j++){ 51 | out[j * total_len + blockIdx.x * query_len + i] = __ldg(weight + j * num_buckets + bucket); 52 | } 53 | } 54 | } 55 | 56 | // block , thread <1024> 57 | CPM_KERNEL_EXPORT void cu_position_embedding_backward( 58 | int32_t query_len, 59 | int32_t key_len, 60 | int32_t num_buckets, 61 | int32_t max_distance, 62 | int32_t num_heads, // no more than 1024 heads 63 | const int32_t *position_mapping, // (max_distance) 64 | const half *grad_out, // (num_head, key_len, query_len) 65 | half *grad, // (num_head, num_bucket) 66 | bool bidirectional 67 | ) { 68 | __shared__ float sum[1024]; 69 | 70 | int32_t total_len = key_len * query_len; 71 | 72 | sum[threadIdx.x] = 0; 73 | 74 | for (int i = 0; i < total_len; i += blockDim.x) { 75 | int32_t bucket = -1; 76 | if (i + threadIdx.x < total_len) { 77 | int32_t relative_position = ((i + threadIdx.x) % query_len) - ((i + threadIdx.x) / query_len); 78 | int32_t bucket_offset = 0; 79 | if (relative_position < 0) { 80 | if (bidirectional) { 81 | relative_position = -relative_position; 82 | bucket_offset = num_buckets / 2; 83 | } else { 84 | relative_position = 0; 85 | } 86 | } 87 | if (relative_position >= max_distance) relative_position = max_distance - 1; 88 | bucket = __ldg(position_mapping + relative_position) + bucket_offset; 89 | } 90 | 91 | for (int j = 0; j < num_heads; j ++) { 92 | float v = 0; 93 | if (bucket == blockIdx.x) v = (float)__ldg(grad_out + j * total_len + i + threadIdx.x); 94 | v = blockReduceSum(v); // synchronized here 95 | if (threadIdx.x == 0) sum[j] += v; 96 | } 97 | } 98 | __syncthreads(); 99 | if (threadIdx.x < num_heads) { 100 | grad[ threadIdx.x * num_buckets + blockIdx.x ] = sum[threadIdx.x]; 101 | } 102 | } 103 | 104 | // block <1> 105 | CPM_KERNEL_EXPORT void cu_position_embedding_step( 106 | int32_t query_pos, 107 | int32_t key_len, 108 | int32_t num_buckets, 109 | int32_t max_distance, 110 | int32_t num_head, 111 | const int32_t *position_mapping, // (max_distance) 112 | const half *weight, // (num_head, num_bucket) 113 | half *out, // (num_head, key_len) 114 | bool bidirectional 115 | ) { 116 | for (int i = threadIdx.x; i < key_len; i += blockDim.x) { 117 | int32_t relative_position = query_pos - i; 118 | int32_t bucket_offset = 0; 119 | if (relative_position < 0) { 120 | if (bidirectional) { 121 | relative_position = -relative_position; 122 | bucket_offset = num_buckets / 2; 123 | } else { 124 | relative_position = 0; 125 | } 126 | } 127 | if (relative_position >= max_distance) relative_position = max_distance - 1; 128 | int32_t bucket = __ldg(position_mapping + relative_position) + bucket_offset; 129 | for (int j = 0; j < num_head; j++){ 130 | out[j * key_len + i] = __ldg(weight + j * num_buckets + bucket); 131 | } 132 | } 133 | } -------------------------------------------------------------------------------- /cuda/softmax.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include "common.h" 3 | #include "reduce.cuh" 4 | 5 | // grid , thread <32, 32> 6 | CPM_KERNEL_EXPORT void cu_softmax_forward( 7 | int32_t batch, int32_t n, int32_t m, 8 | const half *in, // batch, n, m 9 | half *out // batch, n, m 10 | ) { 11 | float local_max = -INFINITY; 12 | 13 | int32_t base_mat_idx = (blockIdx.x * n + threadIdx.y) * m + blockIdx.y * WARP_SZ + threadIdx.x; 14 | int32_t col_idx = blockIdx.y * WARP_SZ + threadIdx.x; 15 | for (int i = 0; i < n; i += WARP_SZ) { 16 | if (col_idx < m && i + threadIdx.y < n) { 17 | local_max = fmaxf((float)in[base_mat_idx + i * m], local_max); 18 | } 19 | } 20 | 21 | local_max = fmaxf(transposeReduceMax(local_max), -1e6); 22 | 23 | float local_sum = 0; 24 | for (int i = 0; i < n; i += WARP_SZ * WARP_SZ) { 25 | float inner_sum = 0; 26 | for (int j = 0; j < WARP_SZ * WARP_SZ && i + j < n; j += WARP_SZ) { 27 | float v = 0; 28 | if (col_idx < m && i + j + threadIdx.y < n) { 29 | v = expf((float)in[base_mat_idx + (i + j) * m] - local_max); 30 | } 31 | v = transposeReduceSum(v); 32 | if (threadIdx.y * WARP_SZ == j) inner_sum = v; 33 | } 34 | local_sum += transposeReduceSum(inner_sum); 35 | } 36 | local_sum += 1e-10; // avoid nan 37 | 38 | for (int i = 0; i < n; i += WARP_SZ) { 39 | if (col_idx < m && i + threadIdx.y < n) { 40 | out[base_mat_idx + i * m] = __float2half( expf((float)in[base_mat_idx + i * m] - local_max) / local_sum ); 41 | } 42 | } 43 | } 44 | 45 | // grid , thread <32, 32> 46 | CPM_KERNEL_EXPORT void cu_softmax_inplace_forward( 47 | int32_t batch, int32_t n, int32_t m, 48 | half *x // batch, n, m 49 | ) { 50 | float local_max = -INFINITY; 51 | 52 | int32_t base_mat_idx = (blockIdx.x * n + threadIdx.y) * m + blockIdx.y * WARP_SZ + threadIdx.x; 53 | int32_t col_idx = blockIdx.y * WARP_SZ + threadIdx.x; 54 | for (int i = 0; i < n; i += WARP_SZ) { 55 | if (col_idx < m && i + threadIdx.y < n) { 56 | local_max = fmaxf((float)x[base_mat_idx + i * m], local_max); 57 | } 58 | } 59 | 60 | local_max = fmaxf(transposeReduceMax(local_max), -1e6); 61 | 62 | float local_sum = 0; 63 | for (int i = 0; i < n; i += WARP_SZ * WARP_SZ) { 64 | float inner_sum = 0; 65 | for (int j = 0; j < WARP_SZ * WARP_SZ && i + j < n; j += WARP_SZ) { 66 | float v = 0; 67 | if (col_idx < m && i + j + threadIdx.y < n) { 68 | v = expf((float)x[base_mat_idx + (i + j) * m] - local_max); 69 | } 70 | v = transposeReduceSum(v); 71 | if (threadIdx.y * WARP_SZ == j) inner_sum = v; 72 | } 73 | local_sum += transposeReduceSum(inner_sum); 74 | } 75 | local_sum += 1e-10; // avoid nan 76 | 77 | for (int i = 0; i < n; i += WARP_SZ) { 78 | if (col_idx < m && i + threadIdx.y < n) { 79 | x[base_mat_idx + i * m] = __float2half( expf((float)x[base_mat_idx + i * m] - local_max) / local_sum ); 80 | } 81 | } 82 | } 83 | 84 | 85 | // grid , thread <32, 32> 86 | CPM_KERNEL_EXPORT void cu_softmax_backward( 87 | int32_t batch, int32_t n, int32_t m, 88 | const half *out, // batch, n, m 89 | const half *grad_in, // batch, n, m 90 | half *grad_out // batch, n, m 91 | ) { 92 | int32_t base_mat_idx = (blockIdx.x * n + threadIdx.y) * m + blockIdx.y * WARP_SZ + threadIdx.x; 93 | int32_t col_idx = blockIdx.y * WARP_SZ + threadIdx.x; 94 | 95 | float local_sum = 0; 96 | for (int i = 0; i < n; i += WARP_SZ * WARP_SZ) { 97 | float inner_sum = 0; 98 | for (int j = 0; j < WARP_SZ * WARP_SZ && i + j < n; j += WARP_SZ) { 99 | float v = 0; 100 | if (col_idx < m && i + j + threadIdx.y < n) { 101 | v = (float)out[base_mat_idx + (i + j) * m] * (float)grad_in[base_mat_idx + (i + j) * m]; 102 | } 103 | v = transposeReduceSum(v); 104 | if (threadIdx.y * WARP_SZ == j) inner_sum = v; 105 | } 106 | local_sum += transposeReduceSum(inner_sum); 107 | } 108 | 109 | for (int i = 0; i < n; i += WARP_SZ) { 110 | if (col_idx < m && i + threadIdx.y < n) { 111 | grad_out[base_mat_idx + i * m] = __float2half((float)__ldg(out + base_mat_idx + i * m) * ((float)__ldg(grad_in + base_mat_idx + i * m) - local_sum ) ); 112 | } 113 | } 114 | } 115 | 116 | // grid , thread 117 | CPM_KERNEL_EXPORT void cu_softmax_step_inplace( 118 | int32_t batch, int32_t n, 119 | half *x // batch, n 120 | ) { 121 | int32_t base_x_idx = blockIdx.x * n + threadIdx.x; 122 | 123 | float local_max = -INFINITY; 124 | __shared__ float global_max; 125 | 126 | for (int i = 0; i < n; i += blockDim.x) { 127 | if (i + threadIdx.x < n) { 128 | local_max = fmaxf(local_max, x[base_x_idx + i]); 129 | } 130 | } 131 | 132 | local_max = blockReduceMax(local_max); 133 | if (threadIdx.x == 0) { 134 | global_max = fmaxf(local_max, -1e6); 135 | } 136 | __syncthreads(); 137 | 138 | local_max = global_max; 139 | float local_sum = 0; 140 | __shared__ float global_sum; 141 | 142 | for (int i = 0; i < n; i += blockDim.x) { 143 | if (i + threadIdx.x < n) { 144 | local_sum += expf((float)x[base_x_idx + i] - local_max); 145 | } 146 | } 147 | local_sum = blockReduceSum(local_sum); 148 | if (threadIdx.x == 0) { 149 | global_sum = local_sum + 1e-10; // avoid nan 150 | } 151 | __syncthreads(); 152 | local_sum = global_sum; 153 | 154 | for (int i = 0; i < n; i += blockDim.x) { 155 | if (i + threadIdx.x < n) { 156 | x[base_x_idx + i] = __float2half(expf((float)x[base_x_idx + i] - local_max) / local_sum); 157 | } 158 | } 159 | } -------------------------------------------------------------------------------- /cuda/transpose.cu: -------------------------------------------------------------------------------- 1 | #include "reduce.cuh" 2 | #include 3 | #include "common.h" 4 | 5 | // block thread <32, 32> 6 | CPM_KERNEL_EXPORT void cu_transpose( 7 | int32_t batch, int32_t n, int32_t m, 8 | const half *in, 9 | half *out 10 | ) { 11 | __shared__ half shared[WARP_SZ][WARP_SZ + 1]; 12 | int32_t row = blockIdx.y * WARP_SZ + threadIdx.y; 13 | int32_t col = blockIdx.z * WARP_SZ + threadIdx.x; 14 | int32_t offset = blockIdx.x * n * m + row * m + col; 15 | if (row < n && col < m) shared[threadIdx.x][threadIdx.y] = in[offset]; 16 | __syncthreads(); 17 | row = blockIdx.z * WARP_SZ + threadIdx.y; 18 | col = blockIdx.y * WARP_SZ + threadIdx.x; 19 | offset = blockIdx.x * n * m + row * n + col; 20 | if (row < m && col < n) { 21 | out[offset] = shared[threadIdx.y][threadIdx.x]; 22 | } 23 | } -------------------------------------------------------------------------------- /cuda/utils.cu: -------------------------------------------------------------------------------- 1 | 2 | #include "reduce.cuh" 3 | #include 4 | #include "common.h" 5 | 6 | __inline__ __device__ bool isnan_(half v) { 7 | #if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600 8 | return __hisnan(v); 9 | #else 10 | return v != v; 11 | #endif 12 | } 13 | 14 | // grid , thread 15 | CPM_KERNEL_EXPORT void copy_data_to_kv( 16 | int32_t batch, int32_t buffer_len, int32_t n, 17 | const half2 *in, // (batch, n) 18 | half2 *out, // (batch, buffer_len, n) 19 | int32_t pos 20 | ) { 21 | int32_t half_n = n >> 1; 22 | int32_t base_in_idx = blockIdx.x * half_n + threadIdx.x; 23 | int32_t base_out_idx = (blockIdx.x * buffer_len + pos) * half_n + threadIdx.x; 24 | for (int i = 0; i < half_n; i += blockDim.x) { 25 | if (threadIdx.x + i < half_n) { 26 | out[base_out_idx + i] = in[base_in_idx + i]; 27 | } 28 | } 29 | } 30 | 31 | // grid<1>, thread<1> 32 | CPM_KERNEL_EXPORT void cu_array_add( 33 | int32_t *arr, int32_t pos, int32_t val 34 | ) { 35 | if (threadIdx.x == 0) arr[pos] += val; 36 | } 37 | 38 | // grid, thread 39 | CPM_KERNEL_EXPORT void cu_adjustify_logits( 40 | int32_t batch, int32_t n, 41 | half *logits, // (batch, n) 42 | float temperature, 43 | float frequency_penalty, 44 | float presence_penalty, 45 | int32_t *frequency // (batch, n) 46 | ) { 47 | int32_t col = blockIdx.y * blockDim.x + threadIdx.x; 48 | if (col < n) { 49 | float v = __half2float(logits[ blockIdx.x * n + col ]); 50 | int32_t freq = frequency[ blockIdx.x * n + col ]; 51 | v /= temperature; 52 | v -= frequency_penalty * (float)freq; 53 | v -= presence_penalty * (freq > 0 ? 1.0f : 0.0f); 54 | logits[ blockIdx.x * n + col ] = __float2half(v); 55 | } 56 | } 57 | 58 | // grid block 59 | CPM_KERNEL_EXPORT void cu_copy_extend_buffer( 60 | int32_t batch, int32_t old_size, int32_t nw_size, 61 | const half* old_buf, // (batch, old_size) 62 | half* nw_buf // (batch, nw_size) 63 | ) { 64 | int32_t col = blockIdx.y * blockDim.x + threadIdx.x; 65 | if (col < old_size) { 66 | nw_buf[ blockIdx.x * nw_size + col ] = old_buf[ blockIdx.x * old_size + col ]; 67 | } 68 | } 69 | 70 | // grid <1>, thread 71 | CPM_KERNEL_EXPORT void cu_has_nan_inf( 72 | int32_t n, 73 | const half* inp, // (n,) 74 | int8_t* out 75 | ) { 76 | float r = 0; 77 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 78 | half v = inp[i]; 79 | if (__hisinf(v) || isnan_(v)) { 80 | r = 10; 81 | break; 82 | } 83 | } 84 | r = blockReduceSum(r); 85 | if (threadIdx.x == 0 && r > 1) { 86 | out[0] = 1; 87 | } 88 | } 89 | 90 | // grid , thread 91 | CPM_KERNEL_EXPORT void cu_copy_pos_hidden( 92 | int32_t batch, int32_t hidden_size, int32_t seq_len, 93 | int32_t pos, 94 | const half* inp, // (batch, hidden_size, seq_len) 95 | half* out // (batch, hidden_size) 96 | ) { 97 | int32_t col = blockIdx.y * blockDim.x + threadIdx.x; 98 | if (col < hidden_size) { 99 | out[ blockIdx.x * hidden_size + col ] = 100 | inp[ blockIdx.x * hidden_size * seq_len + col * seq_len + pos ]; 101 | } 102 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | if __name__ == "__main__": 4 | setup( 5 | name='cpm_kernels', 6 | version='1.0.11', 7 | packages=find_packages(), 8 | description='CPM CUDA kernels', 9 | long_description=open("./README.md", 'r').read(), 10 | long_description_content_type="text/markdown", 11 | keywords="CPM, cuda, AI", 12 | classifiers=[ 13 | "Development Status :: 3 - Alpha", 14 | "Environment :: GPU :: NVIDIA CUDA :: 10.1", 15 | "Environment :: GPU :: NVIDIA CUDA :: 10.2", 16 | "Environment :: GPU :: NVIDIA CUDA :: 11.0", 17 | "Environment :: GPU :: NVIDIA CUDA :: 11.1", 18 | "Environment :: GPU :: NVIDIA CUDA :: 11.2", 19 | "Environment :: GPU :: NVIDIA CUDA :: 11.3", 20 | "Environment :: GPU :: NVIDIA CUDA :: 11.4", 21 | "Environment :: GPU :: NVIDIA CUDA :: 11.5", 22 | "Intended Audience :: Developers", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Natural Language :: English", 26 | "Operating System :: OS Independent", 27 | "Programming Language :: Python :: 3", 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Topic :: Software Development :: Libraries :: Python Modules", 30 | 31 | ], 32 | license='Apache 2.0', 33 | include_package_data=True, 34 | package_data={ 35 | 'cpm_kernels.kernels': ['cuda/*.fatbin'] 36 | } 37 | ) 38 | -------------------------------------------------------------------------------- /tests/run_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def main(args,): 5 | import unittest 6 | import os, sys 7 | 8 | TEST_BASE = os.path.dirname(os.path.abspath(__file__)) 9 | test_cases = unittest.defaultTestLoader.discover(TEST_BASE, pattern=args.file) 10 | runner = unittest.runner.TextTestRunner(verbosity=args.verbosity) 11 | ret = runner.run(test_cases) 12 | sys.exit(len(ret.failures) + len(ret.errors)) 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--file", default="test_*.py", type=str, 19 | ) 20 | parser.add_argument( 21 | "--verbosity", default=2, type=int, 22 | ) 23 | args = parser.parse_args() 24 | main(args) 25 | -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | import cpm_kernels.torch as ct 2 | import cpm_kernels.kernels as ck 3 | import torch 4 | import unittest 5 | 6 | 7 | TEST_CASE = [ 8 | (10, 3, 1, 5), 9 | (5000, 1024, 16, 128), 10 | (999, 888, 77, 6), 11 | (123, 321, 123, 312), 12 | (25012, 4096, 32, 512) 13 | ] 14 | 15 | class TestEmbedding(unittest.TestCase): 16 | def test_embedding(self): 17 | with torch.cuda.device(2): 18 | # Test the embedding layer. 19 | for args in TEST_CASE: 20 | vocab_size, hidden_size, batch, seq_len = args 21 | cpm_emb = ct.Embedding(vocab_size, hidden_size) 22 | pth_emb = ct.EmbeddingTH(vocab_size, hidden_size) 23 | state_dict = { 24 | 'weight': torch.randn(vocab_size, hidden_size, dtype=torch.half), 25 | } 26 | cpm_emb.load_state_dict(state_dict) 27 | pth_emb.load_state_dict(state_dict) 28 | 29 | cpm_emb = cpm_emb.to("cuda") 30 | pth_emb = pth_emb.to("cuda") 31 | 32 | ipt = torch.randint(0, vocab_size, (batch, seq_len), dtype=torch.long).to("cuda") 33 | out = cpm_emb(ipt.to(torch.int32)) 34 | ans = pth_emb(ipt) 35 | 36 | self.assertTrue(torch.isclose(out, ans, 1e-3, 1e-3).all()) 37 | 38 | graident_start = torch.randn(batch, hidden_size, seq_len, dtype=torch.half).to("cuda") / batch 39 | 40 | out.backward(gradient=graident_start) 41 | ans.backward(gradient=graident_start) 42 | 43 | self.assertTrue(torch.isclose(cpm_emb.weight.grad, pth_emb.weight.grad, 1e-3, 1e-3).all()) 44 | 45 | def test_embedding_step(self): 46 | with torch.cuda.device(2): 47 | for args in TEST_CASE: 48 | vocab_size, hidden_size, batch, _ = args 49 | weight = torch.randn(vocab_size, hidden_size, dtype=torch.half, device="cuda") 50 | ipt = torch.randint(0, vocab_size, (batch,), dtype=torch.int32).to("cuda") 51 | 52 | ans = torch.empty((batch, hidden_size), dtype=torch.half, device="cuda") 53 | ck.embedding_forward( 54 | batch, hidden_size, 1, 55 | ipt.data_ptr(), 56 | weight.data_ptr(), 57 | ans.data_ptr(), 58 | torch.cuda.current_stream().cuda_stream 59 | ) 60 | out = torch.empty((batch, hidden_size), dtype=torch.half, device="cuda") 61 | ck.embedding_step( 62 | batch, hidden_size, 63 | ipt.data_ptr(), 64 | weight.data_ptr(), 65 | out.data_ptr(), 66 | torch.cuda.current_stream().cuda_stream 67 | ) 68 | self.assertTrue(torch.isclose(out, ans, 1e-5, 1e-5).all()) 69 | -------------------------------------------------------------------------------- /tests/test_gelu.py: -------------------------------------------------------------------------------- 1 | import cpm_kernels.torch as ct 2 | import torch 3 | import unittest 4 | 5 | 6 | class TestGeLU(unittest.TestCase): 7 | def test_gelu(self): 8 | with torch.cuda.device(2): 9 | x = torch.randn(4, 16, 512, 1024, device="cuda").half() 10 | x1 = x.clone().requires_grad_() 11 | x2 = x.clone().requires_grad_() 12 | del x 13 | out = ct.gelu(x1) 14 | ans = ct.geluTH(x2) 15 | self.assertTrue(torch.isclose(out, ans, 1e-2, 1e-2).all()) 16 | 17 | gradient_start = torch.randn(4, 16, 512, 1024, device="cuda").half() 18 | out.backward(gradient=gradient_start) 19 | ans.backward(gradient=gradient_start) 20 | 21 | self.assertTrue(torch.isclose(x1.grad, x2.grad, 1e-2, 1e-2).all()) 22 | 23 | def test_gelu_inplace(self): 24 | with torch.cuda.device(0): 25 | x = torch.randn(4, 1237, device="cuda").half() 26 | ans = ct.geluTH(x) 27 | ct.gelu_inplace(x) 28 | 29 | self.assertTrue(torch.isclose(x, ans, 1e-2, 1e-2).all()) -------------------------------------------------------------------------------- /tests/test_gemm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cpm_kernels.torch as ct 3 | import random, math 4 | import unittest 5 | 6 | 7 | 8 | def calc_scale(x): 9 | return x.abs().max(dim=-1, keepdim=True)[0] / 127 10 | 11 | 12 | def scale_round(x, scale): 13 | return torch.round(x / scale) * scale 14 | 15 | class ScaleOnlyFunc(torch.autograd.Function): 16 | @staticmethod 17 | def forward(ctx, x : torch.Tensor): 18 | scale = calc_scale(x) 19 | return scale_round(x, scale) 20 | 21 | @staticmethod 22 | def backward(ctx, grad_outputs: torch.Tensor): 23 | return grad_outputs 24 | 25 | def real_bmm(a, aT, b, bT, int8): 26 | if a.ndim == 2: 27 | a = a.unsqueeze(0) 28 | if b.ndim == 2: 29 | b = b.unsqueeze(0) 30 | 31 | if aT: 32 | a = a.transpose(-1, -2) 33 | if bT: 34 | b = b.transpose(-1, -2) 35 | if int8: 36 | old_type = a.dtype 37 | a = ScaleOnlyFunc.apply(a.to(torch.float32)) 38 | b = ScaleOnlyFunc.apply(b.to(torch.float32)) 39 | return torch.matmul(a, b).to(old_type) 40 | else: 41 | return torch.matmul(a, b) 42 | 43 | TEST_CASES = [ 44 | (1, 4, 4, 4, torch.half), 45 | (2, 4, 4, 4, torch.half), 46 | (1, 8, 8, 8, torch.half), 47 | (3, 12, 12, 12, torch.half), 48 | (1, 16, 16, 16, torch.half), 49 | (1, 32, 32, 32, torch.half), 50 | (1, 64, 64, 64, torch.half), 51 | (1, 128, 128, 128, torch.half), 52 | (4, 128, 128, 128, torch.half), 53 | (8, 128, 128, 128, torch.half), 54 | (8, 112, 224, 112, torch.half), 55 | (8, 512, 4096, 2048, torch.half), 56 | (8, 500, 4096, 7996, torch.half), 57 | (8, 512, 4096, 2048, torch.half), 58 | (8, 512, 10240, 4096, torch.half), 59 | (8, 1024, 512, 4096, torch.half), 60 | ] 61 | 62 | def generate_matrix(batch, m, k, n, dtype): 63 | tp = random.randint(0, 2) 64 | skk = math.sqrt(math.sqrt(k)) 65 | if tp == 0: 66 | a = torch.randn(batch, m, k, device="cuda").to(dtype) / skk 67 | b = torch.randn(batch, k, n, device="cuda").to(dtype) / skk 68 | elif tp == 1: 69 | a = torch.randn(1, m, k, device="cuda").to(dtype) / skk 70 | b = torch.randn(batch, k, n, device="cuda").to(dtype) / skk 71 | else: 72 | a = torch.randn(batch, m, k, device="cuda").to(dtype) / skk 73 | b = torch.randn(1, k, n, device="cuda").to(dtype) / skk 74 | if random.randint(0, 1) == 1: 75 | a = a.transpose(-1, -2).contiguous() 76 | aT = True 77 | else: 78 | aT = False 79 | if random.randint(0, 1) == 1: 80 | b = b.transpose(-1, -2).contiguous() 81 | bT = True 82 | else: 83 | bT = False 84 | 85 | return a, aT, b, bT 86 | 87 | class TestGEMM(unittest.TestCase): 88 | def test_gemm_forward(self): 89 | with torch.cuda.device(0): 90 | for args in TEST_CASES: 91 | batch, m, k, n, dtype = args 92 | 93 | for _ in range(10): 94 | a, aT, b, bT = generate_matrix(batch, m, k, n, dtype) 95 | int8 = random.randint(0, 1) == 1 96 | out = ct.bmm(a, aT, b, bT, int8) 97 | ans = real_bmm(a, aT, b, bT, int8) 98 | self.assertTrue(torch.isclose(out, ans, 5e-1, 5e-1).all()) 99 | 100 | def test_gemm_backward(self): 101 | with torch.cuda.device(1): 102 | for args in TEST_CASES: 103 | batch, m, k, n, dtype = args 104 | 105 | if dtype == torch.float32: 106 | threshold = 0.1 107 | elif dtype == torch.float16: 108 | threshold = 0.5 109 | else: 110 | raise RuntimeError("Unknown dtype %s" % dtype) 111 | 112 | for _ in range(10): 113 | a, aT, b, bT = generate_matrix(batch, m, k, n, dtype) 114 | int8 = random.randint(0, 1) == 1 115 | 116 | a1 = a.clone().detach(); a2 = a 117 | b1 = b.clone().detach(); b2 = b 118 | 119 | a1.requires_grad_(); a2.requires_grad_() 120 | b1.requires_grad_(); b2.requires_grad_() 121 | 122 | ans = real_bmm(a2, aT, b2, bT, int8) 123 | out = ct.bmm(a1, aT, b1, bT, int8) 124 | 125 | gradient_start = torch.randn(out.size(), device="cuda").half() / batch 126 | out.backward(gradient=gradient_start) 127 | ans.backward(gradient=gradient_start) 128 | 129 | self.assertTrue(torch.isclose(a1.grad, a2.grad, 5e-1, 5e-1).all()) 130 | self.assertTrue(torch.isclose(b1.grad, b2.grad, 5e-1, 5e-1).all()) -------------------------------------------------------------------------------- /tests/test_gemv.py: -------------------------------------------------------------------------------- 1 | import cpm_kernels.torch as ct 2 | import cpm_kernels.kernels as ck 3 | import torch 4 | import unittest 5 | import math 6 | 7 | class TestGemv(unittest.TestCase): 8 | def test_gemv_int8(self): 9 | with torch.cuda.device(2): 10 | for _ in range(10): 11 | BATCH = 16 12 | N = 4444 13 | M = 8888 14 | ssk = math.sqrt(math.sqrt(M)) 15 | mat = torch.randn(N, M, dtype=torch.half, device="cuda") 16 | vec = torch.randn(BATCH, M, dtype=torch.half, device="cuda") 17 | 18 | mat_scale = torch.empty(N, dtype=torch.half, device="cuda") 19 | mat_quant = torch.empty(N, M, dtype=torch.int8, device="cuda") 20 | ck.gemm_calc_scale( 21 | 1, N, M, 22 | mat.data_ptr(), 23 | mat_scale.data_ptr(), 24 | torch.cuda.current_stream().cuda_stream 25 | ) 26 | ck.gemm_round( 27 | 1, N, M, 28 | mat.data_ptr(), 29 | mat_scale.data_ptr(), 30 | mat_quant.data_ptr(), 31 | torch.cuda.current_stream().cuda_stream 32 | ) 33 | vec_scale = torch.empty(BATCH, dtype=torch.half, device="cuda") 34 | ck.gemv_calc_scale( 35 | BATCH, M, 36 | vec.data_ptr(), 37 | vec_scale.data_ptr(), 38 | torch.cuda.current_stream().cuda_stream 39 | ) 40 | vec_quant = torch.empty(BATCH, M, dtype=torch.int8, device="cuda") 41 | ck.gemv_round( 42 | BATCH, M, 43 | vec.data_ptr(), 44 | vec_scale.data_ptr(), 45 | vec_quant.data_ptr(), 46 | torch.cuda.current_stream().cuda_stream 47 | ) 48 | out = torch.empty(BATCH, N, dtype=torch.half, device="cuda") 49 | ck.gemv_broadcast_mat_int8( 50 | BATCH, N, M, 51 | mat_scale.data_ptr(), 52 | mat_quant.data_ptr(), 53 | vec_scale.data_ptr(), 54 | vec_quant.data_ptr(), 55 | out.data_ptr(), 56 | torch.cuda.current_stream().cuda_stream 57 | ) 58 | ans = ct.bmm( vec.unsqueeze(0), False, mat.unsqueeze(0), True , int8=True) 59 | self.assertTrue(torch.isclose(out, ans, 1e-3, 1e-3).all()) 60 | 61 | def test_gemv_fp16(self): 62 | with torch.cuda.device(2): 63 | for _ in range(10): 64 | BATCH = 16 65 | N = 2222 66 | M = 128 67 | ssk = math.sqrt(math.sqrt(M)) 68 | mat = torch.randn(BATCH, N, M, dtype=torch.half, device="cuda") / ssk 69 | vec = torch.randn(BATCH, M, 2, dtype=torch.half, device="cuda") / ssk 70 | vec_0 = vec[:, :, 0].clone() 71 | 72 | out = torch.empty(BATCH, N, dtype=torch.half, device="cuda") 73 | ck.gemv_fp16( 74 | BATCH, N, M, 75 | mat.data_ptr(), 76 | vec_0.data_ptr(), 77 | out.data_ptr(), 78 | torch.cuda.current_stream().cuda_stream 79 | ) 80 | ans = ct.bmm( mat, False, vec, False , int8=False)[:, :, 0] 81 | self.assertTrue(torch.isclose(out, ans, 1e-3, 1e-3).all()) 82 | 83 | def test_gemv_fp16_light(self): 84 | with torch.cuda.device(2): 85 | for _ in range(10): 86 | BATCH = 16 87 | N = 2222 88 | M = 128 89 | ssk = math.sqrt(math.sqrt(M)) 90 | mat = torch.randn(BATCH, N, M, dtype=torch.half, device="cuda") / ssk 91 | vec = torch.randn(BATCH, M, 2, dtype=torch.half, device="cuda") / ssk 92 | vec_0 = vec[:, :, 0].clone() 93 | 94 | out = torch.empty(BATCH, N, dtype=torch.half, device="cuda") 95 | ck.gemv_fp16_light( 96 | BATCH, N, M, 97 | mat.data_ptr(), 98 | vec_0.data_ptr(), 99 | out.data_ptr(), 100 | torch.cuda.current_stream().cuda_stream 101 | ) 102 | ans = ct.bmm( mat, False, vec, False , int8=False)[:, :, 0] 103 | self.assertTrue(torch.isclose(out, ans, 5e-2, 5e-2).all()) 104 | 105 | def test_gemv_fp16_transpose(self): 106 | with torch.cuda.device(2): 107 | for _ in range(10): 108 | BATCH = 16 109 | N = 128 110 | M = 2222 111 | ssk = math.sqrt(math.sqrt(M)) 112 | mat = torch.randn(BATCH, M, N, dtype=torch.half, device="cuda") / ssk 113 | vec = torch.randn(BATCH, M, 2, dtype=torch.half, device="cuda") / ssk 114 | vec_0 = vec[:, :, 0].clone() 115 | 116 | out = torch.zeros(BATCH, N, dtype=torch.half, device="cuda") 117 | ck.gemv_fp16_transpose( 118 | BATCH, N, M, 119 | mat.data_ptr(), 120 | vec_0.data_ptr(), 121 | out.data_ptr(), 122 | torch.cuda.current_stream().cuda_stream 123 | ) 124 | 125 | ans = ct.bmm( mat, True, vec, False , int8=False)[:, :, 0] 126 | self.assertTrue(torch.isclose(out, ans, 1e-3, 1e-3).all()) 127 | 128 | def test_gemv_fp16_transpose_light(self): 129 | with torch.cuda.device(2): 130 | for _ in range(10): 131 | BATCH = 16 132 | N = 128 133 | M = 2222 134 | ssk = math.sqrt(math.sqrt(M)) 135 | mat = torch.randn(BATCH, M, N, dtype=torch.half, device="cuda") / ssk 136 | vec = torch.randn(BATCH, M, 2, dtype=torch.half, device="cuda") / ssk 137 | vec_0 = vec[:, :, 0].clone() 138 | 139 | out = torch.zeros(BATCH, N, dtype=torch.half, device="cuda") 140 | ck.gemv_fp16_transpose_light( 141 | BATCH, N, M, 142 | mat.data_ptr(), 143 | vec_0.data_ptr(), 144 | out.data_ptr(), 145 | torch.cuda.current_stream().cuda_stream 146 | ) 147 | 148 | ans = ct.bmm( mat, True, vec, False , int8=False)[:, :, 0] 149 | self.assertTrue(torch.isclose(out, ans, 5e-2, 5e-2).all()) 150 | 151 | def test_gemv_logits(self): 152 | with torch.cuda.device(2): 153 | for _ in range(5): 154 | BATCH = 16 155 | N = 22222 156 | M = 4444 157 | ssk = math.sqrt(math.sqrt(M)) 158 | mat = torch.randn(N, M, dtype=torch.half, device="cuda") / ssk 159 | vec = torch.randn(BATCH, M, 4, dtype=torch.half, device="cuda") / ssk 160 | 161 | vecs = [vec[:, :, i].contiguous() for i in range(vec.size(2))] 162 | ans = ct.bmm(mat.unsqueeze(0), False, vec, False, int8=False) 163 | 164 | for i, vec_0 in enumerate(vecs): 165 | out_0 = torch.empty(BATCH, N, dtype=torch.half, device="cuda") 166 | ck.gemv_broadcast_mat_fp16( 167 | BATCH, N, M, 168 | mat.data_ptr(), 169 | vec_0.data_ptr(), 170 | out_0.data_ptr(), 171 | torch.cuda.current_stream().cuda_stream 172 | ) 173 | ans_0 = ans[:, :, i].contiguous() 174 | self.assertTrue(torch.isclose(out_0, ans_0, 1e-3, 1e-3).all()) 175 | 176 | def test_gemv_logits_light(self): 177 | with torch.cuda.device(2): 178 | for _ in range(5): 179 | BATCH = 16 180 | N = 22222 181 | M = 4444 182 | ssk = math.sqrt(math.sqrt(M)) 183 | mat = torch.randn(N, M, dtype=torch.half, device="cuda") / ssk 184 | vec = torch.randn(BATCH, M, dtype=torch.half, device="cuda") / ssk 185 | 186 | 187 | out = torch.empty(BATCH, N, dtype=torch.half, device="cuda") 188 | ck.gemv_broadcast_mat_fp16_light( 189 | BATCH, N, M, 190 | mat.data_ptr(), 191 | vec.data_ptr(), 192 | out.data_ptr(), 193 | torch.cuda.current_stream().cuda_stream 194 | ) 195 | 196 | ans = ct.bmm( vec.unsqueeze(0), False, mat.unsqueeze(0), True , int8=False) 197 | self.assertTrue(torch.isclose(out, ans, 5e-2, 5e-2).all()) 198 | 199 | -------------------------------------------------------------------------------- /tests/test_layernorm.py: -------------------------------------------------------------------------------- 1 | import cpm_kernels.torch as ct 2 | import cpm_kernels.kernels as ck 3 | import torch 4 | import unittest 5 | 6 | def normalize_stepTH(x : torch.Tensor, eps : float, rd_mean : bool) -> torch.Tensor: 7 | old_dtype = x.dtype 8 | x = x.to(torch.float32) 9 | var = (x**2).mean(axis=-1, keepdim=True) 10 | if rd_mean: 11 | mean = x.mean(axis=-1, keepdim=True) 12 | var = var - (mean**2) 13 | x = (x - mean) * torch.rsqrt(var + eps) 14 | else: 15 | x = x * torch.rsqrt(var + eps) 16 | return x.to(old_dtype) 17 | 18 | class TestLayerNorm(unittest.TestCase): 19 | def test_layernorm_unbias(self): 20 | with torch.cuda.device(4): 21 | for shape, eps in [ 22 | (768, 1e-5), 23 | (768, 1e-6), 24 | (1024, 1e-3), 25 | (1024, 1e-6) 26 | ]: 27 | l1 = ct.LayerNormTH(shape, eps, False) 28 | l2 = ct.LayerNorm(shape, eps, False) 29 | state_dict = { 30 | "weight": torch.randn(shape) * 0.1 + 1, 31 | } 32 | l1.load_state_dict(state_dict) 33 | l2.load_state_dict(state_dict) 34 | 35 | l1 = l1.to("cuda").half() 36 | l2 = l2.to("cuda").half() 37 | 38 | for _ in range(16): 39 | x_raw = torch.randn((128, shape, 512), device="cuda").half() 40 | x1 = x_raw.clone().requires_grad_() 41 | x2 = x_raw.requires_grad_() 42 | y1 = l1(x1) 43 | y2 = l2(x2) 44 | 45 | self.assertTrue(torch.isclose(y1, y2, 1e-2, 1e-2).all()) 46 | 47 | rd = torch.randn( x_raw.size(), device="cuda").half() 48 | y1.backward(gradient=rd) 49 | y2.backward(gradient=rd) 50 | 51 | self.assertTrue(torch.isclose(x1.grad, x2.grad, 1e-2, 1e-2).all()) 52 | self.assertTrue(torch.isclose(l1.weight.grad, l2.weight.grad, 1e-1, 5e-1).all()) 53 | 54 | l1.weight.grad.zero_() 55 | l2.weight.grad.zero_() 56 | 57 | def test_layernorm_bias(self): 58 | with torch.cuda.device(4): 59 | for shape, eps in [ 60 | (768, 1e-5), 61 | (768, 1e-6), 62 | (1024, 1e-3), 63 | (1024, 1e-6) 64 | ]: 65 | l1 = ct.LayerNormTH(shape, eps, True) 66 | l2 = ct.LayerNorm(shape, eps, True) 67 | state_dict = { 68 | "weight": torch.randn(shape) * 0.1 + 1, 69 | "bias": torch.randn(shape), 70 | } 71 | l1.load_state_dict(state_dict) 72 | l2.load_state_dict(state_dict) 73 | 74 | l1 = l1.to("cuda").half() 75 | l2 = l2.to("cuda").half() 76 | 77 | for _ in range(16): 78 | x_raw = torch.randn((128, shape, 512), device="cuda").half() 79 | x1 = x_raw.clone().requires_grad_() 80 | x2 = x_raw.requires_grad_() 81 | y1 = l1(x1) 82 | y2 = l2(x2) 83 | 84 | self.assertTrue(torch.isclose(y1, y2, 1e-2, 1e-2).all()) 85 | 86 | rd = torch.randn( x_raw.size(), device="cuda").half() 87 | y1.backward(gradient=rd) 88 | y2.backward(gradient=rd) 89 | 90 | self.assertTrue(torch.isclose(x1.grad, x2.grad, 1e-2, 1e-2).all()) 91 | 92 | self.assertTrue(torch.isclose(l1.weight.grad, l2.weight.grad, 1e-1, 5e-1).all()) 93 | self.assertTrue(torch.isclose(l1.bias.grad, l2.bias.grad, 1e-2, 1e-2).all()) 94 | 95 | l1.weight.grad.zero_() 96 | l2.weight.grad.zero_() 97 | l1.bias.grad.zero_() 98 | l2.bias.grad.zero_() 99 | 100 | def test_normalize(self): 101 | with torch.cuda.device(4): 102 | for shape, eps in [ 103 | (768, 1e-5), 104 | (768, 1e-6), 105 | (1024, 1e-3), 106 | (1024, 1e-6) 107 | ]: 108 | for i in range(16): 109 | x = torch.randn((128, shape, 512), device="cuda").half() 110 | ans = ct.normalizeTH(x, eps, i < 8) 111 | ct.normalize_inplace(x, eps, i < 8) 112 | 113 | self.assertTrue(torch.isclose(ans, x, 5e-3, 5e-3).all()) 114 | 115 | def test_normalize_step(self): 116 | with torch.cuda.device(4): 117 | for shape, eps in [ 118 | (768, 1e-5), 119 | (768, 1e-6), 120 | (1024, 1e-3), 121 | (1024, 1e-6) 122 | ]: 123 | for i in range(16): 124 | x = torch.randn((128, shape), device="cuda").half() 125 | ans = torch.empty(128, shape, device="cuda", dtype=torch.half) 126 | ck.layernorm_forward( 127 | 128, shape, 1, 128 | x.data_ptr(), 129 | ans.data_ptr(), 130 | eps, 131 | i < 8, 132 | torch.cuda.current_stream().cuda_stream 133 | ) 134 | out = torch.empty(128, shape, device="cuda", dtype=torch.half) 135 | ck.layernorm_step( 136 | 128, shape, 137 | x.data_ptr(), 138 | out.data_ptr(), 139 | eps, 140 | i < 8, 141 | torch.cuda.current_stream().cuda_stream 142 | ) 143 | self.assertTrue(torch.isclose(ans, out, 1e-5, 1e-5).all()) 144 | 145 | ck.layernorm_step_inplace( 146 | 128, shape, 147 | x.data_ptr(), 148 | eps, 149 | i < 8, 150 | torch.cuda.current_stream().cuda_stream 151 | ) 152 | self.assertTrue(torch.isclose(ans, x, 1e-5, 1e-5).all()) 153 | -------------------------------------------------------------------------------- /tests/test_position_embedding.py: -------------------------------------------------------------------------------- 1 | import cpm_kernels.torch as ct 2 | import cpm_kernels.kernels as ck 3 | import torch 4 | import unittest 5 | 6 | TEST_CASE = [ 7 | (32, 12, 128, False), 8 | (32, 12, 128, True), 9 | (32, 24, 256, True), 10 | (128, 64, 128, True), 11 | (16, 64, 256, False), 12 | (16, 16, 512, True), 13 | ] 14 | 15 | class TestPositionEmbedding(unittest.TestCase): 16 | def test_position_embedding(self): 17 | with torch.cuda.device(5): 18 | for num_buckets, num_heads, max_distance, bidi in TEST_CASE: 19 | p1 = ct.PositionEmbedding(num_heads, num_buckets, max_distance, bidi) 20 | p2 = ct.PositionEmbeddingTH(num_heads, num_buckets, max_distance, bidi) 21 | state_dict = { 22 | "weight": torch.randn(num_heads, num_buckets, device="cuda").half() 23 | } 24 | p1.load_state_dict(state_dict) 25 | p2.load_state_dict(state_dict) 26 | 27 | p1 = p1.cuda().half() 28 | p2 = p2.cuda().half() 29 | 30 | out = p1(128, 128) 31 | ans = p2(128, 128) 32 | 33 | self.assertTrue(torch.isclose(out, ans, 1e-4, 1e-4).all()) 34 | 35 | gradient_start = torch.randn(out.size(), device="cuda").half() 36 | if not bidi: 37 | mask = torch.arange(128, device="cuda")[:, None] <= torch.arange(128, device="cuda")[None, :] 38 | gradient_start = torch.where( 39 | mask[None, :, :].repeat(num_heads, 1, 1), 40 | gradient_start, 41 | torch.zeros_like(gradient_start), 42 | ) 43 | 44 | out.backward(gradient=gradient_start) 45 | ans.backward(gradient=gradient_start) 46 | self.assertTrue(torch.isclose(p1.weight.grad, p2.weight.grad, 1e-3, 1e-3).all()) 47 | 48 | def test_position_embedding_step(self): 49 | with torch.cuda.device(5): 50 | for num_buckets, num_heads, max_distance, bidi in TEST_CASE: 51 | weight = torch.randn(num_heads, num_buckets, device="cuda", dtype=torch.half) 52 | p2 = ct.PositionEmbeddingTH(num_heads, num_buckets, max_distance, bidi) 53 | p2.load_state_dict({"weight": weight }) 54 | p2 = p2.cuda().half() 55 | 56 | ans = p2(128, 128) 57 | 58 | mapping = torch.empty(max_distance, dtype=torch.int32, device="cuda") 59 | ck.position_embedding_init( 60 | num_buckets, max_distance, 61 | mapping.data_ptr(), 62 | bidi, 63 | torch.cuda.current_stream().cuda_stream 64 | ) 65 | for i in range(128): 66 | out = torch.empty(num_heads, 128, dtype=torch.half, device="cuda") 67 | ck.position_embedding_step( 68 | i, 128, num_buckets, max_distance, num_heads, 69 | mapping.data_ptr(), 70 | weight.data_ptr(), 71 | out.data_ptr(), 72 | bidi, 73 | torch.cuda.current_stream().cuda_stream 74 | ) 75 | self.assertTrue(torch.isclose(out, ans[:, :, i], 1e-4, 1e-4).all()) 76 | -------------------------------------------------------------------------------- /tests/test_softmax.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import cpm_kernels.torch as ct 4 | import cpm_kernels.kernels as ck 5 | 6 | class TestSoftmax(unittest.TestCase): 7 | def test_softmax(self): 8 | with torch.cuda.device(6): 9 | for shape in [ 10 | (1, 2, 32), 11 | (4, 128, 128), 12 | (16, 128, 32), 13 | (4, 16, 128), 14 | (123, 512, 321), 15 | (123, 768, 321), 16 | (233, 1024, 321), 17 | (4, 123, 16), 18 | (4, 321, 16), 19 | ]: 20 | x = torch.randn(*shape, device="cuda").half() 21 | x1 = x.clone().requires_grad_() 22 | x2 = x.requires_grad_() 23 | y1 = ct.softmaxTH(x1) 24 | y2 = ct.softmax(x2) 25 | self.assertTrue(torch.isclose(y1, y2, 1e-3, 1e-3).all()) 26 | 27 | rd = torch.randn( *shape, device="cuda").half() 28 | y1.backward(gradient=rd) 29 | y2.backward(gradient=rd) 30 | 31 | self.assertTrue(torch.isclose(x1.grad, x2.grad, 1e-3, 1e-3).all()) 32 | 33 | def test_softmax_inplace(self): 34 | with torch.cuda.device(6): 35 | for shape in [ 36 | (1, 2, 32), 37 | (4, 128, 128), 38 | (16, 128, 32), 39 | (4, 16, 128), 40 | (123, 512, 321), 41 | (123, 768, 321), 42 | (233, 1024, 321), 43 | (4, 123, 16), 44 | (4, 321, 16), 45 | ]: 46 | x = torch.randn(*shape, device="cuda").half() 47 | ans = ct.softmaxTH(x) 48 | ct.softmax_inplace(x) 49 | self.assertTrue(torch.isclose(x, ans, 1e-3, 1e-3).all()) 50 | 51 | def test_softmax_step(self): 52 | with torch.cuda.device(6): 53 | for shape in [ 54 | (2, 32), 55 | (512, 128), 56 | (123, 512), 57 | (123, 321), 58 | (4, 16), 59 | (3, 12321) 60 | ]: 61 | x = torch.randn(*shape, device="cuda").half() 62 | 63 | ans = torch.empty(shape, dtype=torch.half, device="cuda") 64 | ck.softmax_forward( 65 | shape[0], shape[1], 1, 66 | x.data_ptr(), 67 | ans.data_ptr(), 68 | torch.cuda.current_stream().cuda_stream 69 | ) 70 | ck.softmax_step_inplace( 71 | shape[0], shape[1], 72 | x.data_ptr(), 73 | torch.cuda.current_stream().cuda_stream 74 | ) 75 | 76 | self.assertTrue(torch.isclose(x, ans, 1e-5, 1e-5).all()) -------------------------------------------------------------------------------- /tests/test_transpose.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import cpm_kernels.torch as ct 4 | 5 | class TestTranspose(unittest.TestCase): 6 | def test_transpose(self): 7 | with torch.cuda.device(6): 8 | for shape in [ 9 | (1, 2, 32), 10 | (4, 128, 128), 11 | (16, 128, 32), 12 | (4, 16, 128), 13 | (123, 512, 321), 14 | (123, 768, 321), 15 | (233, 1024, 321), 16 | (4, 123, 16), 17 | (4, 321, 16), 18 | ]: 19 | x = torch.randn(*shape, device="cuda").half() 20 | x1 = x.clone().requires_grad_() 21 | x2 = x.requires_grad_() 22 | y1 = ct.transposeTH(x1) 23 | y2 = ct.transpose(x2) 24 | diff = (y1 - y2).abs().max() 25 | self.assertLess(diff, 1e-5) 26 | 27 | rd = torch.randn( (shape[0], shape[2], shape[1]), device="cuda").half() 28 | y1.backward(gradient=rd) 29 | y2.backward(gradient=rd) 30 | 31 | diff_grad = (x1.grad - x2.grad).abs().max() 32 | self.assertLess(diff_grad, 1e-5) 33 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import cpm_kernels.torch as ct 2 | import cpm_kernels.kernels as ck 3 | import torch 4 | import unittest 5 | import random 6 | 7 | class TestUtils(unittest.TestCase): 8 | def test_gemv_int8(self): 9 | with torch.cuda.device(2): 10 | a = torch.randn(33, 55, 58, dtype=torch.half, device="cuda") 11 | vec = [] 12 | for i in range(55): 13 | b = torch.randn(33, 58, dtype=torch.half, device="cuda") 14 | ck.copy_data_to_kv( 15 | 33, 55, 58, 16 | b.data_ptr(), 17 | a.data_ptr(), 18 | i, 19 | torch.cuda.current_stream().cuda_stream 20 | ) 21 | vec.append(b) 22 | ans = torch.stack(vec, dim=1) 23 | diff = torch.abs(a - ans).max() 24 | self.assertLess(diff, 1e-5) 25 | 26 | def test_array_add(self): 27 | with torch.cuda.device(2): 28 | a = torch.randint(0, 10, (128,), dtype=torch.int32, device="cuda") 29 | 30 | for _ in range(128): 31 | pos = random.randint(0, a.size(0) - 1) 32 | val = random.randint(-10, 10) 33 | old_val = a[pos].item() 34 | ck.utils.array_add(a.data_ptr(), pos, val, torch.cuda.current_stream().cuda_stream) 35 | 36 | self.assertEqual(a[pos], old_val + val) 37 | 38 | def test_justify_logits(self): 39 | with torch.cuda.device(2): 40 | for batch, n in [ 41 | (3, 128), 42 | (16, 333), 43 | (1, 2341), 44 | (15, 2341), 45 | ]: 46 | freq = torch.randint(0, 1, (batch, n), dtype=torch.int32, device="cuda") 47 | logits = torch.randn(batch, n, dtype=torch.float16, device="cuda") 48 | 49 | temp = (random.random() + 1) / 2 50 | freq_p = random.random() 51 | prec_p = random.random() 52 | 53 | ans = logits / temp - freq_p * freq.half() - prec_p * (freq > 0).half() 54 | 55 | ck.utils.adjustify_logits( 56 | batch, n, 57 | logits.data_ptr(), 58 | temp, freq_p, prec_p, 59 | freq.data_ptr(), 60 | torch.cuda.current_stream().cuda_stream 61 | ) 62 | 63 | self.assertTrue(torch.isclose(logits, ans, 1e-3, 1e-3).all()) 64 | 65 | def test_extend_buffer(self): 66 | with torch.cuda.device(2): 67 | for batch, old_size, nw_size in [ 68 | (3, 128, 256), 69 | (16, 333, 334), 70 | (1, 2341, 4567), 71 | (15, 2341, 3451), 72 | ]: 73 | x = torch.randn(batch, old_size, dtype=torch.float16, device="cuda") 74 | nw_buf = torch.empty(batch, nw_size, dtype=torch.float16, device="cuda") 75 | ck.utils.copy_extend_buffer( 76 | batch, old_size, nw_size, 77 | x.data_ptr(), 78 | nw_buf.data_ptr(), 79 | torch.cuda.current_stream().cuda_stream 80 | ) 81 | self.assertTrue(torch.isclose(x, nw_buf[:, :old_size], 1e-5, 1e-5).all()) 82 | 83 | def test_has_nan_inf(self): 84 | with torch.cuda.device(2): 85 | for shape in [ 86 | 1234, 87 | 3213, 88 | 123 * 321 * 77, 89 | 77777, 90 | 16, 91 | 1, 92 | 33 93 | ]: 94 | out = torch.zeros(5, dtype=torch.bool, device="cuda") 95 | x = torch.randn(shape, dtype=torch.half, device="cuda") 96 | self.assertTrue(not ct.has_nan_inf(x, out[0])) 97 | 98 | pos = random.randint(0, shape - 1) 99 | x[pos] = float('inf') 100 | self.assertTrue(ct.has_nan_inf(x, out[1])) 101 | x[pos] = 0 102 | 103 | pos = random.randint(0, shape - 1) 104 | x[pos] = float('-inf') 105 | self.assertTrue(ct.has_nan_inf(x, out[2])) 106 | x[pos] = 0 107 | 108 | pos = random.randint(0, shape - 1) 109 | x[pos] = float('nan') 110 | self.assertTrue(ct.has_nan_inf(x, out[3])) 111 | x[pos] = 0 112 | 113 | out[4] = True 114 | self.assertTrue(ct.has_nan_inf(x, out[4])) 115 | 116 | def test_copy_pos_hidden(self): 117 | with torch.cuda.device(2): 118 | for batch, hidden_size, seq_len in [ 119 | (3, 128, 256), 120 | (16, 333, 334), 121 | (1, 2341, 4567), 122 | (15, 2341, 3451), 123 | ]: 124 | x = torch.randn(batch, hidden_size, seq_len, dtype=torch.float16, device="cuda") 125 | for _ in range(128): 126 | pos = random.randint(0, seq_len - 1) 127 | pos_x = torch.empty(batch, hidden_size, dtype=torch.float16, device="cuda") 128 | ck.utils.copy_pos_hidden( 129 | batch, hidden_size, seq_len, 130 | pos, 131 | x.data_ptr(), 132 | pos_x.data_ptr(), 133 | torch.cuda.current_stream().cuda_stream 134 | ) 135 | self.assertTrue(torch.isclose(x[:, :, pos], pos_x, 1e-5, 1e-5).all()) 136 | 137 | 138 | 139 | 140 | --------------------------------------------------------------------------------