├── .clang-format ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── LICENSE ├── README.md ├── example ├── benchmark │ └── benchmark.py └── stable-diffusion │ ├── convert_model.py │ ├── sample_fp16.py │ └── sample_int8.py ├── include ├── asymmetric │ ├── asymmetric.h │ ├── asymmetric_internal.h │ ├── epilogue │ │ ├── thread │ │ │ └── linear_combination_dequant.h │ │ └── threadblock │ │ │ ├── default_epilogue_tensor_op_dequant.h │ │ │ ├── epilogue_dequant.h │ │ │ ├── predicated_vcol_iterator.h │ │ │ └── predicated_vrow_iterator.h │ └── gemm │ │ ├── device │ │ ├── gemm_dequant.h │ │ └── gemm_sparse_dequant.h │ │ └── kernel │ │ ├── default_gemm_dequant.h │ │ ├── default_gemm_sparse_dequant.h │ │ ├── gemm_dequant.h │ │ └── sparse_gemm_dequant.h ├── int4.h ├── matmul │ ├── matmul.h │ └── matmul_internal.h └── util.h ├── setup.py ├── src ├── CMakeLists.txt ├── asymmetric │ ├── CMakeLists.txt │ ├── asymmetric.cpp │ └── asymmetric.cu ├── binding.cpp └── matmul │ ├── CMakeLists.txt │ ├── matmul.cpp │ └── matmul.cu └── torch_quantizer ├── _C ├── asymmetric.pyi └── matmul.pyi ├── __init__.py ├── src ├── __init__.py ├── benchmark.py ├── converter.py ├── quant_layer.py ├── quant_model.py └── quant_utils.py └── version.py /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .idea 3 | cmake-build-* 4 | build 5 | *.pyc 6 | *.sh 7 | *.out 8 | *.log 9 | *.json 10 | *.wandb 11 | *.yaml 12 | *.DS_Store 13 | *.sbatch 14 | # Distribution / packaging 15 | __pycache__/ 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | wheelhouse/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | *.whl 36 | 37 | # C extensions 38 | *.so 39 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/cutlass"] 2 | path = third-party/cutlass 3 | url = https://github.com/NVIDIA/cutlass.git 4 | [submodule "third-party/pybind11"] 5 | path = third-party/pybind11 6 | url = https://github.com/pybind/pybind11.git 7 | [submodule "third-party/googletest"] 8 | path = third-party/googletest 9 | url = https://github.com/google/googletest.git 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) 2 | project(torch_quantizer LANGUAGES CXX) 3 | 4 | if (NOT CMAKE_BUILD_TYPE) 5 | set(CMAKE_BUILD_TYPE Release) 6 | endif () 7 | 8 | set(CMAKE_CXX_STANDARD 17) 9 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 10 | 11 | find_package(Threads REQUIRED) # -pthread 12 | find_package(OpenMP REQUIRED) # -Xpreprocessor -fopenmp 13 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) # -fPIC 14 | set(CMAKE_CXX_VISIBILITY_PRESET hidden) # -fvisibility=hidden 15 | 16 | if (MSVC) 17 | string(APPEND CMAKE_CXX_FLAGS " /Wall") 18 | string(APPEND CMAKE_CXX_FLAGS_DEBUG " /Zi") 19 | string(APPEND CMAKE_CXX_FLAGS_RELEASE " /O2 /Ob2") 20 | else () 21 | string(APPEND CMAKE_CXX_FLAGS " -Wall") 22 | string(APPEND CMAKE_CXX_FLAGS_DEBUG " -g -Og") 23 | string(APPEND CMAKE_CXX_FLAGS_RELEASE " -O3") 24 | endif () 25 | 26 | set(USE_FP16 ON) 27 | 28 | if (NOT DEFINED USE_FP16 AND NOT "$ENV{USE_FP16}" STREQUAL "") 29 | set(USE_FP16 "$ENV{USE_FP16}") 30 | endif () 31 | 32 | if (NOT DEFINED USE_FP16) 33 | set(USE_FP16 OFF) 34 | message(WARNING "FP16 support disabled, compiling without torch.HalfTensor. Suppress this warning with -DUSE_FP16=ON or -DUSE_FP16=OFF.") 35 | elseif (USE_FP16) 36 | message(STATUS "FP16 support enabled, compiling with torch.HalfTensor.") 37 | else () 38 | message(STATUS "FP16 support disabled, compiling without torch.HalfTensor.") 39 | endif () 40 | 41 | if (USE_FP16) 42 | add_definitions(-DUSE_FP16) 43 | endif () 44 | 45 | find_package(CUDA REQUIRED) 46 | if (CUDA_FOUND AND NOT WIN32) 47 | message(STATUS "Found CUDA, enabling CUDA support.") 48 | enable_language(CUDA) 49 | set(CMAKE_CUDA_STANDARD "${CMAKE_CXX_STANDARD}") 50 | set(CMAKE_CUDA_STANDARD_REQUIRED ON) 51 | add_definitions(-D__USE_CUDA__) 52 | 53 | string(APPEND CMAKE_CUDA_FLAGS " $ENV{TORCH_NVCC_FLAGS}") 54 | 55 | if (NOT DEFINED TORCH_CUDA_ARCH_LIST AND NOT "$ENV{TORCH_CUDA_ARCH_LIST}" STREQUAL "") 56 | set(TORCH_CUDA_ARCH_LIST "$ENV{TORCH_CUDA_ARCH_LIST}") 57 | endif () 58 | 59 | if (NOT TORCH_CUDA_ARCH_LIST) 60 | set(TORCH_CUDA_ARCH_LIST "Auto") 61 | message(WARNING "Torch CUDA arch list is not set, setting to \"Auto\". Suppress this warning with -DTORCH_CUDA_ARCH_LIST=Common.") 62 | endif () 63 | 64 | set(CMAKE_CUDA_ARCHITECTURES OFF) 65 | cuda_select_nvcc_arch_flags(CUDA_ARCH_FLAGS ${TORCH_CUDA_ARCH_LIST}) 66 | message(STATUS "TORCH_CUDA_ARCH_LIST: \"${TORCH_CUDA_ARCH_LIST}\"") 67 | message(STATUS "CUDA_ARCH_FLAGS: \"${CUDA_ARCH_FLAGS}\"") 68 | list(APPEND CUDA_NVCC_FLAGS ${CUDA_ARCH_FLAGS}) 69 | 70 | list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda") 71 | if (CUDA_HAS_FP16 OR NOT "${CUDA_VERSION}" VERSION_LESS "7.5") 72 | if (USE_FP16) 73 | message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor.") 74 | string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1" 75 | " -D__CUDA_NO_HALF_OPERATORS__" 76 | " -D__CUDA_NO_HALF_CONVERSIONS__" 77 | " -D__CUDA_NO_HALF2_OPERATORS__" 78 | " -D__CUDA_NO_BFLOAT16_CONVERSIONS__") 79 | else () 80 | message(STATUS "Found CUDA with FP16 support, but it is suppressed by the compile options, compiling without torch.cuda.HalfTensor.") 81 | endif () 82 | else () 83 | message(STATUS "Could not find CUDA with FP16 support, compiling without torch.cuda.HalfTensor.") 84 | endif () 85 | 86 | foreach (FLAG ${CUDA_NVCC_FLAGS}) 87 | string(FIND "${FLAG}" " " flag_space_position) 88 | if (NOT flag_space_position EQUAL -1) 89 | message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'") 90 | endif () 91 | string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}") 92 | endforeach () 93 | string(STRIP "${CMAKE_CUDA_FLAGS}" CMAKE_CUDA_FLAGS) 94 | message(STATUS "CMAKE_CUDA_FLAGS: \"${CMAKE_CUDA_FLAGS}\"") 95 | 96 | if (MSVC) 97 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} /O2 /Ob2") 98 | else () 99 | set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3") 100 | endif () 101 | elseif (NOT CUDA_FOUND) 102 | message(STATUS "CUDA not found, build for CPU-only.") 103 | else () 104 | message(STATUS "CUDA found, but build for CPU-only on Windows.") 105 | endif () 106 | 107 | function(system) 108 | set(options STRIP) 109 | set(oneValueArgs OUTPUT_VARIABLE ERROR_VARIABLE WORKING_DIRECTORY) 110 | set(multiValueArgs COMMAND) 111 | cmake_parse_arguments( 112 | SYSTEM 113 | "${options}" 114 | "${oneValueArgs}" 115 | "${multiValueArgs}" 116 | "${ARGN}" 117 | ) 118 | 119 | if (NOT DEFINED SYSTEM_WORKING_DIRECTORY) 120 | set(SYSTEM_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}") 121 | endif () 122 | 123 | execute_process( 124 | COMMAND ${SYSTEM_COMMAND} 125 | OUTPUT_VARIABLE STDOUT 126 | ERROR_VARIABLE STDERR 127 | WORKING_DIRECTORY "${SYSTEM_WORKING_DIRECTORY}" 128 | ) 129 | 130 | if ("${SYSTEM_STRIP}") 131 | string(STRIP "${STDOUT}" STDOUT) 132 | string(STRIP "${STDERR}" STDERR) 133 | endif () 134 | 135 | set("${SYSTEM_OUTPUT_VARIABLE}" "${STDOUT}" PARENT_SCOPE) 136 | 137 | if (DEFINED SYSTEM_ERROR_VARIABLE) 138 | set("${SYSTEM_ERROR_VARIABLE}" "${STDERR}" PARENT_SCOPE) 139 | endif () 140 | endfunction() 141 | 142 | if (NOT DEFINED PYTHON_EXECUTABLE) 143 | if (WIN32) 144 | set(PYTHON_EXECUTABLE "python.exe") 145 | else () 146 | set(PYTHON_EXECUTABLE "python") 147 | endif () 148 | endif () 149 | 150 | if (UNIX) 151 | system( 152 | STRIP OUTPUT_VARIABLE PYTHON_EXECUTABLE 153 | COMMAND bash -c "type -P '${PYTHON_EXECUTABLE}'" 154 | ) 155 | endif () 156 | 157 | system( 158 | STRIP OUTPUT_VARIABLE PYTHON_VERSION 159 | COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('platform').python_version())" 160 | ) 161 | 162 | message(STATUS "Use Python version: ${PYTHON_VERSION}") 163 | message(STATUS "Use Python executable: \"${PYTHON_EXECUTABLE}\"") 164 | 165 | if (NOT DEFINED PYTHON_INCLUDE_DIR) 166 | message(STATUS "Auto detecting Python include directory...") 167 | system( 168 | STRIP OUTPUT_VARIABLE PYTHON_INCLUDE_DIR 169 | COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('platinclude'))" 170 | ) 171 | endif () 172 | 173 | if ("${PYTHON_INCLUDE_DIR}" STREQUAL "") 174 | message(FATAL_ERROR "Python include directory not found") 175 | else () 176 | message(STATUS "Detected Python include directory: \"${PYTHON_INCLUDE_DIR}\"") 177 | include_directories("${PYTHON_INCLUDE_DIR}") 178 | endif () 179 | 180 | system( 181 | STRIP OUTPUT_VARIABLE PYTHON_SITE_PACKAGES 182 | COMMAND "${PYTHON_EXECUTABLE}" -c "print(__import__('sysconfig').get_path('purelib'))" 183 | ) 184 | message(STATUS "Detected Python site packages: \"${PYTHON_SITE_PACKAGES}\"") 185 | 186 | find_package(Git REQUIRED) 187 | if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 188 | message(STATUS "Populating Git submodule.") 189 | execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive 190 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 191 | RESULT_VARIABLE GIT_SUBMOD_RESULT) 192 | if(NOT GIT_SUBMOD_RESULT EQUAL "0") 193 | message(FATAL_ERROR 194 | "git submodule updata --init --recursive failed with ${GIT_SUBMOD_RESULT}.") 195 | endif() 196 | endif() 197 | 198 | if (NOT DEFINED TORCH_INCLUDE_PATH) 199 | message(STATUS "Auto detecting Torch include directory...") 200 | system( 201 | STRIP OUTPUT_VARIABLE TORCH_INCLUDE_PATH 202 | COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).include_paths()))" 203 | ) 204 | 205 | if ("${TORCH_INCLUDE_PATH}" STREQUAL "") 206 | set(TORCH_INCLUDE_PATH "${PYTHON_SITE_PACKAGES}/torch/include") 207 | endif () 208 | endif () 209 | 210 | if ("${TORCH_INCLUDE_PATH}" STREQUAL "") 211 | message(FATAL_ERROR "Torch include directory not found. Got: \"${TORCH_INCLUDE_PATH}\"") 212 | else () 213 | message(STATUS "Detected Torch include directory: \"${TORCH_INCLUDE_PATH}\"") 214 | include_directories(${TORCH_INCLUDE_PATH}) 215 | endif () 216 | 217 | if (NOT DEFINED TORCH_LIBRARY_PATH) 218 | message(STATUS "Auto detecting Torch library directory...") 219 | system( 220 | STRIP OUTPUT_VARIABLE TORCH_LIBRARY_PATH 221 | COMMAND "${PYTHON_EXECUTABLE}" -c "print('\\\;'.join(__import__('torch.utils.cpp_extension', fromlist=[None]).library_paths()))" 222 | ) 223 | 224 | if ("${TORCH_LIBRARY_PATH}" STREQUAL "") 225 | set(TORCH_LIBRARY_PATH "${PYTHON_SITE_PACKAGES}/torch/lib") 226 | endif () 227 | endif () 228 | 229 | if ("${TORCH_LIBRARY_PATH}" STREQUAL "") 230 | message(FATAL_ERROR "Torch library directory not found. Got: \"${TORCH_LIBRARY_PATH}\"") 231 | else () 232 | message(STATUS "Detected Torch library directory: \"${TORCH_LIBRARY_PATH}\"") 233 | endif () 234 | 235 | unset(TORCH_LIBRARIES) 236 | 237 | foreach (VAR_PATH ${TORCH_LIBRARY_PATH}) 238 | file(GLOB TORCH_LIBRARY "${VAR_PATH}/*") 239 | message(STATUS "Detected Torch libraries: \"${TORCH_LIBRARY}\"") 240 | endforeach () 241 | 242 | foreach (VAR_PATH ${TORCH_LIBRARY_PATH}) 243 | if (WIN32) 244 | file(GLOB TORCH_LIBRARY "${VAR_PATH}/*.lib") 245 | else () 246 | file(GLOB TORCH_LIBRARY "${VAR_PATH}/libtorch_python.*") 247 | endif () 248 | list(APPEND TORCH_LIBRARIES "${TORCH_LIBRARY}") 249 | endforeach () 250 | 251 | message(STATUS "Detected Torch Python libraries: \"${TORCH_LIBRARIES}\"") 252 | 253 | add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) 254 | 255 | set(_saved_CMAKE_MESSAGE_LOG_LEVEL ${CMAKE_MESSAGE_LOG_LEVEL}) 256 | set(CMAKE_MESSAGE_LOG_LEVEL ERROR) 257 | 258 | # Set the desired flags for CUTLASS 259 | set(CUTLASS_NVCC_ARCHS 80 CACHE STRING "Set CUDA architectures for CUTLASS") 260 | set(CUTLASS_ENABLE_TESTS OFF CACHE BOOL "Disable CUTLASS tests") 261 | set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable CUTLASS Unity Build") 262 | 263 | add_subdirectory(third-party/cutlass) 264 | add_subdirectory(third-party/pybind11) 265 | add_subdirectory(third-party/googletest) 266 | set(CMAKE_MESSAGE_LOG_LEVEL ${_saved_CMAKE_MESSAGE_LOG_LEVEL}) 267 | 268 | include_directories("${CMAKE_SOURCE_DIR}") 269 | include_directories(include) 270 | add_subdirectory(src) 271 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch Quantizer 2 | 3 | `torch_quantizer` is a Python package designed for efficient quantization of PyTorch models, particularly focusing on converting floating-point Linear and Conv2d modules to INT8 precision for improved inference speed on CUDA backend. Also, torch_quantizer supports [temporary quantization](https://arxiv.org/pdf/2310.03270) and is specially optimized for diffusion models. 4 | 5 | ## Installation 6 | 7 | Before installing `torch_quantizer`, ensure you have PyTorch installed in your environment. 8 | 9 | To install pre-built `torch_quantizer`, use the following command: 10 | 11 | ```bash 12 | pip install torch_quantizer-*.whl 13 | ``` 14 | To build from source, cloning the repository and compiling with NVCC: 15 | ```bash 16 | pip install -e . 17 | ``` 18 | ## Usage 19 | 20 | ### Prerequisites 21 | 22 | Import PyTorch before importing `torch_quantizer`: 23 | 24 | ```python 25 | import torch 26 | import torch_quantizer as tq 27 | ``` 28 | 29 | ### Benchmarking 30 | 31 | To benchmark and verify the inference speedup by INT8 operation, use the following methods: 32 | 33 | ```bash 34 | python3 example/benchmark/benchmark.py --o linear --bs 512 --cin 960 --cout 960 35 | python3 example/benchmark/benchmark.py --o conv2d --bs 1 --cin 512 --cout 512 36 | ``` 37 | #### Benchmarking Results 38 | 39 | Here are the results for both linear and 2D convolution operations, showing the average time taken for FP32, FP16, and INT8 (Quant+Dequant) operations on RTX 3090. 40 | 41 | ##### Linear Operation 42 | 43 | - **Batch Size:** 512 44 | - **Channels In:** 960 45 | - **Channels Out:** 960 46 | 47 | | Operation Type | Average Time | 48 | |----------------|-----------------| 49 | | FP32 | 8.414e-05 s | 50 | | FP16 | 3.304e-05 s | 51 | | INT8 (Quant+Dequant) | 2.908e-05 s | 52 | 53 | ##### 2D Convolution Operation 54 | 55 | - **Batch Size:** 1 56 | - **Channels In:** 512 57 | - **Channels Out:** 512 58 | 59 | | Operation Type | Average Time | 60 | |----------------|-----------------| 61 | | FP32 | 0.000903 s | 62 | | FP16 | 0.000413 s | 63 | | INT8 (Quant+Dequant) | 0.000178 s | 64 | 65 | These results highlight the performance improvements achieved through quantization to INT8, demonstrating significant reductions in inference time across both operations. 66 | 67 | ### Model Conversion to INT8 68 | 69 | #### Step 1: FP16 Precision Check 70 | 71 | Ensure your model can be inferenced with FP16 precision. For instance: 72 | 73 | ```python 74 | unet.half() 75 | # do FP16 inference 76 | ``` 77 | 78 | #### Step 2: Fake Quantization 79 | 80 | Convert your model to a fake quantized model for calibration: 81 | 82 | ```python 83 | wq_params = {'n_bits': 8, 'channel_wise': True, 'scale_method': 'max'} 84 | aq_params = {'n_bits': 8, 'channel_wise': False, 'scale_method': 'mse'} 85 | ddim_steps = # Define diffusion steps here, 1 for non-temporal models 86 | 87 | fakeq_model = tq.fake_quant(unet, wq_params, aq_params, num_steps=ddim_steps) 88 | ``` 89 | 90 | #### Step 3: Calibration 91 | 92 | Run your fake quantized model with some input for calibration. For example: 93 | 94 | ```python 95 | from some_diffusion_library import DiffusionPipeline 96 | 97 | pipe = DiffusionPipeline() 98 | prompt = "a photo of a flying dog" 99 | image = pipe(prompt, guidance_scale=7.5)["sample"][0] 100 | ``` 101 | 102 | #### Step 4: Conversion to Real INT8 Model 103 | 104 | Convert the fake quantized model to a real INT8 model: 105 | 106 | ```python 107 | qunet = tq.fake2real(fakeq_model, save_dir='.') 108 | ``` 109 | 110 | The INT8 checkpoint will be saved in the specified directory. 111 | 112 | ### Loading INT8 Model 113 | 114 | Load the INT8 model directly from the checkpoint: 115 | 116 | ```python 117 | ckpt_path = 'path_to_checkpoint' 118 | qnn = tq.real_quant(unet, n_bits, ddim_steps, ckpt_path) 119 | ``` 120 | 121 | ## Acknowledgement 122 | 123 | This repository is built upon [QUIK](https://github.com/IST-DASLab/QUIK). We thank the authors for their open-sourced code. 124 | 125 | ## License 126 | 127 | `torch_quantizer` is released under [Apache-2.0 License](LICENSE). 128 | -------------------------------------------------------------------------------- /example/benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch_quantizer as tq 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(description='Benchmarking script for linear and conv2d operations.') 6 | 7 | # Common arguments 8 | parser.add_argument('--o', type=str, choices=['linear', 'conv2d'], help='The benchmark operation to perform: linear or conv2d') 9 | 10 | # Arguments for linear benchmark 11 | parser.add_argument('--bs', type=int, help='Batch size', default=1) 12 | parser.add_argument('--cin', type=int, help='Channels in', default=960) 13 | parser.add_argument('--cout', type=int, help='Channels out', default=960) 14 | 15 | # Additional arguments for conv2d benchmark 16 | parser.add_argument('--h', type=int, help='Height of the input image', default=64) 17 | parser.add_argument('--w', type=int, help='Width of the input image', default=64) 18 | parser.add_argument('--k', type=int, help='Kernel size', default=3) 19 | parser.add_argument('--p', type=int, help='Padding size', default=0) 20 | 21 | args = parser.parse_args() 22 | 23 | if args.o == 'linear': 24 | # Ensure required arguments for linear operation are provided 25 | if args.cin > 0 and args.cout > 0: 26 | tq.benchmark_linear(bs=args.bs, cin=args.cin, cout=args.cout) 27 | else: 28 | print("Error: cin and cout are required for linear benchmark.") 29 | 30 | elif args.o == 'conv2d': 31 | # Ensure required arguments for conv2d operation are provided 32 | if args.cin > 0 and args.cout > 0 and args.h > 0 and args.w > 0 and args.k > 0: 33 | tq.benchmark_conv2d(bs=args.bs, cin=args.cin, h=args.h, w=args.w, cout=args.cout, k=args.k, padding=args.p) 34 | else: 35 | print("Error: cin, cout, h, w, and k are required for conv2d benchmark.") 36 | 37 | if __name__ == '__main__': 38 | main() 39 | -------------------------------------------------------------------------------- /example/stable-diffusion/convert_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline 3 | import torch_quantizer as tq 4 | 5 | model_id = "PATH TO stable-diffusion-v1-4" 6 | device = "cuda" 7 | 8 | 9 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, local_files_only=True) 10 | 11 | wq_params = {'n_bits': 8, 'channel_wise': True, 'scale_method': 'max'} 12 | aq_params = {'n_bits': 8, 'channel_wise': False, 'scale_method': 'mse'} 13 | ddim_steps = 50 14 | 15 | ## convert to fakeq model 16 | pipe.unet = tq.fake_quant(pipe.unet, wq_params, aq_params, num_steps=ddim_steps) 17 | 18 | ## run fakeq model to do calibration 19 | pipe = pipe.to(device) 20 | prompt = "a photo of an astronaut riding a horse on mars" 21 | image = pipe(prompt, num_inference_steps=ddim_steps).images[0] 22 | 23 | ## convert to realq model 24 | pipe.unet = tq.fake2real(pipe.unet, save_dir='.') 25 | 26 | ## sampling with INT8 model 27 | pipe = pipe.to(device) 28 | prompt = "a photo of an astronaut riding a horse on mars" 29 | image = pipe(prompt, num_inference_steps=ddim_steps).images[0] 30 | image.save("astronaut_rides_horse_8bit.png") 31 | -------------------------------------------------------------------------------- /example/stable-diffusion/sample_fp16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(3407) 3 | from diffusers import StableDiffusionPipeline 4 | 5 | model_id = "PATH TO stable-diffusion-v1-4" 6 | device = "cuda" 7 | 8 | 9 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, local_files_only=True) 10 | ddim_steps = 50 11 | 12 | ## sampling 13 | pipe = pipe.to(device) 14 | prompt = "A cozy cabin nestled in a snowy forest with smoke rising from the chimney" 15 | image = pipe(prompt, num_inference_steps=ddim_steps).images[0] 16 | image.save("cabin_8bit.png") 17 | -------------------------------------------------------------------------------- /example/stable-diffusion/sample_int8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(3407) 3 | from diffusers import StableDiffusionPipeline 4 | import torch_quantizer as tq 5 | 6 | model_id = "PATH TO stable-diffusion-v1-4" 7 | device = "cuda" 8 | 9 | 10 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, local_files_only=True) 11 | 12 | wq_params = {'n_bits': 8, 'channel_wise': True, 'scale_method': 'max'} 13 | aq_params = {'n_bits': 8, 'channel_wise': False, 'scale_method': 'mse'} 14 | ddim_steps = 50 15 | 16 | ## convert to INT8 model 17 | pipe.unet = tq.real_quant(pipe.unet, n_bits=8, num_steps=ddim_steps, ckpt_path="PATH TO UNet2DConditionModel_8bits_{}steps.pth".format(ddim_steps)) 18 | 19 | ## sampling 20 | pipe = pipe.to(device) 21 | prompt = "A cozy cabin nestled in a snowy forest with smoke rising from the chimney" 22 | image = pipe(prompt, num_inference_steps=ddim_steps).images[0] 23 | image.save("cabin_8bit.png") 24 | -------------------------------------------------------------------------------- /include/asymmetric/asymmetric.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace TORCHQ::asymmetric { 5 | void buildSubmodule(pybind11::module &mod); 6 | } // namespace TORCHQ::asymmetric 7 | -------------------------------------------------------------------------------- /include/asymmetric/asymmetric_internal.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace TORCHQ::asymmetric { 5 | 6 | torch::Tensor myQuantizeCUDA(const torch::Tensor &src, const torch::Tensor &delta, 7 | const torch::Tensor &zp); 8 | 9 | torch::Tensor myQuantizeNCHWCUDA(const torch::Tensor &src, const torch::Tensor &delta, 10 | const torch::Tensor &zp); 11 | 12 | } // namespace TORCHQ::asymmetric -------------------------------------------------------------------------------- /include/asymmetric/epilogue/thread/linear_combination_dequant.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | /*! \file 33 | \brief Functor performing linear combination operations used by dequantize 34 | epilogues. 35 | */ 36 | #pragma once 37 | 38 | #include 39 | 40 | #include "cutlass/array.h" 41 | #include "cutlass/cutlass.h" 42 | #include "cutlass/epilogue/thread/linear_combination_params.h" 43 | #include "cutlass/epilogue/thread/scale_type.h" 44 | #include "cutlass/functional.h" 45 | #include "cutlass/numeric_conversion.h" 46 | #include "cutlass/numeric_types.h" 47 | ///////////////////////////////////////////////////////////////////////////////////////////////// 48 | 49 | namespace cutlass { 50 | namespace epilogue { 51 | namespace thread { 52 | namespace asymmetric { 53 | 54 | struct MyScaleType { 55 | enum Kind { 56 | Dequantize, 57 | }; 58 | }; 59 | ///////////////////////////////////////////////////////////////////////////////////////////////// 60 | 61 | template 66 | class LinearCombinationDequant { 67 | public: 68 | using ElementOutput = ElementOutput_; 69 | using ElementSource = ElementSource_; 70 | using ElementAccumulator = ElementAccumulator_; 71 | using ElementCompute = ElementCompute_; 72 | using ElementC = ElementSource_; 73 | using ElementD = ElementOutput_; 74 | 75 | static int const kCount = Count; 76 | static const MyScaleType::Kind kScale = MyScaleType::Dequantize; 77 | 78 | using FragmentOutput = Array; 79 | using FragmentSource = Array; 80 | using FragmentAccumulator = Array; 81 | using FragmentCompute = Array; 82 | 83 | static FloatRoundStyle const kRound = Round; 84 | 85 | struct Params { 86 | ElementCompute shift_value; 87 | 88 | CUTLASS_HOST_DEVICE 89 | Params() : shift_value(ElementCompute(0)) {} 90 | 91 | CUTLASS_HOST_DEVICE 92 | Params(ElementCompute shift_value) : shift_value(shift_value) {} 93 | }; 94 | 95 | private: 96 | // 97 | // Data members 98 | // 99 | 100 | ElementCompute shift_value_ = ElementCompute(0); 101 | 102 | public: 103 | /// Constructs the function object 104 | CUTLASS_HOST_DEVICE 105 | LinearCombinationDequant(Params const ¶ms) { 106 | shift_value_ = params.shift_value; 107 | } 108 | 109 | /// Returns true if source is needed 110 | CUTLASS_HOST_DEVICE 111 | bool is_source_needed() const { return true; } 112 | 113 | CUTLASS_HOST_DEVICE 114 | void set_k_partition(int k_partition, int k_partition_count) {} 115 | 116 | CUTLASS_HOST_DEVICE 117 | FragmentOutput operator()(FragmentAccumulator const &accumulator, 118 | FragmentSource const &row_vec_alpha, 119 | FragmentSource const &col_vec_alpha, 120 | FragmentSource const &zero_row_scaling_frag, 121 | FragmentSource const &w_reduced_frag, 122 | FragmentSource const &y_frag) const { 123 | NumericArrayConverter 124 | source_converter; 125 | NumericArrayConverter 126 | accumulator_converter; 127 | 128 | NumericArrayConverter 129 | destination_converter; 130 | 131 | FragmentCompute converted_row_vec_alpha = source_converter(row_vec_alpha); 132 | FragmentCompute converted_col_vec_alpha = source_converter(col_vec_alpha); 133 | FragmentCompute converted_accumulator = accumulator_converter(accumulator); 134 | FragmentCompute converted_zero_row_scaling = 135 | source_converter(zero_row_scaling_frag); 136 | FragmentCompute converted_w_reduced = source_converter(w_reduced_frag); 137 | FragmentCompute converted_y = source_converter(y_frag); 138 | 139 | FragmentCompute result; 140 | torch::Half *result_ptr = reinterpret_cast(&result); 141 | 142 | const torch::Half *acc_ptr = 143 | reinterpret_cast(&converted_accumulator); 144 | const torch::Half *row_vec_ptr = 145 | reinterpret_cast(&converted_row_vec_alpha); 146 | const torch::Half *col_vec_ptr = 147 | reinterpret_cast(&converted_col_vec_alpha); 148 | const torch::Half *zero_row_scaling_ptr = 149 | reinterpret_cast(&converted_zero_row_scaling); 150 | const torch::Half *w_reduced_ptr = 151 | reinterpret_cast(&converted_w_reduced); 152 | const torch::Half *y_ptr = 153 | reinterpret_cast(&converted_y); 154 | 155 | CUTLASS_PRAGMA_UNROLL 156 | for (int i = 0; i < kCount; ++i) { 157 | result_ptr[i] = acc_ptr[i] * row_vec_ptr[i] * col_vec_ptr[i] + 158 | (zero_row_scaling_ptr[i] + 159 | (torch::Half)shift_value_ * col_vec_ptr[i]) * 160 | w_reduced_ptr[i] + 161 | y_ptr[i]; 162 | } 163 | return destination_converter(result); 164 | } 165 | }; 166 | 167 | ///////////////////////////////////////////////////////////////////////////////////////////////// 168 | 169 | } // namespace asymmetric 170 | } // namespace thread 171 | } // namespace epilogue 172 | } // namespace cutlass 173 | -------------------------------------------------------------------------------- /include/asymmetric/epilogue/threadblock/default_epilogue_tensor_op_dequant.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | #pragma once 33 | 34 | #include "asymmetric/epilogue/threadblock/epilogue_dequant.h" 35 | #include "asymmetric/epilogue/threadblock/predicated_vcol_iterator.h" 36 | #include "asymmetric/epilogue/threadblock/predicated_vrow_iterator.h" 37 | #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" 38 | //////////////////////////////////////////////////////////////////////////////// 39 | 40 | namespace cutlass { 41 | namespace epilogue { 42 | namespace threadblock { 43 | namespace asymmetric { 44 | //////////////////////////////////////////////////////////////////////////////// 45 | template 48 | struct DefaultEpilogueTensorOpDequant 49 | : public DefaultEpilogueTensorOp { 52 | using OutputOp = OutputOp_; 53 | using DefaultEpilogueTensorOp = 54 | DefaultEpilogueTensorOp; 56 | using RowVecIterator = 57 | cutlass::epilogue::threadblock::asymmetric::PredicatedVRowIterator< 58 | typename DefaultEpilogueTensorOp::OutputTileThreadMap, 59 | typename DefaultEpilogueTensorOp::ElementOutput, ScatterD, 60 | PermuteDLayout, DefaultEpilogueTensorOp::UseCUDAStore>; 61 | using ColVecIterator = 62 | cutlass::epilogue::threadblock::asymmetric::PredicatedVColIterator< 63 | typename DefaultEpilogueTensorOp::OutputTileThreadMap, 64 | typename DefaultEpilogueTensorOp::ElementOutput, ScatterD, 65 | PermuteDLayout, DefaultEpilogueTensorOp::UseCUDAStore>; 66 | 67 | using Epilogue = cutlass::epilogue::threadblock::asymmetric::EpilogueDequant< 68 | typename DefaultEpilogueTensorOp::Shape, 69 | typename DefaultEpilogueTensorOp::WarpMmaTensorOp, 70 | DefaultEpilogueTensorOp::kPartitionsK, 71 | typename DefaultEpilogueTensorOp::OutputTileIterator, RowVecIterator, 72 | ColVecIterator, 73 | typename DefaultEpilogueTensorOp::AccumulatorFragmentIterator, 74 | typename DefaultEpilogueTensorOp::WarpTileIterator, 75 | typename DefaultEpilogueTensorOp::SharedLoadIterator, OutputOp, 76 | typename DefaultEpilogueTensorOp::Padding, 77 | DefaultEpilogueTensorOp::kFragmentsPerIteration>; 78 | }; 79 | 80 | //////////////////////////////////////////////////////////////////////////////// 81 | 82 | } // namespace asymmetric 83 | } // namespace threadblock 84 | } // namespace epilogue 85 | } // namespace cutlass 86 | 87 | //////////////////////////////////////////////////////////////////////////////// 88 | -------------------------------------------------------------------------------- /include/asymmetric/epilogue/threadblock/predicated_vcol_iterator.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | /*! \file 33 | \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. 34 | 35 | The epilogue rearranges the result of a matrix product through shared memory 36 | to match canonical tensor layouts in global memory. Epilogues support 37 | conversion and reduction operations. 38 | 39 | */ 40 | 41 | #pragma once 42 | 43 | #include "cutlass/arch/arch.h" 44 | #include "cutlass/arch/memory.h" 45 | #include "cutlass/array.h" 46 | #include "cutlass/cutlass.h" 47 | #include "cutlass/epilogue/threadblock/output_tile_thread_map.h" 48 | #include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" 49 | #include "cutlass/layout/matrix.h" 50 | #include "cutlass/layout/permute.h" 51 | #include "cutlass/layout/tensor.h" 52 | #include "cutlass/matrix_shape.h" 53 | #include "cutlass/numeric_types.h" 54 | #include "cutlass/tensor_ref.h" 55 | #include "cutlass/transform/pitch_linear_thread_map.h" 56 | 57 | //////////////////////////////////////////////////////////////////////////////// 58 | 59 | namespace cutlass { 60 | 61 | //////////////////////////////////////////////////////////////////////////////// 62 | 63 | namespace epilogue { 64 | namespace threadblock { 65 | namespace asymmetric { 66 | //////////////////////////////////////////////////////////////////////////////// 67 | 68 | /// Tile iterator used to load and store output tile from global memory in 69 | /// epilogue. 70 | /// 71 | /// Satisfies: ReadableTileIterator | PredicatedTileIterator | 72 | /// ForwardTileIterator 73 | /// 74 | template 80 | class PredicatedVColIterator { 81 | static_assert(!ScatterD); 82 | static_assert(std::is_same::value); 83 | 84 | public: 85 | using ThreadMap = ThreadMap_; 86 | using Shape = typename ThreadMap::Shape; 87 | 88 | using Element = Element_; 89 | 90 | using Layout = layout::RowMajor; 91 | using TensorRef = TensorRef; 92 | using ConstTensorRef = typename TensorRef::ConstTensorRef; 93 | 94 | using Index = typename Layout::Index; 95 | using LongIndex = typename Layout::LongIndex; 96 | using TensorCoord = MatrixCoord; 97 | 98 | static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; 99 | static int const kThreads = ThreadMap::kThreads; 100 | static int const kIterations = ThreadMap::Count::kTile; 101 | 102 | static_assert(ThreadMap::Iterations::kRow > 0, 103 | "ThreadMap::Iterations::kRow must be > 0"); 104 | static_assert(ThreadMap::Iterations::kGroup > 0, 105 | "ThreadMap::Iterations::kGroup must be > 0"); 106 | static_assert(ThreadMap::Iterations::kCluster > 0, 107 | "ThreadMap::Iterations::kCluster must be > 0"); 108 | static_assert(ThreadMap::Iterations::kColumn > 0, 109 | "ThreadMap::Iterations::kColumn must be > 0"); 110 | 111 | /// Fragment object 112 | using Fragment = Array; 117 | 118 | /// Memory access size 119 | using AccessType = AlignedArray; 120 | 121 | // 122 | // Parameters struct 123 | // 124 | 125 | /// Uses a non-template class 126 | struct Params : PredicatedTileIteratorParams { 127 | using Base = PredicatedTileIteratorParams; 128 | 129 | CUTLASS_HOST_DEVICE 130 | Params() {} 131 | 132 | CUTLASS_HOST_DEVICE 133 | Params(Layout const &layout) 134 | : PredicatedTileIteratorParams( 135 | 1 * int(sizeof(AccessType)) / kElementsPerAccess, 136 | make_OutputTileThreadMapDesc()) {} 137 | 138 | CUTLASS_HOST_DEVICE 139 | Params(Base const &base) : Base(base) {} 140 | }; 141 | 142 | /// Mask object 143 | struct Mask { 144 | static int const kCount = ThreadMap::Iterations::kColumn; 145 | 146 | /// Predicate state 147 | bool predicates[kCount]; 148 | 149 | // 150 | // Mask 151 | // 152 | CUTLASS_HOST_DEVICE 153 | Mask() { enable(); } 154 | 155 | ///< Efficiently disables all accesses guarded by mask 156 | CUTLASS_HOST_DEVICE void clear() { 157 | CUTLASS_PRAGMA_UNROLL 158 | for (int i = 0; i < kCount; ++i) { 159 | predicates[i] = false; 160 | } 161 | } 162 | 163 | ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask 164 | CUTLASS_DEVICE void enable() { 165 | CUTLASS_PRAGMA_UNROLL 166 | for (int i = 0; i < kCount; ++i) { 167 | predicates[i] = true; 168 | } 169 | } 170 | }; 171 | 172 | private: 173 | // 174 | // Data members 175 | // 176 | 177 | /// Parameters structure containing reference and precomputed state. 178 | PredicatedTileIteratorParams params_; 179 | 180 | /// Byte-level pointer. 181 | uint8_t *byte_pointer_; 182 | 183 | /// Byte-level pointer for store(). 184 | uint8_t *store_byte_pointer_; 185 | 186 | /// Array of boolean values to contain steady-state predicates 187 | Mask mask_; 188 | 189 | /// Extent of the matrix tile in rows 190 | Index extent_row_; 191 | 192 | /// Extent of the matrix tile in rows 193 | Index extent_column_; 194 | 195 | /// A thread's starting row position (assuming steady-state predicates have 196 | /// been computed) 197 | Index thread_start_row_; 198 | 199 | /// A thread's starting column 200 | Index thread_start_column_; 201 | 202 | /// Internal state counter 203 | int state_[3]; 204 | 205 | // 206 | // Static asserts about internal strides 207 | // 208 | 209 | static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); 210 | static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); 211 | static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, 212 | "Expected 64b strides"); 213 | 214 | private: 215 | // 216 | // Methods 217 | // 218 | 219 | public: 220 | // 221 | // Methods 222 | // 223 | 224 | /// Constructor 225 | CUTLASS_DEVICE 226 | PredicatedVColIterator(PredicatedTileIteratorParams const ¶ms, 227 | Element *pointer, TensorCoord extent, int thread_idx, 228 | TensorCoord threadblock_offset = TensorCoord(), 229 | int const *indices = nullptr) 230 | : params_(params) { 231 | TensorCoord thread_offset = 232 | ThreadMap::initial_offset(thread_idx) + threadblock_offset; 233 | 234 | extent_row_ = extent.row(); 235 | extent_column_ = extent.column(); 236 | 237 | thread_start_row_ = thread_offset.row(); 238 | thread_start_column_ = thread_offset.column(); 239 | 240 | // Initialize predicates 241 | CUTLASS_PRAGMA_UNROLL 242 | for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { 243 | mask_.predicates[c] = ((thread_offset.column() + 244 | ThreadMap::Delta::kColumn * c) < extent.column()); 245 | } 246 | 247 | // Null pointer performs no accesses 248 | if (!pointer) { 249 | mask_.clear(); 250 | } 251 | 252 | // Initialize byte_pointer_ 253 | byte_pointer_ = reinterpret_cast(pointer) + 254 | LongIndex(thread_offset.row()) * LongIndex(params_.stride) + 255 | LongIndex(thread_offset.column()) * sizeof(AccessType) / 256 | kElementsPerAccess; 257 | 258 | // store_byte_pointer_ is set to be the same with byte_pointer_ 259 | store_byte_pointer_ = byte_pointer_; 260 | 261 | // Initialize internal state counter 262 | state_[0] = state_[1] = state_[2] = 0; 263 | 264 | byte_pointer_ = reinterpret_cast(pointer) + 265 | LongIndex(thread_offset.row()) * LongIndex(params_.stride); 266 | } 267 | 268 | /// Adds a pointer offset in units of Element 269 | CUTLASS_HOST_DEVICE 270 | void add_pointer_offset(LongIndex pointer_offset) { 271 | store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; 272 | byte_pointer_ += pointer_offset * sizeof_bits::value / 8; 273 | } 274 | 275 | /// Loads a fragment from memory 276 | CUTLASS_DEVICE 277 | void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { 278 | uint8_t *byte_pointer = byte_pointer_; 279 | AccessType *frag_ptr = reinterpret_cast(&frag); 280 | 281 | CUTLASS_PRAGMA_UNROLL 282 | for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 283 | ++cluster) { 284 | CUTLASS_PRAGMA_UNROLL 285 | for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 286 | CUTLASS_PRAGMA_UNROLL 287 | for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 288 | int frag_row_idx = 289 | (row + ThreadMap::Iterations::kRow * 290 | (group + ThreadMap::Iterations::kGroup * cluster)); 291 | 292 | int row_offset = row * ThreadMap::Delta::kRow + 293 | group * ThreadMap::Delta::kGroup + 294 | cluster * ThreadMap::Delta::kCluster; 295 | 296 | bool row_guard = ((row_offset + thread_start_row_) < extent_row_); 297 | 298 | CUTLASS_PRAGMA_UNROLL 299 | for (int column = 0; column < ThreadMap::Iterations::kColumn; 300 | ++column) { 301 | bool guard = row_guard && mask_.predicates[column]; 302 | if (guard) { 303 | Element *bias = 304 | reinterpret_cast(byte_pointer + byte_offset); 305 | frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] 306 | .fill(*bias); 307 | } 308 | } 309 | 310 | if (row + 1 < ThreadMap::Iterations::kRow) { 311 | byte_pointer += params_.increment_row; 312 | } 313 | } 314 | 315 | if (group + 1 < ThreadMap::Iterations::kGroup) { 316 | byte_pointer += params_.increment_group; 317 | } 318 | } 319 | 320 | if (cluster + 1 < ThreadMap::Iterations::kCluster) { 321 | byte_pointer += params_.increment_cluster; 322 | } 323 | } 324 | } 325 | 326 | /// Loads a fragment from memory 327 | CUTLASS_DEVICE 328 | void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } 329 | 330 | CUTLASS_DEVICE 331 | MatrixCoord thread_start() const { 332 | return MatrixCoord(thread_start_row_, thread_start_column_); 333 | } 334 | 335 | /// Need to get the thread start row from the tile iterator 336 | CUTLASS_DEVICE 337 | int32_t thread_start_row() const { return thread_start_row_; } 338 | 339 | /// Need to get the thread start row from the tile iterator 340 | CUTLASS_DEVICE 341 | int32_t thread_start_column() const { return thread_start_column_; } 342 | 343 | /// Extent of the matrix in rows 344 | CUTLASS_DEVICE 345 | Index extent_row() const { return extent_row_; } 346 | 347 | /// Extent of the matrix in columns 348 | CUTLASS_DEVICE 349 | Index extent_column() const { return extent_column_; } 350 | 351 | /// Advances to the next position to load or store 352 | CUTLASS_HOST_DEVICE 353 | PredicatedVColIterator &operator++() { 354 | ++state_[0]; 355 | 356 | store_byte_pointer_ += params_.advance_row; 357 | 358 | byte_pointer_ += params_.advance_row; 359 | 360 | thread_start_row_ += ThreadMap::Shape::kRow; 361 | 362 | if (state_[0] == ThreadMap::Count::kRow) { 363 | state_[0] = 0; 364 | ++state_[1]; 365 | byte_pointer_ += params_.advance_group; 366 | store_byte_pointer_ += params_.advance_group; 367 | 368 | thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * 369 | ThreadMap::Shape::kRow * ThreadMap::Count::kRow; 370 | 371 | if (state_[1] == ThreadMap::Count::kGroup) { 372 | state_[1] = 0; 373 | ++state_[2]; 374 | byte_pointer_ += params_.advance_cluster; 375 | store_byte_pointer_ += params_.advance_cluster; 376 | 377 | thread_start_row_ += ThreadMap::Count::kGroup * 378 | ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * 379 | ThreadMap::Shape::kRow; 380 | 381 | if (state_[2] == ThreadMap::Count::kCluster) { 382 | state_[2] = 0; 383 | byte_pointer_ += params_.advance_tile; 384 | store_byte_pointer_ += params_.advance_tile; 385 | 386 | thread_start_row_ += 387 | ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * 388 | ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; 389 | } 390 | } 391 | } 392 | 393 | return *this; 394 | } 395 | 396 | ///< Efficiently disables all accesses guarded by mask 397 | CUTLASS_DEVICE void clear_mask() { mask_.clear(); } 398 | 399 | ///< Efficiently enables all accesses guarded by mask 400 | CUTLASS_DEVICE void enable_mask() { mask_.enable(); } 401 | 402 | ///< Sets the mask 403 | CUTLASS_DEVICE void get_mask(Mask &mask) const { mask = mask_; } 404 | 405 | ///< Sets the mask 406 | CUTLASS_DEVICE void set_mask(Mask const &mask) { mask_ = mask; } 407 | }; 408 | 409 | } // namespace asymmetric 410 | } // namespace threadblock 411 | } // namespace epilogue 412 | } // namespace cutlass 413 | 414 | //////////////////////////////////////////////////////////////////////////////// 415 | -------------------------------------------------------------------------------- /include/asymmetric/epilogue/threadblock/predicated_vrow_iterator.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | /*! \file 33 | \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. 34 | 35 | The epilogue rearranges the result of a matrix product through shared memory 36 | to match canonical tensor layouts in global memory. Epilogues support 37 | conversion and reduction operations. 38 | 39 | */ 40 | 41 | #pragma once 42 | 43 | #include "cutlass/arch/arch.h" 44 | #include "cutlass/arch/memory.h" 45 | #include "cutlass/array.h" 46 | #include "cutlass/cutlass.h" 47 | #include "cutlass/epilogue/threadblock/output_tile_thread_map.h" 48 | #include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" 49 | #include "cutlass/layout/matrix.h" 50 | #include "cutlass/layout/permute.h" 51 | #include "cutlass/layout/tensor.h" 52 | #include "cutlass/matrix_shape.h" 53 | #include "cutlass/numeric_types.h" 54 | #include "cutlass/tensor_ref.h" 55 | #include "cutlass/transform/pitch_linear_thread_map.h" 56 | 57 | //////////////////////////////////////////////////////////////////////////////// 58 | 59 | namespace cutlass { 60 | 61 | //////////////////////////////////////////////////////////////////////////////// 62 | 63 | namespace epilogue { 64 | namespace threadblock { 65 | namespace asymmetric { 66 | //////////////////////////////////////////////////////////////////////////////// 67 | 68 | /// Tile iterator used to load and store output tile from global memory in 69 | /// epilogue. 70 | /// 71 | /// Satisfies: ReadableTileIterator | PredicatedTileIterator | 72 | /// ForwardTileIterator 73 | /// 74 | template 80 | class PredicatedVRowIterator { 81 | public: 82 | using ThreadMap = ThreadMap_; 83 | using Shape = typename ThreadMap::Shape; 84 | 85 | using Element = Element_; 86 | 87 | using Layout = layout::RowMajor; 88 | using TensorRef = TensorRef; 89 | using ConstTensorRef = typename TensorRef::ConstTensorRef; 90 | 91 | using Index = typename Layout::Index; 92 | using LongIndex = typename Layout::LongIndex; 93 | using TensorCoord = MatrixCoord; 94 | 95 | static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; 96 | static int const kThreads = ThreadMap::kThreads; 97 | static int const kIterations = ThreadMap::Count::kTile; 98 | 99 | static bool constexpr PermuteD = !layout::is_trivial_permute; 100 | 101 | static_assert(ThreadMap::Iterations::kRow > 0, 102 | "ThreadMap::Iterations::kRow must be > 0"); 103 | static_assert(ThreadMap::Iterations::kGroup > 0, 104 | "ThreadMap::Iterations::kGroup must be > 0"); 105 | static_assert(ThreadMap::Iterations::kCluster > 0, 106 | "ThreadMap::Iterations::kCluster must be > 0"); 107 | static_assert(ThreadMap::Iterations::kColumn > 0, 108 | "ThreadMap::Iterations::kColumn must be > 0"); 109 | 110 | /// Fragment object 111 | using Fragment = Array; 116 | /// Memory access size 117 | using AccessType = AlignedArray; 118 | 119 | // 120 | // Parameters struct 121 | // 122 | 123 | /// Uses a non-template class 124 | struct Params : PredicatedTileIteratorParams { 125 | using Base = PredicatedTileIteratorParams; 126 | 127 | CUTLASS_HOST_DEVICE 128 | Params() {} 129 | 130 | CUTLASS_HOST_DEVICE 131 | Params(Layout const &layout) 132 | : PredicatedTileIteratorParams( 133 | 0, make_OutputTileThreadMapDesc()) {} 134 | 135 | CUTLASS_HOST_DEVICE 136 | Params(Base const &base) : Base(base) {} 137 | }; 138 | 139 | /// Mask object 140 | struct Mask { 141 | static int const kCount = ThreadMap::Iterations::kColumn; 142 | 143 | /// Predicate state 144 | bool predicates[kCount]; 145 | 146 | // 147 | // Mask 148 | // 149 | CUTLASS_HOST_DEVICE 150 | Mask() { enable(); } 151 | 152 | ///< Efficiently disables all accesses guarded by mask 153 | CUTLASS_HOST_DEVICE void clear() { 154 | CUTLASS_PRAGMA_UNROLL 155 | for (int i = 0; i < kCount; ++i) { 156 | predicates[i] = false; 157 | } 158 | } 159 | 160 | ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask 161 | CUTLASS_DEVICE void enable() { 162 | CUTLASS_PRAGMA_UNROLL 163 | for (int i = 0; i < kCount; ++i) { 164 | predicates[i] = true; 165 | } 166 | } 167 | }; 168 | 169 | private: 170 | // 171 | // Data members 172 | // 173 | 174 | /// Parameters structure containing reference and precomputed state. 175 | PredicatedTileIteratorParams params_; 176 | 177 | /// Byte-level pointer. This pointer is usually for both load() and store(), 178 | /// unless PermuteD is performed. When having PermuteD, byte_pointer_ is only 179 | /// for load(). 180 | uint8_t *byte_pointer_; 181 | 182 | /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ 183 | /// may be with different address computation compared to byte_pointer_. 184 | uint8_t *store_byte_pointer_; 185 | 186 | /// Array of boolean values to contain steady-state predicates 187 | Mask mask_; 188 | 189 | /// Extent of the matrix tile in rows 190 | Index extent_row_; 191 | 192 | /// Extent of the matrix tile in rows 193 | Index extent_column_; 194 | 195 | /// A thread's starting row position (assuming steady-state predicates have 196 | /// been computed) 197 | Index thread_start_row_; 198 | 199 | /// A thread's starting column 200 | Index thread_start_column_; 201 | 202 | /// Internal state counter 203 | int state_[3]; 204 | 205 | /// Scatter indices 206 | int const *indices_; 207 | 208 | /// PermuteDLayout 209 | PermuteDLayout permute_layout_; 210 | 211 | // 212 | // Static asserts about internal strides 213 | // 214 | 215 | static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); 216 | static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); 217 | static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, 218 | "Expected 64b strides"); 219 | 220 | private: 221 | // 222 | // Methods 223 | // 224 | 225 | public: 226 | // 227 | // Methods 228 | // 229 | 230 | /// Constructor 231 | CUTLASS_DEVICE 232 | PredicatedVRowIterator(PredicatedTileIteratorParams const ¶ms, 233 | Element *pointer, TensorCoord extent, int thread_idx, 234 | TensorCoord threadblock_offset = TensorCoord(), 235 | int const *indices = nullptr) 236 | : params_(params), 237 | indices_(indices), 238 | permute_layout_( 239 | PitchLinearCoord(extent.column(), extent.row()), 240 | params_.stride * kElementsPerAccess / sizeof(AccessType)) { 241 | TensorCoord thread_offset = 242 | ThreadMap::initial_offset(thread_idx) + threadblock_offset; 243 | 244 | extent_row_ = extent.row(); 245 | extent_column_ = extent.column(); 246 | 247 | thread_start_row_ = thread_offset.row(); 248 | thread_start_column_ = thread_offset.column(); 249 | 250 | // Initialize predicates 251 | CUTLASS_PRAGMA_UNROLL 252 | for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { 253 | mask_.predicates[c] = ((thread_offset.column() + 254 | ThreadMap::Delta::kColumn * c) < extent.column()); 255 | } 256 | 257 | // Null pointer performs no accesses 258 | if (!pointer) { 259 | mask_.clear(); 260 | } 261 | 262 | if (ScatterD && !indices) { 263 | mask_.clear(); 264 | } 265 | 266 | // Initialize byte_pointer_ 267 | byte_pointer_ = reinterpret_cast(pointer) + 268 | LongIndex(thread_offset.row()) * LongIndex(params_.stride) + 269 | LongIndex(thread_offset.column()) * sizeof(AccessType) / 270 | kElementsPerAccess; 271 | 272 | if (ScatterD) { 273 | byte_pointer_ = reinterpret_cast(pointer) + 274 | LongIndex(thread_offset.column()) * sizeof(AccessType) / 275 | kElementsPerAccess; 276 | } 277 | 278 | // store_byte_pointer_ is set to be the same with byte_pointer_ unless 279 | // PermuteD is used. 280 | store_byte_pointer_ = 281 | PermuteD ? reinterpret_cast(pointer) : byte_pointer_; 282 | 283 | // Initialize internal state counter 284 | state_[0] = state_[1] = state_[2] = 0; 285 | } 286 | 287 | /// Adds a pointer offset in units of Element 288 | CUTLASS_HOST_DEVICE 289 | void add_pointer_offset(LongIndex pointer_offset) { 290 | store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; 291 | byte_pointer_ += pointer_offset * sizeof_bits::value / 8; 292 | } 293 | 294 | /// Loads a fragment from memory 295 | CUTLASS_DEVICE 296 | void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { 297 | uint8_t *byte_pointer = byte_pointer_; 298 | AccessType *frag_ptr = reinterpret_cast(&frag); 299 | 300 | CUTLASS_PRAGMA_UNROLL 301 | for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 302 | ++cluster) { 303 | CUTLASS_PRAGMA_UNROLL 304 | for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 305 | CUTLASS_PRAGMA_UNROLL 306 | for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 307 | int frag_row_idx = 308 | (row + ThreadMap::Iterations::kRow * 309 | (group + ThreadMap::Iterations::kGroup * cluster)); 310 | 311 | int row_offset = row * ThreadMap::Delta::kRow + 312 | group * ThreadMap::Delta::kGroup + 313 | cluster * ThreadMap::Delta::kCluster; 314 | 315 | bool row_guard = ((row_offset + thread_start_row_) < extent_row_); 316 | 317 | AccessType *memory_pointer = 318 | reinterpret_cast(byte_pointer + byte_offset); 319 | 320 | if (ScatterD && row_guard) { 321 | assert(indices_); 322 | 323 | memory_pointer = reinterpret_cast( 324 | byte_pointer + byte_offset + 325 | LongIndex(indices_[row_offset + thread_start_row_]) * 326 | LongIndex(params_.stride)); 327 | } 328 | 329 | CUTLASS_PRAGMA_UNROLL 330 | for (int column = 0; column < ThreadMap::Iterations::kColumn; 331 | ++column) { 332 | bool guard = row_guard && mask_.predicates[column]; 333 | 334 | cutlass::arch::global_load( 335 | frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + 336 | column], 337 | (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / 338 | kElementsPerAccess], 339 | guard); 340 | } 341 | 342 | if (row + 1 < ThreadMap::Iterations::kRow) { 343 | if (!ScatterD) { 344 | byte_pointer += params_.increment_row; 345 | } 346 | } 347 | } 348 | 349 | if (group + 1 < ThreadMap::Iterations::kGroup) { 350 | byte_pointer += params_.increment_group; 351 | } 352 | } 353 | 354 | if (cluster + 1 < ThreadMap::Iterations::kCluster) { 355 | byte_pointer += params_.increment_cluster; 356 | } 357 | } 358 | } 359 | 360 | /// Loads a fragment from memory 361 | CUTLASS_DEVICE 362 | void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } 363 | 364 | CUTLASS_DEVICE 365 | MatrixCoord thread_start() const { 366 | return MatrixCoord(thread_start_row_, thread_start_column_); 367 | } 368 | 369 | /// Need to get the thread start row from the tile iterator 370 | CUTLASS_DEVICE 371 | int32_t thread_start_row() const { return thread_start_row_; } 372 | 373 | /// Need to get the thread start row from the tile iterator 374 | CUTLASS_DEVICE 375 | int32_t thread_start_column() const { return thread_start_column_; } 376 | 377 | /// Extent of the matrix in rows 378 | CUTLASS_DEVICE 379 | Index extent_row() const { return extent_row_; } 380 | 381 | /// Extent of the matrix in columns 382 | CUTLASS_DEVICE 383 | Index extent_column() const { return extent_column_; } 384 | 385 | /// Advances to the next position to load or store 386 | CUTLASS_HOST_DEVICE 387 | PredicatedVRowIterator &operator++() { 388 | ++state_[0]; 389 | 390 | if (!ScatterD) { 391 | byte_pointer_ += params_.advance_row; 392 | } 393 | 394 | if (!ScatterD && !PermuteD) { 395 | store_byte_pointer_ += params_.advance_row; 396 | } 397 | 398 | thread_start_row_ += ThreadMap::Shape::kRow; 399 | 400 | if (state_[0] == ThreadMap::Count::kRow) { 401 | state_[0] = 0; 402 | ++state_[1]; 403 | 404 | if (!ScatterD) { 405 | byte_pointer_ += params_.advance_group; 406 | } 407 | 408 | if (!ScatterD && !PermuteD) { 409 | store_byte_pointer_ += params_.advance_group; 410 | } 411 | 412 | thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * 413 | ThreadMap::Shape::kRow * ThreadMap::Count::kRow; 414 | 415 | if (state_[1] == ThreadMap::Count::kGroup) { 416 | state_[1] = 0; 417 | ++state_[2]; 418 | 419 | if (!ScatterD) { 420 | byte_pointer_ += params_.advance_cluster; 421 | } 422 | 423 | if (!ScatterD && !PermuteD) { 424 | store_byte_pointer_ += params_.advance_cluster; 425 | } 426 | 427 | thread_start_row_ += ThreadMap::Count::kGroup * 428 | ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * 429 | ThreadMap::Shape::kRow; 430 | 431 | if (state_[2] == ThreadMap::Count::kCluster) { 432 | state_[2] = 0; 433 | 434 | if (!ScatterD) { 435 | byte_pointer_ += params_.advance_tile; 436 | } 437 | 438 | if (!ScatterD && !PermuteD) { 439 | store_byte_pointer_ += params_.advance_tile; 440 | } 441 | 442 | thread_start_row_ += 443 | ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow * 444 | ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; 445 | } 446 | } 447 | } 448 | 449 | return *this; 450 | } 451 | 452 | ///< Efficiently disables all accesses guarded by mask 453 | CUTLASS_DEVICE void clear_mask() { mask_.clear(); } 454 | 455 | ///< Efficiently enables all accesses guarded by mask 456 | CUTLASS_DEVICE void enable_mask() { mask_.enable(); } 457 | 458 | ///< Sets the mask 459 | CUTLASS_DEVICE void get_mask(Mask &mask) const { mask = mask_; } 460 | 461 | ///< Sets the mask 462 | CUTLASS_DEVICE void set_mask(Mask const &mask) { mask_ = mask; } 463 | }; 464 | 465 | } // namespace asymmetric 466 | } // namespace threadblock 467 | } // namespace epilogue 468 | } // namespace cutlass 469 | 470 | //////////////////////////////////////////////////////////////////////////////// -------------------------------------------------------------------------------- /include/asymmetric/gemm/device/gemm_dequant.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | /*! \file 33 | \brief Template for a pipelined GEMM kernel. Does not compute batching or 34 | support split-K. 35 | */ 36 | 37 | #pragma once 38 | 39 | #include "asymmetric/epilogue/thread/linear_combination_dequant.h" 40 | #include "asymmetric/gemm/kernel/default_gemm_dequant.h" 41 | #include "cutlass/arch/arch.h" 42 | #include "cutlass/cutlass.h" 43 | #include "cutlass/device_kernel.h" 44 | #include "cutlass/gemm/device/default_gemm_configuration.h" 45 | #include "cutlass/gemm/kernel/gemm.h" 46 | #include "cutlass/gemm/threadblock/threadblock_swizzle.h" 47 | #include "cutlass/layout/permute.h" 48 | #include "cutlass/numeric_types.h" 49 | //////////////////////////////////////////////////////////////////////////////// 50 | 51 | namespace cutlass { 52 | namespace gemm { 53 | namespace device { 54 | namespace asymmetric { 55 | ///////////////////////////////////////////////////////////////////////////////////////////////// 56 | 57 | template < 58 | /// Element type for A matrix operand 59 | typename ElementA_, 60 | /// Layout type for A matrix operand 61 | typename LayoutA_, 62 | /// Element type for B matrix operand 63 | typename ElementB_, 64 | /// Layout type for B matrix operand 65 | typename LayoutB_, 66 | /// Element type for C and D matrix operands 67 | typename ElementC_, 68 | /// Layout type for C and D matrix operands 69 | typename LayoutC_, 70 | /// Element type for internal accumulation 71 | typename ElementAccumulator_ = ElementC_, 72 | /// Operator class tag 73 | typename OperatorClass_ = arch::OpClassTensorOp, 74 | /// Tag indicating architecture to tune for 75 | typename ArchTag_ = arch::Sm80, 76 | /// Threadblock-level tile size (concept: GemmShape) 77 | typename ThreadblockShape_ = typename DefaultGemmConfiguration< 78 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 79 | ElementAccumulator_>::ThreadblockShape, 80 | /// Warp-level tile size (concept: GemmShape) 81 | typename WarpShape_ = typename DefaultGemmConfiguration< 82 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 83 | ElementAccumulator_>::WarpShape, 84 | /// Instruction-level tile size (concept: GemmShape) 85 | typename InstructionShape_ = typename DefaultGemmConfiguration< 86 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 87 | ElementAccumulator_>::InstructionShape, 88 | /// Epilogue output operator 89 | typename EpilogueOutputOp_ = 90 | cutlass::epilogue::thread::asymmetric::LinearCombinationDequant< 91 | ElementC_, 128 / cutlass::sizeof_bits::value, 92 | ElementAccumulator_, ElementC_, 93 | cutlass::epilogue::thread::asymmetric::MyScaleType::Dequantize, 94 | cutlass::FloatRoundStyle::round_to_nearest, ElementC_>, 95 | /// Threadblock-level swizzling operator 96 | typename ThreadblockSwizzle_ = 97 | typename threadblock::GemmIdentityThreadblockSwizzle<>, 98 | /// Number of stages used in the pipelined mainloop 99 | int Stages = 100 | DefaultGemmConfiguration::kStages, 102 | /// Access granularity of A matrix in units of elements 103 | int AlignmentA = 104 | DefaultGemmConfiguration::kAlignmentA, 106 | /// Access granularity of B matrix in units of elements 107 | int AlignmentB = 108 | DefaultGemmConfiguration::kAlignmentB, 110 | /// If true, kernel supports split-K with serial reduction 111 | bool SplitKSerial = false, 112 | /// Operation performed by GEMM 113 | typename Operator_ = typename DefaultGemmConfiguration< 114 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 115 | ElementAccumulator_>::Operator, 116 | /// Gather operand A by using an index array 117 | bool GatherA = false, 118 | /// Gather operand B by using an index array 119 | bool GatherB = false, 120 | /// Scatter result D by using an index array 121 | bool ScatterD = false, 122 | /// Permute result D 123 | typename PermuteDLayout = layout::NoPermute> 124 | class GemmDequant { 125 | public: 126 | using ElementA = ElementA_; 127 | using LayoutA = LayoutA_; 128 | using TensorRefA = TensorRef; 129 | using ElementB = ElementB_; 130 | using LayoutB = LayoutB_; 131 | using TensorRefB = TensorRef; 132 | using ElementC = ElementC_; 133 | using LayoutC = LayoutC_; 134 | using TensorRefC = TensorRef; 135 | using TensorRefD = TensorRef; 136 | using ElementAccumulator = ElementAccumulator_; 137 | using OperatorClass = OperatorClass_; 138 | using ArchTag = ArchTag_; 139 | using ThreadblockShape = ThreadblockShape_; 140 | using WarpShape = WarpShape_; 141 | using InstructionShape = InstructionShape_; 142 | using EpilogueOutputOp = EpilogueOutputOp_; 143 | using ThreadblockSwizzle = ThreadblockSwizzle_; 144 | using Operator = Operator_; 145 | static int const kStages = Stages; 146 | static int const kAlignmentA = AlignmentA; 147 | static int const kAlignmentB = AlignmentB; 148 | static int const kAlignmentC = EpilogueOutputOp::kCount; 149 | static bool const kSplitKSerial = SplitKSerial; 150 | static ComplexTransform const kTransformA = ComplexTransform::kNone; 151 | static ComplexTransform const kTransformB = ComplexTransform::kNone; 152 | 153 | /// Define the kernel 154 | using GemmKernel = typename kernel::asymmetric::DefaultGemmDequant< 155 | ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, 156 | LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, 157 | WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 158 | kStages, kSplitKSerial, Operator, SharedMemoryClearOption::kNone, GatherA, 159 | GatherB, ScatterD, PermuteDLayout>::GemmKernel; 160 | 161 | /// Argument structure 162 | struct Arguments { 163 | // 164 | // Data members 165 | // 166 | 167 | GemmCoord problem_size; 168 | TensorRef ref_A; 169 | TensorRef ref_B; 170 | TensorRef ref_C; 171 | TensorRef ref_D; 172 | TensorRef ref_row_vec; 173 | TensorRef ref_col_vec; 174 | TensorRef ref_zero_row_vec; 175 | TensorRef ref_w_reduce_vec; 176 | typename EpilogueOutputOp::Params epilogue; 177 | int split_k_slices; 178 | // For gather+scatter operations 179 | int const *gather_A_indices; 180 | int const *gather_B_indices; 181 | int const *scatter_D_indices; 182 | 183 | // 184 | // Methods 185 | // 186 | 187 | /// Default ctor 188 | CUTLASS_HOST_DEVICE 189 | Arguments() : problem_size(0, 0, 0), split_k_slices(1) {} 190 | 191 | /// Constructs an Arguments structure 192 | CUTLASS_HOST_DEVICE 193 | Arguments(GemmCoord problem_size_, 194 | TensorRef ref_A_, 195 | TensorRef ref_B_, 196 | TensorRef ref_C_, 197 | TensorRef ref_D_, 198 | TensorRef ref_row_vec_, 199 | TensorRef ref_col_vec_, 200 | TensorRef ref_zero_row_vec_, 201 | TensorRef ref_w_reduce_vec_, 202 | typename EpilogueOutputOp::Params epilogue_ = 203 | typename EpilogueOutputOp::Params(), 204 | int split_k_slices = 1, int const *gather_A_indices_ = nullptr, 205 | int const *gather_B_indices_ = nullptr, 206 | int const *scatter_D_indices_ = nullptr) 207 | : problem_size(problem_size_), 208 | ref_A(ref_A_), 209 | ref_B(ref_B_), 210 | ref_C(ref_C_), 211 | ref_D(ref_D_), 212 | ref_row_vec(ref_row_vec_), 213 | ref_col_vec(ref_col_vec_), 214 | ref_zero_row_vec(ref_zero_row_vec_), 215 | ref_w_reduce_vec(ref_w_reduce_vec_), 216 | epilogue(epilogue_), 217 | split_k_slices(split_k_slices), 218 | gather_A_indices(gather_A_indices_), 219 | gather_B_indices(gather_B_indices_), 220 | scatter_D_indices(scatter_D_indices_) {} 221 | }; 222 | 223 | private: 224 | /// Kernel parameters object 225 | typename GemmKernel::Params params_; 226 | 227 | public: 228 | /// Constructs the GEMM. 229 | GemmDequant() {} 230 | 231 | /// Determines whether the GEMM can execute the given problem. 232 | static Status can_implement(Arguments const &args) { 233 | if (!kSplitKSerial && args.split_k_slices > 1) { 234 | return Status::kErrorInvalidProblem; 235 | } 236 | 237 | Status status = GemmKernel::can_implement( 238 | args.problem_size, args.ref_A.non_const_ref(), 239 | args.ref_B.non_const_ref(), args.ref_C.non_const_ref(), args.ref_D, 240 | args.ref_row_vec.non_const_ref(), args.ref_col_vec.non_const_ref(), 241 | args.ref_zero_row_vec.non_const_ref(), 242 | args.ref_w_reduce_vec.non_const_ref()); 243 | 244 | if (status != Status::kSuccess) { 245 | return status; 246 | } 247 | 248 | return Status::kSuccess; 249 | } 250 | 251 | /// Gets the workspace size 252 | static size_t get_workspace_size(Arguments const &args) { 253 | size_t bytes = 0; 254 | 255 | // Determine grid shape 256 | ThreadblockSwizzle threadblock_swizzle; 257 | 258 | cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( 259 | args.problem_size, 260 | {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 261 | args.split_k_slices); 262 | 263 | if (kSplitKSerial && args.split_k_slices > 1) { 264 | bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); 265 | } 266 | 267 | return bytes; 268 | } 269 | 270 | /// Initializes GEMM state from arguments. 271 | Status initialize(Arguments const &args, void *workspace = nullptr, 272 | cudaStream_t stream = nullptr) { 273 | // Determine grid shape 274 | ThreadblockSwizzle threadblock_swizzle; 275 | 276 | cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( 277 | args.problem_size, 278 | {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 279 | args.split_k_slices); 280 | 281 | if (kSplitKSerial) { 282 | if (args.split_k_slices > 1) { 283 | if (!workspace) { 284 | return Status::kErrorWorkspaceNull; 285 | } 286 | 287 | size_t bytes = get_workspace_size(args); 288 | 289 | cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); 290 | 291 | if (result != cudaSuccess) { 292 | return Status::kErrorInternal; 293 | } 294 | } 295 | } else { 296 | if (args.split_k_slices > 1) { 297 | return Status::kErrorInvalidProblem; 298 | } 299 | } 300 | 301 | // Initialize the Params structure 302 | params_ = typename GemmKernel::Params{args.problem_size, 303 | grid_shape, 304 | args.ref_A.non_const_ref(), 305 | args.ref_B.non_const_ref(), 306 | args.ref_C.non_const_ref(), 307 | args.ref_D, 308 | args.ref_row_vec.non_const_ref(), 309 | args.ref_col_vec.non_const_ref(), 310 | args.ref_zero_row_vec.non_const_ref(), 311 | args.ref_w_reduce_vec.non_const_ref(), 312 | args.epilogue, 313 | static_cast(workspace), 314 | args.gather_A_indices, 315 | args.gather_B_indices, 316 | args.scatter_D_indices}; 317 | 318 | return Status::kSuccess; 319 | } 320 | 321 | /// Lightweight update given a subset of arguments 322 | Status update(Arguments const &args, void *workspace = nullptr) { 323 | if (kSplitKSerial && args.split_k_slices > 1) { 324 | if (!workspace) { 325 | return Status::kErrorWorkspaceNull; 326 | } 327 | } 328 | 329 | params_.ref_A.reset(args.ref_A.non_const_ref().data()); 330 | params_.ref_B.reset(args.ref_B.non_const_ref().data()); 331 | params_.ref_C.reset(args.ref_C.non_const_ref().data()); 332 | params_.ref_D.reset(args.ref_D.data()); 333 | params_.ref_row_vec.reset(args.ref_row_vec.non_const_ref().data()); 334 | params_.ref_col_vec.reset(args.ref_col_vec.non_const_ref().data()); 335 | params_.ref_zero_row_vec.reset( 336 | args.ref_zero_row_vec.non_const_ref().data()); 337 | params_.ref_w_reduce_vec.reset( 338 | args.ref_w_reduce_vec.non_const_ref().data()); 339 | params_.output_op = args.epilogue; 340 | params_.semaphore = static_cast(workspace); 341 | 342 | return Status::kSuccess; 343 | } 344 | 345 | /// Runs the kernel using initialized state. 346 | Status run(cudaStream_t stream = nullptr) { 347 | ThreadblockSwizzle threadblock_swizzle; 348 | 349 | dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); 350 | dim3 block(GemmKernel::kThreadCount, 1, 1); 351 | 352 | cudaError_t result; 353 | 354 | int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); 355 | 356 | if (smem_size >= (48 << 10)) { 357 | result = cudaFuncSetAttribute(Kernel, 358 | cudaFuncAttributeMaxDynamicSharedMemorySize, 359 | smem_size); 360 | 361 | if (result != cudaSuccess) { 362 | return Status::kErrorInternal; 363 | } 364 | } 365 | 366 | cutlass::Kernel<<>>(params_); 367 | 368 | result = cudaGetLastError(); 369 | 370 | return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; 371 | } 372 | 373 | /// Runs the kernel using initialized state. 374 | Status operator()(cudaStream_t stream = nullptr) { return run(stream); } 375 | 376 | /// Runs the kernel using initialized state. 377 | Status operator()(Arguments const &args, void *workspace = nullptr, 378 | cudaStream_t stream = nullptr) { 379 | Status status = initialize(args, workspace, stream); 380 | 381 | if (status == Status::kSuccess) { 382 | status = run(stream); 383 | } 384 | 385 | return status; 386 | } 387 | }; 388 | 389 | //////////////////////////////////////////////////////////////////////////////// 390 | 391 | } // namespace asymmetric 392 | } // namespace device 393 | } // namespace gemm 394 | } // namespace cutlass 395 | 396 | //////////////////////////////////////////////////////////////////////////////// 397 | -------------------------------------------------------------------------------- /include/asymmetric/gemm/device/gemm_sparse_dequant.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | /*! \file 33 | \brief Template for a pipelined GEMM kernel. Does not compute batching or 34 | support split-K. 35 | */ 36 | #pragma once 37 | 38 | #include "asymmetric/epilogue/thread/linear_combination_dequant.h" 39 | #include "asymmetric/gemm/kernel/default_gemm_dequant.h" 40 | #include "asymmetric/gemm/kernel/default_gemm_sparse_dequant.h" 41 | #include "cutlass/arch/arch.h" 42 | #include "cutlass/cutlass.h" 43 | #include "cutlass/device_kernel.h" 44 | #include "cutlass/gemm/device/default_gemm_configuration.h" 45 | #include "cutlass/gemm/kernel/gemm.h" 46 | #include "cutlass/gemm/threadblock/threadblock_swizzle.h" 47 | #include "cutlass/layout/permute.h" 48 | #include "cutlass/numeric_types.h" 49 | //////////////////////////////////////////////////////////////////////////////// 50 | 51 | namespace cutlass { 52 | namespace gemm { 53 | namespace device { 54 | namespace asymmetric { 55 | ///////////////////////////////////////////////////////////////////////////////////////////////// 56 | template < 57 | /// Element type for A matrix operand 58 | typename ElementA_, 59 | /// Layout type for A matrix operand 60 | typename LayoutA_, 61 | /// Element type for B matrix operand 62 | typename ElementB_, 63 | /// Layout type for B matrix operand 64 | typename LayoutB_, 65 | /// Element type for C and D matrix operands 66 | typename ElementC_, 67 | /// Layout type for C and D matrix operands 68 | typename LayoutC_, 69 | /// Element type for internal accumulation 70 | typename ElementAccumulator_ = ElementC_, 71 | /// Operator class tag 72 | typename OperatorClass_ = arch::OpClassSimt, 73 | /// Tag indicating architecture to tune for 74 | typename ArchTag_ = arch::Sm70, 75 | /// Threadblock-level tile size (concept: GemmShape) 76 | typename ThreadblockShape_ = typename DefaultGemmConfiguration< 77 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 78 | ElementAccumulator_>::ThreadblockShape, 79 | /// Warp-level tile size (concept: GemmShape) 80 | typename WarpShape_ = typename DefaultGemmConfiguration< 81 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 82 | ElementAccumulator_>::WarpShape, 83 | /// Instruction-level tile size (concept: GemmShape) 84 | typename InstructionShape_ = typename DefaultGemmConfiguration< 85 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 86 | ElementAccumulator_>::InstructionShape, 87 | /// Epilogue output operator 88 | typename EpilogueOutputOp_ = 89 | cutlass::epilogue::thread::asymmetric::LinearCombinationDequant< 90 | ElementC_, 128 / cutlass::sizeof_bits::value, 91 | ElementAccumulator_, ElementC_, 92 | cutlass::epilogue::thread::asymmetric::MyScaleType::Dequantize, 93 | cutlass::FloatRoundStyle::round_to_nearest, ElementC_>, 94 | /// Threadblock-level swizzling operator 95 | typename ThreadblockSwizzle_ = 96 | typename threadblock::GemmIdentityThreadblockSwizzle<>, 97 | /// Number of stages used in the pipelined mainloop 98 | int Stages = 99 | DefaultGemmConfiguration::kStages, 101 | /// Access granularity of A matrix in units of elements 102 | int AlignmentA = 103 | DefaultGemmConfiguration::kAlignmentA, 105 | /// Access granularity of B matrix in units of elements 106 | int AlignmentB = 107 | DefaultGemmConfiguration::kAlignmentB, 109 | /// If true, kernel supports split-K with serial reduction 110 | bool SplitKSerial = false, 111 | /// Operation performed by GEMM 112 | typename Operator_ = typename DefaultGemmConfiguration< 113 | OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, 114 | ElementAccumulator_>::Operator> 115 | class SparseGemmDequant { 116 | public: 117 | using ElementA = ElementA_; 118 | using LayoutA = LayoutA_; 119 | using TensorRefA = TensorRef; 120 | using ElementB = ElementB_; 121 | using LayoutB = LayoutB_; 122 | using TensorRefB = TensorRef; 123 | using ElementC = ElementC_; 124 | using LayoutC = LayoutC_; 125 | using TensorRefC = TensorRef; 126 | using TensorRefD = TensorRef; 127 | using ElementAccumulator = ElementAccumulator_; 128 | using OperatorClass = OperatorClass_; 129 | using ArchTag = ArchTag_; 130 | using ThreadblockShape = ThreadblockShape_; 131 | using WarpShape = WarpShape_; 132 | using InstructionShape = InstructionShape_; 133 | using EpilogueOutputOp = EpilogueOutputOp_; 134 | using ThreadblockSwizzle = ThreadblockSwizzle_; 135 | using Operator = Operator_; 136 | using MathOperator = Operator; 137 | static int const kStages = Stages; 138 | static int const kAlignmentA = AlignmentA; 139 | static int const kAlignmentB = AlignmentB; 140 | static int const kAlignmentC = EpilogueOutputOp::kCount; 141 | static bool const kSplitKSerial = SplitKSerial; 142 | static ComplexTransform const kTransformA = ComplexTransform::kNone; 143 | static ComplexTransform const kTransformB = ComplexTransform::kNone; 144 | 145 | /// Define the kernel 146 | using GemmKernel = typename kernel::asymmetric::DefaultSparseGemmDequant< 147 | ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, 148 | LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, 149 | WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 150 | kStages, kSplitKSerial, Operator>::GemmKernel; 151 | 152 | using ElementE = typename GemmKernel::ElementE; 153 | 154 | using LayoutE = typename GemmKernel::LayoutE; 155 | 156 | static int const kAlignmentE = 128 / sizeof_bits::value; 157 | 158 | static int const kSparse = GemmKernel::kSparse; 159 | static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; 160 | static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; 161 | 162 | /// Argument structure 163 | struct Arguments { 164 | // 165 | // Data members 166 | // 167 | 168 | GemmCoord problem_size; 169 | TensorRef ref_A; 170 | TensorRef ref_B; 171 | TensorRef ref_C; 172 | TensorRef ref_D; 173 | TensorRef ref_E; 174 | TensorRef ref_row_vec; 175 | TensorRef ref_col_vec; 176 | TensorRef ref_zero_row_vec; 177 | TensorRef ref_w_reduce_vec; 178 | typename EpilogueOutputOp::Params epilogue; 179 | int split_k_slices; 180 | 181 | // 182 | // Methods 183 | // 184 | 185 | /// Default ctor 186 | CUTLASS_HOST_DEVICE 187 | Arguments() : problem_size(0, 0, 0), split_k_slices(1) {} 188 | 189 | /// Constructs an Arguments structure 190 | CUTLASS_HOST_DEVICE 191 | Arguments(GemmCoord problem_size_, 192 | TensorRef ref_A_, 193 | TensorRef ref_B_, 194 | TensorRef ref_C_, 195 | TensorRef ref_D_, 196 | TensorRef ref_E_, 197 | TensorRef ref_row_vec_, 198 | TensorRef ref_col_vec_, 199 | TensorRef ref_zero_row_vec_, 200 | TensorRef ref_w_reduce_vec_, 201 | typename EpilogueOutputOp::Params epilogue_ = 202 | typename EpilogueOutputOp::Params(), 203 | int split_k_slices = 1) 204 | : problem_size(problem_size_), 205 | ref_A(ref_A_), 206 | ref_B(ref_B_), 207 | ref_C(ref_C_), 208 | ref_D(ref_D_), 209 | ref_E(ref_E_), 210 | ref_row_vec(ref_row_vec_), 211 | ref_col_vec(ref_col_vec_), 212 | ref_zero_row_vec(ref_zero_row_vec_), 213 | ref_w_reduce_vec(ref_w_reduce_vec_), 214 | epilogue(epilogue_), 215 | split_k_slices(split_k_slices) {} 216 | }; 217 | 218 | private: 219 | /// Kernel parameters object 220 | typename GemmKernel::Params params_; 221 | 222 | public: 223 | /// Constructs the GEMM. 224 | SparseGemmDequant() {} 225 | 226 | /// Determines whether the GEMM can execute the given problem. 227 | static Status can_implement(Arguments const &args) { 228 | if (!kSplitKSerial && args.split_k_slices > 1) { 229 | return Status::kErrorInvalidProblem; 230 | } 231 | 232 | Status status = GemmKernel::can_implement( 233 | args.problem_size, args.ref_A.non_const_ref(), 234 | args.ref_B.non_const_ref(), args.ref_C.non_const_ref(), args.ref_D, 235 | args.ref_E.non_const_ref(), args.ref_row_vec.non_const_ref(), 236 | args.ref_col_vec.non_const_ref(), args.ref_zero_row_vec.non_const_ref(), 237 | args.ref_w_reduce_vec.non_const_ref()); 238 | 239 | if (status != Status::kSuccess) { 240 | return status; 241 | } 242 | 243 | return Status::kSuccess; 244 | } 245 | 246 | /// Gets the workspace size 247 | static size_t get_workspace_size(Arguments const &args) { 248 | size_t bytes = 0; 249 | 250 | // Determine grid shape 251 | ThreadblockSwizzle threadblock_swizzle; 252 | 253 | cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( 254 | args.problem_size, 255 | {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 256 | args.split_k_slices); 257 | 258 | if (kSplitKSerial && args.split_k_slices > 1) { 259 | bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); 260 | } 261 | 262 | return bytes; 263 | } 264 | 265 | /// Initializes GEMM state from arguments. 266 | Status initialize(Arguments const &args, void *workspace = nullptr, 267 | cudaStream_t stream = nullptr) { 268 | // Determine grid shape 269 | ThreadblockSwizzle threadblock_swizzle; 270 | 271 | cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( 272 | args.problem_size, 273 | {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 274 | args.split_k_slices); 275 | 276 | if (kSplitKSerial) { 277 | if (args.split_k_slices > 1) { 278 | if (!workspace) { 279 | return Status::kErrorWorkspaceNull; 280 | } 281 | 282 | size_t bytes = get_workspace_size(args); 283 | 284 | cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); 285 | 286 | if (result != cudaSuccess) { 287 | return Status::kErrorInternal; 288 | } 289 | } 290 | } else { 291 | if (args.split_k_slices > 1) { 292 | return Status::kErrorInvalidProblem; 293 | } 294 | } 295 | 296 | // Initialize the Params structure 297 | params_ = typename GemmKernel::Params{args.problem_size, 298 | grid_shape, 299 | args.ref_A.non_const_ref(), 300 | args.ref_B.non_const_ref(), 301 | args.ref_C.non_const_ref(), 302 | args.ref_D, 303 | args.ref_E.non_const_ref(), 304 | args.ref_row_vec.non_const_ref(), 305 | args.ref_col_vec.non_const_ref(), 306 | args.ref_zero_row_vec.non_const_ref(), 307 | args.ref_w_reduce_vec.non_const_ref(), 308 | args.epilogue, 309 | static_cast(workspace)}; 310 | 311 | int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); 312 | if (smem_size >= (48 << 10)) { 313 | cudaError_t result = cudaFuncSetAttribute( 314 | Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 315 | smem_size); 316 | 317 | if (result != cudaSuccess) { 318 | return Status::kErrorInternal; 319 | } 320 | } 321 | 322 | return Status::kSuccess; 323 | } 324 | 325 | /// Lightweight update given a subset of arguments 326 | Status update(Arguments const &args, void *workspace = nullptr) { 327 | if (kSplitKSerial && args.split_k_slices > 1) { 328 | if (!workspace) { 329 | return Status::kErrorWorkspaceNull; 330 | } 331 | } 332 | 333 | params_.ref_A.reset(args.ref_A.non_const_ref().data()); 334 | params_.ref_B.reset(args.ref_B.non_const_ref().data()); 335 | params_.ref_C.reset(args.ref_C.non_const_ref().data()); 336 | params_.ref_D.reset(args.ref_D.data()); 337 | params_.ref_E.reset(args.ref_E.non_const_ref().data()); 338 | params_.ref_row_vec.reset(args.ref_row_vec.non_const_ref().data()); 339 | params_.ref_col_vec.reset(args.ref_col_vec.non_const_ref().data()); 340 | params_.ref_zero_row_vec.reset( 341 | args.ref_zero_row_vec.non_const_ref().data()); 342 | params_.ref_w_reduce_vec.reset( 343 | args.ref_w_reduce_vec.non_const_ref().data()); 344 | params_.output_op = args.epilogue; 345 | params_.semaphore = static_cast(workspace); 346 | 347 | return Status::kSuccess; 348 | } 349 | 350 | /// Runs the kernel using initialized state. 351 | Status run(cudaStream_t stream = nullptr) { 352 | ThreadblockSwizzle threadblock_swizzle; 353 | 354 | dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); 355 | dim3 block(GemmKernel::kThreadCount, 1, 1); 356 | 357 | int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); 358 | 359 | cutlass::Kernel<<>>(params_); 360 | 361 | cudaError_t result = cudaGetLastError(); 362 | 363 | return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; 364 | } 365 | 366 | /// Runs the kernel using initialized state. 367 | Status operator()(cudaStream_t stream = nullptr) { return run(stream); } 368 | 369 | /// Runs the kernel using initialized state. 370 | Status operator()(Arguments const &args, void *workspace = nullptr, 371 | cudaStream_t stream = nullptr) { 372 | Status status = initialize(args, workspace, stream); 373 | 374 | if (status == Status::kSuccess) { 375 | status = run(stream); 376 | } 377 | 378 | return status; 379 | } 380 | }; 381 | 382 | } // namespace asymmetric 383 | } // namespace device 384 | } // namespace gemm 385 | } // namespace cutlass 386 | 387 | //////////////////////////////////////////////////////////////////////////////// 388 | -------------------------------------------------------------------------------- /include/asymmetric/gemm/kernel/default_gemm_dequant.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | #pragma once 33 | 34 | #include "asymmetric/epilogue/threadblock/default_epilogue_tensor_op_dequant.h" 35 | #include "asymmetric/gemm/kernel/gemm_dequant.h" 36 | #include "cutlass/gemm/kernel/default_gemm.h" 37 | //////////////////////////////////////////////////////////////////////////////// 38 | 39 | namespace cutlass { 40 | namespace gemm { 41 | namespace kernel { 42 | namespace asymmetric { 43 | 44 | template < 45 | /// Element type for A matrix operand 46 | typename ElementA_, 47 | /// Layout type for A matrix operand 48 | typename LayoutA_, 49 | /// Access granularity of A matrix in units of elements 50 | int kAlignmentA, 51 | /// Element type for B matrix operand 52 | typename ElementB_, 53 | /// Layout type for B matrix operand 54 | typename LayoutB_, 55 | /// Access granularity of B matrix in units of elements 56 | int kAlignmentB, 57 | /// Element type for C and D matrix operands 58 | typename ElementC_, 59 | /// Layout type for C and D matrix operands 60 | typename LayoutC_, 61 | /// Element type for internal accumulation 62 | typename ElementAccumulator, 63 | /// Operator class tag 64 | typename OperatorClass, 65 | /// Tag indicating architecture to tune for 66 | typename ArchTag, 67 | /// Threadblock-level tile size (concept: GemmShape) 68 | typename ThreadblockShape, 69 | /// Warp-level tile size (concept: GemmShape) 70 | typename WarpShape, 71 | /// Warp-level tile size (concept: GemmShape) 72 | typename InstructionShape, 73 | /// Epilogue output operator 74 | typename EpilogueOutputOp, 75 | /// Threadblock-level swizzling operator 76 | typename ThreadblockSwizzle, 77 | /// Number of stages used in the pipelined mainloop 78 | int Stages, 79 | /// If true, kernel is configured to support serial reduction in the 80 | /// epilogue 81 | bool SplitKSerial, 82 | /// Operation performed by GEMM 83 | typename Operator, 84 | /// Use zfill or predicate for out-of-bound cp.async 85 | SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, 86 | /// Gather operand A by using an index array 87 | bool GatherA = false, 88 | /// Gather operand B by using an index array 89 | bool GatherB = false, 90 | /// Scatter result D by using an index array 91 | bool ScatterD = false, 92 | /// Permute result D 93 | typename PermuteDLayout = layout::NoPermute, 94 | /// Permute operand A 95 | typename PermuteALayout = layout::NoPermute, 96 | /// Permute operand B 97 | typename PermuteBLayout = layout::NoPermute, 98 | /// 99 | typename Enable = void> 100 | struct DefaultGemmDequant 101 | : public DefaultGemm { 108 | static_assert((platform::is_same::value || 109 | platform::is_same>::value), 110 | "Epilogue in the kernel level must be row major"); 111 | 112 | using DefaultGemm = 113 | DefaultGemm; 120 | 121 | using Epilogue = typename cutlass::epilogue::threadblock::asymmetric:: 122 | DefaultEpilogueTensorOpDequant< 123 | ThreadblockShape, typename DefaultGemm::Mma::Operator, 124 | DefaultGemm::kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, 125 | ScatterD, PermuteDLayout>::Epilogue; 126 | 127 | using GemmKernel = 128 | kernel::asymmetric::GemmDequant; 130 | }; 131 | //////////////////////////////////////////////////////////////////////////////// 132 | 133 | } // namespace asymmetric 134 | } // namespace kernel 135 | } // namespace gemm 136 | } // namespace cutlass 137 | -------------------------------------------------------------------------------- /include/asymmetric/gemm/kernel/default_gemm_sparse_dequant.h: -------------------------------------------------------------------------------- 1 | /*************************************************************************************************** 2 | * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 | *reserved. SPDX-License-Identifier: BSD-3-Clause 4 | * 5 | * Redistribution and use in source and binary forms, with or without 6 | * modification, are permitted provided that the following conditions are met: 7 | * 8 | * 1. Redistributions of source code must retain the above copyright notice, 9 | *this list of conditions and the following disclaimer. 10 | * 11 | * 2. Redistributions in binary form must reproduce the above copyright notice, 12 | * this list of conditions and the following disclaimer in the documentation 13 | * and/or other materials provided with the distribution. 14 | * 15 | * 3. Neither the name of the copyright holder nor the names of its 16 | * contributors may be used to endorse or promote products derived from 17 | * this software without specific prior written permission. 18 | * 19 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 | *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 | *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 | *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 | *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 | *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 | *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 | *POSSIBILITY OF SUCH DAMAGE. 30 | * 31 | **************************************************************************************************/ 32 | #pragma once 33 | 34 | #include "asymmetric/epilogue/threadblock/default_epilogue_tensor_op_dequant.h" 35 | #include "asymmetric/gemm/kernel/sparse_gemm_dequant.h" 36 | #include "cutlass/gemm/kernel/default_gemm_sparse.h" 37 | 38 | //////////////////////////////////////////////////////////////////////////////// 39 | 40 | namespace cutlass { 41 | namespace gemm { 42 | namespace kernel { 43 | namespace asymmetric { 44 | 45 | //////////////////////////////////////////////////////////////////////////////// 46 | 47 | template < 48 | /// Element type for A matrix operand 49 | typename ElementA_, 50 | /// Layout type for A matrix operand 51 | typename LayoutA_, 52 | /// Access granularity of A matrix in units of elements 53 | int kAlignmentA, 54 | /// Element type for B matrix operand 55 | typename ElementB_, 56 | /// Layout type for B matrix operand 57 | typename LayoutB_, 58 | /// Access granularity of B matrix in units of elements 59 | int kAlignmentB, 60 | /// Element type for C and D matrix operands 61 | typename ElementC_, 62 | /// Layout type for C and D matrix operands 63 | typename LayoutC_, 64 | /// Element type for internal accumulation 65 | typename ElementAccumulator, 66 | /// Operator class tag 67 | typename OperatorClass, 68 | /// Tag indicating architecture to tune for 69 | typename ArchTag, 70 | /// Threadblock-level tile size (concept: GemmShape) 71 | typename ThreadblockShape, 72 | /// Warp-level tile size (concept: GemmShape) 73 | typename WarpShape, 74 | /// Warp-level tile size (concept: GemmShape) 75 | typename InstructionShape, 76 | /// Epilogue output operator 77 | typename EpilogueOutputOp, 78 | /// Threadblock-level swizzling operator 79 | typename ThreadblockSwizzle, 80 | /// Number of stages used in the pipelined mainloop 81 | int Stages, 82 | /// If true, kernel is configured to support serial reduction in the 83 | /// epilogue 84 | bool SplitKSerial, 85 | /// Operation performed by GEMM 86 | typename Operator> 87 | struct DefaultSparseGemmDequant 88 | : public DefaultSparseGemm< 89 | ElementA_, LayoutA_, kAlignmentA, ElementB_, LayoutB_, kAlignmentB, 90 | ElementC_, layout::RowMajor, ElementAccumulator, 91 | arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, 92 | InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, 93 | SplitKSerial, Operator> { 94 | using DefaultSparseGemm = DefaultSparseGemm< 95 | ElementA_, LayoutA_, kAlignmentA, ElementB_, LayoutB_, kAlignmentB, 96 | ElementC_, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, 97 | arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 98 | EpilogueOutputOp, ThreadblockSwizzle, Stages, SplitKSerial, Operator>; 99 | 100 | using Epilogue = typename cutlass::epilogue::threadblock::asymmetric:: 101 | DefaultEpilogueTensorOpDequant< 102 | ThreadblockShape, typename DefaultSparseGemm::Mma::Operator, 103 | DefaultSparseGemm::kPartitionsK, EpilogueOutputOp, 104 | EpilogueOutputOp::kCount>::Epilogue; 105 | 106 | /// Define the kernel-level GEMM operator. 107 | using GemmKernel = 108 | kernel::asymmetric::SparseGemmDequant; 111 | }; 112 | 113 | //////////////////////////////////////////////////////////////////////////////// 114 | 115 | } // namespace asymmetric 116 | } // namespace kernel 117 | } // namespace gemm 118 | } // namespace cutlass 119 | -------------------------------------------------------------------------------- /include/int4.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace cutlass { 5 | 6 | template 10 | class MySubbyteReference { 11 | public: 12 | using Element = Element_; 13 | using Storage = Storage_; 14 | using StoragePointer = Storage *; 15 | 16 | static_assert(sizeof_bits::value <= sizeof_bits::value, 17 | "Size of Element must not be greater than Storage."); 18 | 19 | static_assert(!(sizeof_bits::value % sizeof_bits::value), 20 | "Storage must be divisible by Element"); 21 | 22 | constexpr static int const kElementsPerVector = 23 | sizeof_bits::value / sizeof_bits::value; 24 | 25 | private: 26 | ///! Number of elements per storage vector 27 | 28 | ///! Bit mask 29 | Storage const kMask = 30 | ((sizeof_bits::value < sizeof_bits::value) 31 | ? (Storage(1) << sizeof_bits::value) - Storage(1) 32 | : ~Storage(0)); 33 | 34 | private: 35 | /// Pointer to array containing element 36 | StoragePointer ptr_; 37 | 38 | /// Offset (in units of elements) from pointer. 39 | /// 40 | /// Invariant: must always be in range [0, kElementsPerVector) 41 | int offset_; 42 | 43 | public: 44 | CUTLASS_HOST_DEVICE 45 | MySubbyteReference() : ptr_(nullptr), offset_(0) {} 46 | 47 | /// Constructor 48 | CUTLASS_HOST_DEVICE 49 | MySubbyteReference(Element *ptr, /// pointer to memory 50 | int64_t offset /// logical offset in units of Element 51 | ) 52 | : ptr_(reinterpret_cast(ptr)), offset_(0) { 53 | int64_t offset_in_vectors = offset / kElementsPerVector; 54 | int64_t offset_in_elements = offset % kElementsPerVector; 55 | 56 | ptr_ += offset_in_vectors; 57 | offset_ = int(offset_in_elements); 58 | } 59 | 60 | /// Constructor 61 | CUTLASS_HOST_DEVICE 62 | MySubbyteReference(Element *ptr = nullptr) : MySubbyteReference(ptr, 0) {} 63 | 64 | /// Gets storage pointer 65 | CUTLASS_HOST_DEVICE 66 | StoragePointer storage_pointer() const { return ptr_; } 67 | 68 | /// Gets storage pointer 69 | CUTLASS_HOST_DEVICE 70 | Element *operator&() const { return reinterpret_cast(ptr_); } 71 | 72 | /// Gets element offset within storage vector 73 | CUTLASS_HOST_DEVICE 74 | int element_offset() const { return offset_; } 75 | 76 | /// Unpacks an element from memory 77 | CUTLASS_HOST_DEVICE 78 | Element get() const { 79 | Storage item = 80 | Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); 81 | return reinterpret_cast(item); 82 | } 83 | 84 | /// Stores an element to memory 85 | CUTLASS_HOST_DEVICE 86 | MySubbyteReference &set(Element const &x) { 87 | Storage item = (reinterpret_cast(x) & kMask); 88 | Storage kUpdateMask = 89 | Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); 90 | Storage new_bits = 91 | Storage(item << (offset_ * cutlass::sizeof_bits::value)); 92 | 93 | Storage original = (*ptr_); 94 | Storage updated = Storage((original & kUpdateMask) | new_bits); 95 | *ptr_ = updated; 96 | 97 | return *this; 98 | } 99 | 100 | //// 101 | 102 | /// Unpacks an element from memory 103 | CUTLASS_HOST_DEVICE 104 | operator Element() const { return get(); } 105 | 106 | /// Stores an element to memory 107 | CUTLASS_HOST_DEVICE 108 | MySubbyteReference &operator=(Element const &x) { return set(x); } 109 | 110 | /// Stores an element to memory 111 | CUTLASS_HOST_DEVICE 112 | MySubbyteReference &operator=(MySubbyteReference const &x) { 113 | return set(x.get()); 114 | } 115 | 116 | /// Stores an element to memory 117 | CUTLASS_HOST_DEVICE 118 | MySubbyteReference &operator=( 119 | ConstSubbyteReference const &x) { 120 | return set(x.get()); 121 | } 122 | 123 | /// Adds an offset in units of elements to the reference 124 | CUTLASS_HOST_DEVICE 125 | MySubbyteReference &operator+=(int offset) { 126 | offset += offset_; 127 | 128 | int offset_in_vectors = offset / kElementsPerVector; 129 | int offset_in_elements = offset % kElementsPerVector; 130 | 131 | ptr_ += offset_in_vectors; 132 | offset_ = offset_in_elements; 133 | 134 | return *this; 135 | } 136 | 137 | /// Adds an offset in units of elements to the reference 138 | CUTLASS_HOST_DEVICE 139 | MySubbyteReference &operator+=(long long offset) { 140 | offset += offset_; 141 | 142 | long long offset_in_vectors = offset / kElementsPerVector; 143 | int offset_in_elements = int(offset % kElementsPerVector); 144 | 145 | ptr_ += offset_in_vectors; 146 | offset_ = offset_in_elements; 147 | 148 | return *this; 149 | } 150 | 151 | /// Adds an offset in units of elements to the reference 152 | CUTLASS_HOST_DEVICE 153 | MySubbyteReference &operator-=(int offset) { 154 | int offset_in_vectors = offset / kElementsPerVector; 155 | int offset_in_elements = offset % kElementsPerVector; 156 | 157 | ptr_ -= offset_in_vectors; 158 | offset_ -= offset_in_elements; 159 | 160 | if (offset_ < 0) { 161 | offset_ += kElementsPerVector; 162 | --ptr_; 163 | } 164 | 165 | return *this; 166 | } 167 | 168 | /// Adds an offset in units of elements to the reference 169 | CUTLASS_HOST_DEVICE 170 | MySubbyteReference &operator-=(long long offset) { 171 | long long offset_in_vectors = offset / kElementsPerVector; 172 | int offset_in_elements = int(offset % kElementsPerVector); 173 | 174 | ptr_ -= offset_in_vectors; 175 | offset_ -= offset_in_elements; 176 | 177 | if (offset_ < 0) { 178 | offset_ += kElementsPerVector; 179 | --ptr_; 180 | } 181 | 182 | return *this; 183 | } 184 | 185 | /// Returns a reference to an element with a given offset from the current 186 | /// reference 187 | CUTLASS_HOST_DEVICE 188 | MySubbyteReference operator+(int offset) const { 189 | MySubbyteReference ref(ptr_, offset_); 190 | ref += offset; 191 | 192 | return ref; 193 | } 194 | 195 | /// Returns a reference to an element with a given offset from the current 196 | /// reference 197 | CUTLASS_HOST_DEVICE 198 | MySubbyteReference operator+(long long offset) const { 199 | MySubbyteReference ref(ptr_, offset_); 200 | ref += offset; 201 | 202 | return ref; 203 | } 204 | 205 | /// Returns a reference to an element with a given offset from the current 206 | /// reference 207 | CUTLASS_HOST_DEVICE 208 | MySubbyteReference operator-(int offset) const { 209 | MySubbyteReference ref(ptr_, offset_); 210 | ref -= offset; 211 | 212 | return ref; 213 | } 214 | 215 | /// Returns a reference to an element with a given offset from the current 216 | /// reference 217 | CUTLASS_HOST_DEVICE 218 | MySubbyteReference operator-=(long long offset) const { 219 | MySubbyteReference ref(ptr_, offset_); 220 | ref -= offset; 221 | 222 | return ref; 223 | } 224 | 225 | /// Computes the difference in elements between references 226 | CUTLASS_HOST_DEVICE 227 | ptrdiff_t operator-(MySubbyteReference ref) const { 228 | return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); 229 | } 230 | 231 | /// Explicit cast to int 232 | CUTLASS_HOST_DEVICE 233 | explicit operator int() const { return int(get()); } 234 | 235 | /// Explicit cast to signed 64-bit integer 236 | CUTLASS_HOST_DEVICE 237 | explicit operator int64_t() const { return int64_t(get()); } 238 | 239 | /// Explicit cast to unsigned 64-bit integer 240 | CUTLASS_HOST_DEVICE 241 | explicit operator uint64_t() const { return uint64_t(get()); } 242 | 243 | /// Explicit cast to float 244 | CUTLASS_HOST_DEVICE 245 | explicit operator float() const { return float(get()); } 246 | 247 | /// Explicit cast to double 248 | CUTLASS_HOST_DEVICE 249 | explicit operator double() const { return double(get()); } 250 | }; 251 | 252 | } // namespace cutlass 253 | 254 | using Int4Subbyte = cutlass::MySubbyteReference; 255 | using Int4Storage = Int4Subbyte::Storage; 256 | constexpr const auto kElementsPerVector = 257 | cutlass::sizeof_bits::value / 258 | cutlass::sizeof_bits::value; 259 | -------------------------------------------------------------------------------- /include/matmul/matmul.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace TORCHQ::matmul { 5 | void buildSubmodule(pybind11::module &mod); 6 | } 7 | -------------------------------------------------------------------------------- /include/matmul/matmul_internal.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace TORCHQ::matmul { 6 | 7 | torch::Tensor int8MatmulCUDA(const torch::Tensor &A, const torch::Tensor &B); 8 | 9 | torch::Tensor myInt8MatmulCUDA(const torch::Tensor &A, 10 | const torch::Tensor &B, 11 | const torch::Tensor & zp_times_weight_channel_sum, 12 | const torch::Tensor & act_times_weight_delta, 13 | const torch::Tensor & y); 14 | 15 | torch::Tensor int8ConvCUDA(const torch::Tensor &input, const torch::Tensor &filter, const int padH, const int padW, 16 | const int strideH, const int strideW, const int dilationH, const int dilationW); 17 | 18 | torch::Tensor myInt8ConvCUDA(const torch::Tensor &input, const torch::Tensor &filter, const int padH, const int padW, 19 | const int strideH, const int strideW, const int dilationH, const int dilationW, 20 | const torch::Tensor & zp_times_weight_channel_sum, 21 | const torch::Tensor & act_times_weight_delta, 22 | const torch::Tensor & y); 23 | } // namespace TORCHQ::matmul 24 | -------------------------------------------------------------------------------- /include/util.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace TORCHQ::util { 9 | template 10 | struct TorchDtypeDispatcher; 11 | 12 | template <> 13 | struct TorchDtypeDispatcher { 14 | constexpr static const auto value = torch::kUInt8; 15 | }; 16 | 17 | template <> 18 | struct TorchDtypeDispatcher { 19 | constexpr static const auto value = torch::kInt8; 20 | }; 21 | 22 | template <> 23 | struct TorchDtypeDispatcher { 24 | constexpr static const auto value = torch::kInt32; 25 | }; 26 | 27 | template <> 28 | struct TorchDtypeDispatcher { 29 | constexpr static const auto value = torch::kFloat16; 30 | }; 31 | 32 | template 33 | struct DtypeTorchDispatcher; 34 | 35 | template <> 36 | struct DtypeTorchDispatcher { 37 | using value = __half; 38 | }; 39 | 40 | template <> 41 | struct DtypeTorchDispatcher { 42 | using value = __nv_bfloat16; 43 | }; 44 | 45 | template 46 | __device__ inline int type2int_rn(T a) { 47 | return static_cast(a); 48 | } 49 | 50 | template <> 51 | __device__ inline int type2int_rn<__half>(__half input) { 52 | return __half2int_rn(input); 53 | } 54 | 55 | template <> 56 | __device__ inline int type2int_rn<__nv_bfloat16>(__nv_bfloat16 input) { 57 | return __bfloat162int_rn(input); 58 | } 59 | 60 | template 61 | __device__ inline float type2float(T a) { 62 | return static_cast(a); 63 | } 64 | 65 | template <> 66 | __device__ inline float type2float<__half>(__half input) { 67 | return __half2float(input); 68 | } 69 | 70 | template <> 71 | __device__ inline float type2float<__nv_bfloat16>(__nv_bfloat16 input) { 72 | return __bfloat162float(input); 73 | } 74 | 75 | template 76 | __device__ inline T float2type(float a) { 77 | return static_cast(a); 78 | } 79 | 80 | template <> 81 | __device__ inline __half float2type<__half>(float input) { 82 | return __float2half(input); 83 | } 84 | 85 | template <> 86 | __device__ inline __nv_bfloat16 float2type<__nv_bfloat16>(float input) { 87 | return __float2bfloat16_rn(input); 88 | } 89 | 90 | template 91 | struct DtypeDtype2Dispatcher; 92 | 93 | template <> 94 | struct DtypeDtype2Dispatcher<__half> { 95 | using value = __half2; 96 | }; 97 | 98 | template <> 99 | struct DtypeDtype2Dispatcher<__nv_bfloat16> { 100 | using value = __nv_bfloat162; 101 | }; 102 | 103 | __device__ inline __half2 type2type2(__half input, __half input2) { 104 | return __halves2half2(input, input2); 105 | } 106 | 107 | __device__ inline __nv_bfloat162 type2type2(__nv_bfloat16 input, 108 | __nv_bfloat16 input2) { 109 | return __halves2bfloat162(input, input2); 110 | } 111 | 112 | // template 113 | // T div(T a, T b) { 114 | // return a / b; 115 | // } 116 | // 117 | // template <> 118 | //__half div(__half a, __half b) { 119 | // return __hdiv(a, b); 120 | // } 121 | // 122 | // template <> 123 | //__nv_bfloat16 div(__nv_bfloat16 a, __nv_bfloat16 b) { 124 | // return __hdiv(a, b); 125 | // } 126 | 127 | } // namespace TORCHQ::util -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import platform 4 | import re 5 | import shutil 6 | import sys 7 | import sysconfig 8 | 9 | from setuptools import setup, find_packages 10 | import subprocess 11 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 12 | 13 | try: 14 | from pybind11.setup_helpers import Pybind11Extension as Extension 15 | from pybind11.setup_helpers import build_ext 16 | except ImportError: 17 | from setuptools import Extension 18 | from setuptools.command.build_ext import build_ext 19 | 20 | HERE = pathlib.Path(__file__).absolute().parent 21 | VERSION_FILE = HERE / 'torch_quantizer' / 'version.py' 22 | 23 | sys.path.insert(0, str(VERSION_FILE.parent)) 24 | import version # noqa 25 | 26 | import torch 27 | class CustomBdistWheel(_bdist_wheel): 28 | def get_tag(self): 29 | original_tag = super().get_tag() 30 | # Include PyTorch version in the wheel name 31 | pytorch_version = torch.__version__.replace('.', '_') 32 | cuda_version = torch.version.cuda.replace('.', '_') 33 | custom_tag = f'{original_tag[0]}_pytorch_{pytorch_version}_cuda_{cuda_version}' 34 | return (custom_tag, original_tag[1], original_tag[2]) 35 | 36 | class CMakeExtension(Extension): 37 | def __init__(self, name, source_dir='.', target=None, **kwargs): 38 | super().__init__(name, sources=[], **kwargs) 39 | self.source_dir = os.path.abspath(source_dir) 40 | self.target = target if target is not None else name.rpartition('.')[-1] 41 | 42 | 43 | class cmake_build_ext(build_ext): 44 | def build_extension(self, ext): 45 | if not isinstance(ext, CMakeExtension): 46 | super().build_extension(ext) 47 | return 48 | 49 | from torch.utils import cpp_extension 50 | 51 | cmake = shutil.which('cmake') 52 | if cmake is None: 53 | raise RuntimeError('Cannot find CMake executable.') 54 | 55 | ext_path = pathlib.Path(self.get_ext_fullpath(ext.name)).absolute() 56 | build_temp = pathlib.Path(self.build_temp).absolute() 57 | build_temp.mkdir(parents=True, exist_ok=True) 58 | 59 | config = 'Debug' if self.debug else 'Release' 60 | 61 | cmake_args = [ 62 | f'-DCMAKE_BUILD_TYPE={config}', 63 | f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{config.upper()}={ext_path.parent}', 64 | f'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY_{config.upper()}={build_temp}', 65 | f'-DPYTHON_EXECUTABLE={sys.executable}', 66 | f'-DPYTHON_INCLUDE_DIR={sysconfig.get_path("platinclude")}', 67 | f'-DTORCH_INCLUDE_PATH={";".join(cpp_extension.include_paths())}', 68 | f'-DTORCH_LIBRARY_PATH={";".join(cpp_extension.library_paths())}', 69 | ] 70 | 71 | if platform.system() == 'Darwin': 72 | # Cross-compile support for macOS - respect ARCHFLAGS if set 73 | archs = re.findall(r'-arch (\S+)', os.environ.get('ARCHFLAGS', '')) 74 | if archs: 75 | cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(archs)}') 76 | 77 | try: 78 | import pybind11 79 | 80 | cmake_args.append(f'-DPYBIND11_CMAKE_DIR={pybind11.get_cmake_dir()}') 81 | except ImportError: 82 | pass 83 | 84 | build_args = ['--config', config] 85 | if ( 86 | 'CMAKE_BUILD_PARALLEL_LEVEL' not in os.environ 87 | and hasattr(self, 'parallel') 88 | and self.parallel 89 | ): 90 | build_args.extend(['--parallel', str(self.parallel)]) 91 | else: 92 | build_args.append('--parallel') 93 | 94 | build_args.extend(['--target', ext.target, '--']) 95 | 96 | try: 97 | os.chdir(build_temp) 98 | 99 | retcode = subprocess.call([cmake, ext.source_dir, *cmake_args]) 100 | if retcode != 0: 101 | sys.stderr.write("Error: CMake configuration failed.\n") 102 | sys.exit(1) 103 | 104 | if not self.dry_run: 105 | retcode = subprocess.call([cmake, '--build', '.', *build_args]) 106 | if retcode != 0: 107 | sys.stderr.write("Error: Building with CMake failed.\n") 108 | sys.exit(1) 109 | finally: 110 | os.chdir(HERE) 111 | 112 | 113 | CIBUILDWHEEL = os.getenv('CIBUILDWHEEL', '0') == '1' 114 | LINUX = platform.system() == 'Linux' 115 | MACOS = platform.system() == 'Darwin' 116 | WINDOWS = platform.system() == 'Windows' 117 | ext_kwargs = { 118 | 'cmdclass': {'build_ext': cmake_build_ext, 'bdist_wheel': CustomBdistWheel}, 119 | 'ext_modules': [ 120 | CMakeExtension( 121 | 'torch_quantizer._C', 122 | source_dir=HERE, 123 | optional=not (LINUX and CIBUILDWHEEL), 124 | ), 125 | ], 126 | } 127 | 128 | VERSION_CONTENT = None 129 | 130 | try: 131 | if not version.__release__: 132 | try: 133 | VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8') 134 | VERSION_FILE.write_text( 135 | data=re.sub( 136 | r"""__version__\s*=\s*('[^']+'|"[^"]+")""", 137 | f'__version__ = {version.__version__!r}', 138 | string=VERSION_CONTENT, 139 | ), 140 | encoding='utf-8', 141 | ) 142 | except OSError: 143 | VERSION_CONTENT = None 144 | setup( 145 | name='torch_quantizer', 146 | version=version.__version__, 147 | packages=find_packages(), 148 | **ext_kwargs, 149 | ) 150 | finally: 151 | if VERSION_CONTENT is not None: 152 | with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file: 153 | file.write(VERSION_CONTENT) 154 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if (CUDA_FOUND) 2 | 3 | endif () 4 | 5 | set(_C_LIBRARIES "") 6 | 7 | add_subdirectory(matmul) 8 | add_subdirectory(asymmetric) 9 | 10 | pybind11_add_module(_C MODULE THIN_LTO binding.cpp) 11 | 12 | target_link_libraries( 13 | _C PUBLIC 14 | ${TORCH_LIBRARIES} 15 | ${_C_LIBRARIES} 16 | ) 17 | -------------------------------------------------------------------------------- /src/asymmetric/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(SRCS asymmetric.cpp) 2 | 3 | if (CUDA_FOUND) 4 | list(APPEND SRCS asymmetric.cu) 5 | endif () 6 | 7 | 8 | add_library(_C_LIBRARY_ASYMMETRIC STATIC "${SRCS}") 9 | target_link_libraries(_C_LIBRARY_ASYMMETRIC PRIVATE ${TORCH_LIBRARIES}) 10 | if (CUDA_FOUND) 11 | target_link_libraries(_C_LIBRARY_ASYMMETRIC PRIVATE nvidia::cutlass::cutlass nvidia::cutlass::tools::util) 12 | endif () 13 | 14 | list(APPEND _C_LIBRARIES _C_LIBRARY_ASYMMETRIC) 15 | set(_C_LIBRARIES "${_C_LIBRARIES}" PARENT_SCOPE) 16 | -------------------------------------------------------------------------------- /src/asymmetric/asymmetric.cpp: -------------------------------------------------------------------------------- 1 | #include "asymmetric/asymmetric.h" 2 | 3 | #include 4 | 5 | #include "asymmetric/asymmetric_internal.h" 6 | 7 | namespace TORCHQ::asymmetric { 8 | torch::Tensor myQuantize(const torch::Tensor &src, const torch::Tensor &delta, 9 | const torch::Tensor &zp) { 10 | torch::checkAllContiguous("myQuantize", {{src, "src", 0}, {delta, "delta", 1}, {zp, "zp", 2}}); 11 | torch::checkDeviceType("myQuantize", {src, delta, zp}, at::DeviceType::CUDA); 12 | return myQuantizeCUDA(src, delta, zp); 13 | } 14 | 15 | torch::Tensor myQuantizeNCHW(const torch::Tensor &src, const torch::Tensor &delta, 16 | const torch::Tensor &zp) { 17 | torch::checkAllContiguous("myQuantizeNCHW", {{src, "src", 0}, {delta, "delta", 1}, {zp, "zp", 2}}); 18 | torch::checkDeviceType("myQuantizeNCHW", {src, delta, zp}, at::DeviceType::CUDA); 19 | return myQuantizeNCHWCUDA(src, delta, zp); 20 | } 21 | 22 | void buildSubmodule(py::module &mod) { 23 | py::module m = 24 | mod.def_submodule("asymmetric", "Asymmetric Quantization Functions"); 25 | 26 | m.def("myQuantize", &myQuantize, 27 | "input: (src: torch.Tensor(M x N, FP16, CUDA),\n" 28 | "delta: torch.Tensor(1, FP16, CUDA)\n" 29 | "zp: torch.Tensor(1, INT8, CUDA) \n" 30 | "output: torch.Tensor(M x N, INT8, CUDA)\n" 31 | "output = int{bits}Packing(int{bits}Rounding((source / delta) + zp ", 32 | py::arg("src"), py::arg("delta"), py::arg("zp")); 33 | m.def("myQuantizeNCHW", &myQuantizeNCHW, 34 | "input: (src: torch.Tensor(N x C x H x W, FP16, CUDA),\n" 35 | "delta: torch.Tensor(1, FP16, CUDA)\n" 36 | "zp: torch.Tensor(1, INT8, CUDA) \n" 37 | "output: torch.Tensor(M x N, INT8, CUDA)\n" 38 | "output = int{bits}Packing(int{bits}Rounding((source / delta) + zp ", 39 | py::arg("src"), py::arg("delta"), py::arg("zp")); 40 | } 41 | } // namespace TORCHQ::asymmetric -------------------------------------------------------------------------------- /src/asymmetric/asymmetric.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "asymmetric/asymmetric_internal.h" 4 | #include "int4.h" 5 | #include "util.h" 6 | 7 | namespace TORCHQ::asymmetric { 8 | const unsigned MAX_NUMTHREADS = 1024; 9 | const unsigned MAX_NUMBER_BLOCKS = 65535; 10 | unsigned NUM_STRIDES_PER_THREAD_QUANTIZE = 0; 11 | unsigned NUM_STRIDES_PER_THREAD_DEQUANTIZE = 0; 12 | 13 | __global__ void myQuantizeCUDAKernel8Bits(int8_t *__restrict__ dst, 14 | const torch::Half *__restrict__ src, 15 | const torch::Half * __restrict__ delta, 16 | const torch::Half * __restrict__ zp, 17 | const unsigned rows, 18 | const unsigned cols) { 19 | const unsigned thread_id = threadIdx.x + blockIdx.x * blockDim.x; 20 | const unsigned stride = blockDim.x * gridDim.x; 21 | const unsigned num_elems = rows * cols; 22 | 23 | for (unsigned idx = thread_id; idx < num_elems; idx += stride) { 24 | const unsigned row = idx / cols; 25 | 26 | __half data = __hadd(__hdiv(src[idx], delta[0]), zp[0]); 27 | int val = __half2int_rn(data); 28 | // needs to be shifted by 128 to fit int8_t 29 | val = static_cast(min(max(val, -128), 127)); 30 | dst[idx] = val; 31 | } 32 | } 33 | 34 | torch::Tensor myQuantizeCUDA(const torch::Tensor &src, const torch::Tensor &delta, 35 | const torch::Tensor &zp) { 36 | torch::checkAllSameGPU("myQuantizeCUDA", {{src, "src", 0}, {delta, "delta", 1}, {zp, "zp", 2}}); 37 | const at::cuda::CUDAGuard device_guard(src.device()); 38 | 39 | if (NUM_STRIDES_PER_THREAD_QUANTIZE == 0) { 40 | char const *temp = getenv("NUM_STRIDES_PER_THREAD_QUANTIZE"); 41 | if (temp) 42 | NUM_STRIDES_PER_THREAD_QUANTIZE = std::atoi(temp); 43 | else 44 | NUM_STRIDES_PER_THREAD_QUANTIZE = 1; 45 | TORCH_CHECK(NUM_STRIDES_PER_THREAD_QUANTIZE > 0 and 46 | NUM_STRIDES_PER_THREAD_QUANTIZE < 64, 47 | "Quantize: invalid value of NUM_STRIDES_PER_THREAD_QUANTIZE"); 48 | } 49 | 50 | unsigned rows = src.size(0); 51 | unsigned colsSrc = src.size(1); 52 | torch::Tensor dst; 53 | const unsigned num_elems = src.numel(); 54 | const unsigned num_threads = min(num_elems, MAX_NUMTHREADS); 55 | const unsigned num_blocks = 56 | max((num_elems + num_threads - 1) / 57 | (num_threads * NUM_STRIDES_PER_THREAD_QUANTIZE), 58 | 16); 59 | 60 | 61 | dst = torch::empty({rows, colsSrc}, 62 | torch::dtype(util::TorchDtypeDispatcher::value) 63 | .device(src.device())); 64 | myQuantizeCUDAKernel8Bits<<>>( 65 | dst.data_ptr(), src.data_ptr(), 66 | delta.data_ptr(), zp.data_ptr(), rows, colsSrc); 67 | 68 | auto status = cudaGetLastError(); 69 | TORCH_CHECK(status == cudaSuccess, 70 | "Failed quantize: " + std::string(cudaGetErrorString(status))); 71 | return dst; 72 | } 73 | 74 | 75 | __global__ void myQuantizeNCHWCUDAKernel8Bits(int8_t *__restrict__ dst, 76 | const torch::Half *__restrict__ src, 77 | const torch::Half * __restrict__ delta, 78 | const torch::Half * __restrict__ zp, 79 | const unsigned num_elems) { 80 | const unsigned thread_id = threadIdx.x + blockIdx.x * blockDim.x; 81 | const unsigned stride = blockDim.x * gridDim.x; 82 | // const unsigned num_elems = rows * cols; 83 | 84 | for (unsigned idx = thread_id; idx < num_elems; idx += stride) { 85 | // const unsigned row = idx / cols; 86 | 87 | __half data = __hadd(__hdiv(src[idx], delta[0]), zp[0]); 88 | int val = __half2int_rn(data); 89 | // needs to be shifted by 128 to fit int8_t 90 | val = static_cast(min(max(val, -128), 127)); 91 | dst[idx] = val; 92 | } 93 | } 94 | 95 | torch::Tensor myQuantizeNCHWCUDA(const torch::Tensor &src, const torch::Tensor &delta, 96 | const torch::Tensor &zp) { 97 | torch::checkAllSameGPU("myQuantizeNCHWCUDA", {{src, "src", 0}, {delta, "delta", 1}, {zp, "zp", 2}}); 98 | const at::cuda::CUDAGuard device_guard(src.device()); 99 | 100 | if (NUM_STRIDES_PER_THREAD_QUANTIZE == 0) { 101 | char const *temp = getenv("NUM_STRIDES_PER_THREAD_QUANTIZE"); 102 | if (temp) 103 | NUM_STRIDES_PER_THREAD_QUANTIZE = std::atoi(temp); 104 | else 105 | NUM_STRIDES_PER_THREAD_QUANTIZE = 1; 106 | TORCH_CHECK(NUM_STRIDES_PER_THREAD_QUANTIZE > 0 and 107 | NUM_STRIDES_PER_THREAD_QUANTIZE < 64, 108 | "Quantize: invalid value of NUM_STRIDES_PER_THREAD_QUANTIZE"); 109 | } 110 | 111 | auto N = src.size(0); 112 | auto C = src.size(1); 113 | auto H = src.size(2); 114 | auto W = src.size(3); 115 | 116 | torch::Tensor dst; 117 | const unsigned num_elems = src.numel(); 118 | const unsigned num_threads = min(num_elems, MAX_NUMTHREADS); 119 | const unsigned num_blocks = 120 | max((num_elems + num_threads - 1) / 121 | (num_threads * NUM_STRIDES_PER_THREAD_QUANTIZE), 122 | 16); 123 | 124 | 125 | dst = torch::empty({N,C,H,W}, 126 | torch::dtype(util::TorchDtypeDispatcher::value) 127 | .device(src.device())); 128 | myQuantizeNCHWCUDAKernel8Bits<<>>( 129 | dst.data_ptr(), src.data_ptr(), 130 | delta.data_ptr(), zp.data_ptr(), num_elems); 131 | 132 | auto status = cudaGetLastError(); 133 | TORCH_CHECK(status == cudaSuccess, 134 | "Failed quantize: " + std::string(cudaGetErrorString(status))); 135 | 136 | // NCHW to NHWC 137 | torch::Tensor dst_nhwc = torch::empty({N, H, W, C}, 138 | torch::dtype(util::TorchDtypeDispatcher::value) 139 | .device(src.device())); 140 | 141 | // Step 3: Define tensor sizes for NHWC and NCHW 142 | auto input_tensor_size = cutlass::Tensor4DCoord({N, C, H, W}); 143 | auto output_tensor_size = cutlass::Tensor4DCoord({N, H, W, C}); 144 | 145 | // Step 4: Create Tensor Refs 146 | cutlass::TensorRef ref_input( 147 | dst.data_ptr(), cutlass::layout::TensorNCHW::packed(cutlass::Tensor4DCoord({N,C,H,W}))); 148 | cutlass::TensorRef ref_output( 149 | dst_nhwc.data_ptr(), cutlass::layout::TensorNHWC::packed(cutlass::Tensor4DCoord({N,H,W,C}))); 150 | 151 | // Call nchw_to_nhwc 152 | nchw_to_nhwc(input_tensor_size, output_tensor_size, ref_input, ref_output, 0); // Assuming default stream 153 | 154 | return dst_nhwc; 155 | // return dst; 156 | } 157 | 158 | 159 | } // namespace TORCHQ::asymmetric 160 | -------------------------------------------------------------------------------- /src/binding.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "asymmetric/asymmetric.h" 4 | #include "matmul/matmul.h" 5 | 6 | PYBIND11_MODULE(_C, mod) { 7 | TORCHQ::matmul::buildSubmodule(mod); 8 | TORCHQ::asymmetric::buildSubmodule(mod); 9 | } 10 | -------------------------------------------------------------------------------- /src/matmul/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(SRCS matmul.cpp) 2 | 3 | if (CUDA_FOUND) 4 | list(APPEND SRCS matmul.cu) 5 | endif () 6 | 7 | 8 | add_library(_C_LIBRARY_MATMUL STATIC "${SRCS}") 9 | target_link_libraries(_C_LIBRARY_MATMUL PRIVATE ${TORCH_LIBRARIES}) 10 | 11 | if (CUDA_FOUND) 12 | target_link_libraries(_C_LIBRARY_MATMUL PRIVATE nvidia::cutlass::cutlass nvidia::cutlass::tools::util) 13 | endif () 14 | 15 | list(APPEND _C_LIBRARIES _C_LIBRARY_MATMUL) 16 | set(_C_LIBRARIES "${_C_LIBRARIES}" PARENT_SCOPE) 17 | -------------------------------------------------------------------------------- /src/matmul/matmul.cpp: -------------------------------------------------------------------------------- 1 | #include "matmul/matmul.h" 2 | 3 | #include 4 | 5 | #include "matmul/matmul_internal.h" 6 | 7 | namespace TORCHQ::matmul { 8 | torch::Tensor int8Matmul(const torch::Tensor &A, const torch::Tensor &B) { 9 | torch::checkAllContiguous("int8Matmul", {{A, "A", 0}, {B, "B", 1}}); 10 | // TODO(Tingxuan): support more data type 11 | torch::checkDeviceType("int8Matmul", {A, B}, at::DeviceType::CUDA); 12 | return int8MatmulCUDA(A, B); 13 | } 14 | 15 | torch::Tensor myInt8Matmul(const torch::Tensor &A, 16 | const torch::Tensor &B, 17 | const torch::Tensor & zp_times_weight_channel_sum, 18 | const torch::Tensor & act_times_weight_delta, 19 | const torch::Tensor & y) { 20 | torch::checkAllContiguous("myInt8Matmul", {{A, "A", 0}, {B, "B", 1}, {zp_times_weight_channel_sum, "zp_times_weight_channel_sum", 2}, 21 | {act_times_weight_delta, "act_times_weight_delta", 3}, {y, "y", 4}}); 22 | // TODO(Tingxuan): support more data type 23 | torch::checkDeviceType("myInt8Matmul", {A, B, zp_times_weight_channel_sum, act_times_weight_delta, y}, at::DeviceType::CUDA); 24 | return myInt8MatmulCUDA(A, B, zp_times_weight_channel_sum, act_times_weight_delta, y); 25 | } 26 | 27 | torch::Tensor int8Conv(const torch::Tensor &input, const torch::Tensor &filter, const int padH, const int padW, 28 | const int strideH, const int strideW, const int dilationH, const int dilationW) { 29 | torch::checkAllContiguous("int8Conv", {{input, "input", 0}, {filter, "filter", 1}}); 30 | // TODO(Tingxuan): support more data type 31 | torch::checkDeviceType("int8Conv", {input, filter}, at::DeviceType::CUDA); 32 | return int8ConvCUDA(input, filter, padH, padW, strideH, strideW, dilationH, dilationW); 33 | 34 | } 35 | torch::Tensor myInt8Conv(const torch::Tensor &input, const torch::Tensor &filter, const int padH, const int padW, 36 | const int strideH, const int strideW, const int dilationH, const int dilationW, 37 | const torch::Tensor & zp_times_weight_channel_sum, 38 | const torch::Tensor & act_times_weight_delta, 39 | const torch::Tensor & y) { 40 | torch::checkAllContiguous("myInt8Conv", {{input, "input", 0}, {filter, "filter", 1}, {zp_times_weight_channel_sum, "zp_times_weight_channel_sum", 2}, 41 | {act_times_weight_delta, "act_times_weight_delta", 3}, {y, "y", 4}}); 42 | // TODO(Tingxuan): support more data type 43 | torch::checkDeviceType("myInt8Conv", {input, filter, zp_times_weight_channel_sum, act_times_weight_delta, y}, at::DeviceType::CUDA); 44 | return myInt8ConvCUDA(input, filter, padH, padW, strideH, strideW, dilationH, dilationW, 45 | zp_times_weight_channel_sum, act_times_weight_delta, y); 46 | } 47 | 48 | void buildSubmodule(py::module &mod) { 49 | py::module m = mod.def_submodule("matmul", "Matmul Functions"); 50 | m.def("int8Matmul", &int8Matmul, 51 | "input: (A: torch.Tensor(M x K, INT8, CUDA), B: torch.Tensor(N x K, " 52 | "INT8, CUDA))\n" 53 | "output: torch.Tensor(M x N, INT32, CUDA)\n" 54 | "output = A @ B^T", 55 | py::arg("A"), py::arg("B")); 56 | 57 | m.def("myInt8Matmul", &myInt8Matmul, 58 | "input: (A: torch.Tensor(M x K, INT8, CUDA), B: torch.Tensor(N x K, " 59 | "INT8, CUDA)), zp_times_weight_channel_sum: torch.Tensor(N, INT8, CUDA), act_times_weight_delta: torch.Tensor(N, INT8, CUDA)," 60 | "y: torch.Tensor(M x K, FP16, CUDA)\n" 61 | "output: torch.Tensor(M x N, INT32, CUDA).sub_(zp_times_weight_channel_sum).mul_(act_times_weight_delta)+ y\n" 62 | "output = (A @ B^T - zp_times_weight_channel_sum) * act_times_weight_delta + y", 63 | py::arg("A"), py::arg("B"), py::arg("zp_times_weight_channel_sum"), py::arg("act_times_weight_delta"), py::arg("y")); 64 | 65 | m.def("int8Conv", &int8Conv, 66 | "input: (input: torch.Tensor(N x H x W x Cin, INT8, CUDA), filter: torch.Tensor(Co x Cin x K x K, " 67 | "INT8, CUDA))\n" 68 | "output: torch.Tensor(N x H' x W'x Co, INT32, CUDA)\n" 69 | "output = conv(input, filter)", 70 | py::arg("input"), py::arg("filter"), py::arg("padH"), py::arg("padW"), py::arg("strideH"), py::arg("strideW"), py::arg("dilationH"), py::arg("dilationW")); 71 | 72 | m.def("myInt8Conv", &myInt8Conv, 73 | "input: (input: torch.Tensor(N x H x W x Cin, INT8, CUDA), filter: torch.Tensor(Co x Cin x K x K, " 74 | "INT8, CUDA)), zp_times_weight_channel_sum: torch.Tensor(N, INT8, CUDA), act_times_weight_delta: torch.Tensor(N, INT8, CUDA)," 75 | "y: torch.Tensor(M x K, FP16, CUDA)\n" 76 | "output: torch.Tensor(N x H' x W'x Co, INT32, CUDA).sub_(zp_times_weight_channel_sum).mul_(act_times_weight_delta)+ y\n" 77 | "output = conv(input, filter)", 78 | py::arg("input"), py::arg("filter"), py::arg("padH"), py::arg("padW"), py::arg("strideH"), py::arg("strideW"), 79 | py::arg("dilationH"), py::arg("dilationW"), py::arg("zp_times_weight_channel_sum"), py::arg("act_times_weight_delta"), py::arg("y")); 80 | } 81 | } // namespace TORCHQ::matmul 82 | -------------------------------------------------------------------------------- /src/matmul/matmul.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cutlass/conv/kernel/default_conv2d_fprop.h" 4 | #include "cutlass/conv/device/implicit_gemm_convolution.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "int4.h" 11 | #include "matmul/matmul_internal.h" 12 | #include "util.h" 13 | 14 | 15 | namespace TORCHQ::matmul { 16 | torch::Tensor int8MatmulCUDA(const torch::Tensor &A, const torch::Tensor &B) { 17 | torch::checkAllSameGPU("int8Matmul", {{A, "A", 0}, {B, "B", 1}}); 18 | auto M = A.size(0); 19 | auto N = B.size(0); 20 | auto K = A.size(1); // 4bit packing is on the columns 21 | auto C = torch::empty({M, N}, torch::dtype(torch::kInt32).device(A.device())); 22 | using ElementOutput = int32_t; 23 | using ElementAccumulator = int32_t; 24 | using ElementCompute = int32_t; 25 | 26 | using Gemm = cutlass::gemm::device::Gemm< 27 | int8_t, cutlass::layout::RowMajor, int8_t, 28 | cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, 29 | ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, 30 | cutlass::gemm::GemmShape<64, 64, 128>, 31 | cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, 32 | cutlass::epilogue::thread::LinearCombinationClamp< 33 | ElementOutput, 128 / cutlass::sizeof_bits::value, 34 | ElementAccumulator, ElementCompute>, 35 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; 36 | 37 | Gemm gemmOp; 38 | 39 | using GemmCoord = cutlass::gemm::GemmCoord; 40 | 41 | typename Gemm::Arguments arguments{ 42 | {static_cast(M), static_cast(N), 43 | static_cast(K)}, 44 | {A.data_ptr(), K}, 45 | {B.data_ptr(), K}, 46 | {C.data_ptr(), N}, 47 | {C.data_ptr(), N}, 48 | {1, 0}}; 49 | 50 | auto status = gemmOp(arguments); 51 | 52 | TORCH_CHECK(status == cutlass::Status::kSuccess, 53 | cutlassGetStatusString(status)) 54 | 55 | return C; 56 | } 57 | 58 | template 59 | __global__ void myDequantizationKernel(KTorch *__restrict__ out, 60 | const int *__restrict__ x, 61 | const KTorch *__restrict__ zp_times_weight_channel_sum, 62 | const KTorch *__restrict__ act_times_weight_delta, 63 | const KTorch *__restrict__ y, 64 | const unsigned rows, const unsigned cols) { 65 | const unsigned row = threadIdx.y + blockIdx.y * blockDim.y; 66 | const unsigned col = threadIdx.x + blockIdx.x * blockDim.x; 67 | 68 | if (col >= cols || row >= rows) { 69 | return; 70 | } 71 | // using K = typename util::DtypeTorchDispatcher::value; 72 | // Convert int32_t element to float32 first 73 | float xElement = static_cast(x[col + row * cols]); 74 | // float zp_times_weight_channel_sum_element = util::type2float(zp_times_weight_channel_sum[row]); 75 | // float act_times_weight_delta_element = util::type2float(act_times_weight_delta[row]); 76 | 77 | float zp_times_weight_channel_sum_element = zp_times_weight_channel_sum[col]; 78 | float act_times_weight_delta_element = act_times_weight_delta[col]; 79 | 80 | // Subtract zp_times_weight_channel_sum and multiply by act_times_weight_delta 81 | xElement -= zp_times_weight_channel_sum_element; 82 | xElement *= act_times_weight_delta_element; 83 | xElement += y[col]; 84 | 85 | // out[col + row * cols] = util::float2type(xElement); 86 | 87 | out[col + row * cols] = xElement; 88 | } 89 | 90 | torch::Tensor myInt8MatmulCUDA(const torch::Tensor &A, 91 | const torch::Tensor &B, 92 | const torch::Tensor & zp_times_weight_channel_sum, 93 | const torch::Tensor & act_times_weight_delta, 94 | const torch::Tensor & y) { 95 | // Step 1: Perform GEMM Operation 96 | torch::Tensor C = int8MatmulCUDA(A, B); 97 | 98 | // Step 2: Setup for Dequantization 99 | unsigned M = C.size(0); 100 | unsigned N = C.size(1); 101 | auto out = torch::empty_like(C, torch::dtype(torch::kFloat).device(C.device())); 102 | 103 | // Step 3: Dequantization Kernel Call 104 | dim3 blockDim(16, 16); // Adjust block size as needed, how many thread in a block 105 | dim3 gridDim((N + blockDim.x - 1) / blockDim.x, (M + blockDim.y - 1) / blockDim.y); 106 | myDequantizationKernel<<>>(out.data_ptr(), 107 | C.data_ptr(), 108 | zp_times_weight_channel_sum.data_ptr(), 109 | act_times_weight_delta.data_ptr(), 110 | y.data_ptr(), 111 | M, N); 112 | 113 | // Check for errors in kernel launch 114 | cudaError_t err = cudaGetLastError(); 115 | if (err != cudaSuccess) { 116 | // Handle the error appropriately 117 | } 118 | 119 | // Step 4: Return the dequantized output 120 | return out; 121 | } 122 | 123 | torch::Tensor int8ConvCUDA(const torch::Tensor &input, const torch::Tensor &filter, const int padH, const int padW, 124 | const int strideH, const int strideW, const int dilationH, const int dilationW) { 125 | // Check that tensors are on the same GPU 126 | torch::checkAllSameGPU("int8ConvCUDA", {{input, "input", 0}, {filter, "filter", 1}}); 127 | 128 | // Assuming input tensor layout is NCHW and filter layout is KCRS 129 | auto N = input.size(0); 130 | auto H = input.size(1); 131 | auto W = input.size(2); 132 | auto C = input.size(3); 133 | auto K = filter.size(0); 134 | auto R = filter.size(1); 135 | auto S = filter.size(2); 136 | 137 | /// Conv operation element types for the Gemm equivalent (ImplicitGemm) 138 | using ElementA = int8_t; 139 | using ElementB = int8_t; 140 | using ElementC = int32_t; 141 | using ElementAccumulator = int32_t; 142 | using ElementCompute = float; 143 | 144 | using Conv2dFpropKernel = cutlass::conv::kernel::DefaultConv2dFprop< 145 | ElementA, cutlass::layout::TensorNHWC, 146 | ElementB, cutlass::layout::TensorNHWC, 147 | ElementC, cutlass::layout::TensorNHWC, 148 | ElementAccumulator, 149 | cutlass::arch::OpClassTensorOp, 150 | cutlass::arch::Sm80, 151 | cutlass::gemm::GemmShape<128, 128, 64>, 152 | cutlass::gemm::GemmShape<64, 64, 64>, 153 | cutlass::gemm::GemmShape<16, 8, 32>, 154 | cutlass::epilogue::thread::LinearCombination< 155 | ElementC, 156 | 128 / cutlass::sizeof_bits::value, 157 | ElementAccumulator, 158 | ElementCompute 159 | >, 160 | cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 161 | 3, 162 | cutlass::arch::OpMultiplyAddSaturate, 163 | cutlass::conv::IteratorAlgorithm::kOptimized 164 | >::Kernel; 165 | 166 | using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; 167 | 168 | // Define arguments for CUTLASS Convolution 169 | cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; 170 | 171 | // Split K dimension into 1 partitions 172 | int split_k_slices = 1; 173 | 174 | 175 | // Define output tensor 176 | // Assuming stride of 1, no padding, and no dilation for simplicity 177 | auto outputH = (H + padH + padH - R) / strideH + 1; 178 | auto outputW = (W + padW + padW - S) / strideW + 1; 179 | 180 | torch::Tensor output = torch::empty({N, outputH, outputW, K}, torch::dtype(torch::kInt32).device(input.device())); 181 | 182 | auto output_size = cutlass::Tensor4DCoord( 183 | N, 184 | outputH, 185 | outputW, 186 | K); 187 | 188 | // Construct Conv2dProblemSize with user defined output size 189 | cutlass::conv::Conv2dProblemSize problem_size( 190 | cutlass::Tensor4DCoord({N, H, W, C}), 191 | cutlass::Tensor4DCoord({K, R, S, C}), 192 | cutlass::Tensor4DCoord({padH, padH, padW, padW}), 193 | cutlass::MatrixCoord({strideH, strideW}), 194 | cutlass::MatrixCoord({dilationH, dilationW}), 195 | output_size, 196 | mode, 197 | split_k_slices 198 | ); 199 | 200 | cutlass::TensorRef input_ref( 201 | input.data_ptr(), cutlass::layout::TensorNHWC::packed(cutlass::Tensor4DCoord({N, H, W, C}))); 202 | 203 | cutlass::TensorRef filter_ref( 204 | filter.data_ptr(), cutlass::layout::TensorNHWC::packed(cutlass::Tensor4DCoord({K, R, S, C}))); 205 | 206 | cutlass::TensorRef output_ref( 207 | output.data_ptr(), cutlass::layout::TensorNHWC::packed(cutlass::Tensor4DCoord({N, outputH, outputW, K}))); 208 | 209 | // Set up arguments for the convolution operation 210 | typename ImplicitGemm::Arguments arguments{ 211 | problem_size, 212 | input_ref, 213 | filter_ref, 214 | output_ref, 215 | output_ref, 216 | {1, 0}}; 217 | 218 | // Launch the convolution 219 | ImplicitGemm convOp; 220 | auto status = convOp(arguments); 221 | 222 | TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)); 223 | 224 | return output; 225 | } 226 | 227 | template 228 | __global__ void myDequantizationConvKernel(KTorch *__restrict__ out, 229 | const int *__restrict__ x, 230 | const KTorch *__restrict__ zp_times_weight_channel_sum, 231 | const KTorch *__restrict__ act_times_weight_delta, 232 | const KTorch *__restrict__ y, 233 | const unsigned N, const unsigned H, const unsigned W, const unsigned C) { 234 | // Calculate indices for H, W, and C using block dimensions 235 | const unsigned nhw = blockIdx.x * blockDim.x + threadIdx.x; // Index for H 236 | const unsigned c = blockIdx.y * blockDim.y + threadIdx.y; // Index for W 237 | // const unsigned c = blockIdx.z * blockDim.z + threadIdx.z; // Index for C 238 | 239 | // Check if the current thread is within the bounds for H, W, and C 240 | if (nhw >= N*H*W || c >= C) { 241 | return; 242 | } 243 | 244 | // Split the combined index back into N and H indices 245 | // const unsigned n = nh / H; // Recover N 246 | // const unsigned h = nh % H; // Recover H 247 | 248 | // Calculate the linear index for the 4D array 249 | // unsigned index = n * H * W * C + h * W * C + w * C + c; 250 | unsigned index = nhw * C + c; 251 | 252 | // Convert int32_t element to float32 253 | float xElement = static_cast(x[index]); 254 | 255 | // Get zp_times_weight_channel_sum and act_times_weight_delta for the current channel 256 | float zp_times_weight_channel_sum_element = zp_times_weight_channel_sum[c]; 257 | float act_times_weight_delta_element = act_times_weight_delta[c]; 258 | 259 | xElement -= zp_times_weight_channel_sum_element; 260 | xElement *= act_times_weight_delta_element; 261 | xElement += y[c]; 262 | 263 | // Write the result back 264 | out[index] = xElement; 265 | 266 | } 267 | 268 | torch::Tensor myInt8ConvCUDA(const torch::Tensor &input, const torch::Tensor &filter, const int padH, const int padW, 269 | const int strideH, const int strideW, const int dilationH, const int dilationW, 270 | const torch::Tensor & zp_times_weight_channel_sum, 271 | const torch::Tensor & act_times_weight_delta, 272 | const torch::Tensor & y) { 273 | // Step 1: Perform GEMM Operation 274 | torch::Tensor C = int8ConvCUDA(input, filter, padH, padW, strideH, strideW, dilationH, dilationW); 275 | 276 | // Step 2: Setup for Dequantization 277 | auto N = C.size(0); 278 | auto H = C.size(1); 279 | auto W = C.size(2); 280 | auto Co = C.size(3); 281 | auto out = torch::empty_like(C, torch::dtype(torch::kFloat).device(C.device())); 282 | 283 | // Step 3: Dequantization Kernel Call 284 | dim3 threadsPerBlock(16, 16); // Adjust as necessary 285 | dim3 numBlocks((N*H*W + threadsPerBlock.x - 1) / threadsPerBlock.x, 286 | (Co + threadsPerBlock.y - 1) / threadsPerBlock.y); 287 | myDequantizationConvKernel<<>>(out.data_ptr(), 288 | C.data_ptr(), 289 | zp_times_weight_channel_sum.data_ptr(), 290 | act_times_weight_delta.data_ptr(), 291 | y.data_ptr(), N, H, W, Co); 292 | 293 | // Check for errors in kernel launch 294 | cudaError_t err = cudaGetLastError(); 295 | if (err != cudaSuccess) { 296 | fprintf(stderr, "CUDA Error: %s\n", cudaGetErrorString(err)); 297 | // Handle the error appropriately 298 | } 299 | 300 | // Step 4: Return the dequantized output 301 | // return out; 302 | 303 | // NHWC to NCHW 304 | torch::Tensor dst_nchw = torch::empty({N, Co, H, W}, 305 | torch::dtype(torch::kFloat).device(C.device())); 306 | 307 | // Step 3: Define tensor sizes for NHWC and NCHW 308 | auto input_tensor_size = cutlass::Tensor4DCoord({N, H, W, Co}); 309 | auto output_tensor_size = cutlass::Tensor4DCoord({N, Co, H, W}); 310 | 311 | // Step 4: Create Tensor Refs 312 | cutlass::TensorRef ref_input( 313 | out.data_ptr(), cutlass::layout::TensorNHWC::packed(cutlass::Tensor4DCoord({N,H,W,Co}))); 314 | cutlass::TensorRef ref_output( 315 | dst_nchw.data_ptr(), cutlass::layout::TensorNCHW::packed(cutlass::Tensor4DCoord({N,Co,H,W}))); 316 | 317 | // Call nchw_to_nhwc 318 | nhwc_to_nchw(input_tensor_size, output_tensor_size, ref_input, ref_output, 0); // Assuming default stream 319 | 320 | return dst_nchw; 321 | } 322 | } // namespace TORCHQ::matmul 323 | -------------------------------------------------------------------------------- /torch_quantizer/_C/asymmetric.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | def myQuantize(src: torch.Tensor, act_delta: torch.Tensor, act_zp: torch.Tensor) -> torch.Tensor: ... 5 | 6 | def myQuantizeNCHW(src: torch.Tensor, act_delta: torch.Tensor, act_zp: torch.Tensor) -> torch.Tensor: ... -------------------------------------------------------------------------------- /torch_quantizer/_C/matmul.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def myInt8Matmul(A: torch.Tensor, B: torch.Tensor, zp_times_weight_channel_sum, act_times_weight_delta, bias) -> torch.Tensor: ... 4 | 5 | def myInt8Conv(A: torch.Tensor, B: torch.Tensor, padH, padW, strideH, strideW, dilationH, dilationW, zp_times_weight_channel_sum, act_times_weight_delta, bias) -> torch.Tensor: ... -------------------------------------------------------------------------------- /torch_quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_quantizer._C import matmul, asymmetric 2 | from torch_quantizer.src.converter import fake_quant, real_quant, fake2real 3 | from torch_quantizer.src.benchmark import benchmark_conv2d, benchmark_linear 4 | from torch_quantizer.version import __version__ 5 | 6 | __all__ = ['matmul', 'asymmetric', 'fake_quant', 'real_quant', 'fake2real', 'benchmark_conv2d' , 'benchmark_linear'] -------------------------------------------------------------------------------- /torch_quantizer/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThisisBillhe/torch_quantizer/041908ecdbce24788af5c71ea7cbf1a17a716fb7/torch_quantizer/src/__init__.py -------------------------------------------------------------------------------- /torch_quantizer/src/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_quantizer 5 | 6 | def calculate_channelwise_quant_params(weight_matrix, n_bits=8): 7 | # Assuming each row is a channel 8 | min_vals = weight_matrix.min(dim=1).values 9 | max_vals = weight_matrix.max(dim=1).values 10 | 11 | # Calculate delta (scale) and zero-point for each channel 12 | deltas = (max_vals - min_vals) / (2 ** n_bits - 1) 13 | zero_points = -min_vals / deltas 14 | zero_points = zero_points.round().clamp(0, 2 ** n_bits - 1) 15 | 16 | return deltas, zero_points 17 | 18 | def calculate_channelwise_symmetric_scale(weight_matrix, n_levels=256): 19 | # Assuming each row is a channel 20 | # Calculate the maximum absolute value in each channel 21 | max_abs_vals = weight_matrix.abs().max(dim=1).values 22 | 23 | # Calculate delta (scale) for each channel 24 | # For symmetric quantization, we use half the quantization levels for positive and half for negative 25 | half_range = (n_levels // 2) - 1 26 | deltas = max_abs_vals / half_range 27 | 28 | return deltas 29 | 30 | def quantize(x, n_bits): 31 | xmax = torch.max(x) 32 | xmin = torch.min(x) 33 | delta = (xmax - xmin) / (2 ** n_bits - 1) 34 | zero_point = (-128 - xmin / delta).round() 35 | 36 | return delta, zero_point 37 | 38 | def calculate_channelwise_symmetric_scale_4D(weights): 39 | """ 40 | Calculate per-channel symmetric quantization scales for a convolution weight tensor. 41 | Weights tensor should have shape [Co, Cin, K, K]. 42 | """ 43 | Co, Cin, K, K = weights.shape 44 | half_range = 127 # for int8 quantization 45 | scale_factors = torch.zeros(Co).cuda() 46 | 47 | for i in range(Co): 48 | # Find the maximum absolute value in this output channel 49 | max_abs_val = torch.max(torch.abs(weights[i])) 50 | 51 | # Calculate the scale factor for this channel 52 | scale_factors[i] = max_abs_val / half_range if max_abs_val != 0 else 0 53 | 54 | return scale_factors 55 | 56 | def benchmark_linear(bs=512, cin=960, cout=960): 57 | assert cin % 32 == 0 and cout % 32 == 0, 'cin and cout should be divisible by 32' 58 | 59 | x = 5 * torch.randn(bs,cin).cuda().to(torch.float16) 60 | 61 | linearfp = nn.Linear(cout,cin).cuda().half() 62 | 63 | bias = linearfp.bias.data.to(torch.float32) 64 | weight = linearfp.weight.data.to(torch.float32) 65 | 66 | weight_delta = calculate_channelwise_symmetric_scale(weight).unsqueeze(-1) 67 | # weight_delta = weight_delta.transpose(1,0) 68 | int_weight = (linearfp.weight.data / weight_delta).round().to(torch.int8) 69 | 70 | act_delta, act_zp = quantize(x, n_bits=8) 71 | 72 | 73 | dequantized_weight = int_weight * weight_delta 74 | zp_times_weight_channel_sum = act_zp.to(torch.float32) * int_weight.sum(dim=1) 75 | act_times_weight_delta = act_delta.to(torch.float32) * weight_delta.to(torch.float32).squeeze(0) 76 | 77 | int_act = ((x / act_delta).round() + act_zp).to(torch.int8) 78 | dequantized_act = (int_act - act_zp) * act_delta 79 | 80 | out_int_fake = dequantized_act.to(torch.float32) @ dequantized_weight.T + bias 81 | 82 | ## warm up 83 | out_fp = linearfp(x) 84 | int_x = torch_quantizer.asymmetric.myQuantize(x, act_delta, act_zp) 85 | out_int8 = torch_quantizer.matmul.myInt8Matmul(int_x, int_weight, zp_times_weight_channel_sum, act_times_weight_delta, bias) 86 | 87 | ## start benchmark 88 | import time 89 | 90 | linearfp.float() 91 | start_time = time.perf_counter() 92 | torch.cuda.synchronize() 93 | for i in range(100): 94 | out_fp = linearfp(x.to(torch.float32)) 95 | torch.cuda.synchronize() 96 | end_time = time.perf_counter() 97 | print('average time for FP32: ', (end_time-start_time) / 100) 98 | 99 | linearfp.half() 100 | start_time = time.perf_counter() 101 | torch.cuda.synchronize() 102 | for i in range(100): 103 | out_fp = linearfp(x) 104 | torch.cuda.synchronize() 105 | end_time = time.perf_counter() 106 | print('average time for FP16: ', (end_time-start_time) / 100) 107 | 108 | start_time = time.perf_counter() 109 | torch.cuda.synchronize() 110 | for i in range(100): 111 | # int_x = ((x / act_delta).round() + act_zp).to(torch.int8) 112 | int_x = torch_quantizer.asymmetric.myQuantize(x, act_delta, act_zp) 113 | out_int8 = torch_quantizer.matmul.myInt8Matmul(int_x, int_weight, zp_times_weight_channel_sum, act_times_weight_delta, bias) 114 | torch.cuda.synchronize() 115 | end_time = time.perf_counter() 116 | print('average time for INT8 (Quant+Dequant): ', (end_time-start_time) / 100) 117 | 118 | def benchmark_conv2d(bs=8, cin=384, h=32, w=32, cout=384, k=3, padding=0): 119 | x = (torch.randn((bs,cin,h,w)) + 3).half().cuda() 120 | conv1 = nn.Conv2d(cout,cin,k, padding=padding).half().cuda() 121 | 122 | conv1.weight.requires_grad = False 123 | conv1.bias.requires_grad = False 124 | 125 | weight = conv1.weight.data 126 | weight_delta = calculate_channelwise_symmetric_scale_4D(weight).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 127 | int_weight = (weight / weight_delta).round().to(torch.int8) 128 | weight_nhwc = int_weight.permute(0,2,3,1).contiguous() 129 | dequantized_weight = int_weight * weight_delta 130 | 131 | act_delta, act_zp = quantize(x, n_bits=8) 132 | int_act = ((x / act_delta).round() + act_zp).to(torch.int8) 133 | dequantized_act = (int_act - act_zp) * act_delta 134 | 135 | bias = conv1.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 136 | 137 | ## warm up 138 | out_fp = conv1(x) 139 | 140 | import time 141 | 142 | conv1.float() 143 | start_time = time.perf_counter() 144 | torch.cuda.synchronize() 145 | for i in range(100): 146 | out_fp = conv1(x.to(torch.float32)) 147 | torch.cuda.synchronize() 148 | end_time = time.perf_counter() 149 | print('average time for FP32: ', (end_time-start_time) / 100) 150 | 151 | conv1.half() 152 | start_time = time.perf_counter() 153 | torch.cuda.synchronize() 154 | for i in range(100): 155 | out_fp = conv1(x) 156 | torch.cuda.synchronize() 157 | end_time = time.perf_counter() 158 | print('average time for FP16: ', (end_time-start_time) / 100) 159 | 160 | ## int8Conv + fused dequantization 161 | zp_times_weight_channel_sum = act_zp.to(torch.float32) * int_weight.sum(dim=(1,2,3)) 162 | act_times_weight_delta = act_delta.to(torch.float32) * weight_delta.reshape(-1,) 163 | bias = conv1.bias.to(torch.float32) 164 | start_time = time.perf_counter() 165 | torch.cuda.synchronize() 166 | for i in range(100): 167 | x_nhwc = torch_quantizer.asymmetric.myQuantizeNCHW(x, act_delta, act_zp) 168 | if padding != 0: 169 | x_nhwc = F.pad(x_nhwc, pad=(0,0,padding,padding,padding,padding), value=act_zp) 170 | out_int8_fused = torch_quantizer.matmul.myInt8Conv(x_nhwc, weight_nhwc, 0, 0, 1, 1, 1, 1, zp_times_weight_channel_sum, act_times_weight_delta, bias) 171 | torch.cuda.synchronize() 172 | end_time = time.perf_counter() 173 | print('average time for INT8 (Quant+Dequant): ', (end_time-start_time) / 100) 174 | -------------------------------------------------------------------------------- /torch_quantizer/src/converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch_quantizer.src.quant_model import QuantModel, FakeQuantModel 5 | 6 | def fake_quant(model, weight_quant_params, act_quant_params, num_steps=1): 7 | ''' 8 | Quantize a floating-point model to fake quantized model. The fakeq model will be used for calibration. 9 | 10 | Args: 11 | model (YourFloatingPointModelClass): The original floating-point model to be quantized. 12 | weight_quant_params: a dict specifies n_bits, channel_wise and scale_method. For example: {'n_bits': n_bits_w, 'channel_wise': True, 'scale_method': 'max'}. 13 | act_quant_params: same as weight_quant_params. 14 | num_steps (int, optional): The number of quantization steps. Default is 1. 15 | Returns: 16 | FakeQModelClass: The fake quantized model. 17 | ''' 18 | 19 | fakeq_model = FakeQuantModel(model, weight_quant_params, act_quant_params, num_steps).model 20 | setattr(fakeq_model, 'num_steps', num_steps) 21 | setattr(fakeq_model, 'n_bits', 8) 22 | 23 | return fakeq_model 24 | 25 | def real_quant(model, n_bits=8, num_steps=1, ckpt_path=None): 26 | ''' 27 | Quantize a floating-point model to INT8. 28 | 29 | Args: 30 | model (YourFloatingPointModelClass): The original floating-point model to be quantized. 31 | num_steps (int, optional): The number of quantization steps. Default is 1. 32 | ckpt_path (str, optional): The path to a pre-trained INT8 model checkpoint. If not provided, 33 | an INT8 model will be randomly initialized (which can be used to benchmark). 34 | Returns: 35 | YourINT8ModelClass: The quantized INT8 model. 36 | ''' 37 | if ckpt_path is not None: 38 | print('Restoring INT8 models from {}'.format(ckpt_path)) 39 | else: 40 | print('Get INT8 models without checkpoint...') 41 | realq_model = QuantModel(model, n_bits=n_bits, n_steps=num_steps) 42 | realq_model.half() 43 | realq_model = realq_model.model ## get rid of the wrapper 44 | 45 | if ckpt_path is not None: 46 | ckpt = torch.load(ckpt_path, map_location='cpu') 47 | realq_model.load_state_dict(ckpt, strict=False) 48 | 49 | setattr(realq_model, 'num_steps', num_steps) 50 | 51 | return realq_model 52 | 53 | 54 | 55 | def fake2real(fakeq_model, save_dir='.'): 56 | n_levels = 2**8 57 | ckpt = fakeq_model.state_dict() 58 | for k in list(ckpt.keys()): 59 | if 'weight_quantizer' in k and len(ckpt[k[:-16]].shape)==2: ## Linear layer 60 | prefix = k[:-22] ## with '.' in the end 61 | weight_k = k[:-16] 62 | weight_delta = ckpt[k].reshape(-1,1) 63 | int_weight = torch.clamp((ckpt[weight_k] / weight_delta).round(), -n_levels//2, n_levels//2 - 1).to(torch.int8) 64 | ckpt[prefix + 'int_weight'] = int_weight 65 | # ckpt[prefix + 'weight_delta'] = weight_delta.transpose(1,0) 66 | del ckpt[weight_k] 67 | del ckpt[k] 68 | 69 | act_delta_list_k = k[:-22]+'act_quantizer.delta_list' 70 | act_zp_list_k = k[:-22]+'act_quantizer.zp_list' 71 | act_zp = ckpt[act_zp_list_k].clone().detach().reshape(-1,) 72 | act_delta = ckpt[act_delta_list_k].clone().detach().reshape(-1,) 73 | ckpt[prefix + 'act_delta'] = act_delta 74 | ckpt[prefix + 'act_zp'] = act_zp 75 | ckpt[prefix + 'zp_times_weight_channel_sum'] = act_zp.unsqueeze(-1) * int_weight.sum(dim=1).unsqueeze(0).to(torch.float32) 76 | ckpt[prefix + 'act_times_weight_delta'] = act_delta.unsqueeze(-1) * weight_delta.reshape(1,-1) 77 | if ckpt[prefix + 'zp_times_weight_channel_sum'].isnan().any() or ckpt[prefix + 'zp_times_weight_channel_sum'].isinf().any(): 78 | print('nan or inf!') 79 | del ckpt[act_delta_list_k] 80 | del ckpt[act_zp_list_k] 81 | # del ckpt[act_delta_list_k[:-5]] 82 | # del ckpt[act_delta_list_k[:-10]+'zero_point'] 83 | 84 | elif 'weight_quantizer' in k and len(ckpt[k[:-16]].shape)==4: ## Conv layer 85 | prefix = k[:-22] ## with '.' in the end 86 | weight_k = k[:-16] 87 | weight_delta = ckpt[k] ## (Co, 1, 1, 1) 88 | int_weight = torch.clamp((ckpt[weight_k] / weight_delta).round(), -n_levels//2, n_levels//2 - 1).to(torch.int8) 89 | weight_nhwc = int_weight.permute(0,2,3,1).contiguous() 90 | ckpt[prefix + 'int_weight'] = weight_nhwc 91 | del ckpt[weight_k] 92 | del ckpt[k] 93 | 94 | act_delta_list_k = k[:-22]+'act_quantizer.delta_list' 95 | act_zp_list_k = k[:-22]+'act_quantizer.zp_list' 96 | act_zp = ckpt[act_zp_list_k].clone().detach().reshape(-1,) 97 | act_delta = ckpt[act_delta_list_k].clone().detach().reshape(-1,) 98 | ckpt[prefix + 'act_delta'] = act_delta 99 | ckpt[prefix + 'act_zp'] = act_zp 100 | ckpt[prefix + 'zp_times_weight_channel_sum'] = act_zp.unsqueeze(-1) * int_weight.sum(dim=(1,2,3)).unsqueeze(0).to(torch.float32) 101 | ckpt[prefix + 'act_times_weight_delta'] = act_delta.unsqueeze(-1) * weight_delta.reshape(1,-1) 102 | if ckpt[prefix + 'zp_times_weight_channel_sum'].isnan().any() or ckpt[prefix + 'zp_times_weight_channel_sum'].isinf().any(): 103 | print('nan or inf!') 104 | del ckpt[act_delta_list_k] 105 | del ckpt[act_zp_list_k] 106 | # del ckpt[act_delta_list_k[:-5]] 107 | # del ckpt[act_delta_list_k[:-10]+'zero_point'] 108 | 109 | model_name = fakeq_model.__class__.__name__ 110 | save_path = os.path.join(save_dir, '{}_8bits_{}steps.pth'.format(model_name, fakeq_model.num_steps)) 111 | print('Saving quantized checkpoint to {}'.format(save_path)) 112 | torch.save(ckpt, save_path) 113 | 114 | realq_model = QuantModel(fakeq_model, n_bits=fakeq_model.n_bits, n_steps=fakeq_model.num_steps).to(next(fakeq_model.parameters()).device) 115 | realq_model.half() 116 | realq_model = realq_model.model ## to get rid of redundent fakeq_model 117 | is_compatible = realq_model.load_state_dict(ckpt, strict=False) ## we assume there is bias for every layer, which may cause missing keys. 118 | 119 | return realq_model -------------------------------------------------------------------------------- /torch_quantizer/src/quant_layer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import Union 6 | import math 7 | import torch_quantizer 8 | from .quant_utils import SymmetricQuantizer, naiveTemporalQuantizer 9 | 10 | class FakeQuantModule(nn.Module): 11 | """ 12 | Quantized Module that can perform quantized convolution or normal convolution. 13 | To activate quantization, please use set_quant_state function. 14 | """ 15 | def __init__(self, org_module: Union[nn.Conv2d, nn.Linear], weight_quant_params: dict = {}, 16 | act_quant_params: dict = {}, num_steps=1): 17 | super(FakeQuantModule, self).__init__() 18 | if isinstance(org_module, nn.Conv2d): 19 | self.stride = org_module.stride 20 | self.padding = org_module.padding 21 | self.dilation = org_module.dilation 22 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding, 23 | dilation=org_module.dilation, groups=org_module.groups) 24 | self.fwd_func = F.conv2d 25 | else: 26 | self.fwd_kwargs = dict() 27 | self.fwd_func = F.linear 28 | self.weight = org_module.weight 29 | self.org_weight = org_module.weight.data.clone() 30 | if org_module.bias is not None: 31 | self.bias = org_module.bias 32 | self.org_bias = org_module.bias.data.clone() 33 | else: 34 | self.bias = None 35 | self.org_bias = None 36 | 37 | # initialize quantizer 38 | self.weight_quantizer = SymmetricQuantizer(**weight_quant_params) 39 | 40 | self.act_quantizer = naiveTemporalQuantizer(**act_quant_params, num_steps=num_steps) 41 | 42 | def forward(self, input: torch.Tensor): 43 | weight = self.weight_quantizer(self.weight) 44 | bias = self.bias 45 | input = self.act_quantizer(input) 46 | 47 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 48 | 49 | return out 50 | 51 | class qlinear_8bit(nn.Module): 52 | """ 53 | 8-bit Linear Module. 54 | """ 55 | def __init__(self, org_module: nn.Linear, n_bits=8, num_steps=1): 56 | super(qlinear_8bit, self).__init__() 57 | 58 | ## Copy attributes from org_module 59 | self.in_features = org_module.in_features 60 | self.out_features = org_module.out_features 61 | 62 | self.fwd_kwargs = dict() 63 | # self.org_weight = org_module.weight.data.clone() 64 | 65 | self.ori_shape = org_module.weight.shape 66 | self.n_bits = n_bits 67 | # self.register_buffer('int_weight', torch.randint(-128, 127, (self.ori_shape[0], self.ori_shape[1]), 68 | # dtype=torch.int8, requires_grad=False)) 69 | self.register_buffer('int_weight', torch.zeros((self.ori_shape[0], self.ori_shape[1]), dtype=torch.int8)) 70 | 71 | if org_module.bias is not None: 72 | self.register_buffer('bias', org_module.bias.data) 73 | else: 74 | self.register_buffer('bias', torch.zeros(size=(self.ori_shape[0],))) 75 | 76 | # de-activate the quantized forward default 77 | self.use_weight_quant = False 78 | self.use_act_quant = False 79 | 80 | self.ignore_reconstruction = False 81 | 82 | self.register_buffer('act_delta',torch.randn(size=(num_steps,), dtype=torch.float16)) ## should be float16 83 | self.register_buffer('act_zp',torch.randn(size=(num_steps,), dtype=torch.float16)) ## should be float16 84 | self.register_buffer('zp_times_weight_channel_sum',torch.randn(size=(num_steps, self.ori_shape[0]), dtype=torch.float32)) ## should be float32 85 | self.register_buffer('act_times_weight_delta',torch.randn(size=(num_steps, self.ori_shape[0]), dtype=torch.float32)) ## should be float32 86 | 87 | self.total_steps = num_steps 88 | self.current_step = self.total_steps - 1 89 | 90 | def forward(self, input: torch.Tensor): 91 | ## fetch quantization parameters 92 | act_delta = self.act_delta[self.current_step] 93 | act_zp = self.act_zp[self.current_step] 94 | zp_times_weight_channel_sum = self.zp_times_weight_channel_sum[self.current_step] 95 | act_times_weight_delta = self.act_times_weight_delta[self.current_step] 96 | 97 | self.current_step = self.total_steps - 1 if self.current_step - 1 < 0 else self.current_step - 1 98 | 99 | ## perform linear operation 100 | if len(input.shape) != 2: 101 | original_shape = input.shape[:-1] 102 | input = input.view(-1, input.shape[-1]) 103 | else: 104 | original_shape = None 105 | 106 | int_x = torch_quantizer.asymmetric.myQuantize(input, act_delta, act_zp) 107 | output = torch_quantizer.matmul.myInt8Matmul(int_x, self.int_weight, zp_times_weight_channel_sum, act_times_weight_delta, self.bias) 108 | 109 | if original_shape is not None: 110 | output = output.view(original_shape + (-1,)).to(torch.float16) 111 | else: 112 | output = output.to(torch.float16) 113 | 114 | return output 115 | 116 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 117 | self.use_weight_quant = weight_quant 118 | self.use_act_quant = act_quant 119 | 120 | class qconv2d_8bit(nn.Module): 121 | """ 122 | 8-bit Conv2d Module. 123 | """ 124 | def __init__(self, org_module: nn.Conv2d, n_bits=8, num_steps=1): 125 | super(qconv2d_8bit, self).__init__() 126 | self.fwd_kwargs = dict(strideH=org_module.stride[0], strideW=org_module.stride[1], padH=org_module.padding[0], padW=org_module.padding[1], 127 | dilationH=org_module.dilation[0], dilationW=org_module.dilation[1]) 128 | self.ori_shape = org_module.weight.shape 129 | self.weight_nhwc_shape = [self.ori_shape[0], self.ori_shape[2], self.ori_shape[3], self.ori_shape[1]] 130 | 131 | self.n_bits = n_bits 132 | self.register_buffer('int_weight', torch.randint(-128, 127, self.weight_nhwc_shape, 133 | dtype=torch.int8, requires_grad=False)) 134 | 135 | if org_module.bias is not None: 136 | self.register_buffer('bias', org_module.bias.data) 137 | 138 | else: 139 | self.register_buffer('bias', torch.zeros(size=(self.ori_shape[0],))) 140 | 141 | # de-activate the quantized forward default 142 | self.use_weight_quant = False 143 | self.use_act_quant = False 144 | 145 | self.ignore_reconstruction = False 146 | 147 | self.register_buffer('act_delta',torch.randn(size=(num_steps,), dtype=torch.float16)) ## should be float16 148 | self.register_buffer('act_zp',torch.randn(size=(num_steps,), dtype=torch.float16)) ## should be float16 149 | self.register_buffer('zp_times_weight_channel_sum',torch.randn(size=(num_steps, self.ori_shape[0]), dtype=torch.float32)) ## should be float32 150 | self.register_buffer('act_times_weight_delta',torch.randn(size=(num_steps, self.ori_shape[0]), dtype=torch.float32)) ## should be float32 151 | 152 | self.total_steps = num_steps 153 | self.current_step = self.total_steps - 1 154 | 155 | def forward(self, input: torch.Tensor): 156 | ## fetch quantization parameters 157 | act_delta = self.act_delta[self.current_step] 158 | act_zp = self.act_zp[self.current_step] 159 | zp_times_weight_channel_sum = self.zp_times_weight_channel_sum[self.current_step] 160 | act_times_weight_delta = self.act_times_weight_delta[self.current_step] 161 | 162 | self.current_step = self.total_steps - 1 if self.current_step - 1 < 0 else self.current_step - 1 163 | 164 | ## perform conv operation 165 | if not input.is_contiguous(): 166 | input = input.contiguous() 167 | x_nhwc = torch_quantizer.asymmetric.myQuantizeNCHW(input, act_delta, act_zp) ## why some input can be non-contiguous? 168 | if self.fwd_kwargs['padH'] > 0: 169 | x_nhwc = F.pad(x_nhwc, pad=(0,0,self.fwd_kwargs['padH'],self.fwd_kwargs['padH'],self.fwd_kwargs['padW'],self.fwd_kwargs['padW']), value=act_zp) 170 | output_nchw = torch_quantizer.matmul.myInt8Conv(x_nhwc, self.int_weight, 0,0, self.fwd_kwargs['strideH'],\ 171 | self.fwd_kwargs['strideW'],self.fwd_kwargs['dilationH'],self.fwd_kwargs['dilationW'],\ 172 | zp_times_weight_channel_sum, act_times_weight_delta, self.bias).to(torch.float16) 173 | 174 | return output_nchw 175 | 176 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 177 | self.use_weight_quant = weight_quant 178 | self.use_act_quant = act_quant 179 | -------------------------------------------------------------------------------- /torch_quantizer/src/quant_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .quant_layer import qconv2d_8bit, qlinear_8bit, FakeQuantModule 5 | 6 | class QuantModel(nn.Module): 7 | 8 | def __init__(self, model: nn.Module, n_bits=8, n_steps=1, skip_layer=None): 9 | super().__init__() 10 | self.model = model 11 | self.count = 0 12 | self.skip_layer = [] if skip_layer is None else skip_layer 13 | self.quant_module_refactor(self.model, n_bits=n_bits, n_steps=n_steps) 14 | 15 | def quant_module_refactor(self, module: nn.Module, n_bits=8, n_steps=1): 16 | """ 17 | Recursively replace the normal conv2d and Linear layer to QuantModule 18 | """ 19 | for name, child_module in module.named_children(): 20 | if isinstance(child_module, nn.Conv2d): 21 | self.count += 1 22 | if self.count not in self.skip_layer: 23 | Cout, Cin = child_module.weight.shape[0], child_module.weight.shape[1] 24 | if Cout % 32 == 0 and Cin % 32 ==0: 25 | setattr(module, name, qconv2d_8bit(child_module, n_bits=n_bits, num_steps=n_steps)) 26 | else: 27 | pass 28 | else: 29 | pass 30 | elif isinstance(child_module, nn.Linear): 31 | self.count += 1 32 | if self.count not in self.skip_layer: 33 | Cout, Cin = child_module.weight.shape[0], child_module.weight.shape[1] 34 | if Cout % 32 == 0 and Cin % 32 ==0: 35 | setattr(module, name, qlinear_8bit(child_module, n_bits=n_bits, num_steps=n_steps)) 36 | else: 37 | pass 38 | else: 39 | pass 40 | elif isinstance(child_module, FakeQuantModule): 41 | if child_module.fwd_func is F.linear: 42 | setattr(module, name, qlinear_8bit(child_module, n_bits=n_bits, num_steps=n_steps)) 43 | elif child_module.fwd_func is F.conv2d: 44 | setattr(module, name, qconv2d_8bit(child_module, n_bits=n_bits, num_steps=n_steps)) 45 | else: 46 | self.quant_module_refactor(child_module, n_bits=n_bits, n_steps=n_steps) 47 | 48 | def forward(self, *args, **kwargs): 49 | return self.model(*args, **kwargs) 50 | 51 | def half(self): 52 | for name, param in self.named_parameters(): 53 | param.data = param.data.half() 54 | for name, buf in self.named_buffers(): 55 | if 'zp_times_weight_channel_sum' in name or 'act_times_weight_delta' in name or 'bias' in name: 56 | buf.data = buf.data.float() 57 | elif 'int_weight' not in name: 58 | buf.data = buf.data.half() ## these data is required to be float32 for cuda kernel 59 | 60 | class FakeQuantModel(nn.Module): 61 | 62 | def __init__(self, model: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}, num_steps=1): 63 | super().__init__() 64 | self.model = model 65 | self.num_steps = num_steps 66 | self.n_bits = weight_quant_params['n_bits'] 67 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params, num_steps=num_steps) 68 | 69 | def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, act_quant_params: dict = {}, num_steps=1): 70 | """ 71 | Recursively replace the normal conv2d and Linear layer to FakeQuantModule 72 | """ 73 | for name, child_module in module.named_children(): 74 | if isinstance(child_module, (nn.Conv2d, nn.Linear)): 75 | Cout, Cin = child_module.weight.shape[0], child_module.weight.shape[1] 76 | if Cout % 32 == 0 and Cin % 32 ==0: 77 | setattr(module, name, FakeQuantModule(child_module, weight_quant_params, act_quant_params, num_steps=num_steps)) 78 | else: 79 | pass 80 | else: 81 | self.quant_module_refactor(child_module, weight_quant_params, act_quant_params, num_steps=num_steps) 82 | 83 | def forward(self, *args, **kwargs): 84 | return self.model(*args, **kwargs) -------------------------------------------------------------------------------- /torch_quantizer/src/quant_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def round_ste(x: torch.Tensor): 7 | """ 8 | Implement Straight-Through Estimator for rounding operation. 9 | """ 10 | return (x.round() - x).detach() + x 11 | 12 | 13 | def lp_loss(pred, tgt, p=2.0, reduction='none'): 14 | """ 15 | loss function measured in L_p Norm 16 | """ 17 | if reduction == 'none': 18 | return (pred-tgt).abs().pow(p).sum(1).mean() 19 | else: 20 | return (pred-tgt).abs().pow(p).mean() 21 | 22 | class SymmetricQuantizer(nn.Module): 23 | """ 24 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine 25 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight 26 | through' on the backward pass, ignoring the quantization that occurred. 27 | Based on https://arxiv.org/abs/1806.08342. 28 | 29 | :param n_bits: number of bit for quantization 30 | :param symmetric: if True, the zero_point should always be 0 31 | :param channel_wise: if True, compute scale and zero_point in each channel 32 | :param scale_method: determines the quantization scale and zero point 33 | """ 34 | def __init__(self, n_bits: int = 8, channel_wise: bool = False, scale_method: str = 'max'): 35 | super(SymmetricQuantizer, self).__init__() 36 | assert n_bits == 8, 'bitwidth not supported' 37 | self.n_bits = n_bits 38 | self.n_levels = 2 ** self.n_bits 39 | self.inited = True ## modified here 40 | self.channel_wise = channel_wise 41 | self.scale_method = scale_method 42 | 43 | self.inited = False # use this when quantizing models 44 | self.register_buffer('delta', torch.tensor(0.005)) 45 | 46 | def clipping(self, x, lower, upper): 47 | # clip lower 48 | x = x + F.relu(lower - x) 49 | # clip upper 50 | x = x - F.relu(x - upper) 51 | 52 | return x 53 | 54 | def forward(self, x: torch.Tensor): 55 | 56 | if self.inited is False: 57 | delta= self.init_quantization_scale(x, self.channel_wise) 58 | self.delta = torch.nn.Parameter(delta) 59 | 60 | self.inited = True 61 | 62 | # start quantization 63 | x_int = round_ste(x / self.delta) 64 | x_quant = self.clipping(x_int, -self.n_levels//2, self.n_levels//2 - 1) 65 | # x_quant = torch.clamp(x_int, 0, self.n_levels - 1) 66 | x_dequant = x_quant * self.delta 67 | return x_dequant 68 | 69 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): 70 | delta = None 71 | if channel_wise: 72 | x_clone = x.clone().detach() 73 | n_channels = x_clone.shape[0] 74 | if len(x.shape) == 4: 75 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] 76 | else: 77 | x_max = x_clone.abs().max(dim=-1)[0] 78 | delta = x_max.clone() 79 | 80 | ## comment below for faster initialization in inference 81 | # determine the scale and zero point channel-by-channel 82 | for c in range(n_channels): 83 | delta[c]= self.init_quantization_scale(x_clone[c], channel_wise=False) 84 | 85 | if len(x.shape) == 4: 86 | delta = delta.view(-1, 1, 1, 1) 87 | else: 88 | delta = delta.view(-1, 1) 89 | else: 90 | if 'max' in self.scale_method: 91 | x_min = min(x.min().item(), 0) 92 | x_max = max(x.max().item(), 0) 93 | ## symmetric 94 | x_absmax = max(abs(x_min), x_max) 95 | x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 96 | 97 | delta = float(x_max - x_min) / (self.n_levels - 1) 98 | if delta < 1e-8: 99 | warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max)) 100 | delta = 1e-8 101 | 102 | zero_point = round(-x_min / delta) 103 | delta = torch.tensor(delta).type_as(x) 104 | 105 | elif self.scale_method == 'mse': 106 | x_max = x.max() 107 | x_min = x.min() 108 | best_score = 1e+10 109 | # Initial factors 110 | start_factor = 0.95 111 | end_factor = 1.05 112 | 113 | # Calculate the step increase per iteration 114 | step = (end_factor - start_factor) / 80 115 | 116 | for i in range(80): 117 | factor = start_factor + i * step 118 | new_max = x_max * factor 119 | new_min = x_min * factor 120 | x_q = self.quantize(x, new_max, new_min) 121 | # L_p norm minimization as described in LAPQ 122 | # https://arxiv.org/abs/1911.07190 123 | score = lp_loss(x, x_q, p=2.4, reduction='all') 124 | if score < best_score: 125 | best_score = score 126 | delta = (new_max - new_min) / (2 ** self.n_bits - 1) 127 | else: 128 | raise NotImplementedError 129 | 130 | return delta 131 | 132 | def quantize(self, x, max, min): 133 | delta = (max - min) / (2 ** self.n_bits - 1) 134 | # we assume weight quantization is always signed 135 | x_int = torch.round(x / delta) 136 | x_quant = torch.clamp(x_int, -self.n_levels//2, self.n_levels//2 - 1) 137 | x_float_q = x_quant * delta 138 | return x_float_q 139 | 140 | class naiveTemporalQuantizer(nn.Module): 141 | """ 142 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine 143 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight 144 | through' on the backward pass, ignoring the quantization that occurred. 145 | Based on https://arxiv.org/abs/1806.08342. 146 | 147 | :param n_bits: number of bit for quantization 148 | :param symmetric: if True, the zero_point should always be 0 149 | :param channel_wise: if True, compute scale and zero_point in each channel 150 | :param scale_method: determines the quantization scale and zero point 151 | """ 152 | def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, scale_method: str = 'max', 153 | num_steps = 1): 154 | super(naiveTemporalQuantizer, self).__init__() 155 | self.sym = symmetric 156 | assert n_bits == 8, 'bitwidth not supported' 157 | self.n_bits = n_bits 158 | self.n_levels = 2 ** self.n_bits 159 | 160 | self.total_steps = num_steps 161 | self.current_step = self.total_steps - 1 162 | 163 | self.register_buffer('delta_list', torch.tensor([torch.tensor(0.005) for _ in range(self.total_steps)])) 164 | self.register_buffer('zp_list', torch.tensor([torch.tensor(0.005) for _ in range(self.total_steps)])) 165 | 166 | self.inited = False 167 | self.channel_wise = channel_wise 168 | self.scale_method = scale_method 169 | 170 | def clipping(self, x, lower, upper): 171 | # clip lower 172 | x = x + F.relu(lower - x) 173 | # clip upper 174 | x = x - F.relu(x - upper) 175 | 176 | return x 177 | 178 | def forward(self, x: torch.Tensor): 179 | if self.inited is False: 180 | if self.current_step == 0: 181 | self.inited = True 182 | delta, zero_point = self.init_quantization_scale(x, self.channel_wise) 183 | self.delta_list[self.current_step] = delta 184 | self.zp_list[self.current_step] = zero_point 185 | 186 | x_int = round_ste(x / self.delta_list[self.current_step]) + round_ste(self.zp_list[self.current_step]) 187 | x_quant = self.clipping(x_int, -self.n_levels//2, self.n_levels//2 - 1) ## modified here to replace torch.clamp for gradient prop 188 | x_dequant = (x_quant - round_ste(self.zp_list[self.current_step])) * self.delta_list[self.current_step] 189 | self.current_step = self.total_steps - 1 if self.current_step - 1 < 0 else self.current_step - 1 190 | return x_dequant 191 | else: 192 | x_int = round_ste(x / self.delta_list[self.current_step]) + round_ste(self.zp_list[self.current_step]) 193 | x_quant = self.clipping(x_int, -self.n_levels//2, self.n_levels//2 - 1) ## modified here to replace torch.clamp for gradient prop 194 | x_dequant = (x_quant - round_ste(self.zp_list[self.current_step])) * self.delta_list[self.current_step] 195 | self.current_step = self.total_steps - 1 if self.current_step - 1 < 0 else self.current_step - 1 196 | return x_dequant 197 | 198 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): 199 | delta, zero_point = None, None 200 | if channel_wise: 201 | x_clone = x.clone().detach() 202 | n_channels = x_clone.shape[0] 203 | if len(x.shape) == 4: 204 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] 205 | else: 206 | x_max = x_clone.abs().max(dim=-1)[0] 207 | delta = x_max.clone() 208 | zero_point = x_max.clone() 209 | 210 | ## comment below for faster initialization in inference 211 | # determine the scale and zero point channel-by-channel 212 | for c in range(n_channels): 213 | delta[c], zero_point[c] = self.init_quantization_scale(x_clone[c], channel_wise=False) 214 | 215 | if len(x.shape) == 4: 216 | delta = delta.view(-1, 1, 1, 1) 217 | zero_point = zero_point.view(-1, 1, 1, 1) 218 | else: 219 | delta = delta.view(-1, 1) 220 | zero_point = zero_point.view(-1, 1) 221 | else: 222 | if 'max' in self.scale_method: 223 | x_min = min(x.min().item(), 0) 224 | x_max = max(x.max().item(), 0) 225 | if 'scale' in self.scale_method: 226 | x_min = x_min * (self.n_bits + 2) / 8 227 | x_max = x_max * (self.n_bits + 2) / 8 228 | 229 | x_absmax = max(abs(x_min), x_max) 230 | if self.sym: 231 | x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 232 | 233 | delta = float(x_max - x_min) / (self.n_levels - 1) 234 | if delta < 1e-8: 235 | warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max)) 236 | delta = 1e-8 237 | 238 | zero_point = round(-x_min / delta) - self.n_levels // 2 239 | delta = torch.tensor(delta).type_as(x) 240 | 241 | elif self.scale_method == 'mse': 242 | x_max = x.max() 243 | x_min = x.min() 244 | 245 | best_score = 1e+10 246 | # Initial factors 247 | start_factor = 0.95 248 | end_factor = 1.05 249 | 250 | # Calculate the step increase per iteration 251 | step = (end_factor - start_factor) / 80 252 | 253 | for i in range(80): 254 | factor = start_factor + i * step 255 | new_max = x_max * factor 256 | new_min = x_min * factor 257 | 258 | x_q = self.quantize(x, new_max, new_min) 259 | # L_p norm minimization as described in LAPQ 260 | # https://arxiv.org/abs/1911.07190 261 | score = lp_loss(x, x_q, p=2.4, reduction='all') 262 | # score = lp_loss(x, x_q, p=torch.sqrt(x_max - x_min), reduction='all') ## adaptive p-norm 263 | # score = lp_loss(x, x_q, p=torch.clamp(-1.22*torch.pow(torch.var(x),-1)+9.42,0.1,10.0), reduction='all') 264 | 265 | if score < best_score: 266 | best_score = score 267 | delta = (new_max - new_min) / (2 ** self.n_bits - 1) 268 | zero_point = (- new_min / delta).round() - self.n_levels // 2 269 | 270 | else: 271 | raise NotImplementedError 272 | return delta, zero_point 273 | 274 | def quantize(self, x, max, min): 275 | delta = (max - min) / (2 ** self.n_bits - 1) 276 | zero_point = (- min / delta).round() 277 | # we assume weight quantization is always signed 278 | x_int = torch.round(x / delta) 279 | x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1) 280 | x_float_q = (x_quant - zero_point) * delta 281 | return x_float_q 282 | 283 | def bitwidth_refactor(self, refactored_bit: int): 284 | assert 2 <= refactored_bit <= 8, 'bitwidth not supported' 285 | self.n_bits = refactored_bit 286 | self.n_levels = 2 ** self.n_bits 287 | 288 | def extra_repr(self): 289 | s = 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' 290 | return s.format(**self.__dict__) 291 | -------------------------------------------------------------------------------- /torch_quantizer/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | __license__ = '' 3 | __author__ = '' 4 | __release__ = False 5 | 6 | if not __release__: 7 | import os 8 | import subprocess 9 | 10 | try: 11 | prefix, sep, suffix = ( 12 | subprocess.check_output( 13 | ['git', 'describe', '--abbrev=7'], # noqa: S603,S607 14 | cwd=os.path.dirname(os.path.abspath(__file__)), 15 | stderr=subprocess.DEVNULL, 16 | text=True, 17 | ) 18 | .strip() 19 | .lstrip('v') 20 | .replace('-', '.dev', 1) 21 | .replace('-', '+', 1) 22 | .partition('.dev') 23 | ) 24 | if sep: 25 | version_prefix, dot, version_tail = prefix.rpartition('.') 26 | prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' 27 | __version__ = sep.join((prefix, suffix)) 28 | del version_prefix, dot, version_tail 29 | else: 30 | __version__ = prefix 31 | del prefix, sep, suffix 32 | except (OSError, subprocess.CalledProcessError): 33 | pass 34 | 35 | del os, subprocess 36 | --------------------------------------------------------------------------------