├── .gitignore ├── .gitlab-ci.yml ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── .gitkeep ├── load_vector-en.svg ├── m8n8k4.md ├── make_direct_product_fragments-en.svg ├── make_eye-en.svg ├── make_identity-en.svg ├── mma.md ├── mma_f32.md ├── ops.md ├── store_vector-en.svg ├── utils.md └── wmmae.svg ├── include └── wmma_extension │ ├── detail │ ├── common.hpp │ ├── m16n8k16.hpp │ ├── m16n8k16_imma.hpp │ ├── m16n8k32_imma.hpp │ ├── m16n8k8.hpp │ ├── m16n8k8_tf32.hpp │ ├── m8n8k4.hpp │ ├── operators.hpp │ ├── sm_70.hpp │ ├── sm_75.hpp │ ├── sm_80.hpp │ └── sm_80_tf32.hpp │ ├── operators.hpp │ ├── tcec │ ├── complex.hpp │ ├── detail │ │ ├── common.hpp │ │ ├── functions.hpp │ │ ├── functions_simt.hpp │ │ ├── no_cor.hpp │ │ ├── notc.hpp │ │ ├── policy.hpp │ │ ├── policy_simt.hpp │ │ ├── print.hpp │ │ ├── scale.hpp │ │ ├── simt │ │ │ ├── detail │ │ │ │ ├── common.hpp │ │ │ │ ├── fma.hpp │ │ │ │ └── m16n16k16.hpp │ │ │ └── mma_simt.hpp │ │ ├── wmma_extension_include.hpp │ │ └── wmma_extension_simt_include.hpp │ └── tcec.hpp │ ├── utils.hpp │ ├── wmma_extension.hpp │ └── wmma_mma.hpp ├── research ├── bank-conflict │ ├── Makefile │ ├── README.md │ └── main.cu ├── common │ └── utils.hpp ├── fragment_analysis │ ├── Makefile │ └── main.cu ├── fragment_analysis_ij │ ├── Makefile │ └── main.cu └── fragment_analysis_map │ ├── Makefile │ └── main.cu └── test ├── performance ├── Makefile.common ├── batched_m8n8k4 │ ├── Makefile │ └── batched_m8n8k4.cu ├── foreach │ ├── Makefile │ └── matmul.cu ├── givens_rotation │ ├── Makefile │ └── givens.cu ├── householder │ ├── Makefile │ └── householder.cu ├── load_matrix_with_op_sync │ ├── Makefile │ └── matmul.cu ├── load_vector_sync │ ├── Makefile │ ├── batched_direct_product.cu │ └── direct_product.cu └── make_identity │ ├── Makefile │ └── batched_householder.cu ├── primitive ├── Makefile ├── add_eye.cu ├── common.hpp ├── direct_product.cu ├── fill.cu ├── foreach.cu ├── foreach_ij.cu ├── foreach_v.cu ├── foreach_v_acc.cu ├── gevm.cu ├── map.cu ├── mma.cu ├── operators.cu ├── print_fragment.cu ├── vector.cu ├── wmma.load_vector.cu └── wmma.store_vector.cu ├── tcec ├── Makefile ├── batch_gemm.cu ├── elementwise.cu ├── matvec.cu ├── mma.cu ├── mma_complex.cu ├── utils.hpp └── vector.cu └── utils ├── Makefile ├── cast.cu └── cp_async.cu /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/c++,cuda,vim,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=c++,cuda,vim,python 3 | 4 | ### C++ ### 5 | # Prerequisites 6 | *.d 7 | 8 | # Compiled Object files 9 | *.slo 10 | *.lo 11 | *.o 12 | *.obj 13 | 14 | # Precompiled Headers 15 | *.gch 16 | *.pch 17 | 18 | # Compiled Dynamic libraries 19 | *.so 20 | *.dylib 21 | *.dll 22 | 23 | # Fortran module files 24 | *.mod 25 | *.smod 26 | 27 | # Compiled Static libraries 28 | *.lai 29 | *.la 30 | *.a 31 | *.lib 32 | 33 | # Executables 34 | *.exe 35 | *.out 36 | *.app 37 | 38 | ### CUDA ### 39 | *.i 40 | *.ii 41 | *.gpu 42 | *.ptx 43 | *.cubin 44 | *.fatbin 45 | 46 | ### Python ### 47 | # Byte-compiled / optimized / DLL files 48 | __pycache__/ 49 | *.py[cod] 50 | *$py.class 51 | 52 | # C extensions 53 | 54 | # Distribution / packaging 55 | .Python 56 | build/ 57 | develop-eggs/ 58 | dist/ 59 | downloads/ 60 | eggs/ 61 | .eggs/ 62 | lib/ 63 | lib64/ 64 | parts/ 65 | sdist/ 66 | var/ 67 | wheels/ 68 | share/python-wheels/ 69 | *.egg-info/ 70 | .installed.cfg 71 | *.egg 72 | MANIFEST 73 | 74 | # PyInstaller 75 | # Usually these files are written by a python script from a template 76 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 77 | *.manifest 78 | *.spec 79 | 80 | # Installer logs 81 | pip-log.txt 82 | pip-delete-this-directory.txt 83 | 84 | # Unit test / coverage reports 85 | htmlcov/ 86 | .tox/ 87 | .nox/ 88 | .coverage 89 | .coverage.* 90 | .cache 91 | nosetests.xml 92 | coverage.xml 93 | *.cover 94 | *.py,cover 95 | .hypothesis/ 96 | .pytest_cache/ 97 | cover/ 98 | 99 | # Translations 100 | *.mo 101 | *.pot 102 | 103 | # Django stuff: 104 | *.log 105 | local_settings.py 106 | db.sqlite3 107 | db.sqlite3-journal 108 | 109 | # Flask stuff: 110 | instance/ 111 | .webassets-cache 112 | 113 | # Scrapy stuff: 114 | .scrapy 115 | 116 | # Sphinx documentation 117 | docs/_build/ 118 | 119 | # PyBuilder 120 | .pybuilder/ 121 | target/ 122 | 123 | # Jupyter Notebook 124 | .ipynb_checkpoints 125 | 126 | # IPython 127 | profile_default/ 128 | ipython_config.py 129 | 130 | # pyenv 131 | # For a library or package, you might want to ignore these files since the code is 132 | # intended to run in multiple environments; otherwise, check them in: 133 | # .python-version 134 | 135 | # pipenv 136 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 137 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 138 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 139 | # install all needed dependencies. 140 | #Pipfile.lock 141 | 142 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 143 | __pypackages__/ 144 | 145 | # Celery stuff 146 | celerybeat-schedule 147 | celerybeat.pid 148 | 149 | # SageMath parsed files 150 | *.sage.py 151 | 152 | # Environments 153 | .env 154 | .venv 155 | env/ 156 | venv/ 157 | ENV/ 158 | env.bak/ 159 | venv.bak/ 160 | 161 | # Spyder project settings 162 | .spyderproject 163 | .spyproject 164 | 165 | # Rope project settings 166 | .ropeproject 167 | 168 | # mkdocs documentation 169 | /site 170 | 171 | # mypy 172 | .mypy_cache/ 173 | .dmypy.json 174 | dmypy.json 175 | 176 | # Pyre type checker 177 | .pyre/ 178 | 179 | # pytype static type analyzer 180 | .pytype/ 181 | 182 | # Cython debug symbols 183 | cython_debug/ 184 | 185 | ### Vim ### 186 | # Swap 187 | [._]*.s[a-v][a-z] 188 | !*.svg # comment out if you don't need vector files 189 | [._]*.sw[a-p] 190 | [._]s[a-rt-v][a-z] 191 | [._]ss[a-gi-z] 192 | [._]sw[a-p] 193 | 194 | # Session 195 | Session.vim 196 | Sessionx.vim 197 | 198 | # Temporary 199 | .netrwhist 200 | *~ 201 | # Auto-generated tag files 202 | tags 203 | # Persistent undo 204 | [._]*.un~ 205 | 206 | # End of https://www.toptal.com/developers/gitignore/api/c++,cuda,vim,python 207 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | precommit: 2 | image: python:3.10.2-slim-bullseye 3 | before_script: 4 | - apt update && apt install -y --no-install-recommends git 5 | - pip install pre-commit 6 | script: 7 | - pre-commit run --all-files 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | files: ^(.*\.(cpp|hpp|cu|md))$ 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v3.2.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-yaml 11 | - id: check-added-large-files 12 | - repo: https://github.com/codespell-project/codespell 13 | rev: v2.3.0 14 | hooks: 15 | - id: codespell 16 | - repo: https://github.com/pre-commit/mirrors-clang-format 17 | rev: v18.1.8 18 | hooks: 19 | - id: clang-format 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 mutsuki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # WMMA API Extension 4 | 5 | This extension provides features for 6 | - mapping between memory and fragment (**primitive functions**) 7 | - operationf for vectors 8 | - loading a vector as a fragment 9 | - storing a fragment as a vector 10 | - C++ interface for `mma` instructions [[detail](./docs/mma.md)] 11 | - Error Correction (TCEC) for SGEMM emulation [[detail](./docs/mma_f32.md)] 12 | - arithmetic operators for fragments (`+, -, *, /, fma`) [[detail](./docs/ops.md)] 13 | - utils [[detail](./docs/utils.md)] 14 | - etc 15 | 16 | without using extra shared memory. 17 | 18 | > [!IMPORTANT] 19 | > Please specify an appropriate virtual architecture for real GPU. 20 | > For instance, a program which is compiled with `-arch=sm_70` will not work correctly on Ampere GPUs. 21 | 22 | ## Requirements 23 | - CUDA (10.2 or later) 24 | - C++ (17 or later) 25 | 26 | ## Supported architectures / fragment 27 | - [x] sm_70: ((16, 16, 16), fp16/fp32) 28 | - [x] sm_75: ((16, 16, 16), fp16/fp32) 29 | - [x] sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) 30 | - [x] sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) 31 | - [x] sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (`wgmma` instruction is not supported yet) 32 | 33 | # Functions 34 | ## Primitive functions 35 | ### foreach 36 | This function calculates the mapping of the memory and fragment elements. 37 | ```cuda 38 | nvcuda::wmma::fragment frag_b; 39 | __shared__ compute_t matrix[16 * 16]; 40 | mtk::wmma::foreach( 41 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) { 42 | const auto m = mem_index % 16; 43 | const auto n = mem_index / 16; 44 | for (unsigned i = 0; i < fragment_index_count; i++) 45 | frag_b.x[frag_index_list[i]] = convert_to(matrix[n * 16 + m]); 46 | }); 47 | ``` 48 | 49 | ### foreach_ij 50 | This function calculates the mapping of the matrix element position (i,j) and fragment elements. 51 | ```cuda 52 | nvcuda::wmma::fragment frag_b; 53 | __shared__ compute_t matrix[16 * 16]; 54 | mtk::wmma::foreach_ij( 55 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) { 56 | for (unsigned f = 0; f < fragment_index_count; f++) 57 | frag_b.x[frag_index_list[f]] = convert_to(matrix[j * 16 + i]); 58 | }); 59 | ``` 60 | 61 | ### foreach_v 62 | #### For matrix A/B 63 | This function calculates the mapping of a given vector and fragment elements. 64 | ```cuda 65 | nvcuda::wmma::fragment frag_b; 66 | __shared__ compute_t vector[16]; 67 | mtk::wmma::foreach_v( 68 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) { 69 | for (unsigned i = 0; i < fragment_index_count; i++) 70 | frag_b.x[frag_index_list[i]] = convert_to(vector[mem_index]); 71 | }); 72 | // is equivalent to `load_vector` 73 | ``` 74 | 75 | #### For accumulator 76 | ```cuda 77 | nvcuda::wmma::fragment frag_c; 78 | __shared__ compute_t vector[16]; 79 | mtk::wmma::foreach_v(nvcuda::wmma::mem_col_major, 80 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) { 81 | for (unsigned i = 0; i < fragment_index_count; i++) 82 | vector[mem_index] = convert_to(frag_c.x[frag_index_list[i]]); 83 | }); 84 | // is equivalent to `store_vector` 85 | ``` 86 | 87 | ### map 88 | This function returns the mapping of matrix element (i, j) and fragment element (tid, fid) 89 | ```cuda 90 | nvcuda::wmma::fragment frag_b; 91 | unsigned tid_list[2]; 92 | unsigned fid_list[2]; 93 | unsigned list_size; 94 | mtk::wmma::map(tid_list, fid_list, list_size, i, j); 95 | for (unsigned k = 0; k < list_size; k++) { 96 | if ((threadIdx.x & 0x1f) == tid_list[k]) { 97 | frag_b.x[fid_list[k]] = 3.0f; 98 | } 99 | } 100 | ``` 101 | 102 | 103 | ## Functions for vector 104 | ## Sample 105 | ```cuda 106 | #include 107 | #include 108 | 109 | __global__ void kernel() { 110 | nvcuda::wmma::fragment frag_a; 111 | nvcuda::wmma::fragment frag_b; 112 | nvcuda::wmma::fragment frag_c; 113 | 114 | __shared__ float vec16[16]; 115 | 116 | mtk::wmma::load_vector(frag_a, vec16); 117 | mtk::wmma::load_vector(frag_b, vec16); 118 | 119 | nvcuda::wmma::fill_fragment(frag_c, 0.0f); 120 | nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c); 121 | 122 | mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major); 123 | } 124 | ``` 125 | 126 | ## Other functions 127 | ### make_identity_matrix / add_eye 128 | ![load_matrix](docs/make_eye-en.svg) 129 | - Arguments 130 | - dst_fragment : Destination fragment (`accumulator`) 131 | - alpha : diagonal element 132 | 133 | ### fill_zero 134 | - Argument 135 | - dst_fragment : Destination fragment 136 | 137 | ## Debugging functions 138 | 139 | #### print_fragment 140 | This function output the elements of a fragment. 141 | - Arguments 142 | - frag : Target fragment 143 | - name : printing name of fragment (`char*`, optional) 144 | 145 | # Publication 146 | ```bibtex 147 | @inproceedings{ootomo_wmmae_2023, 148 | author = {Ootomo, Hiroyuki and Yokota, Rio}, 149 | title = {Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library}, 150 | year = {2023}, 151 | series = {HPC Asia '23} 152 | } 153 | ``` 154 | 155 | # LICENSE 156 | MIT 157 | -------------------------------------------------------------------------------- /docs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmmae/wmma_extension/a9e6b34e40a884cbdd3436f9d5f870bb6085bf6c/docs/.gitkeep -------------------------------------------------------------------------------- /docs/load_vector-en.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 000000nvcuda::wmma::load_matrix_syncmtk::wmma::load_vector_syncfragment 4 | -------------------------------------------------------------------------------- /docs/m8n8k4.md: -------------------------------------------------------------------------------- 1 | # m8n8k4 2 | 3 | CUDA provides an experimental PTX instruction `mma.sync.aligned.m8n8k4` which computes `(m, n, k) = (8, 8, 4)` matrix FMA. 4 | This library provides its C++ interface. 5 | 6 | ## Sample 7 | ```cpp 8 | constexpr unsigned M = 8; 9 | constexpr unsigned M = 8; 10 | constexpr unsigned M = 4; 11 | 12 | __global__ void m8n8k4_test_kernel(float* const d, const half* const a, const half* const b, const float* const c) { 13 | mtk::wmma::fragment frag_a; 14 | mtk::wmma::fragment frag_b; 15 | mtk::wmma::fragment frag_c; 16 | mtk::wmma::fragment frag_d; 17 | 18 | mtk::wmma::load_matrix_sync(frag_a, a, M); 19 | mtk::wmma::load_matrix_sync(frag_b, b, K); 20 | mtk::wmma::load_matrix_sync(frag_c, c, M, nvcuda::wmma::mem_col_major); 21 | 22 | mtk::wmma::mma_sync(frag_d, frag_a, frag_b, frag_c); 23 | 24 | mtk::wmma::store_matrix_sync(d, frag_d, M, nvcuda::wmma::mem_col_major); 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /docs/mma.md: -------------------------------------------------------------------------------- 1 | # C++ interface of `mma` instructions 2 | 3 | ```cpp 4 | #include 5 | 6 | __global__ void kernel(float* const d, const half* const a, const half* const b, const float* const c) { 7 | mtk::wmma::mma::fragment frag_a; 8 | mtk::wmma::mma::fragment frag_b; 9 | mtk::wmma::mma::fragment frag_c; 10 | mtk::wmma::mma::fragment frag_d; 11 | 12 | mtk::wmma::mma::load_matrix_sync(frag_a, a, 16); 13 | mtk::wmma::mma::load_matrix_sync(frag_b, b, 8); 14 | mtk::wmma::mma::load_matrix_sync(frag_c, c, 16, nvcuda::wmma::mem_col_major); 15 | 16 | mtk::wmma::mma::mma_sync(frag_d, frag_a, frag_b, frag_c); 17 | 18 | mtk::wmma::mma::store_matrix_sync(d, frag_d, 16, nvcuda::wmma::mem_col_major); 19 | } 20 | ``` 21 | 22 | ## Supported fragments 23 | 24 | | shape | A,B type | C, D type | arch | 25 | |:-------- |:-------------------- |:-------------------- |:--------------- | 26 | | m16n8k16 | `half` | `float` / `half` | sm_80 or higher | 27 | | m16n8k8 | `half` | `float` / `half` | sm_75 or higher | 28 | | m16n8k8 | `nvcuda::wmma::tf32` | `float` | sm_80 or higher | 29 | | m8n8k4 | `half` | `float` / `half` | sm_70, sm_75 | 30 | | m16n8k16 | `int8` / `uint8` | `int32` | sm_80 or higher | 31 | | m16n8k32 | `int8` / `uint8` | `int32` | sm_80 or higher | 32 | 33 | ## Supported functions 34 | - `foreach` 35 | - `foreach_ij` 36 | - `load_matrix_sync` 37 | - `store_matrix_sync` 38 | - `fill_fragment` 39 | - `fill_zero` 40 | -------------------------------------------------------------------------------- /docs/mma_f32.md: -------------------------------------------------------------------------------- 1 | # WMMA Extension for single precision matmul using Tensor Cores and error correction technique (TCEC) 2 | 3 | ## Error correction technique 4 | See [our paper](https://arxiv.org/abs/2203.03341). 5 | 6 | ## Requirements 7 | - CUDA 8 | - CUDA >= 10.0 for HMMA-FP16 9 | - CUDA >= 11.1 for HMMA-TF32 10 | 11 | - C++ >= 14 12 | 13 | ## Installation 14 | 1. Clone [wmma_extension](https://github.com/wmmae/wmma_extension) 15 | ```bash 16 | git clone https://github.com/wmmae/wmma_extension 17 | ``` 18 | 19 | ## Sample code 20 | ```cuda 21 | // sample.cu 22 | // nvcc -I./path/to/wmma_extension/include/ -std=c++17 sample.cu ... 23 | // 24 | #include 25 | 26 | template 27 | __global__ void mma_kernel(float* const d_ptr, const float* const a_ptr, const float* const b_ptr, const float* const c_ptr) { 28 | __shared__ float smem[N * N]; 29 | fill_zero(smem, N * N); 30 | 31 | mtk::wmma::tcec::fragment frag_a; 32 | mtk::wmma::tcec::fragment frag_b; 33 | mtk::wmma::tcec::fragment frag_c, frag_d; 34 | 35 | // Load A 36 | // copy_matrix(smem, N, a_ptr, N, N, N); 37 | mtk::wmma::tcec::load_matrix_sync(frag_a, smem, N); 38 | 39 | // Load B 40 | // copy_matrix(smem, N, b_ptr, N, N, N); 41 | mtk::wmma::tcec::load_matrix_sync(frag_b, smem, N); 42 | 43 | // Load C 44 | // copy_matrix(smem, N, c_ptr, N, N, N); 45 | mtk::wmma::tcec::load_matrix_sync(frag_c, smem, N, nvcuda::wmma::mem_col_major); 46 | 47 | // Fill D 48 | mtk::wmma::tcec::fill_fragment(frag_d, 0.0f); 49 | 50 | // mma 51 | mtk::wmma::tcec::mma_sync(frag_d, frag_a, frag_b, frag_c); 52 | 53 | // Store D 54 | mtk::wmma::tcec::store_matrix_sync(smem, frag_d, N, nvcuda::wmma::mem_col_major); 55 | //copy_matrix(d_ptr, N, smem, N, N, N); 56 | } 57 | ``` 58 | 59 | ## Fragment 60 | ```cpp 61 | template ::type> 62 | struct fragment; 63 | ``` 64 | 65 | ### Template arguments 66 | `mtk::wmma::tcec::fragment` is a fragment for this computation. 67 | It contains arrays of `nvcuda::wmma::fragment`. 68 | - `m`, `n` and `k` have to be a multiple of `Policy::m`, `Policy::n` and `Policy::k` respectively. 69 | You can get a default policy using `mtk::wmma::tcec::default_policy::type`. 70 | - `k` has to be a multiple of 16 when `T` is `half` and 8 when `T` is `nvcuda::wmma::precision::tf32`. 71 | - `T` is `half` or `nvcuda::wmma::precision::tf32`. Unlike `nvcuda::wmma::fragment`, even if `Use` is `nvcuda::wmma::accumulator`, the same is true. 72 | - `Policy` is a concept of `mtk::wmma::tcec::Policy`. 73 | - `Op` : `mtk::wmma::tcec::op_mma` / `mtk::wmma::tcec::op_wmma` 74 | - `ErrorCorrection` : `mtk::wmma::tcec::with_ec` / `mtk::wmma::tcec::without_ec` 75 | - `fm`, `fn`, `fk` is a size of internal fragments. 76 | 77 | ### Policy 78 | `default_policy` can make `Policy` easily. 79 | ```cuda 80 | using policy = mtk::wmma::tcec::default_policy::type; 81 | ``` 82 | 83 | ## Supported fragment 84 | 85 | | fm | fn | fk | LayoutA | LayoutB | Type | Operation | Supported arch | 86 | | -- | -- | -- | ------- | ------- | ----- | -------------- | ---------------| 87 | | 16 | 16 | 16 | col/row | col/row | half | Arch dependent | sm_70 or later | 88 | | 16 | 16 | 16 | col/row | col/row | tf32 | wmma | sm_80 or later | 89 | | 16 | 8 | 8 | row | col | tf32 | mma | sm_80 or later | 90 | | 16 | 8 | 8 | row | col | half | mma | sm_75 or later | 91 | | 16 | 8 | 16 | row | col | half | mma | sm_80 or later | 92 | 93 | ### Note 94 | To get default policy for `sm_75` and `op_mma`, specify the architecture as follows: 95 | ```cuda 96 | using policy = mtk::wmma::tcec::default_policy::type; 97 | ``` 98 | 99 | ### Member variables/functions 100 | - Member variable `element_type` is `float` 101 | - Member function `x(index)` and `dx(index)` return the reference of a elements. 102 | 103 | ## Functions 104 | - `mtk::wmma::tcec::fill_fragment` 105 | - `mtk::wmma::tcec::load_matrix_sync` 106 | - `mtk::wmma::tcec::store_matrix_sync` 107 | - `mtk::wmma::tcec::mma_sync` 108 | 109 | - `mtk::wmma::tcec::mma_rz_sync` 110 | - `mtk::wmma::tcec::load_vector` 111 | - `mtk::wmma::tcec::store_vector` 112 | - `mtk::wmma::tcec::fill_zero` 113 | 114 | ### Note 115 | While some `fragment` only supports either `row` or `col`, `load_matrix_sync` function can load both memory layout matrices using an additional template parameter. 116 | 117 | ```cpp 118 | // e.g. 119 | using policy = mtk::wmma::tcec::default_policy::type; 120 | mtk::wmma::tcec::fragment frag_a; 121 | 122 | mtk::wmma::tcec::load_matrix_sync(frag_a, matrix_ptr, ldm); 123 | ``` 124 | 125 | 126 | ## Rounding mode 127 | To specify the rounding mode in `+C` operation, use functions as follows. 128 | - `mtk::wmma::tcec::mma_rn_sync` 129 | - `mtk::wmma::tcec::mma_rz_sync` 130 | 131 | ### Default rounding mode 132 | | op | rounding mode | 133 | | ---------- | ------------- | 134 | | with_ec | RN | 135 | | without_ec | RZ | 136 | 137 | Read [our paper](https://arxiv.org/abs/2203.03341) for detail. 138 | 139 | ## SIMT Core computation 140 | 141 | This library provides fragments and functionf for mma operations using CUDA SIMT Core with the same API as WMMA API. 142 | 143 | | fm | fn | fk | LayoutA | LayoutB | Type | Operation | Supported arch | 144 | | -- | -- | -- | ------- | ------- | ----- | -------------- | ---------------| 145 | | 16 | 16 | 16 | col/row | col/row | float | simt | sm_70 or later | 146 | 147 | ### Policy 148 | ```cuda 149 | using simt_policy = typename mtk::wmma::tcec::default_policy::type; 150 | 151 | mtk::wmma::tcec::fragment frag_a; 152 | ``` 153 | 154 | ## Complex type 155 | ```cuda 156 | mtk::wmma::tcec::fragment_complex frag_a; 157 | // or 158 | using policy = typename mtk::wmma::tcec::default_policy::type; 159 | mtk::wmma::tcec::fragment_complex frag_a; 160 | ``` 161 | 162 | ### Supported functions 163 | - `mtk::wmma::tcec::fill_fragment` 164 | - `mtk::wmma::tcec::load_matrix_sync` 165 | - `mtk::wmma::tcec::store_matrix_sync` 166 | - `mtk::wmma::tcec::mma_sync` 167 | 168 | - `mtk::wmma::tcec::mma_rz_sync` 169 | - `mtk::wmma::tcec::fill_zero` 170 | 171 | See [test code](../test/tcec/mma_complex.cu) for more detail. 172 | -------------------------------------------------------------------------------- /docs/ops.md: -------------------------------------------------------------------------------- 1 | # Arithmetic operators for fragments 2 | 3 | ## Supported operators 4 | 5 | | op | A type | B type | C type | 6 | |:----:|:------:|:------:|:------:| 7 | | `+` | `fragment` | `fragment` || 8 | | `-` | `fragment` | `fragment` || 9 | | `*` | `fragment` | `fragment::storage_element_t` || 10 | | `/` | `fragment` | `fragment::storage_element_t` || 11 | | `mtk::wmma::fma` | `fragment` | `fragment::storage_element_t` | `fragment` | 12 | | `mtk::wmma::fma` | `fragment::storage_element_t` | `fragment` | `fragment` | 13 | 14 | ## Example 15 | 16 | ```cpp 17 | #include 18 | 19 | nvcuda::wmma::fragment frag_a0, frag_a1; 20 | 21 | const auto frag_a0 = frag_a0 + frag_a1 * __float2half(2.0f); 22 | ``` 23 | -------------------------------------------------------------------------------- /docs/store_vector-en.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 000000nvcuda::wmma::store_matrix_syncmtk::wmma::store_vector_syncfragment 4 | -------------------------------------------------------------------------------- /docs/utils.md: -------------------------------------------------------------------------------- 1 | # Utils 2 | 3 | ## Type conversion 4 | ```cpp 5 | const auto dst_val = mtk::wmma::utils::cast(src_val); 6 | ``` 7 | 8 | ## Asynchronous D2S data copy 9 | ```cpp 10 | mtk::wmma::utils::cp_async::cp_async(dst_ptr, src_ptr); 11 | mtk::wmma::utils::cp_async::commit(); 12 | mtk::wmma::utils::cp_async::wait_group(); 13 | mtk::wmma::utils::cp_async::wait_all(); 14 | ``` 15 | 16 | - `N` is data size in byte. (4, 8, 16) 17 | -------------------------------------------------------------------------------- /docs/wmmae.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | e 4 | -------------------------------------------------------------------------------- /include/wmma_extension/detail/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_DETAIL_COMMON__ 2 | #define __WMMAE_DETAIL_COMMON__ 3 | #include 4 | #include 5 | #include 6 | 7 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 8 | namespace nvcuda { 9 | namespace wmma { 10 | namespace precision { 11 | class tf32; 12 | } // namespace precision 13 | } // namespace wmma 14 | } // namespace nvcuda 15 | #endif 16 | 17 | namespace mtk { 18 | namespace wmma { 19 | 20 | namespace detail { 21 | namespace common { 22 | template struct storage_t { 23 | using type = T; 24 | }; 25 | template 26 | inline __device__ __host__ typename storage_t::type cast(const float v) { 27 | return static_cast::type>(v); 28 | } 29 | template 30 | inline __device__ __host__ typename storage_t::type cast(const half v) { 31 | return static_cast::type>(v); 32 | } 33 | 34 | template <> struct storage_t { 35 | using type = float; 36 | }; 37 | __device__ __host__ inline float to_tf32(const float a) { 38 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 39 | float ret; 40 | asm("{.reg .b32 %mr;\n" 41 | "cvt.rna.tf32.f32 %mr, %1;\n" 42 | "mov.b32 %0, %mr;}\n" 43 | : "=f"(ret) 44 | : "f"(a)); 45 | return ret; 46 | #else 47 | return a; 48 | #endif 49 | } 50 | template <> 51 | inline __device__ __host__ 52 | typename storage_t::type 53 | cast(const float v) { 54 | return to_tf32(v); 55 | } 56 | template <> 57 | inline __device__ __host__ 58 | typename storage_t::type 59 | cast(const half v) { 60 | return to_tf32(__half2float(v)); 61 | } 62 | 63 | inline __device__ unsigned get_lane_id() { 64 | unsigned lane_id; 65 | asm(R"({mov.s32 %0, %laneid;})" : "=r"(lane_id)); 66 | return lane_id; 67 | } 68 | 69 | template struct get_M; 70 | template struct get_M { 71 | static const int value = M; 72 | }; 73 | template struct get_M { 74 | static const int value = K; 75 | }; 76 | template 77 | struct get_M { 78 | static const int value = M; 79 | }; 80 | 81 | template struct get_N; 82 | template struct get_N { 83 | static const int value = K; 84 | }; 85 | template struct get_N { 86 | static const int value = N; 87 | }; 88 | template 89 | struct get_N { 90 | static const int value = N; 91 | }; 92 | 93 | template struct layout_switch; 94 | template 95 | struct layout_switch { 96 | static const int value = col_value; 97 | }; 98 | template 99 | struct layout_switch { 100 | static const int value = row_value; 101 | }; 102 | 103 | } // namespace common 104 | 105 | template struct fill_zero_core; 106 | 107 | template struct fill_zero_core<2, T> { 108 | __device__ void operator()(T *const ptr) { 109 | *reinterpret_cast(ptr) = 0; 110 | } 111 | }; 112 | 113 | template struct fill_zero_core<4, T> { 114 | __device__ void operator()(T *const ptr) { 115 | *reinterpret_cast(ptr) = 0; 116 | } 117 | }; 118 | 119 | template struct fill_zero_core<8, T> { 120 | __device__ void operator()(T *const ptr) { 121 | *reinterpret_cast(ptr) = make_int2(0, 0); 122 | } 123 | }; 124 | 125 | template struct fill_zero_core<16, T> { 126 | __device__ void operator()(T *const ptr) { 127 | *reinterpret_cast(ptr) = make_int4(0, 0, 0, 0); 128 | } 129 | }; 130 | 131 | template struct fill_zero_core<32, T> { 132 | __device__ void operator()(T *const ptr) { 133 | *(reinterpret_cast(ptr) + 0) = make_int4(0, 0, 0, 0); 134 | *(reinterpret_cast(ptr) + 1) = make_int4(0, 0, 0, 0); 135 | } 136 | }; 137 | 138 | template struct fill_zero_core<64, T> { 139 | __device__ void operator()(T *const ptr) { 140 | *(reinterpret_cast(ptr) + 0) = make_int4(0, 0, 0, 0); 141 | *(reinterpret_cast(ptr) + 1) = make_int4(0, 0, 0, 0); 142 | *(reinterpret_cast(ptr) + 2) = make_int4(0, 0, 0, 0); 143 | *(reinterpret_cast(ptr) + 3) = make_int4(0, 0, 0, 0); 144 | } 145 | }; 146 | 147 | template struct size_of { 148 | static constexpr unsigned value = 0; 149 | }; 150 | template <> struct size_of { 151 | static constexpr unsigned value = 1; 152 | }; 153 | template <> struct size_of { 154 | static constexpr unsigned value = 1; 155 | }; 156 | template <> struct size_of { 157 | static constexpr unsigned value = 4; 158 | }; 159 | template <> struct size_of { 160 | static constexpr unsigned value = 2; 161 | }; 162 | template <> struct size_of { 163 | static constexpr unsigned value = 4; 164 | }; 165 | } // namespace detail 166 | 167 | namespace mma { 168 | template struct __align__(4) __frag_base { 169 | T x[size]; 170 | enum { num_elements = size }; 171 | }; 172 | 173 | template 174 | __device__ inline void fill_fragment(__frag_base &f, const T v) { 175 | #pragma unroll 176 | for (unsigned i = 0; i < f.num_elements; i++) 177 | f.x[i] = v; 178 | } 179 | template 180 | __device__ inline void fill_fragment(__frag_base &f, const T v) { 181 | #pragma unroll 182 | for (unsigned i = 0; i < f.num_elements; i++) 183 | f.x[i] = v; 184 | } 185 | template 186 | __device__ inline void fill_fragment(__frag_base &f, const T v) { 187 | #pragma unroll 188 | for (unsigned i = 0; i < f.num_elements; i++) 189 | f.x[i] = v; 190 | } 191 | template 192 | __device__ inline void fill_fragment(__frag_base &f, const T v) { 193 | #pragma unroll 194 | for (unsigned i = 0; i < f.num_elements; i++) 195 | f.x[i] = v; 196 | } 197 | template 198 | __device__ inline void fill_fragment(__frag_base &f, const T v) { 199 | #pragma unroll 200 | for (unsigned i = 0; i < f.num_elements; i++) 201 | f.x[i] = v; 202 | } 203 | template 204 | __device__ inline void fill_fragment(__frag_base &f, 205 | const T v) { 206 | #pragma unroll 207 | for (unsigned i = 0; i < f.num_elements; i++) 208 | f.x[i] = v; 209 | } 210 | template 211 | __device__ inline void fill_fragment(__frag_base &f, 212 | const T v) { 213 | #pragma unroll 214 | for (unsigned i = 0; i < f.num_elements; i++) 215 | f.x[i] = v; 216 | } 217 | template 218 | __device__ inline void fill_fragment(__frag_base &f, 219 | const T v) { 220 | #pragma unroll 221 | for (unsigned i = 0; i < f.num_elements; i++) 222 | f.x[i] = v; 223 | } 224 | template 225 | __device__ inline void fill_fragment(__frag_base &f, 226 | const T v) { 227 | #pragma unroll 228 | for (unsigned i = 0; i < f.num_elements; i++) 229 | f.x[i] = v; 230 | } 231 | template 232 | __device__ inline void fill_fragment(__frag_base &f, 233 | const T v) { 234 | #pragma unroll 235 | for (unsigned i = 0; i < f.num_elements; i++) 236 | f.x[i] = v; 237 | } 238 | 239 | template 240 | class fragment; 241 | 242 | template 243 | __device__ inline void 244 | fill_zero(mtk::wmma::mma::fragment &frag) { 245 | constexpr unsigned size = 246 | detail::size_of::type>::value * 247 | mtk::wmma::mma::fragment::num_elements; 248 | detail::fill_zero_core{}(reinterpret_cast(frag.x)); 249 | } 250 | } // namespace mma 251 | 252 | template 253 | __device__ inline void 254 | fill_zero(nvcuda::wmma::fragment &frag) { 255 | const unsigned size = 256 | 4 * nvcuda::wmma::fragment::num_elements; 257 | detail::fill_zero_core{}(reinterpret_cast(frag.x)); 258 | } 259 | template 260 | __device__ inline void 261 | fill_zero(nvcuda::wmma::fragment &frag) { 262 | const unsigned size = 263 | 2 * nvcuda::wmma::fragment::num_elements; 264 | detail::fill_zero_core{}(reinterpret_cast(frag.x)); 265 | } 266 | 267 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 268 | template 269 | __device__ inline void 270 | fill_zero(nvcuda::wmma::fragment &frag) { 272 | const unsigned size = 273 | 4 * nvcuda::wmma::fragment::num_elements; 275 | detail::fill_zero_core{}(reinterpret_cast(frag.x)); 276 | } 277 | #endif 278 | } // namespace wmma 279 | } // namespace mtk 280 | #endif 281 | -------------------------------------------------------------------------------- /include/wmma_extension/detail/m16n8k8.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_M16N8K8_HPP__ 2 | #define __WMMAE_M16N8K8_HPP__ 3 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 4 | #include "common.hpp" 5 | #include 6 | 7 | namespace mtk { 8 | namespace wmma { 9 | namespace mma { 10 | template <> 11 | class fragment 12 | : public __frag_base {}; 13 | template <> 14 | class fragment 15 | : public __frag_base {}; 16 | template <> 17 | class fragment 18 | : public __frag_base {}; 19 | template <> 20 | class fragment 21 | : public __frag_base {}; 22 | 23 | // foreach 24 | template 25 | __device__ inline void foreach ( 26 | mtk::wmma::mma::fragment &frag, 28 | Func func) { 29 | const unsigned col = (mtk::wmma::detail::common::get_lane_id() % 4) * 2; 30 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() / 4; 31 | 32 | for (unsigned i = 0; i < 2; i++) { 33 | const auto row = row_block_id + i * 8; 34 | { 35 | const unsigned frag_index_list[1] = {(i * 2 + 0)}; 36 | func(frag_index_list, 1, row * 8 + (col + 0)); 37 | } 38 | { 39 | const unsigned frag_index_list[1] = {(i * 2 + 1)}; 40 | func(frag_index_list, 1, row * 8 + (col + 1)); 41 | } 42 | } 43 | } 44 | 45 | template 46 | __device__ inline void foreach ( 47 | mtk::wmma::mma::fragment &frag, 49 | Func func) { 50 | const unsigned col = mtk::wmma::detail::common::get_lane_id() / 4; 51 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() % 4; 52 | 53 | const auto row = row_block_id * 2; 54 | { 55 | const unsigned frag_index_list[1] = {0}; 56 | func(frag_index_list, 1, (row + 0) + col * 8); 57 | } 58 | { 59 | const unsigned frag_index_list[1] = {1}; 60 | func(frag_index_list, 1, (row + 1) + col * 8); 61 | } 62 | } 63 | 64 | template 65 | __device__ inline void foreach ( 66 | mtk::wmma::mma::fragment &frag, 67 | const nvcuda::wmma::layout_t layout, Func func) { 68 | const unsigned col = (mtk::wmma::detail::common::get_lane_id() % 4) * 2; 69 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() / 4; 70 | 71 | for (unsigned i = 0; i < 2; i++) { 72 | const auto row = row_block_id + i * 8; 73 | if (layout == nvcuda::wmma::mem_col_major) { 74 | { 75 | const unsigned frag_index_list[1] = {(i * 2 + 0)}; 76 | func(frag_index_list, 1, row + (col + 0) * 16); 77 | } 78 | { 79 | const unsigned frag_index_list[1] = {(i * 2 + 1)}; 80 | func(frag_index_list, 1, row + (col + 1) * 16); 81 | } 82 | } else { 83 | { 84 | const unsigned frag_index_list[1] = {(i * 2 + 0)}; 85 | func(frag_index_list, 1, row * 8 + (col + 0)); 86 | } 87 | { 88 | const unsigned frag_index_list[1] = {(i * 2 + 1)}; 89 | func(frag_index_list, 1, row * 8 + (col + 1)); 90 | } 91 | } 92 | } 93 | } 94 | 95 | // foreach_ij 96 | template 97 | __device__ inline void 98 | foreach_ij(mtk::wmma::mma::fragment &frag, 100 | Func func) { 101 | const unsigned col = (mtk::wmma::detail::common::get_lane_id() % 4) * 2; 102 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() / 4; 103 | 104 | for (unsigned i = 0; i < 2; i++) { 105 | const auto row = row_block_id + i * 8; 106 | { 107 | const unsigned frag_index_list[1] = {(i * 2 + 0)}; 108 | func(frag_index_list, 1, row, col + 0); 109 | } 110 | { 111 | const unsigned frag_index_list[1] = {(i * 2 + 1)}; 112 | func(frag_index_list, 1, row, col + 1); 113 | } 114 | } 115 | } 116 | 117 | template 118 | __device__ inline void 119 | foreach_ij(mtk::wmma::mma::fragment &frag, 121 | Func func) { 122 | const unsigned col = mtk::wmma::detail::common::get_lane_id() / 4; 123 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() % 4; 124 | 125 | const auto row = row_block_id * 2; 126 | { 127 | const unsigned frag_index_list[1] = {0}; 128 | func(frag_index_list, 1, row + 0, col); 129 | } 130 | { 131 | const unsigned frag_index_list[1] = {1}; 132 | func(frag_index_list, 1, row + 1, col); 133 | } 134 | } 135 | 136 | template 137 | __device__ inline void foreach_ij( 138 | mtk::wmma::mma::fragment &frag, 139 | const nvcuda::wmma::layout_t layout, Func func) { 140 | const unsigned col = (mtk::wmma::detail::common::get_lane_id() % 4) * 2; 141 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() / 4; 142 | 143 | for (unsigned i = 0; i < 2; i++) { 144 | const auto row = row_block_id + i * 8; 145 | if (layout == nvcuda::wmma::mem_col_major) { 146 | { 147 | const unsigned frag_index_list[1] = {(i * 2 + 0)}; 148 | func(frag_index_list, 1, row, col + 0); 149 | } 150 | { 151 | const unsigned frag_index_list[1] = {(i * 2 + 1)}; 152 | func(frag_index_list, 1, row, col + 1); 153 | } 154 | } else { 155 | { 156 | const unsigned frag_index_list[1] = {(i * 2 + 0)}; 157 | func(frag_index_list, 1, row, col + 0); 158 | } 159 | { 160 | const unsigned frag_index_list[1] = {(i * 2 + 1)}; 161 | func(frag_index_list, 1, row, col + 1); 162 | } 163 | } 164 | } 165 | } 166 | 167 | // foreach_v 168 | template 169 | __device__ inline void 170 | foreach_v(mtk::wmma::mma::fragment &frag, 172 | Func func) { 173 | if (mtk::wmma::detail::common::get_lane_id() >= 4) 174 | return; 175 | 176 | { 177 | const unsigned frag_index_list[1] = {0}; 178 | func(frag_index_list, 1, mtk::wmma::detail::common::get_lane_id() * 2 + 0); 179 | } 180 | { 181 | const unsigned frag_index_list[1] = {1}; 182 | func(frag_index_list, 1, mtk::wmma::detail::common::get_lane_id() * 2 + 1); 183 | } 184 | } 185 | 186 | template 187 | __device__ inline void 188 | foreach_v(mtk::wmma::mma::fragment &frag, 190 | Func func) { 191 | if (mtk::wmma::detail::common::get_lane_id() >= 4) 192 | return; 193 | 194 | { 195 | const unsigned frag_index_list[1] = {0}; 196 | func(frag_index_list, 1, mtk::wmma::detail::common::get_lane_id() * 2 + 0); 197 | } 198 | { 199 | const unsigned frag_index_list[1] = {1}; 200 | func(frag_index_list, 1, mtk::wmma::detail::common::get_lane_id() * 2 + 1); 201 | } 202 | } 203 | 204 | template 205 | __device__ inline void foreach_v( 206 | mtk::wmma::mma::fragment &frag, 207 | const nvcuda::wmma::layout_t layout, Func func) { 208 | if (layout == nvcuda::wmma::mem_col_major) { 209 | if (mtk::wmma::detail::common::get_lane_id() & 0b11) 210 | return; 211 | { 212 | const unsigned frag_index_list[1] = {0}; 213 | func(frag_index_list, 1, 214 | mtk::wmma::detail::common::get_lane_id() / 4 + 0); 215 | } 216 | { 217 | const unsigned frag_index_list[1] = {2}; 218 | func(frag_index_list, 1, 219 | mtk::wmma::detail::common::get_lane_id() / 4 + 8); 220 | } 221 | } else { 222 | if (mtk::wmma::detail::common::get_lane_id() >= 4) 223 | return; 224 | { 225 | const unsigned frag_index_list[1] = {0}; 226 | func(frag_index_list, 1, 227 | mtk::wmma::detail::common::get_lane_id() * 2 + 0); 228 | } 229 | { 230 | const unsigned frag_index_list[1] = {1}; 231 | func(frag_index_list, 1, 232 | mtk::wmma::detail::common::get_lane_id() * 2 + 1); 233 | } 234 | } 235 | } 236 | 237 | // Mma 238 | __device__ inline void 239 | mma_sync(fragment &d, 240 | const fragment &a, 242 | const fragment &b, 244 | const fragment &c) { 245 | asm(R"({ 246 | mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 247 | {%0, %1, %2, %3}, 248 | {%4, %5}, 249 | {%6}, 250 | {%7, %8, %9, %10}; 251 | })" 252 | : "=f"(d.x[0]), "=f"(d.x[1]), "=f"(d.x[2]), "=f"(d.x[3]) 253 | : "r"(*reinterpret_cast(a.x)), 254 | "r"(*reinterpret_cast(a.x + 2)), 255 | "r"(*reinterpret_cast(b.x)), "f"(c.x[0]), "f"(c.x[1]), 256 | "f"(c.x[2]), "f"(c.x[3])); 257 | } 258 | } // namespace mma 259 | } // namespace wmma 260 | } // namespace mtk 261 | 262 | #endif /* end of include guard */ 263 | -------------------------------------------------------------------------------- /include/wmma_extension/detail/m16n8k8_tf32.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_M16N8K8_TF32_HPP__ 2 | #define __WMMAE_M16N8K8_TF32_HPP__ 3 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 4 | #include "common.hpp" 5 | #include 6 | 7 | namespace mtk { 8 | namespace wmma { 9 | namespace mma { 10 | template <> 11 | class fragment : public __frag_base {}; 13 | template <> 14 | class fragment : public __frag_base {}; 16 | // The accumulator is same with m16n8k8-float for nvcuda::wmma::precision::tf32 17 | // template <> class fragment : 18 | // public __frag_base{}; 19 | 20 | // foreach 21 | template 22 | __device__ inline void foreach ( 23 | mtk::wmma::mma::fragment &frag, 26 | Func func) { 27 | const unsigned col = (mtk::wmma::detail::common::get_lane_id() % 4); 28 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() / 4; 29 | 30 | { 31 | const unsigned frag_index_list[1] = {0}; 32 | func(frag_index_list, 1, (row_block_id + 0) * 8 + (col + 0)); 33 | } 34 | { 35 | const unsigned frag_index_list[1] = {1}; 36 | func(frag_index_list, 1, (row_block_id + 8) * 8 + (col + 0)); 37 | } 38 | { 39 | const unsigned frag_index_list[1] = {2}; 40 | func(frag_index_list, 1, (row_block_id + 0) * 8 + (col + 4)); 41 | } 42 | { 43 | const unsigned frag_index_list[1] = {3}; 44 | func(frag_index_list, 1, (row_block_id + 8) * 8 + (col + 4)); 45 | } 46 | } 47 | 48 | template 49 | __device__ inline void foreach ( 50 | mtk::wmma::mma::fragment &frag, 53 | Func func) { 54 | const unsigned col = mtk::wmma::detail::common::get_lane_id() / 4; 55 | const unsigned row_start = mtk::wmma::detail::common::get_lane_id() % 4; 56 | 57 | { 58 | const unsigned frag_index_list[1] = {0}; 59 | func(frag_index_list, 1, (row_start + 0) + col * 8); 60 | } 61 | { 62 | const unsigned frag_index_list[1] = {1}; 63 | func(frag_index_list, 1, (row_start + 4) + col * 8); 64 | } 65 | } 66 | 67 | // foreach_ij 68 | template 69 | __device__ inline void 70 | foreach_ij(mtk::wmma::mma::fragment &frag, 73 | Func func) { 74 | const unsigned col = (mtk::wmma::detail::common::get_lane_id() % 4); 75 | const unsigned row_block_id = mtk::wmma::detail::common::get_lane_id() / 4; 76 | 77 | { 78 | const unsigned frag_index_list[1] = {0}; 79 | func(frag_index_list, 1, (row_block_id + 0), (col + 0)); 80 | } 81 | { 82 | const unsigned frag_index_list[1] = {1}; 83 | func(frag_index_list, 1, (row_block_id + 8), (col + 0)); 84 | } 85 | { 86 | const unsigned frag_index_list[1] = {2}; 87 | func(frag_index_list, 1, (row_block_id + 0), (col + 4)); 88 | } 89 | { 90 | const unsigned frag_index_list[1] = {3}; 91 | func(frag_index_list, 1, (row_block_id + 8), (col + 4)); 92 | } 93 | } 94 | 95 | template 96 | __device__ inline void 97 | foreach_ij(mtk::wmma::mma::fragment &frag, 100 | Func func) { 101 | const unsigned col = mtk::wmma::detail::common::get_lane_id() / 4; 102 | const unsigned row_start = mtk::wmma::detail::common::get_lane_id() % 4; 103 | 104 | { 105 | const unsigned frag_index_list[1] = {0}; 106 | func(frag_index_list, 1, (row_start + 0), col); 107 | } 108 | { 109 | const unsigned frag_index_list[1] = {1}; 110 | func(frag_index_list, 1, (row_start + 4), col); 111 | } 112 | } 113 | 114 | // foreach_v 115 | template 116 | __device__ inline void 117 | foreach_v(mtk::wmma::mma::fragment &frag, 120 | Func func) { 121 | if (mtk::wmma::detail::common::get_lane_id() >= 4) 122 | return; 123 | 124 | { 125 | const unsigned frag_index_list[1] = {0}; 126 | func(frag_index_list, 1, (mtk::wmma::detail::common::get_lane_id() + 0)); 127 | } 128 | { 129 | const unsigned frag_index_list[1] = {2}; 130 | func(frag_index_list, 1, (mtk::wmma::detail::common::get_lane_id() + 4)); 131 | } 132 | } 133 | 134 | template 135 | __device__ inline void 136 | foreach_v(mtk::wmma::mma::fragment &frag, 139 | Func func) { 140 | if (mtk::wmma::detail::common::get_lane_id() >= 4) 141 | return; 142 | 143 | { 144 | const unsigned frag_index_list[1] = {0}; 145 | func(frag_index_list, 1, mtk::wmma::detail::common::get_lane_id() + 0); 146 | } 147 | { 148 | const unsigned frag_index_list[1] = {1}; 149 | func(frag_index_list, 1, mtk::wmma::detail::common::get_lane_id() + 4); 150 | } 151 | } 152 | 153 | // Mma 154 | __device__ inline void mma_sync( 155 | fragment &d, 156 | const fragment &a, 158 | const fragment &b, 160 | const fragment &c) { 161 | asm(R"({ 162 | mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 163 | {%0, %1, %2, %3}, 164 | {%4, %5, %6, %7}, 165 | {%8, %9}, 166 | {%10, %11, %12, %13}; 167 | })" 168 | : "=f"(d.x[0]), "=f"(d.x[1]), "=f"(d.x[2]), "=f"(d.x[3]) 169 | : "r"(*reinterpret_cast(a.x)), 170 | "r"(*reinterpret_cast(a.x + 1)), 171 | "r"(*reinterpret_cast(a.x + 2)), 172 | "r"(*reinterpret_cast(a.x + 3)), 173 | "r"(*reinterpret_cast(b.x)), 174 | "r"(*reinterpret_cast(b.x + 1)), "f"(c.x[0]), 175 | "f"(c.x[1]), "f"(c.x[2]), "f"(c.x[3])); 176 | } 177 | } // namespace mma 178 | } // namespace wmma 179 | } // namespace mtk 180 | 181 | #endif /* end of include guard */ 182 | -------------------------------------------------------------------------------- /include/wmma_extension/detail/operators.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_DETAIL_OPERATORS__ 2 | #define __WMMAE_DETAIL_OPERATORS__ 3 | #include 4 | 5 | namespace mtk { 6 | namespace wmma { 7 | namespace ops { 8 | 9 | // Add 10 | template struct add { 11 | __device__ nvcuda::wmma::fragment 12 | operator()(const nvcuda::wmma::fragment &a, 13 | const nvcuda::wmma::fragment &b) { 14 | nvcuda::wmma::fragment res; 15 | for (unsigned i = 0; i < res.num_elements; i++) { 16 | res.x[i] = a.x[i] + b.x[i]; 17 | } 18 | return res; 19 | } 20 | }; 21 | 22 | template 23 | struct add { 24 | __device__ nvcuda::wmma::fragment 25 | operator()(const nvcuda::wmma::fragment &a, 26 | const nvcuda::wmma::fragment &b) { 27 | nvcuda::wmma::fragment res; 28 | using simd_t = half2; 29 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 30 | reinterpret_cast(res.x)[i] = 31 | __hadd2(reinterpret_cast(a.x)[i], 32 | reinterpret_cast(b.x)[i]); 33 | } 34 | return res; 35 | } 36 | }; 37 | 38 | template 39 | struct add { 40 | __device__ nvcuda::wmma::fragment 41 | operator()( 42 | const nvcuda::wmma::fragment &a, 43 | const nvcuda::wmma::fragment &b) { 44 | nvcuda::wmma::fragment res; 45 | using simd_t = __nv_bfloat162; 46 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 47 | reinterpret_cast(res.x)[i] = 48 | __hadd2(reinterpret_cast(a.x)[i], 49 | reinterpret_cast(b.x)[i]); 50 | } 51 | return res; 52 | } 53 | }; 54 | 55 | // Sub 56 | template struct sub { 57 | __device__ nvcuda::wmma::fragment 58 | operator()(const nvcuda::wmma::fragment &a, 59 | const nvcuda::wmma::fragment &b) { 60 | nvcuda::wmma::fragment res; 61 | for (unsigned i = 0; i < res.num_elements; i++) { 62 | res.x[i] = a.x[i] - b.x[i]; 63 | } 64 | return res; 65 | } 66 | }; 67 | 68 | template 69 | struct sub { 70 | __device__ nvcuda::wmma::fragment 71 | operator()(const nvcuda::wmma::fragment &a, 72 | const nvcuda::wmma::fragment &b) { 73 | nvcuda::wmma::fragment res; 74 | using simd_t = half2; 75 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 76 | reinterpret_cast(res.x)[i] = 77 | __hsub2(reinterpret_cast(a.x)[i], 78 | reinterpret_cast(b.x)[i]); 79 | } 80 | return res; 81 | } 82 | }; 83 | 84 | template 85 | struct sub { 86 | __device__ nvcuda::wmma::fragment 87 | operator()( 88 | const nvcuda::wmma::fragment &a, 89 | const nvcuda::wmma::fragment &b) { 90 | nvcuda::wmma::fragment res; 91 | using simd_t = __nv_bfloat162; 92 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 93 | reinterpret_cast(res.x)[i] = 94 | __hsub2(reinterpret_cast(a.x)[i], 95 | reinterpret_cast(b.x)[i]); 96 | } 97 | return res; 98 | } 99 | }; 100 | 101 | // Mul 102 | template struct mul { 103 | __device__ nvcuda::wmma::fragment operator()( 104 | const nvcuda::wmma::fragment &a, 105 | const typename nvcuda::wmma::fragment::storage_element_type b) { 107 | nvcuda::wmma::fragment res; 108 | for (unsigned i = 0; i < res.num_elements; i++) { 109 | res.x[i] = a.x[i] * b; 110 | } 111 | return res; 112 | } 113 | }; 114 | 115 | template 116 | struct mul { 117 | __device__ nvcuda::wmma::fragment operator()( 118 | const nvcuda::wmma::fragment &a, 119 | const typename nvcuda::wmma::fragment::storage_element_type b) { 121 | nvcuda::wmma::fragment res; 122 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 123 | reinterpret_cast(res.x)[i] = 124 | __hmul2(reinterpret_cast(a.x)[i], __half2half2(b)); 125 | } 126 | return res; 127 | } 128 | }; 129 | 130 | template 131 | struct mul { 132 | __device__ nvcuda::wmma::fragment 133 | operator()( 134 | const nvcuda::wmma::fragment &a, 135 | const typename nvcuda::wmma::fragment::storage_element_type b) { 137 | nvcuda::wmma::fragment res; 138 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 139 | reinterpret_cast<__nv_bfloat162 *>(res.x)[i] = 140 | __hmul2(reinterpret_cast(a.x)[i], 141 | __bfloat162bfloat162(b)); 142 | } 143 | return res; 144 | } 145 | }; 146 | 147 | // Fma 148 | template 150 | struct fma { 151 | __device__ nvcuda::wmma::fragment 152 | operator()(const A_Type alpha, 153 | const nvcuda::wmma::fragment &a, 154 | const nvcuda::wmma::fragment &b) { 155 | nvcuda::wmma::fragment res; 156 | for (unsigned i = 0; i < res.num_elements; i++) { 157 | res.x[i] = __fmaf_rn(alpha, a.x[i], b.x[i]); 158 | } 159 | return res; 160 | } 161 | }; 162 | 163 | template 164 | struct fma { 165 | __device__ nvcuda::wmma::fragment 166 | operator()(const half alpha, 167 | const nvcuda::wmma::fragment &a, 168 | const nvcuda::wmma::fragment &b) { 169 | nvcuda::wmma::fragment res; 170 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 171 | reinterpret_cast(res.x)[i] = 172 | __hfma2(__half2half2(alpha), reinterpret_cast(a.x)[i], 173 | reinterpret_cast(b.x)[i]); 174 | } 175 | return res; 176 | } 177 | }; 178 | 179 | template 180 | struct fma { 181 | __device__ nvcuda::wmma::fragment 182 | operator()( 183 | const __nv_bfloat16 alpha, 184 | const nvcuda::wmma::fragment &a, 185 | const nvcuda::wmma::fragment &b) { 186 | nvcuda::wmma::fragment res; 187 | for (unsigned i = 0; i < res.num_elements / 2; i++) { 188 | reinterpret_cast<__nv_bfloat162 *>(res.x)[i] = 189 | __hfma2(__bfloat162bfloat162(alpha), 190 | reinterpret_cast(a.x)[i], 191 | reinterpret_cast(b.x)[i]); 192 | } 193 | return res; 194 | } 195 | }; 196 | 197 | // Div 198 | template struct div { 199 | __device__ nvcuda::wmma::fragment operator()( 200 | const nvcuda::wmma::fragment &a, 201 | const typename nvcuda::wmma::fragment::storage_element_type b) { 203 | nvcuda::wmma::fragment res; 204 | for (unsigned i = 0; i < res.num_elements; i++) { 205 | res.x[i] = a.x[i] / b; 206 | } 207 | return res; 208 | } 209 | }; 210 | } // namespace ops 211 | } // namespace wmma 212 | } // namespace mtk 213 | #endif 214 | -------------------------------------------------------------------------------- /include/wmma_extension/operators.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_OPERATORS__ 2 | #define __WMMAE_OPERATORS__ 3 | #include "detail/common.hpp" 4 | #include "detail/operators.hpp" 5 | 6 | template 7 | __device__ nvcuda::wmma::fragment 8 | operator+(const nvcuda::wmma::fragment &a, 9 | const nvcuda::wmma::fragment &b) { 10 | return mtk::wmma::ops::add{}(a, b); 11 | } 12 | 13 | template 14 | __device__ nvcuda::wmma::fragment 15 | operator-(const nvcuda::wmma::fragment &a, 16 | const nvcuda::wmma::fragment &b) { 17 | return mtk::wmma::ops::sub{}(a, b); 18 | } 19 | 20 | template 21 | __device__ nvcuda::wmma::fragment operator*( 22 | const nvcuda::wmma::fragment &a, 23 | const typename nvcuda::wmma::fragment::storage_element_type b) { 25 | return mtk::wmma::ops::mul{}(a, b); 26 | } 27 | 28 | template 29 | __device__ nvcuda::wmma::fragment operator/( 30 | const nvcuda::wmma::fragment &a, 31 | const typename nvcuda::wmma::fragment::storage_element_type b) { 33 | return mtk::wmma::ops::div{}(a, b); 34 | } 35 | 36 | namespace mtk { 37 | namespace wmma { 38 | template 39 | __device__ nvcuda::wmma::fragment 40 | fma(const typename nvcuda::wmma::fragment::storage_element_type alpha, 42 | const nvcuda::wmma::fragment &a, 43 | const nvcuda::wmma::fragment &b) { 44 | using A_Type = typename nvcuda::wmma::fragment::storage_element_type; 46 | return mtk::wmma::ops::fma{}(alpha, a, b); 47 | } 48 | 49 | template 50 | __device__ nvcuda::wmma::fragment 51 | fma(const nvcuda::wmma::fragment &a, 52 | const typename nvcuda::wmma::fragment::storage_element_type alpha, 54 | const nvcuda::wmma::fragment &b) { 55 | using A_Type = typename nvcuda::wmma::fragment::storage_element_type; 57 | return mtk::wmma::ops::fma{}(alpha, b, a); 58 | } 59 | } // namespace wmma 60 | } // namespace mtk 61 | #endif 62 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __MTK_HMMA_F32_F32_DETAIL_HPP__ 2 | #define __MTK_HMMA_F32_F32_DETAIL_HPP__ 3 | #include "wmma_extension_include.hpp" 4 | #include 5 | #include 6 | #include 7 | 8 | namespace mtk { 9 | namespace wmma { 10 | namespace tcec { 11 | namespace detail { 12 | template struct select_value {}; 13 | template 14 | struct select_value { 15 | const static int value = a; 16 | }; 17 | template 18 | struct select_value { 19 | const static int value = b; 20 | }; 21 | template 22 | struct select_value { 23 | const static int value = c; 24 | }; 25 | 26 | template __device__ constexpr int get_fragment_k() { return 16; }; 27 | template <> 28 | __device__ constexpr int get_fragment_k() { 29 | return 8; 30 | } 31 | 32 | template struct compute_mem_offset { 33 | // calculate memory offset from mem_index given by foreach 34 | __device__ unsigned operator()(const unsigned mem_offset, const unsigned ldm, 35 | const unsigned m_offset, 36 | const unsigned n_offset) { 37 | return ((mem_offset % frag_n) + n_offset) + 38 | (mem_offset / frag_n + m_offset) * ldm; 39 | } 40 | // calculate memory offset from matrix position (i,j) given by foreach_ij 41 | __device__ unsigned operator()(const unsigned i, const unsigned j, 42 | const unsigned ldm, const unsigned m_offset, 43 | const unsigned n_offset) { 44 | return (j + n_offset) + (i + m_offset) * ldm; 45 | } 46 | }; 47 | 48 | template 49 | struct compute_mem_offset { 50 | // calculate memory offset from mem_index given by foreach 51 | __device__ unsigned operator()(const unsigned mem_offset, const unsigned ldm, 52 | const unsigned m_offset, 53 | const unsigned n_offset) { 54 | return (mem_offset % frag_m + m_offset) + 55 | ((mem_offset / frag_m) + n_offset) * ldm; 56 | } 57 | // calculate memory offset from matrix position (i,j) given by foreach_ij 58 | __device__ unsigned operator()(const unsigned i, const unsigned j, 59 | const unsigned ldm, const unsigned m_offset, 60 | const unsigned n_offset) { 61 | return (i + m_offset) + (j + n_offset) * ldm; 62 | } 63 | }; 64 | 65 | template struct sub_frag_t { 66 | using type = T; 67 | }; 68 | template <> struct sub_frag_t { 69 | using type = float; 70 | }; 71 | template <> 72 | struct sub_frag_t { 73 | using type = float; 74 | }; 75 | 76 | template struct layout_switch; 77 | template struct layout_switch { 78 | const static int value = a; 79 | }; 80 | template struct layout_switch { 81 | const static int value = b; 82 | }; 83 | } // namespace detail 84 | } // namespace tcec 85 | } // namespace wmma 86 | } // namespace mtk 87 | #endif 88 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/functions_simt.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_HMMA_F32_F32_DETAIL_FUNCTIONS_SIMT_HPP__ 2 | #define __WMMAE_HMMA_F32_F32_DETAIL_FUNCTIONS_SIMT_HPP__ 3 | #include "policy_simt.hpp" 4 | #include "wmma_extension_simt_include.hpp" 5 | namespace mtk { 6 | namespace wmma { 7 | namespace tcec { 8 | namespace detail { 9 | // foreach 10 | template 12 | struct foreach_wrapper< 13 | Use, T, Layout, 14 | Policy> { 15 | template __device__ void operator()(Func func) { 16 | mtk::wmma::mma_simt::foreach< 17 | typename mtk::wmma::mma_simt::fragment>( 18 | func); 19 | } 20 | template 21 | __device__ void operator()(const nvcuda::wmma::layout_t layout, Func func) { 22 | mtk::wmma::mma_simt::foreach< 23 | typename mtk::wmma::mma_simt::fragment>( 24 | layout, func); 25 | } 26 | }; 27 | 28 | // foreach_ij 29 | template 31 | struct foreach_ij_wrapper< 32 | Use, T, Layout, 33 | Policy> { 34 | template __device__ void operator()(Func func) { 35 | mtk::wmma::mma_simt::foreach_ij< 36 | typename mtk::wmma::mma_simt::fragment>( 37 | func); 38 | } 39 | template 40 | __device__ void operator()(const nvcuda::wmma::layout_t layout, Func func) { 41 | mtk::wmma::mma_simt::foreach_ij< 42 | typename mtk::wmma::mma_simt::fragment>( 43 | layout, func); 44 | } 45 | }; 46 | 47 | // foreach_v 48 | template 50 | struct foreach_v_wrapper< 51 | Use, T, Layout, 52 | Policy> { 53 | template __device__ void operator()(Func func) { 54 | mtk::wmma::mma_simt::foreach_v< 55 | typename mtk::wmma::mma_simt::fragment>( 56 | func); 57 | } 58 | template 59 | __device__ void operator()(const nvcuda::wmma::layout_t layout, Func func) { 60 | mtk::wmma::mma_simt::foreach_v< 61 | typename mtk::wmma::mma_simt::fragment>( 62 | layout, func); 63 | } 64 | }; 65 | 66 | // fill zero 67 | template 69 | struct fill_zero_wrapper< 70 | Use, T, Layout, 71 | Policy> { 72 | __device__ void 73 | operator()(mtk::wmma::mma_simt::fragment &frag) { 74 | mtk::wmma::mma_simt::fill_zero(frag); 75 | } 76 | }; 77 | 78 | // load_matrix_sync 79 | template 81 | struct load_matrix_sync_wrapper< 82 | Use, T, Layout, 83 | Policy> { 84 | __device__ void 85 | operator()(mtk::wmma::mma_simt::fragment &frag, 86 | const float *const ptr, const unsigned ldm, 87 | const nvcuda::wmma::layout_t layout) { 88 | mtk::wmma::mma_simt::load_matrix_sync(frag, ptr, ldm, layout); 89 | } 90 | __device__ void 91 | operator()(mtk::wmma::mma_simt::fragment &frag, 92 | const float *const ptr, const unsigned ldm) { 93 | mtk::wmma::mma_simt::load_matrix_sync(frag, ptr, ldm); 94 | } 95 | }; 96 | 97 | // store_matrix_sync 98 | template 99 | struct store_matrix_sync_wrapper< 100 | T, Policy> { 101 | __device__ void 102 | operator()(float *ptr, 103 | mtk::wmma::mma_simt::fragment &frag, 105 | const unsigned ldm, const nvcuda::wmma::layout_t layout) { 106 | mtk::wmma::mma_simt::store_matrix_sync(ptr, frag, ldm, layout); 107 | } 108 | }; 109 | 110 | // load_vector 111 | template 113 | struct load_vector_wrapper< 114 | Use, T, Layout, 115 | Policy> { 116 | __device__ void 117 | operator()(mtk::wmma::mma_simt::fragment &frag, 118 | const float *const ptr, const nvcuda::wmma::layout_t layout) { 119 | mtk::wmma::mma_simt::load_vector(frag, ptr, layout); 120 | } 121 | __device__ void 122 | operator()(mtk::wmma::mma_simt::fragment &frag, 123 | const float *const ptr) { 124 | mtk::wmma::mma_simt::load_vector(frag, ptr); 125 | } 126 | }; 127 | 128 | // store_vector 129 | template 131 | struct store_vector_wrapper< 132 | Use, T, Layout, 133 | Policy> { 134 | __device__ void 135 | operator()(float *ptr, 136 | mtk::wmma::mma_simt::fragment &frag, 137 | const nvcuda::wmma::layout_t layout) { 138 | mtk::wmma::mma_simt::store_vector(ptr, frag, layout); 139 | } 140 | }; 141 | 142 | // fill_fragment 143 | template 145 | struct fill_fragment_wrapper< 146 | Use, T, Layout, 147 | Policy, VT> { 148 | __device__ void 149 | operator()(mtk::wmma::mma_simt::fragment &frag, 150 | const VT v) { 151 | mtk::wmma::mma_simt::fill_fragment(frag, v); 152 | } 153 | }; 154 | 155 | // mma_sync 156 | template 158 | struct mma_sync_wrapper< 159 | AB_T, A_Layout, B_Layout, CD_T, 160 | Policy> { 161 | using Fragment_A = mtk::wmma::mma_simt::fragment; 163 | using Fragment_B = mtk::wmma::mma_simt::fragment; 165 | using Fragment_C = mtk::wmma::mma_simt::fragment; 167 | __device__ void operator()(Fragment_C &d, const Fragment_A &a, 168 | const Fragment_B &b, const Fragment_C &c) { 169 | mtk::wmma::mma_simt::mma_sync(d, a, b, c); 170 | } 171 | }; 172 | 173 | } // namespace detail 174 | } // namespace tcec 175 | } // namespace wmma 176 | } // namespace mtk 177 | #endif 178 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/policy.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_HMMA_F32_F32_DETAIL_POLICY_HPP__ 2 | #define __WMMAE_HMMA_F32_F32_DETAIL_POLICY_HPP__ 3 | #include "wmma_extension_include.hpp" 4 | #include 5 | 6 | namespace mtk { 7 | namespace wmma { 8 | 9 | namespace tcec { 10 | 11 | // Instruction policy 12 | struct op_mma; 13 | struct op_wmma; 14 | 15 | // Error correction policy 16 | struct with_ec; 17 | struct without_ec; 18 | // Alias for compatibility 19 | using op_with_error_correction = with_ec; 20 | using op_without_error_correction = without_ec; 21 | 22 | struct sm_70; 23 | struct sm_75; 24 | struct sm_80; 25 | struct sm_86; 26 | struct sm_89; 27 | struct sm_90; 28 | struct sm_not_specified; 29 | 30 | #ifdef __CUDA_ARCH__ 31 | #if __CUDA_ARCH__ <= 600 32 | #error "CC <= 6.0 is not supported" 33 | #elif __CUDA_ARCH__ < 750 34 | using sm_auto = sm_70; 35 | #elif __CUDA_ARCH__ < 800 36 | using sm_auto = sm_75; 37 | #elif __CUDA_ARCH__ < 860 38 | using sm_auto = sm_80; 39 | #elif __CUDA_ARCH__ < 890 40 | using sm_auto = sm_86; 41 | #elif __CUDA_ARCH__ < 900 42 | using sm_auto = sm_89; 43 | #else 44 | using sm_auto = sm_90; 45 | #endif 46 | #else 47 | using sm_auto = sm_not_specified; 48 | #endif 49 | 50 | template 51 | struct Policy { 52 | using op = Op; 53 | using error_correction = ErrorCorrection; 54 | static const int m = m_; 55 | static const int n = n_; 56 | static const int k = k_; 57 | }; 58 | 59 | namespace detail { 60 | // =================================== 61 | // Default policy selector 62 | // =================================== 63 | template 66 | struct default_policy; 67 | 68 | template 69 | struct default_policy { 70 | using type = mtk::wmma::tcec::Policy; 72 | }; 73 | 74 | template 75 | struct default_policy { 77 | using type = mtk::wmma::tcec::Policy; 79 | }; 80 | 81 | template 82 | struct default_policy { 83 | using type = mtk::wmma::tcec::Policy; 85 | }; 86 | 87 | template 88 | struct default_policy { 90 | using type = mtk::wmma::tcec::Policy; 92 | }; 93 | 94 | template 95 | struct default_policy { 97 | using type = mtk::wmma::tcec::Policy; 99 | }; 100 | 101 | // =================================== 102 | // Default fragment selector 103 | // =================================== 104 | template 105 | struct default_fragment; 106 | 107 | template 109 | struct default_fragment> { 111 | using type = nvcuda::wmma::fragment; 112 | }; 113 | 114 | template 116 | struct default_fragment> { 118 | using type = mtk::wmma::mma::fragment; 119 | }; 120 | } // namespace detail 121 | 122 | template 125 | using default_policy = detail::default_policy; 126 | 127 | } // namespace tcec 128 | } // namespace wmma 129 | } // namespace mtk 130 | #endif 131 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/policy_simt.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_HMMA_F32_F32_DETAIL_POLICY_SIMT_HPP__ 2 | #define __WMMAE_HMMA_F32_F32_DETAIL_POLICY_SIMT_HPP__ 3 | #include "policy.hpp" 4 | #include "wmma_extension_simt_include.hpp" 5 | #include 6 | 7 | namespace mtk { 8 | namespace wmma { 9 | 10 | namespace tcec { 11 | 12 | // Instruction policy 13 | struct op_simt; 14 | 15 | namespace detail { 16 | // =================================== 17 | // Default policy selector 18 | // =================================== 19 | template 20 | struct default_policy { 22 | using type = mtk::wmma::tcec::Policy; 24 | }; 25 | 26 | // =================================== 27 | // Default fragment selector 28 | // =================================== 29 | template 31 | struct default_fragment> { 33 | using type = mtk::wmma::mma_simt::fragment; 34 | }; 35 | } // namespace detail 36 | 37 | } // namespace tcec 38 | } // namespace wmma 39 | } // namespace mtk 40 | #endif 41 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/print.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_TCEC_DETAIL_PRINT_HPP__ 2 | #define __WMMAE_TCEC_DETAIL_PRINT_HPP__ 3 | #include "common.hpp" 4 | #include "policy.hpp" 5 | #include "scale.hpp" 6 | #include 7 | 8 | namespace mtk { 9 | namespace wmma { 10 | namespace tcec { 11 | template 13 | __device__ void 14 | print_fragment(const mtk::wmma::tcec::fragment< 15 | Use, m, n, k, Type, Layout, 16 | Policy> &frag, 17 | const char *const name = "") { 18 | if (*name != '\0' && mtk::wmma::detail::common::get_lane_id() == 0) { 19 | printf("%s = \n", name); 20 | } 21 | 22 | __syncwarp(); 23 | for (unsigned i = 0; i < 32; i++) { 24 | if (i == mtk::wmma::detail::common::get_lane_id()) { 25 | for (unsigned i = 0; i < frag.num_elements; i++) { 26 | printf("(%+.3e)+(%+.3e) ", frag.x(i), 27 | detail::correction_scale_1(frag.dx(i))); 28 | } 29 | printf("\n"); 30 | } 31 | __syncwarp(); 32 | } 33 | } 34 | template 36 | __device__ void 37 | print_fragment(const mtk::wmma::tcec::fragment< 38 | Use, m, n, k, Type, Layout, 39 | Policy> &frag, 40 | const char *const name = "") { 41 | if (*name != '\0' && mtk::wmma::detail::common::get_lane_id() == 0) { 42 | printf("%s = \n", name); 43 | } 44 | 45 | __syncwarp(); 46 | for (unsigned i = 0; i < 32; i++) { 47 | if (i == mtk::wmma::detail::common::get_lane_id()) { 48 | for (unsigned i = 0; i < frag.num_elements; i++) { 49 | printf("%+.3e ", frag.x(i)); 50 | } 51 | printf("\n"); 52 | } 53 | __syncwarp(); 54 | } 55 | } 56 | } // namespace tcec 57 | } // namespace wmma 58 | } // namespace mtk 59 | #endif 60 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/scale.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_TCEC_DETAIL_SCALE_HPP__ 2 | #define __WMMAE_TCEC_DETAIL_SCALE_HPP__ 3 | 4 | namespace mtk { 5 | namespace wmma { 6 | namespace tcec { 7 | namespace detail { 8 | template __device__ inline float correction_scale_0(const float v) { 9 | return v; 10 | } 11 | template <> __device__ inline float correction_scale_0(const float v) { 12 | return v * 2048; 13 | } 14 | 15 | template __device__ inline float correction_scale_1(const float v) { 16 | return v; 17 | } 18 | template <> __device__ inline float correction_scale_1(const float v) { 19 | return v / 2048; 20 | } 21 | } // namespace detail 22 | } // namespace tcec 23 | } // namespace wmma 24 | } // namespace mtk 25 | 26 | #endif 27 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/simt/detail/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_SIMT_DETAIL_COMMON__ 2 | #define __WMMAE_SIMT_DETAIL_COMMON__ 3 | #include 4 | 5 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800 6 | namespace nvcuda { 7 | namespace wmma { 8 | namespace precision { 9 | class tf32; 10 | } // namespace precision 11 | } // namespace wmma 12 | } // namespace nvcuda 13 | #endif 14 | 15 | namespace mtk { 16 | namespace wmma { 17 | namespace mma_simt { 18 | template 19 | class fragment; 20 | 21 | namespace detail { 22 | template struct __align__(4) __frag_base { 23 | T x[size]; 24 | enum { num_elements = size }; 25 | }; 26 | template struct __align__(2) __frag_base { 27 | half x[size]; 28 | enum { num_elements = size }; 29 | }; 30 | template struct __align__(8) __frag_base { 31 | double x[size]; 32 | enum { num_elements = size }; 33 | }; 34 | 35 | template struct select_value; 36 | template 37 | struct select_value { 38 | static const int value = M; 39 | }; 40 | template 41 | struct select_value { 42 | static const int value = N; 43 | }; 44 | template 45 | struct select_value { 46 | static const int value = K; 47 | }; 48 | 49 | template 50 | struct get_M : public select_value {}; 51 | template 52 | struct get_N : public select_value {}; 53 | 54 | template struct layout_switch; 55 | template 56 | struct layout_switch { 57 | static const int value = col_value; 58 | }; 59 | template 60 | struct layout_switch { 61 | static const int value = row_value; 62 | }; 63 | 64 | template struct storage_t { 65 | using type = T; 66 | }; 67 | template 68 | inline __device__ __host__ typename storage_t::type cast(const SRC v) { 69 | return static_cast(v); 70 | } 71 | template <> 72 | inline __device__ __host__ typename storage_t::type 73 | cast(const half v) { 74 | return __half2float(v); 75 | } 76 | template <> 77 | inline __device__ __host__ typename storage_t::type 78 | cast(const float v) { 79 | return __float2half(v); 80 | } 81 | 82 | template <> struct storage_t { 83 | using type = float; 84 | }; 85 | __device__ __host__ inline float to_tf32(const float a) { 86 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 87 | float ret; 88 | asm("{.reg .b32 %mr;\n" 89 | "cvt.rna.tf32.f32 %mr, %1;\n" 90 | "mov.b32 %0, %mr;}\n" 91 | : "=f"(ret) 92 | : "f"(a)); 93 | return ret; 94 | #else 95 | return a; 96 | #endif 97 | } 98 | template <> 99 | inline __device__ __host__ 100 | typename storage_t::type 101 | cast(const float v) { 102 | return to_tf32(v); 103 | } 104 | 105 | } // namespace detail 106 | 107 | template 108 | __device__ inline void fill_fragment(detail::__frag_base &f, 109 | const S v) { 110 | #pragma unroll 111 | for (unsigned i = 0; i < f.num_elements; i++) 112 | f.x[i] = detail::cast(v); 113 | } 114 | 115 | template 116 | __device__ inline void fill_zero(detail::__frag_base &f) { 117 | fill_fragment(f, 0.0f); 118 | } 119 | 120 | } // namespace mma_simt 121 | } // namespace wmma 122 | } // namespace mtk 123 | #endif 124 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/simt/detail/fma.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_MMA_SIMT_DETAIL_FMA_HPP__ 2 | #define __WMMAE_MMA_SIMT_DETAIL_FMA_HPP__ 3 | #include "common.hpp" 4 | 5 | namespace mtk { 6 | namespace wmma { 7 | namespace mma_simt { 8 | namespace detail { 9 | 10 | template struct fma { 11 | virtual __device__ C_T operator()(const A_T a, const B_T b, 12 | const C_T c) const { 13 | const auto fa = cast(a); 14 | const auto fb = cast(b); 15 | const auto fc = cast(c); 16 | return fa * fb + fc; 17 | } 18 | }; 19 | 20 | template <> struct fma { 21 | virtual __device__ double operator()(const double a, const double b, 22 | const double c) const { 23 | return a * b + c; 24 | } 25 | }; 26 | } // namespace detail 27 | } // namespace mma_simt 28 | } // namespace wmma 29 | } // namespace mtk 30 | #endif 31 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/wmma_extension_include.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_MMA_F32_F32_DETAIL_INCLUDE_HPP__ 2 | #define __WMMAE_MMA_F32_F32_DETAIL_INCLUDE_HPP__ 3 | #include "../../wmma_extension.hpp" 4 | #include "../../wmma_mma.hpp" 5 | #endif 6 | -------------------------------------------------------------------------------- /include/wmma_extension/tcec/detail/wmma_extension_simt_include.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_MMA_F32_F32_DETAIL_INCLUDE_SIMT_HPP__ 2 | #define __WMMAE_MMA_F32_F32_DETAIL_INCLUDE_SIMT_HPP__ 3 | #include "./simt/mma_simt.hpp" 4 | #endif 5 | -------------------------------------------------------------------------------- /include/wmma_extension/utils.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_UTILS_HPP__ 2 | #define __WMMAE_UTILS_HPP__ 3 | #include "detail/common.hpp" 4 | #include 5 | 6 | namespace mtk { 7 | namespace wmma { 8 | namespace utils { 9 | namespace detail { 10 | __device__ inline uint32_t get_smem_ptr_uint(const void *const ptr) { 11 | uint32_t smem_ptr; 12 | asm volatile("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; " 13 | "cvt.u32.u64 %0, smem_ptr; }\n" 14 | : "=r"(smem_ptr) 15 | : "l"(ptr)); 16 | return smem_ptr; 17 | } 18 | } // namespace detail 19 | 20 | // cast 21 | template 22 | __device__ __host__ inline 23 | typename mtk::wmma::detail::common::storage_t::type 24 | cast(const SRC_T v) { 25 | return mtk::wmma::detail::common::cast(v); 26 | } 27 | 28 | // async copy 29 | namespace cp_async { 30 | template 31 | __device__ inline void cp_async(void *const smem, const void *const gmem) { 32 | #if __CUDA_ARCH__ >= 800 33 | static_assert(Size == 4 || Size == 8 || Size == 16, 34 | "Size must be one of 4, 8 and 16"); 35 | const unsigned smem_int_ptr = detail::get_smem_ptr_uint(smem); 36 | asm volatile( 37 | "{cp.async.ca.shared.global [%0], [%1], %2;}" ::"r"(smem_int_ptr), 38 | "l"(gmem), "n"(Size)); 39 | #else 40 | for (unsigned i = 0; i < Size / 4; i++) { 41 | *(reinterpret_cast(smem) + i) = 42 | *(reinterpret_cast(gmem) + i); 43 | } 44 | #endif 45 | } 46 | 47 | __device__ inline void commit() { 48 | #if __CUDA_ARCH__ >= 800 49 | asm volatile("{cp.async.commit_group;}\n"); 50 | #endif 51 | } 52 | 53 | __device__ inline void wait_all() { 54 | #if __CUDA_ARCH__ >= 800 55 | asm volatile("{cp.async.wait_all;}"); 56 | #endif 57 | } 58 | 59 | template __device__ inline void wait_group() { 60 | #if __CUDA_ARCH__ >= 800 61 | asm volatile("{cp.async.wait_group %0;}" ::"n"(N)); 62 | #endif 63 | } 64 | } // namespace cp_async 65 | } // namespace utils 66 | } // namespace wmma 67 | } // namespace mtk 68 | 69 | #endif 70 | -------------------------------------------------------------------------------- /research/bank-conflict/Makefile: -------------------------------------------------------------------------------- 1 | NVCC=nvcc 2 | NVCCFLAGS=-std=c++17 3 | NVCCFLAGS+=-gencode arch=compute_80,code=sm_80 4 | NVCCFLAGS+=-I../../include 5 | 6 | TARGET=bank-conflict.test 7 | 8 | $(TARGET):main.cu 9 | $(NVCC) $< -o $@ $(NVCCFLAGS) 10 | 11 | clean: 12 | rm -f $(TARGET) 13 | -------------------------------------------------------------------------------- /research/bank-conflict/README.md: -------------------------------------------------------------------------------- 1 | # Shared memory bank conflict investigation for mma::foreach_ij 2 | 3 | ## usage 4 | 1. Change ldm params at line 7, 8 in main.cu 5 | 2. Change arch at line 3 in Makefile 6 | 3. Build 7 | ``` 8 | make 9 | ``` 10 | 11 | 4. Run 12 | ``` 13 | ./bank-conflict.test 14 | ``` 15 | 16 | 5. Check the result 17 | ``` 18 | [, layout = row, ldm = 64] ----------... 19 | 0(000) 0(080) 0(100) 0(180) ... // the bank and smem memory index of loading frag.x[0] = smem[index]; 20 | 0(040) 0(0c0) 0(140) 0(1c0) ... // the bank and smem memory index of loading frag.x[1] = smem[index]; 21 | 0(200) 0(280) 0(300) 0(380) ... // the bank and smem memory index of loading frag.x[2] = smem[index]; 22 | 0(240) 0(2c0) 0(340) 0(3c0) ... // the bank and smem memory index of loading frag.x[3] = smem[index]; 23 | [bank_conflict: 3]: 4 4 4 4 4 4 ... // access counter of each bank when loading frag.x[0] = smem[index]; 24 | [bank_conflict: 3]: 4 4 4 4 4 4 ... // access counter of each bank when loading frag.x[1] = smem[index]; 25 | [bank_conflict: 3]: 4 4 4 4 4 4 ... // access counter of each bank when loading frag.x[2] = smem[index]; 26 | [bank_conflict: 3]: 4 4 4 4 4 4 ... // access counter of each bank when loading frag.x[3] = smem[index]; 27 | ``` 28 | -------------------------------------------------------------------------------- /research/bank-conflict/main.cu: -------------------------------------------------------------------------------- 1 | #include "../common/utils.hpp" 2 | #include 3 | #include 4 | 5 | constexpr unsigned bank_size = 32; 6 | constexpr unsigned warp_size = 32; 7 | 8 | constexpr unsigned skew_min = 0; 9 | constexpr unsigned skew_max = 8; 10 | constexpr unsigned ldm_base = 64; 11 | 12 | namespace { 13 | template std::string get_name_str(); 14 | template <> std::string get_name_str() { return "half"; } 15 | template <> std::string get_name_str() { return "float"; } 16 | template <> std::string get_name_str() { 17 | return "tf32"; 18 | } 19 | template <> std::string get_name_str() { 20 | return "matrix_a"; 21 | } 22 | template <> std::string get_name_str() { 23 | return "matrix_b"; 24 | } 25 | template <> std::string get_name_str() { 26 | return "accumulator"; 27 | } 28 | template <> std::string get_name_str() { 29 | return "col_major"; 30 | } 31 | template <> std::string get_name_str() { 32 | return "row_major"; 33 | } 34 | template <> std::string get_name_str() { return "void"; } 35 | } // namespace 36 | 37 | template 38 | __global__ void kernel(unsigned *const bank_array_ptr, 39 | const nvcuda::wmma::layout_t layout, 40 | const std::size_t ldm) { 41 | using FRAG_T = mtk::wmma::mma::fragment; 42 | FRAG_T frag; 43 | 44 | unsigned frag_i = 0; 45 | const auto func = [&](const unsigned *frag_index_list, 46 | const unsigned fragment_index_count, const unsigned i, 47 | const unsigned j) { 48 | const unsigned mem_index = 49 | (layout == nvcuda::wmma::mem_col_major) ? (i + j * ldm) : (j + i * ldm); 50 | const unsigned bank = mem_index % bank_size; 51 | 52 | atomicAdd(&bank_array_ptr[(frag_i++) * bank_size + bank], 1); 53 | 54 | for (unsigned t = 0; t < bank_size; t++) { 55 | if (t == threadIdx.x) { 56 | printf("%3u(%03x) ", bank, mem_index); 57 | } 58 | __syncthreads(); 59 | } 60 | if (threadIdx.x == 0) { 61 | printf("\n"); 62 | } 63 | __syncthreads(); 64 | }; 65 | if constexpr (std::is_same::value) { 66 | mtk::wmma::mma::foreach_ij(layout, func); 67 | } else { 68 | mtk::wmma::mma::foreach_ij(func); 69 | } 70 | } 71 | 72 | template 73 | void print_bank_conflict(const nvcuda::wmma::layout_t layout, 74 | const std::size_t ldm) { 75 | const unsigned num_elements_per_thread = 76 | ((std::is_same::value) ? (m * k) : (k * n)) / 77 | warp_size; 78 | unsigned *bank_array; 79 | CUDA_CHECK_ERROR(cudaMallocHost(&bank_array, sizeof(unsigned) * bank_size * 80 | num_elements_per_thread)); 81 | for (unsigned i = 0; i < bank_size * num_elements_per_thread; i++) 82 | bank_array[i] = 0; 83 | std::printf("[<%s,%d,%d,%d,%s,%s>, layout = %s, ldm = %lu (skew=%lu)] " 84 | "---------------------------------------------------------- \n", 85 | get_name_str().c_str(), m, n, k, get_name_str().c_str(), 86 | get_name_str().c_str(), 87 | (layout == nvcuda::wmma::mem_col_major ? "col" : "row"), ldm, 88 | (ldm % 32lu)); 89 | kernel<<<1, 32>>>(bank_array, layout, ldm); 90 | CUDA_CHECK_ERROR(cudaDeviceSynchronize()); 91 | 92 | for (unsigned i = 0; i < num_elements_per_thread; i++) { 93 | const auto bank_array_ptr = bank_array + i * bank_size; 94 | unsigned max_bank_access = 0; 95 | for (unsigned j = 0; j < bank_size; j++) { 96 | max_bank_access = std::max(max_bank_access, bank_array_ptr[j]); 97 | } 98 | std::printf("[bank_conflict:%2u]: ", max_bank_access - 1); 99 | for (unsigned j = 0; j < bank_size; j++) { 100 | std::printf("%2u ", bank_array_ptr[j]); 101 | } 102 | std::printf("\n"); 103 | } 104 | CUDA_CHECK_ERROR(cudaFreeHost(bank_array)); 105 | } 106 | 107 | int main() { 108 | for (std::size_t skew = skew_min; skew <= skew_max; skew++) { 109 | const auto ldm = ldm_base + skew; 110 | print_bank_conflict(nvcuda::wmma::mem_col_major, 112 | ldm); 113 | print_bank_conflict(nvcuda::wmma::mem_row_major, 115 | ldm); 116 | print_bank_conflict(nvcuda::wmma::mem_col_major, 118 | ldm); 119 | print_bank_conflict(nvcuda::wmma::mem_row_major, 121 | ldm); 122 | print_bank_conflict( 123 | nvcuda::wmma::mem_col_major, ldm); 124 | print_bank_conflict( 125 | nvcuda::wmma::mem_row_major, ldm); 126 | 127 | print_bank_conflict( 129 | nvcuda::wmma::mem_col_major, ldm); 130 | print_bank_conflict( 132 | nvcuda::wmma::mem_row_major, ldm); 133 | print_bank_conflict( 135 | nvcuda::wmma::mem_col_major, ldm); 136 | print_bank_conflict( 138 | nvcuda::wmma::mem_row_major, ldm); 139 | print_bank_conflict( 140 | nvcuda::wmma::mem_col_major, ldm); 141 | print_bank_conflict( 142 | nvcuda::wmma::mem_row_major, ldm); 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /research/common/utils.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_RESEARCH_UTILS_HPP__ 2 | #define __WMMAE_RESEARCH_UTILS_HPP__ 3 | #include 4 | 5 | #ifndef CUDA_CHECK_ERROR 6 | #define CUDA_CHECK_ERROR(status) \ 7 | cuda_check_error(status, __FILE__, __LINE__, __func__) 8 | #endif 9 | 10 | inline void cuda_check_error(cudaError_t error, const std::string filename, 11 | const std::size_t line, 12 | const std::string funcname) { 13 | if (error != cudaSuccess) { 14 | std::stringstream ss; 15 | ss << cudaGetErrorString(error); 16 | ss << " [" << filename << ":" << line << " in " << funcname << "]"; 17 | throw std::runtime_error(ss.str()); 18 | } 19 | } 20 | 21 | #endif 22 | -------------------------------------------------------------------------------- /research/fragment_analysis/Makefile: -------------------------------------------------------------------------------- 1 | SM=89 2 | NVCC=nvcc 3 | NVCCFLAGS=-std=c++14 -arch=sm_${SM} -DARCH=${SM} -I../../include 4 | 5 | TARGET=fragment_analysis.out 6 | 7 | $(TARGET):main.cu 8 | $(NVCC) $< -o $@ $(NVCCFLAGS) 9 | 10 | clean: 11 | rm -f $(TARGET) 12 | -------------------------------------------------------------------------------- /research/fragment_analysis_ij/Makefile: -------------------------------------------------------------------------------- 1 | SM=80 2 | NVCC=nvcc 3 | NVCCFLAGS=-std=c++14 -arch=sm_${SM} -DARCH=${SM} -I../../include 4 | 5 | TARGET=fragment_analysis.out 6 | 7 | $(TARGET):main.cu 8 | $(NVCC) $< -o $@ $(NVCCFLAGS) 9 | 10 | clean: 11 | rm -f $(TARGET) 12 | -------------------------------------------------------------------------------- /research/fragment_analysis_map/Makefile: -------------------------------------------------------------------------------- 1 | SM=80 2 | NVCC=nvcc 3 | NVCCFLAGS=-std=c++14 -arch=sm_${SM} -DARCH=${SM} -I../../include 4 | 5 | TARGET=fragment_analysis.out 6 | 7 | $(TARGET):main.cu 8 | $(NVCC) $< -o $@ $(NVCCFLAGS) 9 | 10 | clean: 11 | rm -f $(TARGET) 12 | -------------------------------------------------------------------------------- /test/performance/Makefile.common: -------------------------------------------------------------------------------- 1 | TEST_ARCH=70 2 | ROOT_DIR=../../../include 3 | NVCC=nvcc 4 | NVCCFLAGS=-std=c++11 -I$(ROOT_DIR) -arch=sm_$(TEST_ARCH) -DCUDA_ARCH_SM=$(TEST_ARCH) 5 | -------------------------------------------------------------------------------- /test/performance/batched_m8n8k4/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: batched_m8n8k4.test 4 | 5 | %.test : %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/batched_m8n8k4/batched_m8n8k4.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | constexpr unsigned warp_size = 32; 6 | constexpr unsigned block_size = 256; 7 | 8 | constexpr unsigned M = 8; 9 | constexpr unsigned N = 8; 10 | constexpr unsigned K = 4; 11 | 12 | constexpr std::size_t num_matrices = 1lu << 20; 13 | constexpr unsigned C = 1 << 8; 14 | 15 | __global__ void batched_matmul_kernel(float *const c_ptr, 16 | const half *const a_ptr, 17 | const half *const b_ptr) { 18 | const auto tid = blockIdx.x * blockDim.x + threadIdx.x; 19 | const auto matrix_id = tid / warp_size; 20 | 21 | if (matrix_id >= num_matrices) { 22 | return; 23 | } 24 | 25 | mtk::wmma::fragment 27 | frag_a; 28 | mtk::wmma::fragment 30 | frag_b; 31 | mtk::wmma::fragment frag_c; 32 | 33 | mtk::wmma::load_matrix_sync(frag_a, a_ptr + matrix_id * M * K, M); 34 | mtk::wmma::load_matrix_sync(frag_b, b_ptr + matrix_id * N * K, K); 35 | mtk::wmma::fill_fragment(frag_c, 0.0f); 36 | 37 | mtk::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c); 38 | 39 | mtk::wmma::store_matrix_sync(c_ptr + matrix_id * M * N, frag_c, M, 40 | nvcuda::wmma::mem_col_major); 41 | } 42 | 43 | int main() { 44 | half *da; 45 | half *db; 46 | float *dc; 47 | 48 | cudaMalloc(&da, sizeof(half) * M * K * num_matrices); 49 | cudaMalloc(&db, sizeof(half) * K * N * num_matrices); 50 | cudaMalloc(&dc, sizeof(float) * M * N * num_matrices); 51 | 52 | const auto start_clock = std::chrono::system_clock::now(); 53 | for (unsigned c = 0; c < C; c++) 54 | batched_matmul_kernel<<< 55 | (warp_size * num_matrices + block_size - 1) / block_size, block_size>>>( 56 | dc, da, db); 57 | cudaDeviceSynchronize(); 58 | const auto end_clock = std::chrono::system_clock::now(); 59 | 60 | const auto elapsed_time = 61 | std::chrono::duration_cast(end_clock - 62 | start_clock) 63 | .count() * 64 | 1e-6; 65 | 66 | std::printf("%15s : %e [s]\n", "elapsed time", elapsed_time); 67 | std::printf("%15s : %e [TFlop/s]\n", "performance", 68 | (2 * M * N * K * C * num_matrices) / elapsed_time / (1lu << 40)); 69 | std::printf("%15s : %e [GiB/s]\n", "band width", 70 | static_cast(M * N * sizeof(float) + 71 | 2 * (M * K + N * K) * sizeof(half)) * 72 | num_matrices * C / elapsed_time / (1lu << 30)); 73 | 74 | cudaFree(da); 75 | cudaFree(db); 76 | cudaFree(dc); 77 | } 78 | -------------------------------------------------------------------------------- /test/performance/foreach/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: matmul.test 4 | 5 | %.test : %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/givens_rotation/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: givens.test 4 | 5 | %.test : %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/householder/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: householder.test 4 | 5 | %.test : %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/householder/householder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | constexpr unsigned warp_size = 32; 6 | constexpr unsigned block_size = 256; 7 | constexpr unsigned test_count = 1024; 8 | constexpr unsigned warp_dim = 16; 9 | 10 | template 11 | __device__ void cp_matrix(half2 *const smem, const half2 *const gmem) { 12 | for (unsigned i = 0; i < DIM * DIM / 2; i += warp_size) { 13 | const unsigned index = i + (threadIdx.x & 0x1fu); 14 | smem[index] = gmem[index]; 15 | } 16 | } 17 | 18 | template 19 | __global__ void batched_householder_kernel(half *const ptr, 20 | const unsigned batch_size) { 21 | __shared__ half smem_mat[DIM * DIM * block_size / warp_size]; 22 | __shared__ half smem_vec[DIM * block_size / warp_size]; 23 | 24 | half *const smem_mat_ptr = smem_mat + DIM * DIM * (threadIdx.x / warp_size); 25 | half *const smem_vec_ptr = smem_vec + DIM * (threadIdx.x / warp_size); 26 | 27 | const unsigned matrix_id = threadIdx.x + blockIdx.x * blockDim.x / warp_size; 28 | if (matrix_id >= batch_size) 29 | return; 30 | 31 | cp_matrix( 32 | reinterpret_cast(smem_mat_ptr), 33 | reinterpret_cast( 34 | ptr + DIM * DIM * 35 | ((threadIdx.x + block_size / warp_size * blockIdx.x) / 36 | warp_size))); 37 | 38 | nvcuda::wmma::fragment 40 | frag_b[warp_dim * warp_dim / (warp_dim * warp_dim)]; 41 | for (unsigned i = 0; i < DIM / warp_dim; i += 1) { 42 | for (unsigned j = 0; j < DIM / warp_dim; j += 1) { 43 | nvcuda::wmma::load_matrix_sync( 44 | frag_b[i + j * (DIM / warp_dim)], 45 | smem_mat_ptr + j * warp_dim + DIM * warp_dim * i, DIM); 46 | } 47 | } 48 | 49 | if ((threadIdx.x & 0x1f) < DIM) { 50 | smem_vec_ptr[(threadIdx.x & 0x1f)] = smem_mat_ptr[(threadIdx.x & 0x1f)]; 51 | } 52 | __syncwarp(); 53 | 54 | nvcuda::wmma::fragment 56 | frag_a[warp_dim * warp_dim / (warp_dim * warp_dim)]; 57 | HouseholderMatGen{}(frag_a, smem_mat_ptr, smem_vec_ptr); 58 | 59 | for (unsigned i = 0; i < DIM / warp_dim; i += 1) { 60 | for (unsigned j = 0; j < DIM / warp_dim; j += 1) { 61 | nvcuda::wmma::fragment 63 | frag_c; 64 | nvcuda::wmma::fill_fragment(frag_c, 0.f); 65 | for (unsigned k = 0; k < DIM / warp_dim; k += 1) { 66 | nvcuda::wmma::mma_sync(frag_c, frag_a[i + k * (DIM / warp_dim)], 67 | frag_b[k + j * (DIM / warp_dim)], frag_c); 68 | } 69 | nvcuda::wmma::store_matrix_sync(smem_mat_ptr + i * warp_dim + 70 | j * warp_dim * DIM, 71 | frag_c, DIM, nvcuda::wmma::mem_col_major); 72 | } 73 | } 74 | cp_matrix( 75 | reinterpret_cast( 76 | ptr + DIM * DIM * 77 | ((threadIdx.x + block_size / warp_size * blockIdx.x) / 78 | warp_size)), 79 | reinterpret_cast(smem_mat_ptr)); 80 | } 81 | 82 | template struct HouseholderMatGenWMMA { 83 | __device__ void operator()( 84 | nvcuda::wmma::fragment *frag, 86 | half *const smem_mat, half *const smem_vec) const { 87 | #pragma unroll 88 | for (unsigned i = 0; i < DIM * DIM; i += warp_size) { 89 | const unsigned index = i + (threadIdx.x & 0x1fu); 90 | const unsigned m = index % DIM; 91 | const unsigned n = index / DIM; 92 | 93 | half v = smem_vec[m] * smem_vec[n] * __float2half(-2.f); 94 | if (m == n) { 95 | v += __float2half(1.f); 96 | } 97 | __syncwarp(); 98 | smem_mat[index] = v; 99 | } 100 | for (unsigned i = 0; i < DIM / warp_dim; i += 1) { 101 | for (unsigned j = 0; j < DIM / warp_dim; j += 1) { 102 | nvcuda::wmma::load_matrix_sync( 103 | frag[i + j * (DIM / warp_dim)], 104 | smem_mat + i * warp_dim + DIM * warp_dim * j, DIM); 105 | } 106 | } 107 | } 108 | }; 109 | 110 | template struct HouseholderMatGenWMMAe { 111 | __device__ void operator()( 112 | nvcuda::wmma::fragment *frag, 114 | half *const, const half *const smem_vec) const { 115 | mtk::wmma::foreach_ij< 116 | nvcuda::wmma::fragment>( 118 | [&](const unsigned *list, const unsigned list_size, const unsigned m, 119 | const unsigned n) { 120 | for (unsigned i = 0; i < DIM; i += warp_dim) { 121 | for (unsigned j = 0; j < DIM; j += warp_dim) { 122 | half v = smem_vec[i + m] * smem_vec[j + n] * __float2half(-2.f); 123 | if (m == n && j == i) { 124 | v += __float2half(1.f); 125 | } 126 | __syncwarp(); 127 | #pragma unroll 128 | for (unsigned f = 0; f < list_size; f++) { 129 | frag[i / warp_dim + j / warp_dim * (DIM / warp_dim)].x[f] = v; 130 | } 131 | } 132 | } 133 | }); 134 | __syncwarp(); 135 | } 136 | }; 137 | 138 | template std::string get_class_name(); 139 | template <> std::string get_class_name>() { 140 | return "wmma_16"; 141 | } 142 | template <> std::string get_class_name>() { 143 | return "wmmae_16"; 144 | } 145 | template <> std::string get_class_name>() { 146 | return "wmma_32"; 147 | } 148 | template <> std::string get_class_name>() { 149 | return "wmmae_32"; 150 | } 151 | 152 | template 153 | void batched_householder(half *const ptr, const unsigned batch_size) { 154 | const unsigned grid_size = 155 | (batch_size * warp_size + block_size - 1) / block_size; 156 | batched_householder_kernel 157 | <<>>(ptr, batch_size); 158 | } 159 | 160 | template 161 | void test_batched_kernel(const unsigned batch_size) { 162 | half *input_matrix; 163 | cudaMalloc(&input_matrix, sizeof(half) * DIM * DIM * batch_size); 164 | const auto start_clock = std::chrono::system_clock::now(); 165 | for (unsigned c = 0; c < test_count; c++) { 166 | batched_householder(input_matrix, batch_size); 167 | } 168 | cudaDeviceSynchronize(); 169 | const auto end_clock = std::chrono::system_clock::now(); 170 | cudaFree(input_matrix); 171 | 172 | const auto elapsed_time = 173 | std::chrono::duration_cast(end_clock - 174 | start_clock) 175 | .count() / 176 | static_cast(test_count) * 1e-6; 177 | 178 | std::printf("%u,%s,%e\n", batch_size, 179 | get_class_name().c_str(), elapsed_time); 180 | } 181 | 182 | int main() { 183 | std::printf("batch_size,api,time\n"); 184 | for (unsigned i = 13; i <= 21; i++) { 185 | test_batched_kernel<32, HouseholderMatGenWMMA<32>>(1u << i); 186 | test_batched_kernel<32, HouseholderMatGenWMMAe<32>>(1u << i); 187 | test_batched_kernel<16, HouseholderMatGenWMMA<16>>(1u << i); 188 | test_batched_kernel<16, HouseholderMatGenWMMAe<16>>(1u << i); 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /test/performance/load_matrix_with_op_sync/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: matmul.test 4 | 5 | %.test : %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/load_vector_sync/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: batched_direct_product.test direct_product.test 4 | 5 | %.test : %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/load_vector_sync/batched_direct_product.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | constexpr std::size_t block_size = 256; 8 | constexpr unsigned warp_size = 32; 9 | 10 | #ifndef CUDA_ARCH_SM 11 | #define CUDA_ARCH_SM 0 12 | #endif 13 | 14 | template 15 | __global__ void batched_direct_product_16x16(float *const c_ptr, 16 | const half *const u_ptr); 17 | 18 | template <> 19 | __global__ void batched_direct_product_16x16(float *const c_ptr, 20 | const half *const u_ptr) { 21 | constexpr unsigned DIM = 32; 22 | constexpr unsigned FDIM = 16; 23 | __shared__ half u_smem[block_size]; 24 | __shared__ float c_smem[block_size * DIM]; 25 | 26 | const unsigned warp_id = threadIdx.x >> 5; 27 | half *const u_smem_ptr = u_smem + warp_id * DIM; 28 | float *const c_smem_ptr = c_ptr + warp_id * DIM * DIM; 29 | 30 | u_smem[threadIdx.x] = u_ptr[blockIdx.x * block_size + threadIdx.x]; 31 | 32 | nvcuda::wmma::fragment 34 | a_frag[2]; 35 | nvcuda::wmma::fragment 37 | b_frag[2]; 38 | nvcuda::wmma::fragment 39 | c_frag[4]; 40 | 41 | mtk::wmma::load_vector(a_frag[0], u_smem_ptr); 42 | mtk::wmma::load_vector(a_frag[1], u_smem_ptr + FDIM); 43 | mtk::wmma::load_vector(b_frag[0], u_smem_ptr); 44 | mtk::wmma::load_vector(b_frag[1], u_smem_ptr + FDIM); 45 | 46 | nvcuda::wmma::fill_fragment(c_frag[0], 0.0f); 47 | nvcuda::wmma::mma_sync(c_frag[0], a_frag[0], b_frag[0], c_frag[0]); 48 | nvcuda::wmma::fill_fragment(c_frag[1], 0.0f); 49 | nvcuda::wmma::mma_sync(c_frag[1], a_frag[1], b_frag[0], c_frag[1]); 50 | nvcuda::wmma::fill_fragment(c_frag[2], 0.0f); 51 | nvcuda::wmma::mma_sync(c_frag[2], a_frag[0], b_frag[1], c_frag[2]); 52 | nvcuda::wmma::fill_fragment(c_frag[3], 0.0f); 53 | nvcuda::wmma::mma_sync(c_frag[3], a_frag[1], b_frag[1], c_frag[3]); 54 | 55 | nvcuda::wmma::store_matrix_sync(c_smem_ptr, c_frag[0], DIM, 56 | nvcuda::wmma::mem_col_major); 57 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM, c_frag[1], DIM, 58 | nvcuda::wmma::mem_col_major); 59 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM, c_frag[2], DIM, 60 | nvcuda::wmma::mem_col_major); 61 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM + FDIM, c_frag[3], 62 | DIM, nvcuda::wmma::mem_col_major); 63 | 64 | for (unsigned i = 0; i < DIM * block_size; i += block_size) { 65 | c_ptr[blockIdx.x * DIM + threadIdx.x + i] = c_smem[threadIdx.x + i]; 66 | } 67 | } 68 | 69 | template <> 70 | __global__ void batched_direct_product_16x16(float *const c_ptr, 71 | const half *const u_ptr) { 72 | constexpr unsigned DIM = 32; 73 | constexpr unsigned FDIM = 16; 74 | __shared__ half u_tmp_smem[block_size / warp_size * FDIM * FDIM]; 75 | __shared__ float c_smem[block_size * DIM]; 76 | 77 | const unsigned warp_id = threadIdx.x >> 5; 78 | half *const u_smem_ptr = u_tmp_smem + warp_id * FDIM * FDIM; 79 | float *const c_smem_ptr = c_ptr + warp_id * DIM * DIM; 80 | 81 | for (std::size_t i = 0; i < FDIM * FDIM; i += warp_size) { 82 | u_smem_ptr[i + threadIdx.x] = __float2half(0.0f); 83 | } 84 | 85 | u_tmp_smem[threadIdx.x] = u_ptr[blockIdx.x * block_size + threadIdx.x]; 86 | 87 | nvcuda::wmma::fragment 89 | a_frag[2]; 90 | nvcuda::wmma::fragment 92 | b_frag[2]; 93 | nvcuda::wmma::fragment 94 | c_frag[4]; 95 | 96 | nvcuda::wmma::load_matrix_sync(a_frag[0], u_smem_ptr, DIM); 97 | nvcuda::wmma::load_matrix_sync(a_frag[1], u_smem_ptr + FDIM, DIM); 98 | nvcuda::wmma::load_matrix_sync(b_frag[0], u_smem_ptr, DIM); 99 | nvcuda::wmma::load_matrix_sync(b_frag[1], u_smem_ptr + FDIM, DIM); 100 | 101 | nvcuda::wmma::fill_fragment(c_frag[0], 0.0f); 102 | nvcuda::wmma::mma_sync(c_frag[0], a_frag[0], b_frag[0], c_frag[0]); 103 | nvcuda::wmma::fill_fragment(c_frag[1], 0.0f); 104 | nvcuda::wmma::mma_sync(c_frag[1], a_frag[1], b_frag[0], c_frag[1]); 105 | nvcuda::wmma::fill_fragment(c_frag[2], 0.0f); 106 | nvcuda::wmma::mma_sync(c_frag[2], a_frag[0], b_frag[1], c_frag[2]); 107 | nvcuda::wmma::fill_fragment(c_frag[3], 0.0f); 108 | nvcuda::wmma::mma_sync(c_frag[3], a_frag[1], b_frag[1], c_frag[3]); 109 | 110 | nvcuda::wmma::store_matrix_sync(c_smem_ptr, c_frag[0], DIM, 111 | nvcuda::wmma::mem_col_major); 112 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM, c_frag[1], DIM, 113 | nvcuda::wmma::mem_col_major); 114 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM, c_frag[2], DIM, 115 | nvcuda::wmma::mem_col_major); 116 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM + FDIM, c_frag[3], 117 | DIM, nvcuda::wmma::mem_col_major); 118 | 119 | for (unsigned i = 0; i < DIM * block_size; i += block_size) { 120 | c_ptr[blockIdx.x * DIM + threadIdx.x + i] = c_smem[threadIdx.x + i]; 121 | } 122 | } 123 | 124 | template 125 | void test_batched_direct_product(const unsigned size_power) { 126 | constexpr std::size_t C = 1lu << 6; 127 | const unsigned batch_size = 1lu << size_power; 128 | const std::size_t grid_size = batch_size / (block_size / warp_size); 129 | 130 | half *dU; 131 | float *dC; 132 | cudaMalloc(&dU, sizeof(half) * batch_size * warp_size); 133 | cudaMalloc(&dC, sizeof(float) * batch_size * warp_size * warp_size); 134 | 135 | const auto start_clock = std::chrono::system_clock::now(); 136 | for (std::size_t c = 0; c < C; c++) { 137 | batched_direct_product_16x16<<>>(dC, dU); 138 | } 139 | const auto status = cudaGetLastError(); 140 | cudaDeviceSynchronize(); 141 | if (status != 0) { 142 | std::fprintf(stderr, "%s\n", cudaGetErrorString(status)); 143 | } 144 | const auto end_clock = std::chrono::system_clock::now(); 145 | const auto elapsed_time = 146 | std::chrono::duration_cast(end_clock - 147 | start_clock) 148 | .count() / 149 | 1.e6 / C; 150 | 151 | std::printf("%u,%u,%u,%e\n", static_cast(CUDA_ARCH_SM), batch_size, 152 | (UseWMMAe ? 1u : 0u), elapsed_time); 153 | 154 | cudaFree(dU); 155 | cudaFree(dC); 156 | } 157 | 158 | void test_batched_direct_product(const unsigned min_p, const unsigned max_p) { 159 | std::printf("# %s\n", __func__); 160 | std::printf("-- 1\n"); 161 | for (unsigned i = min_p; i <= max_p; i++) { 162 | test_batched_direct_product(i); 163 | } 164 | for (unsigned i = min_p; i <= max_p; i++) { 165 | test_batched_direct_product(i); 166 | } 167 | std::printf("-- 2\n"); 168 | for (unsigned i = min_p; i <= max_p; i++) { 169 | test_batched_direct_product(i); 170 | } 171 | for (unsigned i = min_p; i <= max_p; i++) { 172 | test_batched_direct_product(i); 173 | } 174 | } 175 | int main() { test_batched_direct_product(8, 20); } 176 | -------------------------------------------------------------------------------- /test/performance/make_identity/Makefile: -------------------------------------------------------------------------------- 1 | include ../Makefile.common 2 | 3 | all: batched_householder.test 4 | 5 | %.test: %.cu $(ROOT_DIR)/wmma_extension/wmma_extension.hpp Makefile 6 | $(NVCC) $(NVCCFLAGS) -o $@ $< 7 | 8 | clean: 9 | rm -f *.test 10 | -------------------------------------------------------------------------------- /test/performance/make_identity/batched_householder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | constexpr std::size_t block_size = 256; 8 | constexpr unsigned warp_size = 32; 9 | 10 | #ifndef CUDA_ARCH_SM 11 | #define CUDA_ARCH_SM 0 12 | #endif 13 | 14 | template 15 | __global__ void householder_16x16(float *const c_ptr, const half *const u_ptr); 16 | 17 | template <> 18 | __global__ void householder_16x16(float *const c_ptr, 19 | const half *const u_ptr) { 20 | constexpr unsigned DIM = 32; 21 | constexpr unsigned FDIM = 16; 22 | __shared__ half u_smem[block_size]; 23 | __shared__ float c_smem[block_size * DIM]; 24 | 25 | const unsigned warp_id = threadIdx.x >> 5; 26 | half *const u_smem_ptr = u_smem + warp_id * DIM; 27 | float *const c_smem_ptr = c_ptr + warp_id * DIM * DIM; 28 | 29 | u_smem[threadIdx.x] = u_ptr[blockIdx.x * block_size + threadIdx.x]; 30 | 31 | nvcuda::wmma::fragment 33 | a_frag[2]; 34 | nvcuda::wmma::fragment 36 | b_frag[2]; 37 | nvcuda::wmma::fragment 38 | c_frag[4]; 39 | 40 | mtk::wmma::load_vector(a_frag[0], u_smem_ptr); 41 | mtk::wmma::load_vector(a_frag[1], u_smem_ptr + FDIM); 42 | mtk::wmma::load_vector(b_frag[0], u_smem_ptr); 43 | mtk::wmma::load_vector(b_frag[1], u_smem_ptr + FDIM); 44 | 45 | mtk::wmma::make_identity_matrix(c_frag[0]); 46 | nvcuda::wmma::mma_sync(c_frag[0], a_frag[0], b_frag[0], c_frag[0]); 47 | nvcuda::wmma::fill_fragment(c_frag[1], 0.0f); 48 | nvcuda::wmma::mma_sync(c_frag[1], a_frag[1], b_frag[0], c_frag[1]); 49 | nvcuda::wmma::fill_fragment(c_frag[2], 0.0f); 50 | nvcuda::wmma::mma_sync(c_frag[2], a_frag[0], b_frag[1], c_frag[2]); 51 | mtk::wmma::make_identity_matrix(c_frag[3]); 52 | nvcuda::wmma::mma_sync(c_frag[3], a_frag[1], b_frag[1], c_frag[3]); 53 | 54 | nvcuda::wmma::store_matrix_sync(c_smem_ptr, c_frag[0], DIM, 55 | nvcuda::wmma::mem_col_major); 56 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM, c_frag[1], DIM, 57 | nvcuda::wmma::mem_col_major); 58 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM, c_frag[2], DIM, 59 | nvcuda::wmma::mem_col_major); 60 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM + FDIM, c_frag[3], 61 | DIM, nvcuda::wmma::mem_col_major); 62 | 63 | for (unsigned i = 0; i < DIM * block_size; i += block_size) { 64 | c_ptr[blockIdx.x * DIM + threadIdx.x + i] = c_smem[threadIdx.x + i]; 65 | } 66 | } 67 | 68 | template <> 69 | __global__ void householder_16x16(float *const c_ptr, 70 | const half *const u_ptr) { 71 | constexpr unsigned DIM = 32; 72 | constexpr unsigned FDIM = 16; 73 | __shared__ half u_smem[block_size]; 74 | __shared__ float c_smem[block_size * DIM]; 75 | 76 | const unsigned warp_id = threadIdx.x >> 5; 77 | half *const u_smem_ptr = u_smem + warp_id * DIM; 78 | float *const c_smem_ptr = c_ptr + warp_id * DIM * DIM; 79 | 80 | u_smem[threadIdx.x] = u_ptr[blockIdx.x * block_size + threadIdx.x]; 81 | 82 | nvcuda::wmma::fragment 84 | a_frag[2]; 85 | nvcuda::wmma::fragment 87 | b_frag[2]; 88 | nvcuda::wmma::fragment 89 | c_frag[4]; 90 | 91 | mtk::wmma::load_vector(a_frag[0], u_smem_ptr); 92 | mtk::wmma::load_vector(a_frag[1], u_smem_ptr + FDIM); 93 | mtk::wmma::load_vector(b_frag[0], u_smem_ptr); 94 | mtk::wmma::load_vector(b_frag[1], u_smem_ptr + FDIM); 95 | 96 | nvcuda::wmma::fill_fragment(c_frag[1], 0.0f); 97 | nvcuda::wmma::fill_fragment(c_frag[2], 0.0f); 98 | 99 | const unsigned unique_id = threadIdx.x & 0x1f; 100 | for (unsigned i = 0; i < FDIM * FDIM; i += warp_size) { 101 | if (unique_id % (FDIM + 1) == 0) { 102 | c_smem_ptr[unique_id] = 1.0f; 103 | } else { 104 | c_smem_ptr[unique_id] = 0.0f; 105 | } 106 | } 107 | __syncthreads(); 108 | nvcuda::wmma::load_matrix_sync(c_frag[0], c_smem_ptr, FDIM, 109 | nvcuda::wmma::mem_col_major); 110 | nvcuda::wmma::load_matrix_sync(c_frag[3], c_smem_ptr, FDIM, 111 | nvcuda::wmma::mem_col_major); 112 | 113 | nvcuda::wmma::mma_sync(c_frag[0], a_frag[0], b_frag[0], c_frag[0]); 114 | nvcuda::wmma::mma_sync(c_frag[1], a_frag[1], b_frag[0], c_frag[1]); 115 | nvcuda::wmma::mma_sync(c_frag[2], a_frag[0], b_frag[1], c_frag[2]); 116 | nvcuda::wmma::mma_sync(c_frag[3], a_frag[1], b_frag[1], c_frag[3]); 117 | 118 | nvcuda::wmma::store_matrix_sync(c_smem_ptr, c_frag[0], DIM, 119 | nvcuda::wmma::mem_col_major); 120 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM, c_frag[1], DIM, 121 | nvcuda::wmma::mem_col_major); 122 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM, c_frag[2], DIM, 123 | nvcuda::wmma::mem_col_major); 124 | nvcuda::wmma::store_matrix_sync(c_smem_ptr + FDIM * DIM + FDIM, c_frag[3], 125 | DIM, nvcuda::wmma::mem_col_major); 126 | 127 | for (unsigned i = 0; i < DIM * block_size; i += block_size) { 128 | c_ptr[blockIdx.x * DIM + threadIdx.x + i] = c_smem[threadIdx.x + i]; 129 | } 130 | } 131 | 132 | template void test_householder(const unsigned size_power) { 133 | constexpr std::size_t C = 1lu << 6; 134 | const unsigned batch_size = 1lu << size_power; 135 | const std::size_t grid_size = batch_size / (block_size / warp_size); 136 | 137 | half *dU; 138 | float *dC; 139 | cudaMalloc(&dU, sizeof(half) * batch_size * warp_size); 140 | cudaMalloc(&dC, sizeof(float) * batch_size * warp_size * warp_size); 141 | 142 | const auto start_clock = std::chrono::system_clock::now(); 143 | for (std::size_t c = 0; c < C; c++) { 144 | householder_16x16<<>>(dC, dU); 145 | } 146 | const auto status = cudaGetLastError(); 147 | cudaDeviceSynchronize(); 148 | if (status != 0) { 149 | std::fprintf(stderr, "%s\n", cudaGetErrorString(status)); 150 | } 151 | const auto end_clock = std::chrono::system_clock::now(); 152 | const auto elapsed_time = 153 | std::chrono::duration_cast(end_clock - 154 | start_clock) 155 | .count() / 156 | 1.e6 / C; 157 | 158 | std::printf("%u,%u,%u,%e\n", static_cast(CUDA_ARCH_SM), batch_size, 159 | (UseWMMAe ? 1u : 0u), elapsed_time); 160 | 161 | cudaFree(dU); 162 | cudaFree(dC); 163 | } 164 | 165 | void test_householder(const unsigned min_p, const unsigned max_p) { 166 | std::printf("# %s\n", __func__); 167 | std::printf("-- 1\n"); 168 | for (unsigned i = min_p; i <= max_p; i++) { 169 | test_householder(i); 170 | } 171 | for (unsigned i = min_p; i <= max_p; i++) { 172 | test_householder(i); 173 | } 174 | std::printf("-- 2\n"); 175 | for (unsigned i = min_p; i <= max_p; i++) { 176 | test_householder(i); 177 | } 178 | for (unsigned i = min_p; i <= max_p; i++) { 179 | test_householder(i); 180 | } 181 | } 182 | 183 | int main() { test_householder(8, 14); } 184 | -------------------------------------------------------------------------------- /test/primitive/Makefile: -------------------------------------------------------------------------------- 1 | TEST_ARCH=80 2 | ROOT_DIR=../../include 3 | NVCC=nvcc 4 | NVCCFLAGS=-std=c++17 -I$(ROOT_DIR) -arch=sm_$(TEST_ARCH) -DTEST_ARCH=$(TEST_ARCH) --extended-lambda 5 | HEADERS=$(shell find ../../include -name '*.hpp') 6 | 7 | TARGET= 8 | TARGET+=add_eye.test 9 | TARGET+=direct_product.test 10 | TARGET+=foreach.test 11 | TARGET+=foreach_ij.test 12 | TARGET+=foreach_v.test 13 | TARGET+=foreach_v_acc.test 14 | TARGET+=gevm.test 15 | TARGET+=wmma.load_vector.test 16 | TARGET+=wmma.store_vector.test 17 | TARGET+=print_fragment.test 18 | TARGET+=fill.test 19 | TARGET+=mma.test 20 | TARGET+=vector.test 21 | TARGET+=map.test 22 | TARGET+=operators.test 23 | 24 | all: $(TARGET) 25 | 26 | %.test : %.cu Makefile $(HEADERS) 27 | $(NVCC) $(NVCCFLAGS) -o $@ $< 28 | 29 | clean: 30 | rm -f *.test 31 | -------------------------------------------------------------------------------- /test/primitive/add_eye.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifndef TEST_ARCH 8 | #define TEST_ARCH (-1) 9 | #endif 10 | 11 | template __device__ __host__ T convert(const S); 12 | template <> __device__ __host__ float convert(const float a) { 13 | return a; 14 | } 15 | template <> __device__ __host__ float convert(const half a) { 16 | return __half2float(a); 17 | } 18 | template <> __device__ __host__ half convert(const float a) { 19 | return __float2half(a); 20 | } 21 | template <> __device__ __host__ half convert(const half a) { 22 | return a; 23 | } 24 | 25 | template 26 | __global__ void make_eye_kernel(T *const eye, const T a) { 27 | nvcuda::wmma::fragment frag_c; 28 | nvcuda::wmma::fill_fragment(frag_c, convert(1.0f)); 29 | mtk::wmma::add_eye(frag_c, a); 30 | nvcuda::wmma::store_matrix_sync(eye, frag_c, N, nvcuda::wmma::mem_col_major); 31 | } 32 | 33 | template void test() { 34 | T *h; 35 | 36 | cudaMallocHost(&h, sizeof(T) * N * N); 37 | 38 | cudaDeviceSynchronize(); 39 | make_eye_kernel<<<1, 32>>>(h, convert(2.0f)); 40 | cudaDeviceSynchronize(); 41 | 42 | double max_error = 0.0; 43 | for (unsigned i = 0; i < N; i++) { 44 | for (unsigned j = 0; j < N; j++) { 45 | const float c = (i == j) ? 3.0f : 1.0f; 46 | const double diff = c - convert(h[i * 16 + j]); 47 | max_error = std::max(max_error, std::abs(diff)); 48 | } 49 | } 50 | std::printf("[%s] arch=%d, error=%e, [%s]\n", __FILE__, TEST_ARCH, max_error, 51 | mtk::test_utils::get_test_result_string( 52 | max_error < mtk::test_utils::get_machine_eps())); 53 | } 54 | 55 | int main() { 56 | test(); 57 | test(); 58 | #ifdef TEST_TF32 59 | test(); 60 | #endif 61 | } 62 | -------------------------------------------------------------------------------- /test/primitive/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __WMMAE_TEST_COMMON_HPP__ 2 | #define __WMMAE_TEST_COMMON_HPP__ 3 | #include 4 | #include 5 | 6 | namespace mtk { 7 | namespace test_utils { 8 | template std::string get_string(); 9 | template <> std::string get_string() { return "float"; } 10 | template <> std::string get_string() { return "half"; } 11 | template <> std::string get_string<__nv_bfloat16>() { return "bfloat16"; } 12 | template <> std::string get_string() { return "uint8"; } 13 | template <> std::string get_string() { return "int8"; } 14 | template <> std::string get_string() { return "int32"; } 15 | template <> std::string get_string() { 16 | return "tf32"; 17 | } 18 | template <> std::string get_string() { 19 | return "col_major"; 20 | } 21 | template <> std::string get_string() { 22 | return "row_major"; 23 | } 24 | template <> std::string get_string() { return "void"; } 25 | template <> std::string get_string() { 26 | return "matrix_a"; 27 | } 28 | template <> std::string get_string() { 29 | return "matrix_b"; 30 | } 31 | template <> std::string get_string() { 32 | return "accumulator"; 33 | } 34 | 35 | template double get_machine_eps(); 36 | template <> double get_machine_eps() { return 1. / (1 << 10); } 37 | template <> double get_machine_eps() { 38 | return 1. / (1 << 10); 39 | } 40 | template <> double get_machine_eps() { return 1. / (1 << 23); } 41 | template <> double get_machine_eps() { 42 | return get_machine_eps(); 43 | } 44 | template <> double get_machine_eps() { 45 | return get_machine_eps(); 46 | } 47 | template <> double get_machine_eps() { 48 | return get_machine_eps(); 49 | } 50 | 51 | const char *get_test_result_string(const bool passed) { 52 | return passed ? "PASSED" : "FAILED"; 53 | } 54 | 55 | template 58 | double get_max_relative_error( 59 | const typename mtk::wmma::detail::common::storage_t::type *const a, 60 | const typename mtk::wmma::detail::common::storage_t::type *const b, 61 | const C_T *const c, const D_T *const d) { 62 | double max_base = 0.0; 63 | double max_diff = 0.0; 64 | 65 | for (unsigned m = 0; m < M; m++) { 66 | for (unsigned n = 0; n < N; n++) { 67 | double c_v = 0.0; 68 | for (unsigned k = 0; k < K; k++) { 69 | double a_v, b_v; 70 | if (std::is_same::value) { 71 | a_v = mtk::wmma::detail::common::cast(a[k * M + m]); 72 | } else { 73 | a_v = mtk::wmma::detail::common::cast(a[k + K * m]); 74 | } 75 | if (std::is_same::value) { 76 | b_v = mtk::wmma::detail::common::cast(b[k + K * n]); 77 | } else { 78 | b_v = mtk::wmma::detail::common::cast(b[k * N + n]); 79 | } 80 | c_v += a_v * b_v; 81 | } 82 | if (c_layout == nvcuda::wmma::mem_col_major) { 83 | c_v += mtk::wmma::detail::common::cast(c[m + M * n]); 84 | } else { 85 | c_v += mtk::wmma::detail::common::cast(c[m * N + n]); 86 | } 87 | 88 | // compute error 89 | double d_v; 90 | if (d_layout == nvcuda::wmma::mem_col_major) { 91 | d_v = mtk::wmma::detail::common::cast(d[m + M * n]); 92 | } else { 93 | d_v = mtk::wmma::detail::common::cast(d[m * N + n]); 94 | } 95 | const auto diff = d_v - c_v; 96 | 97 | // accumulate 98 | max_base = std::max(max_base, std::abs(c_v)); 99 | max_diff = std::max(max_diff, std::abs(diff)); 100 | } 101 | } 102 | return max_diff / max_base; 103 | } 104 | } // namespace test_utils 105 | } // namespace mtk 106 | #endif 107 | -------------------------------------------------------------------------------- /test/primitive/direct_product.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifndef TEST_ARCH 8 | #define TEST_ARCH (-1) 9 | #endif 10 | 11 | // #define TEST_TF32 12 | 13 | #ifndef TEST_TF32 14 | constexpr std::size_t M = 16; 15 | constexpr std::size_t N = 16; 16 | constexpr std::size_t K = 16; 17 | using ab_type = half; 18 | #else 19 | constexpr std::size_t M = 16; 20 | constexpr std::size_t N = 16; 21 | constexpr std::size_t K = 8; 22 | using ab_type = nvcuda::wmma::precision::tf32; 23 | #endif 24 | 25 | // #define SMALLER_WORKING_MEMORY 26 | 27 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 28 | 29 | template __device__ __host__ T convert(const S); 30 | template <> __device__ __host__ float convert(const float a) { 31 | return a; 32 | } 33 | template <> __device__ __host__ float convert(const half a) { 34 | return __half2float(a); 35 | } 36 | template <> __device__ __host__ half convert(const float a) { 37 | return __float2half(a); 38 | } 39 | template <> __device__ __host__ half convert(const half a) { 40 | return a; 41 | } 42 | 43 | template 44 | __global__ void direct_product_kernel(float *const h, const float *const u, 45 | const float *const v) { 46 | nvcuda::wmma::fragment 48 | frag_a; 49 | nvcuda::wmma::fragment 51 | frag_b; 52 | nvcuda::wmma::fragment frag_c; 53 | 54 | __shared__ storage_t su[N]; 55 | __shared__ storage_t sdu[N]; 56 | __shared__ storage_t sv[N]; 57 | __shared__ storage_t sdv[N]; 58 | 59 | if (threadIdx.x < N) { 60 | const auto fv = u[threadIdx.x]; 61 | const auto hv = mtk::wmma::detail::common::cast(fv); 62 | su[threadIdx.x] = hv; 63 | sdu[threadIdx.x] = convert(fv - convert(hv)); 64 | } else { 65 | const auto fv = v[threadIdx.x - N]; 66 | const auto hv = mtk::wmma::detail::common::cast(fv); 67 | sv[threadIdx.x - N] = hv; 68 | sdv[threadIdx.x - N] = convert(fv - convert(hv)); 69 | } 70 | 71 | __syncthreads(); 72 | 73 | if (CORRECTION_TERMS == 3) { 74 | #ifdef SMALLER_WORKING_MEMORY 75 | mtk::wmma::make_direct_product_fragment_c3(frag_a, u); 76 | mtk::wmma::make_direct_product_fragment_c3(frag_b, v); 77 | #else 78 | mtk::wmma::make_direct_product_fragment_c3(frag_a, su, sdu); 79 | mtk::wmma::make_direct_product_fragment_c3(frag_b, sv, sdv); 80 | #endif 81 | } else { 82 | #ifdef SMALLER_WORKING_MEMORY 83 | mtk::wmma::make_direct_product_fragment(frag_a, u); 84 | mtk::wmma::make_direct_product_fragment(frag_b, v); 85 | #else 86 | mtk::wmma::make_direct_product_fragment(frag_a, su, sdu); 87 | mtk::wmma::make_direct_product_fragment(frag_b, sv, sdv); 88 | #endif 89 | } 90 | 91 | nvcuda::wmma::fill_fragment(frag_c, 0.0f); 92 | 93 | nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c); 94 | 95 | nvcuda::wmma::store_matrix_sync(h, frag_c, N, nvcuda::wmma::mem_col_major); 96 | } 97 | 98 | template void test() { 99 | float *u; 100 | float *v; 101 | float *h; 102 | 103 | cudaMallocHost(&u, sizeof(float) * N); 104 | cudaMallocHost(&v, sizeof(float) * N); 105 | cudaMallocHost(&h, sizeof(float) * N * N); 106 | 107 | std::mt19937 mt(0); 108 | std::uniform_real_distribution dist(0.0f, 1.0f); 109 | 110 | for (unsigned i = 0; i < N; i++) { 111 | u[i] = dist(mt); 112 | v[i] = dist(mt); 113 | } 114 | 115 | cudaDeviceSynchronize(); 116 | direct_product_kernel<<<1, 32>>>(h, u, v); 117 | cudaDeviceSynchronize(); 118 | 119 | double max_error = 0.0; 120 | for (unsigned i = 0; i < M; i++) { 121 | for (unsigned j = 0; j < N; j++) { 122 | const double diff = 123 | static_cast(u[i]) * static_cast(v[j]) - 124 | static_cast(h[i + N * j]); 125 | max_error = std::max(max_error, std::abs(diff)); 126 | } 127 | } 128 | std::printf("[%s] ARCH=%d, c_terms=%u, error=%e [%s]\n", __FILE__, TEST_ARCH, 129 | CORRECTION_TERMS, max_error, 130 | mtk::test_utils::get_test_result_string( 131 | max_error < mtk::test_utils::get_machine_eps() * 16)); 132 | 133 | cudaFreeHost(u); 134 | cudaFreeHost(v); 135 | } 136 | 137 | int main() { 138 | test<2>(); 139 | test<3>(); 140 | } 141 | -------------------------------------------------------------------------------- /test/primitive/fill.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | 5 | #ifndef TEST_ARCH 6 | #define TEST_ARCH (-1) 7 | #endif 8 | 9 | __device__ float to_float(const float a) { return a; } 10 | __device__ float to_float(const half a) { return __half2float(a); } 11 | 12 | __device__ float my_fabs(const float a) { return a > 0.f ? a : -a; } 13 | 14 | template 15 | __global__ void fill_test_kernel(float *const g_max_error_a, 16 | float *const g_max_error_z) { 17 | constexpr float a = 2.f; 18 | mtk::wmma::mma::fragment frag_zero, frag_a; 19 | mtk::wmma::mma::fill_zero(frag_zero); 20 | mtk::wmma::mma::fill_fragment(frag_a, a); 21 | 22 | float max_error_z = 0.f; 23 | for (unsigned i = 0; i < frag_zero.num_elements; i++) { 24 | max_error_z = max(abs(to_float(frag_zero.x[i])), max_error_z); 25 | } 26 | 27 | float max_error_a = 0.f; 28 | for (unsigned i = 0; i < frag_a.num_elements; i++) { 29 | max_error_a = max(abs(to_float(frag_a.x[i]) - a), max_error_a); 30 | } 31 | 32 | if (threadIdx.x == 0) { 33 | *g_max_error_a = 0; 34 | *g_max_error_z = 0; 35 | } 36 | __syncthreads(); 37 | for (unsigned i = 0; i < blockDim.x; i++) { 38 | if (threadIdx.x == i) { 39 | *g_max_error_a = max(*g_max_error_a, my_fabs(max_error_a)); 40 | *g_max_error_z = max(*g_max_error_z, my_fabs(max_error_z)); 41 | } 42 | __syncthreads(); 43 | } 44 | } 45 | 46 | template void test() { 47 | float *max_error_a; 48 | float *max_error_z; 49 | cudaMallocHost(&max_error_a, sizeof(float)); 50 | cudaMallocHost(&max_error_z, sizeof(float)); 51 | fill_test_kernel 52 | <<<1, 32>>>(max_error_a, max_error_z); 53 | cudaDeviceSynchronize(); 54 | std::printf("[%s] ARCH=%d, %11s, %2d, %2d, %2d, %5s, %10s, " 55 | "fill_zero_error=%e [%s], fill_a_error=%e, [%s]\n", 56 | __FILE__, TEST_ARCH, mtk::test_utils::get_string().c_str(), 57 | M, N, K, mtk::test_utils::get_string().c_str(), 58 | mtk::test_utils::get_string().c_str(), *max_error_a, 59 | mtk::test_utils::get_test_result_string( 60 | (*max_error_a) < mtk::test_utils::get_machine_eps()), 61 | *max_error_z, 62 | mtk::test_utils::get_test_result_string( 63 | (*max_error_z) < mtk::test_utils::get_machine_eps())); 64 | cudaFreeHost(max_error_a); 65 | cudaFreeHost(max_error_z); 66 | } 67 | 68 | int main() { 69 | #if TEST_ARCH >= 80 70 | test(); 71 | test(); 72 | test(); 73 | #endif 74 | #if TEST_ARCH >= 75 75 | test(); 76 | test(); 77 | test(); 78 | #endif 79 | #if TEST_ARCH == 75 || TEST_ARCH == 70 80 | test(); 81 | test(); 82 | test(); 83 | test(); 84 | test(); 85 | #endif 86 | } 87 | -------------------------------------------------------------------------------- /test/primitive/foreach.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #ifndef TEST_ARCH 7 | #define TEST_ARCH (-1) 8 | #endif 9 | 10 | // #define TEST_TF32 11 | 12 | #ifndef TEST_TF32 13 | constexpr std::size_t M = 16; 14 | constexpr std::size_t N = 16; 15 | constexpr std::size_t K = 16; 16 | using ab_type = half; 17 | #else 18 | constexpr std::size_t M = 16; 19 | constexpr std::size_t N = 16; 20 | constexpr std::size_t K = 8; 21 | using ab_type = nvcuda::wmma::precision::tf32; 22 | #endif 23 | 24 | __global__ void matmul16x16_kernel(float *const c_ptr, const float *const a_ptr, 25 | const float *const b_ptr) { 26 | nvcuda::wmma::fragment 28 | frag_a, frag_da; 29 | nvcuda::wmma::fragment 31 | frag_b, frag_db; 32 | nvcuda::wmma::fragment frag_c; 33 | 34 | mtk::wmma::foreach( 35 | nvcuda::wmma::mem_col_major, 36 | [&](const unsigned frag_index_list[], const unsigned frag_index_count, 37 | const unsigned mem_index) { 38 | const auto c = c_ptr[mem_index]; 39 | for (unsigned i = 0; i < frag_index_count; i++) { 40 | const unsigned frag_index = frag_index_list[i]; 41 | frag_c.x[frag_index] = c; 42 | } 43 | }); 44 | 45 | mtk::wmma::foreach([&](const unsigned frag_index_list[], 46 | const unsigned frag_index_count, 47 | const unsigned mem_index) { 48 | const auto a = a_ptr[mem_index]; 49 | const auto a_rp = mtk::wmma::detail::common::cast(a); 50 | const auto da_rp = mtk::wmma::detail::common::cast( 51 | a - mtk::wmma::detail::common::cast(a_rp)); 52 | for (unsigned i = 0; i < frag_index_count; i++) { 53 | const unsigned frag_index = frag_index_list[i]; 54 | frag_a.x[frag_index] = a_rp; 55 | frag_da.x[frag_index] = da_rp; 56 | } 57 | }); 58 | 59 | mtk::wmma::foreach([&](const unsigned frag_index_list[], 60 | const unsigned frag_index_count, 61 | const unsigned mem_index) { 62 | const auto b = b_ptr[mem_index]; 63 | const auto b_rp = mtk::wmma::detail::common::cast(b); 64 | const auto db_rp = mtk::wmma::detail::common::cast( 65 | b - mtk::wmma::detail::common::cast(b_rp)); 66 | for (unsigned i = 0; i < frag_index_count; i++) { 67 | const unsigned frag_index = frag_index_list[i]; 68 | frag_b.x[frag_index] = b_rp; 69 | frag_db.x[frag_index] = db_rp; 70 | } 71 | }); 72 | 73 | nvcuda::wmma::mma_sync(frag_c, frag_a, frag_db, frag_c); 74 | nvcuda::wmma::mma_sync(frag_c, frag_da, frag_b, frag_c); 75 | nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c); 76 | 77 | nvcuda::wmma::store_matrix_sync(c_ptr, frag_c, N, 78 | nvcuda::wmma::mem_col_major); 79 | } 80 | 81 | void test() { 82 | std::printf("-- test (%s) --\n", __FILE__); 83 | std::printf("arch : %d\n", TEST_ARCH); 84 | 85 | float *a, *b, *c, *d; 86 | cudaMallocHost(&a, sizeof(float) * N * N); 87 | cudaMallocHost(&b, sizeof(float) * N * N); 88 | cudaMallocHost(&c, sizeof(float) * N * N); 89 | cudaMallocHost(&d, sizeof(float) * N * N); 90 | 91 | std::mt19937 mt(std::random_device{}()); 92 | std::uniform_real_distribution dist(-1.0f, 1.0f); 93 | for (unsigned i = 0; i < N * N; i++) { 94 | a[i] = dist(mt); 95 | b[i] = dist(mt); 96 | d[i] = c[i] = dist(mt); 97 | } 98 | 99 | cudaDeviceSynchronize(); 100 | matmul16x16_kernel<<<1, 32>>>(c, a, b); 101 | cudaDeviceSynchronize(); 102 | 103 | double max_error = 0.0; 104 | double max_element = 0.0; 105 | for (unsigned i = 0; i < M; i++) { 106 | for (unsigned j = 0; j < N; j++) { 107 | double sum = d[i + j * M]; 108 | for (unsigned k = 0; k < K; k++) { 109 | sum += static_cast(a[i + M * k]) * 110 | static_cast(b[k + j * K]); 111 | } 112 | const auto error = std::abs(sum - c[i + j * M]); 113 | const auto element = std::abs(sum); 114 | max_error = std::max(max_error, error); 115 | max_element = std::max(max_element, element); 116 | } 117 | } 118 | const auto e = max_error / max_element; 119 | std::printf("{%s} error=%e [", __FILE__, e); 120 | if (e < 1e-6) { 121 | std::printf("PASSED"); 122 | } else { 123 | std::printf("FAILED"); 124 | } 125 | std::printf("]\n"); 126 | } 127 | 128 | int main() { test(); } 129 | -------------------------------------------------------------------------------- /test/primitive/foreach_ij.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #ifndef TEST_ARCH 9 | #define TEST_ARCH (-1) 10 | #endif 11 | 12 | __device__ float myabs(const float a) { 13 | if (a > 0) { 14 | return a; 15 | } else { 16 | return -a; 17 | } 18 | } 19 | 20 | template 22 | __global__ void test_kernel(float *const diff, const float *const src, 23 | const unsigned ld) { 24 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 25 | 26 | __shared__ storage_t smem[MATRIX_DIM * MATRIX_DIM]; 27 | for (unsigned i = 0; i < MATRIX_DIM * MATRIX_DIM; i += blockDim.x) { 28 | smem[i + threadIdx.x] = src[i + threadIdx.x]; 29 | } 30 | 31 | nvcuda::wmma::fragment frag_nvcuda; 32 | nvcuda::wmma::fragment frag_mtk; 33 | 34 | nvcuda::wmma::load_matrix_sync(frag_nvcuda, smem, ld); 35 | 36 | mtk::wmma::foreach_ij( 37 | [&](const unsigned *frag_index_list, const unsigned num_indeces, 38 | const unsigned i, const unsigned j) { 39 | unsigned mem_index; 40 | if (std::is_same::value) { 41 | mem_index = i + j * ld; 42 | } else { 43 | mem_index = i * ld + j; 44 | } 45 | for (unsigned f = 0; f < num_indeces; f++) { 46 | frag_mtk.x[frag_index_list[f]] = smem[mem_index]; 47 | } 48 | }); 49 | 50 | float max_diff = 0.f; 51 | for (unsigned i = 0; i < frag_mtk.num_elements; i++) { 52 | max_diff = max(max_diff, myabs(frag_mtk.x[i] - frag_nvcuda.x[i])); 53 | } 54 | diff[threadIdx.x] = max_diff; 55 | } 56 | 57 | template 59 | __global__ void test_kernel_acc(float *const diff, const float *const src, 60 | const unsigned ld) { 61 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 62 | 63 | __shared__ storage_t smem[MATRIX_DIM * MATRIX_DIM]; 64 | for (unsigned i = 0; i < MATRIX_DIM * MATRIX_DIM; i += blockDim.x) { 65 | smem[i + threadIdx.x] = src[i + threadIdx.x]; 66 | } 67 | 68 | nvcuda::wmma::fragment frag_nvcuda; 69 | nvcuda::wmma::fragment frag_mtk; 70 | 71 | nvcuda::wmma::load_matrix_sync(frag_nvcuda, smem, ld, layout); 72 | 73 | mtk::wmma::foreach_ij( 74 | layout, [&](const unsigned *frag_index_list, const unsigned num_indeces, 75 | const unsigned i, const unsigned j) { 76 | unsigned mem_index; 77 | if (layout == nvcuda::wmma::mem_col_major) { 78 | mem_index = i + j * ld; 79 | } else { 80 | mem_index = i * ld + j; 81 | } 82 | for (unsigned f = 0; f < num_indeces; f++) { 83 | frag_mtk.x[frag_index_list[f]] = smem[mem_index]; 84 | } 85 | }); 86 | 87 | float max_diff = 0.f; 88 | for (unsigned i = 0; i < frag_mtk.num_elements; i++) { 89 | max_diff = max(max_diff, myabs(frag_mtk.x[i] - frag_nvcuda.x[i])); 90 | } 91 | diff[threadIdx.x] = max_diff; 92 | } 93 | 94 | template 95 | void test() { 96 | constexpr unsigned MATRIX_DIM = 32; 97 | constexpr unsigned warp_size = 32; 98 | float *src_matrix; 99 | float *diff; 100 | cudaMallocHost(&src_matrix, sizeof(float) * MATRIX_DIM * MATRIX_DIM); 101 | cudaMallocHost(&diff, sizeof(float) * warp_size); 102 | 103 | for (unsigned i = 0; i < MATRIX_DIM * MATRIX_DIM; i++) { 104 | src_matrix[i] = static_cast(i) / (MATRIX_DIM); 105 | } 106 | 107 | test_kernel 108 | <<<1, warp_size>>>(diff, src_matrix, MATRIX_DIM); 109 | cudaDeviceSynchronize(); 110 | 111 | bool passed = true; 112 | for (unsigned i = 0; i < warp_size; i++) { 113 | if (diff[i] > (1.f / MATRIX_DIM / 2)) { 114 | passed = false; 115 | } 116 | } 117 | 118 | std::printf("%s{SM=%2d,Use=%15s,M=%2d,N=%2d,K=%2d,Type=%5s,Layout=%8s}:", 119 | __FILE__, TEST_ARCH, mtk::test_utils::get_string().c_str(), 120 | M, N, K, mtk::test_utils::get_string().c_str(), 121 | mtk::test_utils::get_string().c_str()); 122 | if (passed) { 123 | std::printf("PASSED"); 124 | } else { 125 | std::printf("FAILED"); 126 | } 127 | std::printf("\n"); 128 | 129 | cudaFreeHost(diff); 130 | cudaFreeHost(src_matrix); 131 | } 132 | 133 | template 134 | void test_acc() { 135 | constexpr unsigned MATRIX_DIM = 32; 136 | constexpr unsigned warp_size = 32; 137 | float *src_matrix; 138 | float *diff; 139 | cudaMallocHost(&src_matrix, sizeof(float) * MATRIX_DIM * MATRIX_DIM); 140 | cudaMallocHost(&diff, sizeof(float) * warp_size); 141 | 142 | for (unsigned i = 0; i < MATRIX_DIM * MATRIX_DIM; i++) { 143 | src_matrix[i] = static_cast(i) / (MATRIX_DIM); 144 | } 145 | 146 | if (std::is_same::value) { 147 | test_kernel_acc 148 | <<<1, warp_size>>>(diff, src_matrix, MATRIX_DIM); 149 | } else { 150 | test_kernel_acc 151 | <<<1, warp_size>>>(diff, src_matrix, MATRIX_DIM); 152 | } 153 | cudaDeviceSynchronize(); 154 | 155 | bool passed = true; 156 | for (unsigned i = 0; i < warp_size; i++) { 157 | if (diff[i] > (1.f / MATRIX_DIM / 2)) { 158 | passed = false; 159 | } 160 | } 161 | 162 | std::printf("%s{SM=%2d,Use=%15s,M=%2d,N=%2d,K=%2d,Type=%5s,Layout=%8s}:", 163 | __FILE__, TEST_ARCH, mtk::test_utils::get_string().c_str(), 164 | M, N, K, mtk::test_utils::get_string().c_str(), 165 | mtk::test_utils::get_string().c_str()); 166 | if (passed) { 167 | std::printf("PASSED"); 168 | } else { 169 | std::printf("FAILED"); 170 | } 171 | std::printf("\n"); 172 | 173 | cudaFreeHost(diff); 174 | cudaFreeHost(src_matrix); 175 | } 176 | 177 | int main() { 178 | test(); 179 | test(); 180 | test(); 181 | test(); 182 | test_acc(); 184 | test_acc(); 186 | test_acc(); 188 | test_acc(); 190 | #ifdef TEST_TF32 191 | test(); 193 | test(); 195 | test(); 197 | test(); 199 | test_acc(); 201 | test_acc(); 203 | #endif 204 | } 205 | -------------------------------------------------------------------------------- /test/primitive/foreach_v.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #ifndef TEST_ARCH 6 | #define TEST_ARCH (-1) 7 | #endif 8 | 9 | // #define TEST_TF32 10 | 11 | #ifndef TEST_TF32 12 | constexpr std::size_t M = 16; 13 | constexpr std::size_t N = 16; 14 | constexpr std::size_t K = 16; 15 | using ab_type = half; 16 | #else 17 | constexpr std::size_t M = 16; 18 | constexpr std::size_t N = 16; 19 | constexpr std::size_t K = 8; 20 | using ab_type = nvcuda::wmma::precision::tf32; 21 | #endif 22 | 23 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 24 | 25 | template 26 | __device__ __host__ typename mtk::wmma::detail::common::storage_t::type 27 | convert(const S); 28 | template <> 29 | __device__ __host__ typename mtk::wmma::detail::common::storage_t::type 30 | convert(const float a) { 31 | return a; 32 | } 33 | template <> 34 | __device__ __host__ typename mtk::wmma::detail::common::storage_t::type 35 | convert(const half a) { 36 | return __half2float(a); 37 | } 38 | template <> 39 | __device__ __host__ typename mtk::wmma::detail::common::storage_t::type 40 | convert(const float a) { 41 | return __float2half(a); 42 | } 43 | template <> 44 | __device__ __host__ typename mtk::wmma::detail::common::storage_t::type 45 | convert(const half a) { 46 | return a; 47 | } 48 | 49 | template __device__ T m_abs(const T a) { 50 | if (a >= convert(0)) 51 | return a; 52 | return -a; 53 | } 54 | 55 | template 56 | __global__ void test_foreach_v_kernel(const storage_t *const src, 57 | const storage_t *const cor) { 58 | nvcuda::wmma::fragment vec_frag; 59 | mtk::wmma::fill_zero(vec_frag); 60 | mtk::wmma::foreach_v([&](const unsigned frag_index_list[], 61 | const unsigned frag_index_count, 62 | const unsigned mem_index) { 63 | for (unsigned i = 0; i < frag_index_count; i++) { 64 | vec_frag.x[frag_index_list[i]] = convert(src[mem_index]); 65 | } 66 | }); 67 | 68 | nvcuda::wmma::fragment cor_frag; 69 | nvcuda::wmma::load_matrix_sync(cor_frag, cor, M); 70 | 71 | storage_t error = convert(0.0f); 72 | for (unsigned i = 0; i < vec_frag.num_elements; i++) { 73 | error += m_abs(vec_frag.x[i] - cor_frag.x[i]); 74 | } 75 | printf("[%2u] error = %e\n", threadIdx.x, convert(error)); 76 | } 77 | 78 | template void test() { 79 | std::printf("-- test (%s) --\n", __FILE__); 80 | std::size_t cor_size = 0; 81 | std::size_t vec_length = 0; 82 | std::printf("arch : %d\n", TEST_ARCH); 83 | if (std::is_same::value) { 84 | std::printf("layout : col_major\n"); 85 | } else { 86 | std::printf("layout : row_major\n"); 87 | } 88 | if (std::is_same::value) 89 | std::printf("type : float\n"); 90 | if (std::is_same::value) 91 | std::printf("type : half\n"); 92 | if (std::is_same::value) 93 | std::printf("type : tf32\n"); 94 | 95 | if (std::is_same::value) { 96 | std::printf("use : a\n"); 97 | cor_size = M * K; 98 | if (std::is_same::value) { 99 | vec_length = M; 100 | } else { 101 | vec_length = K; 102 | } 103 | } 104 | if (std::is_same::value) { 105 | std::printf("use : b\n"); 106 | cor_size = N * K; 107 | if (std::is_same::value) { 108 | vec_length = K; 109 | } else { 110 | vec_length = N; 111 | } 112 | } 113 | std::printf("size : %lu, %lu, %lu\n", M, N, K); 114 | 115 | storage_t *src_mem; 116 | storage_t *cor_mem; 117 | 118 | cudaMallocHost(&src_mem, M * sizeof(storage_t)); 119 | cudaMallocHost(&cor_mem, cor_size * sizeof(storage_t)); 120 | 121 | for (std::size_t i = 0; i < cor_size; i++) { 122 | cor_mem[i] = convert(0); 123 | } 124 | 125 | for (std::size_t i = 0; i < vec_length; i++) { 126 | const float v = i / 3.0f; 127 | src_mem[i] = convert(v); 128 | cor_mem[i] = convert(v); 129 | } 130 | 131 | cudaDeviceSynchronize(); 132 | test_foreach_v_kernel<<<1, 32>>>(src_mem, cor_mem); 133 | cudaDeviceSynchronize(); 134 | } 135 | 136 | int main() { 137 | test(); 138 | test(); 139 | 140 | test(); 141 | test(); 142 | } 143 | -------------------------------------------------------------------------------- /test/primitive/foreach_v_acc.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #ifndef TEST_ARCH 6 | #define TEST_ARCH (-1) 7 | #endif 8 | 9 | // #define TEST_TF32 10 | 11 | #ifndef TEST_TF32 12 | constexpr std::size_t M = 16; 13 | constexpr std::size_t N = 16; 14 | constexpr std::size_t K = 16; 15 | #else 16 | constexpr std::size_t M = 16; 17 | constexpr std::size_t N = 16; 18 | constexpr std::size_t K = 8; 19 | #endif 20 | 21 | template __device__ __host__ T convert(const S); 22 | template <> __device__ __host__ float convert(const float a) { 23 | return a; 24 | } 25 | template <> __device__ __host__ float convert(const half a) { 26 | return __half2float(a); 27 | } 28 | template <> __device__ __host__ half convert(const float a) { 29 | return __float2half(a); 30 | } 31 | template <> __device__ __host__ half convert(const half a) { 32 | return a; 33 | } 34 | 35 | __global__ void test_foreach_v_acc_kernel(float *const dst, 36 | const float *const src, 37 | const nvcuda::wmma::layout_t layout) { 38 | nvcuda::wmma::fragment frag_c; 39 | nvcuda::wmma::load_matrix_sync(frag_c, src, M, layout); 40 | mtk::wmma::foreach_v( 41 | layout, 42 | [&](const unsigned *frag_index_list, const unsigned fragment_index_count, 43 | const unsigned mem_index) { 44 | for (unsigned i = 0; i < fragment_index_count; i++) { 45 | dst[mem_index] = frag_c.x[frag_index_list[i]]; 46 | } 47 | }); 48 | } 49 | 50 | void test(const nvcuda::wmma::layout_t layout) { 51 | std::printf("-- test (%s) --\n", __FILE__); 52 | std::printf("arch : %d\n", TEST_ARCH); 53 | if (layout == nvcuda::wmma::mem_col_major) { 54 | std::printf("layout : col_major\n"); 55 | } else { 56 | std::printf("layout : row_major\n"); 57 | } 58 | std::printf("size : %lu, %lu, %lu\n", M, N, K); 59 | float *src_mem; 60 | float *dst_mem; 61 | 62 | cudaMallocHost(&src_mem, M * N * sizeof(float)); 63 | cudaMallocHost(&dst_mem, M * sizeof(float)); 64 | 65 | for (std::size_t i = 0; i < M * N; i++) { 66 | src_mem[i] = static_cast(i); 67 | } 68 | 69 | cudaDeviceSynchronize(); 70 | test_foreach_v_acc_kernel<<<1, 32>>>(dst_mem, src_mem, layout); 71 | cudaDeviceSynchronize(); 72 | 73 | float error = 0.f; 74 | for (std::size_t i = 0; i < M; i++) { 75 | error = std::max(std::abs(dst_mem[i] - src_mem[i]), error); 76 | } 77 | std::printf("error = %e\n", error); 78 | } 79 | 80 | int main() { 81 | test(nvcuda::wmma::mem_row_major); 82 | test(nvcuda::wmma::mem_col_major); 83 | test(nvcuda::wmma::mem_row_major); 84 | test(nvcuda::wmma::mem_col_major); 85 | } 86 | -------------------------------------------------------------------------------- /test/primitive/gevm.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifndef TEST_ARCH 8 | #define TEST_ARCH (-1) 9 | #endif 10 | 11 | template __device__ __host__ T convert(const S); 12 | template <> __device__ __host__ float convert(const float a) { 13 | return a; 14 | } 15 | template <> __device__ __host__ float convert(const half a) { 16 | return __half2float(a); 17 | } 18 | template <> __device__ __host__ half convert(const float a) { 19 | return __float2half(a); 20 | } 21 | template <> __device__ __host__ half convert(const half a) { 22 | return a; 23 | } 24 | 25 | __global__ void test_gevm_kernel(float *const dst, const half *const src, 26 | const half *const eye) { 27 | nvcuda::wmma::fragment 29 | eye_frag; 30 | nvcuda::wmma::fragment 32 | vec_frag; 33 | nvcuda::wmma::fragment 34 | result_frag; 35 | nvcuda::wmma::load_matrix_sync(eye_frag, eye, 16); 36 | nvcuda::wmma::fill_fragment(result_frag, 0.0f); 37 | 38 | mtk::wmma::load_vector(vec_frag, src); 39 | 40 | nvcuda::wmma::mma_sync(result_frag, eye_frag, vec_frag, result_frag); 41 | 42 | mtk::wmma::store_vector(dst, result_frag, nvcuda::wmma::mem_col_major); 43 | } 44 | 45 | void test() { 46 | half *src_mem; 47 | float *dst_mem; 48 | half *eye_mem; 49 | 50 | cudaMallocHost(&src_mem, 16 * 16 * sizeof(half)); 51 | cudaMallocHost(&dst_mem, 16 * sizeof(float)); 52 | cudaMallocHost(&eye_mem, 16 * 16 * sizeof(half)); 53 | 54 | for (std::size_t i = 0; i < 16 * 16; i++) { 55 | src_mem[i] = convert((i < 16) ? i : 0); 56 | eye_mem[i] = convert((i % 17 == 0) ? 1.0f : 0.0f); 57 | } 58 | 59 | cudaDeviceSynchronize(); 60 | test_gevm_kernel<<<1, 32>>>(dst_mem, src_mem, eye_mem); 61 | cudaDeviceSynchronize(); 62 | 63 | double error = 0.; 64 | for (std::size_t i = 0; i < 16; i++) { 65 | const double diff = 66 | convert(dst_mem[i]) - convert(src_mem[i]); 67 | error = std::max(error, std::abs(diff)); 68 | } 69 | std::printf("[%s] ARCH=%d, error=%e [%s]\n", __FILE__, TEST_ARCH, error, 70 | mtk::test_utils::get_test_result_string( 71 | error < mtk::test_utils::get_machine_eps() * 16)); 72 | } 73 | 74 | int main() { test(); } 75 | -------------------------------------------------------------------------------- /test/primitive/map.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | 5 | // #define TEST_TF32 6 | 7 | constexpr unsigned warp_size = 32; 8 | 9 | __device__ inline float my_abs(float a) { return (a < 0) ? (-a) : a; } 10 | 11 | template 12 | __global__ void map_test(float *const error_ptr) { 13 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 14 | constexpr unsigned mat_m = 15 | mtk::wmma::detail::common::get_M::value; 16 | constexpr unsigned mat_n = 17 | mtk::wmma::detail::common::get_N::value; 18 | __shared__ storage_t smem[mat_m * mat_n]; 19 | 20 | for (unsigned i = 0; i < mat_m * mat_n; i += warp_size) { 21 | const unsigned index = i + threadIdx.x; 22 | const auto m = std::is_same::value 23 | ? (index / mat_n) 24 | : (index % mat_m); 25 | const auto n = std::is_same::value 26 | ? (index % mat_n) 27 | : (index / mat_m); 28 | smem[index] = mtk::wmma::detail::common::cast(m + n * mat_m); 29 | } 30 | 31 | nvcuda::wmma::fragment frag_ref; 32 | nvcuda::wmma::load_matrix_sync( 33 | frag_ref, smem, 34 | std::is_same::value ? mat_n : mat_m); 35 | 36 | nvcuda::wmma::fragment frag_map; 37 | for (unsigned i = 0; i < mat_m; i++) { 38 | for (unsigned j = 0; j < mat_n; j++) { 39 | unsigned tid_list[2]; 40 | unsigned fid_list[2]; 41 | unsigned list_size; 42 | mtk::wmma::map(tid_list, fid_list, list_size, i, j); 43 | 44 | for (unsigned k = 0; k < list_size; k++) { 45 | if (threadIdx.x == tid_list[k]) { 46 | frag_map.x[fid_list[k]] = 47 | mtk::wmma::detail::common::cast(i + j * mat_m); 48 | } 49 | __syncwarp(); 50 | } 51 | } 52 | } 53 | float error = 0.f; 54 | for (unsigned i = 0; i < frag_map.num_elements; i++) { 55 | error += my_abs(frag_map.x[i] - frag_ref.x[i]); 56 | } 57 | 58 | atomicAdd(error_ptr, error); 59 | } 60 | 61 | template 62 | __global__ void map_test(float *const error_ptr, 63 | const nvcuda::wmma::layout_t layout) { 64 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 65 | constexpr unsigned mat_m = 66 | mtk::wmma::detail::common::get_M::value; 67 | constexpr unsigned mat_n = 68 | mtk::wmma::detail::common::get_N::value; 69 | __shared__ storage_t smem[mat_m * mat_n]; 70 | 71 | for (unsigned i = 0; i < mat_m * mat_n; i += warp_size) { 72 | const unsigned index = i + threadIdx.x; 73 | const float v = (layout == nvcuda::wmma::mem_row_major) 74 | ? ((index / mat_n) + (index % mat_n) * mat_m) 75 | : index; 76 | smem[index] = mtk::wmma::detail::common::cast(v); 77 | } 78 | for (unsigned i = 0; i < mat_m * mat_n; i += warp_size) { 79 | const unsigned index = i + threadIdx.x; 80 | const auto m = (layout == nvcuda::wmma::mem_row_major) ? (index / mat_n) 81 | : (index % mat_m); 82 | const auto n = (layout == nvcuda::wmma::mem_row_major) ? (index % mat_n) 83 | : (index / mat_m); 84 | smem[index] = mtk::wmma::detail::common::cast(m + n * mat_m); 85 | } 86 | 87 | nvcuda::wmma::fragment frag_ref; 88 | nvcuda::wmma::load_matrix_sync( 89 | frag_ref, smem, (layout == nvcuda::wmma::mem_row_major) ? mat_n : mat_m, 90 | layout); 91 | 92 | nvcuda::wmma::fragment frag_map; 93 | for (unsigned i = 0; i < mat_m; i++) { 94 | for (unsigned j = 0; j < mat_n; j++) { 95 | unsigned tid_list[2]; 96 | unsigned fid_list[2]; 97 | unsigned list_size; 98 | mtk::wmma::map(tid_list, fid_list, list_size, i, j); 99 | 100 | for (unsigned k = 0; k < list_size; k++) { 101 | if (threadIdx.x == tid_list[k]) { 102 | frag_map.x[fid_list[k]] = 103 | mtk::wmma::detail::common::cast(i + j * mat_m); 104 | } 105 | __syncwarp(); 106 | } 107 | } 108 | } 109 | float error = 0.f; 110 | for (unsigned i = 0; i < frag_map.num_elements; i++) { 111 | error += my_abs(frag_map.x[i] - frag_ref.x[i]); 112 | } 113 | 114 | atomicAdd(error_ptr, error); 115 | } 116 | 117 | template void test() { 118 | float *error; 119 | cudaMallocHost(&error, sizeof(float)); 120 | map_test<<<1, warp_size>>>(error); 121 | cudaDeviceSynchronize(); 122 | 123 | std::printf("%s<%12s,%2d,%2d,%2d,%7s,%10s>:Error=%e [", __FILE__, 124 | mtk::test_utils::get_string().c_str(), M, N, K, 125 | mtk::test_utils::get_string().c_str(), 126 | mtk::test_utils::get_string().c_str(), *error); 127 | if (*error < 1) { 128 | std::printf("PASSED"); 129 | } else { 130 | std::printf("FAILED"); 131 | } 132 | std::printf("]\n"); 133 | cudaFreeHost(error); 134 | } 135 | 136 | template 137 | void test(const nvcuda::wmma::layout_t layout) { 138 | float *error; 139 | cudaMallocHost(&error, sizeof(float)); 140 | map_test<<<1, warp_size>>>(error, layout); 141 | cudaDeviceSynchronize(); 142 | 143 | std::printf( 144 | "%s<%12s,%2d,%2d,%2d,%7s,%10s>:Error=%e [", __FILE__, 145 | mtk::test_utils::get_string().c_str(), M, N, K, 146 | mtk::test_utils::get_string().c_str(), 147 | (layout == nvcuda::wmma::mem_col_major) 148 | ? mtk::test_utils::get_string().c_str() 149 | : mtk::test_utils::get_string().c_str(), 150 | *error); 151 | if (*error < 1) { 152 | std::printf("PASSED"); 153 | } else { 154 | std::printf("FAILED"); 155 | } 156 | std::printf("]\n"); 157 | cudaFreeHost(error); 158 | } 159 | 160 | int main() { 161 | test(); 162 | test(); 163 | test(); 164 | test(); 165 | test( 166 | nvcuda::wmma::mem_col_major); 167 | test( 168 | nvcuda::wmma::mem_col_major); 169 | test( 170 | nvcuda::wmma::mem_row_major); 171 | test( 172 | nvcuda::wmma::mem_row_major); 173 | 174 | #ifdef TEST_TF32 175 | test(); 177 | test(); 179 | test(); 181 | test(); 183 | test( 184 | nvcuda::wmma::mem_col_major); 185 | test( 186 | nvcuda::wmma::mem_row_major); 187 | #endif 188 | } 189 | -------------------------------------------------------------------------------- /test/primitive/print_fragment.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #ifndef TEST_ARCH 6 | #define TEST_ARCH (-1) 7 | #endif 8 | 9 | template __device__ __host__ T convert(const S); 10 | template <> __device__ __host__ float convert(const float a) { 11 | return a; 12 | } 13 | template <> __device__ __host__ float convert(const half a) { 14 | return __half2float(a); 15 | } 16 | template <> __device__ __host__ half convert(const float a) { 17 | return __float2half(a); 18 | } 19 | template <> __device__ __host__ half convert(const half a) { 20 | return a; 21 | } 22 | 23 | template 24 | __global__ void test_load_vector_kernel(const half *const src) { 25 | nvcuda::wmma::fragment frag; 26 | nvcuda::wmma::load_matrix_sync(frag, src, 16); 27 | 28 | mtk::wmma::print_fragment(frag); 29 | } 30 | 31 | template void test() { 32 | std::printf("-- test (%s) --\n", __FILE__); 33 | std::printf("arch : %d\n", TEST_ARCH); 34 | if (std::is_same::value) { 35 | std::printf("layout : col_major\n"); 36 | } else { 37 | std::printf("layout : row_major\n"); 38 | } 39 | half *src_mem; 40 | 41 | cudaMallocHost(&src_mem, 16 * sizeof(half)); 42 | 43 | for (std::size_t i = 0; i < 16 * 16; i++) { 44 | src_mem[i] = convert(i); 45 | } 46 | 47 | cudaDeviceSynchronize(); 48 | test_load_vector_kernel<<<1, 32>>>(src_mem); 49 | cudaDeviceSynchronize(); 50 | } 51 | 52 | int main() { 53 | test(); 54 | test(); 55 | } 56 | -------------------------------------------------------------------------------- /test/primitive/vector.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #ifndef TEST_ARCH 7 | #define TEST_ARCH (-1) 8 | #endif 9 | 10 | // #define TEST_TF32 11 | // #define TF32_ROUNDING 12 | 13 | template __device__ __host__ T convert(const S); 14 | template <> __device__ __host__ float convert(const float a) { 15 | return a; 16 | } 17 | template <> __device__ __host__ float convert(const half a) { 18 | return __half2float(a); 19 | } 20 | template <> __device__ __host__ half convert(const float a) { 21 | return __float2half(a); 22 | } 23 | template <> __device__ __host__ half convert(const half a) { 24 | return a; 25 | } 26 | 27 | template __device__ T m_abs(const T a) { 28 | if (a >= convert(0)) 29 | return a; 30 | return -a; 31 | } 32 | 33 | template struct fragment_layout { 34 | using type = Layout; 35 | }; 36 | template <> 37 | struct fragment_layout { 38 | using type = void; 39 | }; 40 | template <> 41 | struct fragment_layout { 42 | using type = void; 43 | }; 44 | 45 | template __device__ float error_threshold(); 46 | template <> __device__ float error_threshold() { return 1e-6f; }; 47 | template <> __device__ float error_threshold() { return 1e-3f; }; 48 | 49 | template 50 | __global__ void test_load_vector_ab_kernel( 51 | const typename mtk::wmma::detail::common::storage_t::type *const src, 52 | const typename mtk::wmma::detail::common::storage_t::type *const cor) { 53 | mtk::wmma::mma::fragment vec_frag; 54 | mtk::wmma::mma::fill_zero(vec_frag); 55 | mtk::wmma::mma::load_vector(vec_frag, src); 56 | 57 | mtk::wmma::mma::fragment::type> 59 | cor_frag; 60 | mtk::wmma::mma::load_matrix_sync(cor_frag, cor, m); 61 | 62 | auto error = 63 | convert::type, float>( 64 | 0.0f); 65 | for (unsigned i = 0; i < vec_frag.num_elements; i++) { 66 | error += m_abs(vec_frag.x[i] - cor_frag.x[i]); 67 | } 68 | printf("[%2u] error = %e (%s)\n", threadIdx.x, convert(error), 69 | (convert(error) < error_threshold() ? "PASSED" : "FAILED")); 70 | } 71 | 72 | template 73 | __global__ void test_load_vector_acc_kernel(const T *const src, 74 | const T *const cor) { 75 | mtk::wmma::mma::fragment vec_frag; 76 | mtk::wmma::mma::fragment cor_frag; 77 | mtk::wmma::mma::fill_zero(vec_frag); 78 | if (std::is_same::value) { 79 | mtk::wmma::mma::load_vector(vec_frag, src, nvcuda::wmma::mem_col_major); 80 | mtk::wmma::mma::load_matrix_sync(cor_frag, cor, m, 81 | nvcuda::wmma::mem_col_major); 82 | } else { 83 | mtk::wmma::mma::load_vector(vec_frag, src, nvcuda::wmma::mem_row_major); 84 | mtk::wmma::mma::load_matrix_sync(cor_frag, cor, n, 85 | nvcuda::wmma::mem_row_major); 86 | } 87 | 88 | auto error = 89 | convert::type, float>( 90 | 0.0f); 91 | for (unsigned i = 0; i < vec_frag.num_elements; i++) { 92 | error += m_abs(vec_frag.x[i] - cor_frag.x[i]); 93 | } 94 | printf("[%2u] error = %e (%s)\n", threadIdx.x, convert(error), 95 | (convert(error) < error_threshold() ? "PASSED" : "FAILED")); 96 | } 97 | 98 | template void test() { 99 | std::printf("-- test (%s) --\n", __FILE__); 100 | std::size_t cor_size = 0; 101 | std::size_t vec_length = 0; 102 | std::printf("arch : %d\n", TEST_ARCH); 103 | if (std::is_same::value) { 104 | std::printf("layout : col_major\n"); 105 | } else if (std::is_same::value) { 106 | std::printf("layout : row_major\n"); 107 | } else { 108 | std::printf("layout : void\n"); 109 | } 110 | if (std::is_same::value) 111 | std::printf("type : float\n"); 112 | if (std::is_same::value) 113 | std::printf("type : half\n"); 114 | if (std::is_same::value) 115 | std::printf("type : tf32\n"); 116 | 117 | if (std::is_same::value) { 118 | std::printf("use : a\n"); 119 | cor_size = m * k; 120 | if (std::is_same::value) { 121 | vec_length = m; 122 | } else { 123 | vec_length = k; 124 | } 125 | } 126 | if (std::is_same::value) { 127 | std::printf("use : b\n"); 128 | cor_size = n * k; 129 | if (std::is_same::value) { 130 | vec_length = k; 131 | } else { 132 | vec_length = n; 133 | } 134 | } 135 | if (std::is_same::value) { 136 | std::printf("use : acc\n"); 137 | cor_size = n * m; 138 | if (std::is_same::value) { 139 | vec_length = m; 140 | } else { 141 | vec_length = n; 142 | } 143 | } 144 | std::printf("size : %d, %d, %d\n", m, n, k); 145 | 146 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 147 | storage_t *src_mem; 148 | storage_t *cor_mem; 149 | 150 | cudaMallocHost(&src_mem, m * sizeof(storage_t)); 151 | cudaMallocHost(&cor_mem, cor_size * sizeof(storage_t)); 152 | 153 | for (std::size_t i = 0; i < cor_size; i++) { 154 | cor_mem[i] = convert(0); 155 | } 156 | 157 | for (std::size_t i = 0; i < vec_length; i++) { 158 | const float v = i * 1.f; 159 | src_mem[i] = convert(v); 160 | cor_mem[i] = convert(v); 161 | } 162 | 163 | cudaDeviceSynchronize(); 164 | if constexpr (std::is_same::value) { 165 | test_load_vector_acc_kernel 166 | <<<1, 32>>>(src_mem, cor_mem); 167 | } else { 168 | test_load_vector_ab_kernel 169 | <<<1, 32>>>(src_mem, cor_mem); 170 | } 171 | cudaDeviceSynchronize(); 172 | } 173 | 174 | int main() { 175 | #if TEST_ARCH >= 80 176 | test(); 177 | test(); 178 | test(); 179 | test(); 180 | #endif 181 | 182 | #if TEST_ARCH >= 75 183 | test(); 184 | test(); 185 | test(); 186 | test(); 187 | #endif 188 | 189 | #if TEST_ARCH >= 70 && TEST_ARCH <= 75 190 | test(); 191 | test(); 192 | test(); 193 | test(); 194 | test(); 195 | test(); 196 | test(); 197 | test(); 198 | #endif 199 | } 200 | -------------------------------------------------------------------------------- /test/primitive/wmma.load_vector.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | #include 5 | 6 | #ifndef TEST_ARCH 7 | #define TEST_ARCH (-1) 8 | #endif 9 | 10 | // #define TEST_TF32 11 | // #define TF32_ROUNDING 12 | 13 | template __device__ __host__ T convert(const S); 14 | template <> __device__ __host__ float convert(const float a) { 15 | return a; 16 | } 17 | template <> __device__ __host__ float convert(const half a) { 18 | return __half2float(a); 19 | } 20 | template <> __device__ __host__ half convert(const float a) { 21 | return __float2half(a); 22 | } 23 | template <> __device__ __host__ half convert(const half a) { 24 | return a; 25 | } 26 | 27 | template __device__ T m_abs(const T a) { 28 | if (a >= convert(0)) 29 | return a; 30 | return -a; 31 | } 32 | 33 | template struct fragment_layout { 34 | using type = Layout; 35 | }; 36 | template <> 37 | struct fragment_layout { 38 | using type = void; 39 | }; 40 | template <> 41 | struct fragment_layout { 42 | using type = void; 43 | }; 44 | 45 | template 46 | __global__ void test_load_vector_ab_kernel( 47 | float *const error, 48 | const typename mtk::wmma::detail::common::storage_t::type *const src, 49 | const typename mtk::wmma::detail::common::storage_t::type *const cor) { 50 | nvcuda::wmma::fragment vec_frag; 51 | mtk::wmma::load_vector(vec_frag, src); 52 | 53 | nvcuda::wmma::fragment::type> 55 | cor_frag; 56 | nvcuda::wmma::load_matrix_sync(cor_frag, cor, m); 57 | 58 | auto e = 59 | convert::type, float>( 60 | 0.0f); 61 | for (unsigned i = 0; i < vec_frag.num_elements; i++) { 62 | e += m_abs(vec_frag.x[i] - cor_frag.x[i]); 63 | } 64 | if (threadIdx.x == 0) { 65 | *error = 0; 66 | } 67 | __syncthreads(); 68 | for (unsigned i = 0; i < blockDim.x; i++) { 69 | *error = max(m_abs(e), *error); 70 | __syncthreads(); 71 | } 72 | } 73 | 74 | template 75 | __global__ void test_load_vector_acc_kernel(float *const error, 76 | const T *const src, 77 | const T *const cor) { 78 | nvcuda::wmma::fragment vec_frag; 79 | nvcuda::wmma::fragment cor_frag; 80 | if (std::is_same::value) { 81 | mtk::wmma::load_vector(vec_frag, src, nvcuda::wmma::mem_col_major); 82 | nvcuda::wmma::load_matrix_sync(cor_frag, cor, m, 83 | nvcuda::wmma::mem_col_major); 84 | } else { 85 | mtk::wmma::load_vector(vec_frag, src, nvcuda::wmma::mem_row_major); 86 | nvcuda::wmma::load_matrix_sync(cor_frag, cor, m, 87 | nvcuda::wmma::mem_row_major); 88 | } 89 | 90 | auto e = 91 | convert::type, float>( 92 | 0.0f); 93 | for (unsigned i = 0; i < vec_frag.num_elements; i++) { 94 | e += m_abs(vec_frag.x[i] - cor_frag.x[i]); 95 | } 96 | if (threadIdx.x == 0) { 97 | *error = 0; 98 | } 99 | __syncthreads(); 100 | for (unsigned i = 0; i < blockDim.x; i++) { 101 | *error = max(m_abs(e), *error); 102 | __syncthreads(); 103 | } 104 | } 105 | 106 | template void test() { 107 | std::size_t cor_size = 0; 108 | std::size_t vec_length = 0; 109 | 110 | if (std::is_same::value) { 111 | cor_size = m * k; 112 | if (std::is_same::value) { 113 | vec_length = m; 114 | } else { 115 | vec_length = k; 116 | } 117 | } 118 | if (std::is_same::value) { 119 | cor_size = n * k; 120 | if (std::is_same::value) { 121 | vec_length = k; 122 | } else { 123 | vec_length = n; 124 | } 125 | } 126 | if (std::is_same::value) { 127 | cor_size = n * m; 128 | if (std::is_same::value) { 129 | vec_length = m; 130 | } else { 131 | vec_length = n; 132 | } 133 | } 134 | 135 | using storage_t = typename mtk::wmma::detail::common::storage_t::type; 136 | storage_t *src_mem; 137 | storage_t *cor_mem; 138 | 139 | cudaMallocHost(&src_mem, m * sizeof(storage_t)); 140 | cudaMallocHost(&cor_mem, cor_size * sizeof(storage_t)); 141 | 142 | for (std::size_t i = 0; i < cor_size; i++) { 143 | cor_mem[i] = convert(0); 144 | } 145 | 146 | for (std::size_t i = 0; i < vec_length; i++) { 147 | const float v = i / 3.0f; 148 | src_mem[i] = convert(v); 149 | cor_mem[i] = convert(v); 150 | } 151 | 152 | float *error; 153 | cudaMallocHost(&error, sizeof(float)); 154 | cudaDeviceSynchronize(); 155 | if constexpr (std::is_same::value) { 156 | test_load_vector_acc_kernel 157 | <<<1, 32>>>(error, src_mem, cor_mem); 158 | } else { 159 | test_load_vector_ab_kernel 160 | <<<1, 32>>>(error, src_mem, cor_mem); 161 | } 162 | cudaDeviceSynchronize(); 163 | std::printf("[%s] ARCH=%d, <%2d, %2d, %2d>, %10s, %10s, error=%e [%s]\n", 164 | __FILE__, TEST_ARCH, m, n, k, 165 | mtk::test_utils::get_string().c_str(), 166 | mtk::test_utils::get_string().c_str(), (*error), 167 | mtk::test_utils::get_test_result_string( 168 | (*error) < mtk::test_utils::get_machine_eps() * 16)); 169 | cudaFreeHost(error); 170 | } 171 | 172 | int main() { 173 | test(); 174 | test(); 175 | 176 | test(); 177 | test(); 178 | 179 | test(); 180 | test(); 181 | #ifdef TEST_TF32 182 | test(); 184 | test(); 186 | 187 | test(); 189 | test(); 191 | 192 | test(); 193 | test(); 194 | #endif 195 | } 196 | -------------------------------------------------------------------------------- /test/primitive/wmma.store_vector.cu: -------------------------------------------------------------------------------- 1 | #include "common.hpp" 2 | #include 3 | #include 4 | #include 5 | 6 | #ifndef TEST_ARCH 7 | #define TEST_ARCH (-1) 8 | #endif 9 | 10 | // #define TEST_TF32 11 | 12 | #ifndef TEST_TF32 13 | constexpr int M = 16; 14 | constexpr int N = 16; 15 | constexpr int K = 16; 16 | #else 17 | constexpr int M = 16; 18 | constexpr int N = 16; 19 | constexpr int K = 8; 20 | #endif 21 | 22 | template __device__ __host__ T convert(const S); 23 | template <> __device__ __host__ float convert(const float a) { 24 | return a; 25 | } 26 | template <> __device__ __host__ float convert(const half a) { 27 | return __half2float(a); 28 | } 29 | template <> __device__ __host__ half convert(const float a) { 30 | return __float2half(a); 31 | } 32 | template <> __device__ __host__ half convert(const half a) { 33 | return a; 34 | } 35 | 36 | __global__ void test_store_vector_kernel(float *const dst, 37 | const float *const src, 38 | const nvcuda::wmma::layout_t layout) { 39 | nvcuda::wmma::fragment frag_c; 40 | nvcuda::wmma::load_matrix_sync(frag_c, src, M, layout); 41 | mtk::wmma::store_vector(dst, frag_c, layout); 42 | } 43 | 44 | void test(const nvcuda::wmma::layout_t layout) { 45 | float *src_mem; 46 | float *dst_mem; 47 | 48 | cudaMallocHost(&src_mem, M * N * sizeof(float)); 49 | cudaMallocHost(&dst_mem, M * sizeof(float)); 50 | 51 | for (std::size_t i = 0; i < M * N; i++) { 52 | src_mem[i] = static_cast(i); 53 | } 54 | 55 | cudaDeviceSynchronize(); 56 | test_store_vector_kernel<<<1, 32>>>(dst_mem, src_mem, layout); 57 | cudaDeviceSynchronize(); 58 | 59 | double error = 0; 60 | for (std::size_t i = 0; i < M; i++) { 61 | const double diff = src_mem[i] - dst_mem[i]; 62 | error = std::max(error, std::abs(diff)); 63 | } 64 | 65 | cudaFreeHost(src_mem); 66 | cudaFreeHost(dst_mem); 67 | 68 | std::printf("[%s] ARCH=%d, <%2d, %2d, %2d>, error=%e, [%s]\n", __FILE__, 69 | TEST_ARCH, M, N, K, error, 70 | mtk::test_utils::get_test_result_string( 71 | error < mtk::test_utils::get_machine_eps())); 72 | } 73 | 74 | int main() { 75 | test(nvcuda::wmma::mem_row_major); 76 | test(nvcuda::wmma::mem_col_major); 77 | test(nvcuda::wmma::mem_row_major); 78 | test(nvcuda::wmma::mem_col_major); 79 | } 80 | -------------------------------------------------------------------------------- /test/tcec/Makefile: -------------------------------------------------------------------------------- 1 | NVCC=nvcc 2 | 3 | INCDIR=../../include 4 | 5 | TEST_TF32=NO 6 | TEST_SIMT=YES 7 | SM_ARCH=Ampere 8 | 9 | NVCCFLAGS=-std=c++14 -I$(INCDIR) -Xcompiler="-fopenmp" --ptxas-options=-v -lcublas 10 | 11 | ifeq ($(SM_ARCH), Ada) 12 | NVCCFLAGS+=-gencode arch=compute_89,code=sm_89 13 | ifeq ($(TEST_TF32), YES) 14 | NVCCFLAGS+=-DTEST_TF32 15 | endif 16 | endif 17 | 18 | ifeq ($(SM_ARCH), Ampere) 19 | NVCCFLAGS+=-gencode arch=compute_86,code=sm_86 20 | NVCCFLAGS+=-gencode arch=compute_80,code=sm_80 21 | ifeq ($(TEST_TF32), YES) 22 | NVCCFLAGS+=-DTEST_TF32 23 | endif 24 | endif 25 | 26 | ifeq ($(SM_ARCH), Turing) 27 | NVCCFLAGS+=-gencode arch=compute_75,code=sm_75 -DSM_ARCH=75 28 | endif 29 | 30 | ifeq ($(SM_ARCH), Volta) 31 | NVCCFLAGS+=-gencode arch=compute_70,code=sm_70 -DSM_ARCH=70 32 | endif 33 | 34 | ifeq ($(TEST_SIMT), YES) 35 | NVCCFLAGS+=-DTEST_SIMT 36 | endif 37 | 38 | TARGET=batch_gemm.test mma.test matvec.test elementwise.test mma_complex.test vector.test 39 | 40 | all: $(TARGET) 41 | 42 | %.test:%.cu 43 | $(NVCC) $< $(OBJS) $(NVCCFLAGS) -o $@ 44 | 45 | clean: 46 | rm -f *.test 47 | -------------------------------------------------------------------------------- /test/tcec/elementwise.cu: -------------------------------------------------------------------------------- 1 | #include "utils.hpp" 2 | #include 3 | #include 4 | 5 | constexpr unsigned warp_size = 32; 6 | 7 | template 8 | __global__ void test_elementwise_kernel(float *const ptr) { 9 | __shared__ float smem[N * N]; 10 | 11 | mtk::wmma::tcec::fragment 12 | frag; 13 | mtk::wmma::tcec::fill_fragment(frag, 0.0f); 14 | 15 | for (unsigned i = 0; i < frag.num_elements; i++) { 16 | frag.x(i) = threadIdx.x * 100 + i; 17 | } 18 | 19 | mtk::wmma::tcec::store_matrix_sync(smem, frag, N, 20 | nvcuda::wmma::mem_col_major); 21 | 22 | for (unsigned i = 0; i < N * N; i += warp_size) { 23 | const auto index = i + threadIdx.x; 24 | ptr[index] = smem[index]; 25 | } 26 | } 27 | 28 | template void test_elementwise() { 29 | std::printf("[%s, N = %u, T = %s, Policy = <%7s,%9s,%2u,%2u,%2u>]\n", 30 | __func__, N, mtk::test_utils::to_string().c_str(), 31 | mtk::test_utils::to_string().c_str(), 32 | std::is_same::value 34 | ? "{w/ ec}" 35 | : "{w/o ec}", 36 | Policy::m, Policy::n, Policy::k); 37 | float *hC; 38 | cudaMallocHost(&hC, N * N * sizeof(float)); 39 | 40 | test_elementwise_kernel<<<1, warp_size>>>(hC); 41 | cudaDeviceSynchronize(); 42 | 43 | for (unsigned i = 0; i < N; i++) { 44 | for (unsigned j = 0; j < N; j++) { 45 | std::printf("%e ", hC[i + j * N]); 46 | } 47 | std::printf("\n"); 48 | } 49 | } 50 | 51 | int main() { 52 | test_elementwise< 53 | 32, half, 54 | typename mtk::wmma::tcec::detail::default_policy< 55 | half, mtk::wmma::tcec::with_ec, mtk::wmma::tcec::op_wmma>::type>(); 56 | test_elementwise< 57 | 32, half, 58 | typename mtk::wmma::tcec::detail::default_policy< 59 | half, mtk::wmma::tcec::without_ec, mtk::wmma::tcec::op_wmma>::type>(); 60 | test_elementwise< 61 | 32, half, 62 | typename mtk::wmma::tcec::detail::default_policy< 63 | half, mtk::wmma::tcec::with_ec, mtk::wmma::tcec::op_mma>::type>(); 64 | test_elementwise< 65 | 32, half, 66 | typename mtk::wmma::tcec::detail::default_policy< 67 | half, mtk::wmma::tcec::without_ec, mtk::wmma::tcec::op_mma>::type>(); 68 | test_elementwise<32, float, 69 | typename mtk::wmma::tcec::detail::default_policy< 70 | float, mtk::wmma::tcec::without_ec, 71 | mtk::wmma::tcec::op_simt>::type>(); 72 | 73 | #ifdef TEST_SIMT 74 | test_elementwise<32, float, 75 | typename mtk::wmma::tcec::detail::default_policy< 76 | float, mtk::wmma::tcec::without_ec, 77 | mtk::wmma::tcec::op_simt>::type>(); 78 | #endif 79 | #ifdef TEST_TF32 80 | test_elementwise<32, nvcuda::wmma::precision::tf32, 81 | typename mtk::wmma::tcec::detail::default_policy< 82 | nvcuda::wmma::precision::tf32, mtk::wmma::tcec::with_ec, 83 | mtk::wmma::tcec::op_wmma>::type>(); 84 | test_elementwise< 85 | 32, nvcuda::wmma::precision::tf32, 86 | typename mtk::wmma::tcec::detail::default_policy< 87 | nvcuda::wmma::precision::tf32, mtk::wmma::tcec::without_ec, 88 | mtk::wmma::tcec::op_wmma>::type>(); 89 | #endif 90 | } 91 | -------------------------------------------------------------------------------- /test/tcec/matvec.cu: -------------------------------------------------------------------------------- 1 | #include "utils.hpp" 2 | #include 3 | #include 4 | 5 | template 6 | constexpr double error_threshold = 0.0; 7 | template <> 8 | constexpr double error_threshold = 1e-5; 9 | template <> 10 | constexpr double 11 | error_threshold = 12 | 1e-5; 13 | template <> 14 | constexpr double error_threshold = 1e-2; 15 | template <> 16 | constexpr double error_threshold = 1e-2; 18 | template <> 19 | constexpr double error_threshold = 1e-6; 20 | 21 | template 22 | __global__ void matvec_kernel(float *const y_ptr, const float *const a_ptr, 23 | const float *const x_ptr) { 24 | __shared__ float smem[N * N]; 25 | mtk::test_utils::fill_zero(smem, N * N); 26 | 27 | mtk::wmma::tcec::fragment 29 | frag_a; 30 | mtk::wmma::tcec::fragment 32 | frag_x; 33 | mtk::wmma::tcec::fragment 34 | frag_y; 35 | // Load A 36 | mtk::test_utils::copy_matrix(smem, N, a_ptr, N, N, N); 37 | mtk::wmma::tcec::load_matrix_sync(frag_a, smem, N); 38 | 39 | // Load X 40 | mtk::test_utils::copy_matrix(smem, N, x_ptr, N, N, 1); 41 | mtk::wmma::tcec::fill_zero(frag_x); 42 | mtk::wmma::tcec::load_vector(frag_x, smem); 43 | 44 | // mma 45 | mtk::wmma::tcec::mma_sync(frag_y, frag_a, frag_x); 46 | 47 | // Store D 48 | mtk::wmma::tcec::store_vector(smem, frag_y, nvcuda::wmma::mem_col_major); 49 | mtk::test_utils::copy_matrix(y_ptr, N, smem, N, N, 1); 50 | } 51 | 52 | template void test_matvec() { 53 | float *hX, *hY, *hA; 54 | cudaMallocHost(&hX, N * sizeof(float)); 55 | cudaMallocHost(&hY, N * sizeof(float)); 56 | cudaMallocHost(&hA, N * N * sizeof(float)); 57 | 58 | std::mt19937 mt(std::random_device{}()); 59 | std::uniform_real_distribution dist(-1.0f, 1.0f); 60 | 61 | for (unsigned i = 0; i < N * N; i++) { 62 | hA[i] = dist(mt); 63 | } 64 | for (unsigned i = 0; i < N; i++) { 65 | hX[i] = dist(mt); 66 | } 67 | cudaDeviceSynchronize(); 68 | 69 | matvec_kernel<<<1, mtk::test_utils::warp_size>>>(hY, hA, hX); 70 | 71 | cudaDeviceSynchronize(); 72 | 73 | double max_error = 0.; 74 | for (unsigned n = 0; n < N; n++) { 75 | double cor_d = 0.; 76 | for (unsigned k = 0; k < N; k++) { 77 | cor_d += static_cast(hA[k * N + n]) * static_cast(hX[k]); 78 | } 79 | 80 | max_error = std::max(max_error, std::abs(cor_d - hY[n])); 81 | } 82 | 83 | std::printf( 84 | "[Type:%5s, N:%3u, Policy<%7s,%9s,%2u,%2u,%2u>] max_error: %e (%6s)\n", 85 | mtk::test_utils::to_string().c_str(), N, 86 | mtk::test_utils::to_string().c_str(), 87 | std::is_same::value 89 | ? "{w/ ec}" 90 | : "{w/o ec}", 91 | Policy::m, Policy::n, Policy::k, max_error, 92 | (max_error < error_threshold 93 | ? "PASSED" 94 | : "FAILED")); 95 | 96 | cudaFreeHost(hA); 97 | cudaFreeHost(hX); 98 | cudaFreeHost(hY); 99 | } 100 | 101 | int main() { 102 | // wmma FP16 test 103 | test_matvec< 104 | 32, half, 105 | typename mtk::wmma::tcec::detail::default_policy< 106 | half, mtk::wmma::tcec::with_ec, mtk::wmma::tcec::op_wmma>::type>(); 107 | test_matvec< 108 | 32, half, 109 | typename mtk::wmma::tcec::detail::default_policy< 110 | half, mtk::wmma::tcec::without_ec, mtk::wmma::tcec::op_wmma>::type>(); 111 | 112 | #ifdef TEST_SIMT 113 | // simt test 114 | test_matvec<32, float, 115 | typename mtk::wmma::tcec::detail::default_policy< 116 | float, mtk::wmma::tcec::without_ec, 117 | mtk::wmma::tcec::op_simt>::type>(); 118 | #endif 119 | 120 | #ifdef TEST_TF32 121 | // wmma TF32 test 122 | test_matvec<32, nvcuda::wmma::precision::tf32, 123 | typename mtk::wmma::tcec::detail::default_policy< 124 | nvcuda::wmma::precision::tf32, mtk::wmma::tcec::with_ec, 125 | mtk::wmma::tcec::op_wmma>::type>(); 126 | test_matvec<32, nvcuda::wmma::precision::tf32, 127 | typename mtk::wmma::tcec::detail::default_policy< 128 | nvcuda::wmma::precision::tf32, mtk::wmma::tcec::without_ec, 129 | mtk::wmma::tcec::op_wmma>::type>(); 130 | #endif 131 | } 132 | -------------------------------------------------------------------------------- /test/tcec/utils.hpp: -------------------------------------------------------------------------------- 1 | #ifndef __HMMA_F32_F32_TEST_UTILS_HPP__ 2 | #define __HMMA_F32_F32_TEST_UTILS_HPP__ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #ifndef WMMAE_CUDA_CHECK_ERROR 11 | #define WMMAE_CUDA_CHECK_ERROR(status) \ 12 | cuda_check_error(status, __FILE__, __LINE__, __func__) 13 | #endif 14 | #ifndef WMMAE_CUDA_CHECK_ERROR_M 15 | #define WMMAE_CUDA_CHECK_ERROR_M(status, message) \ 16 | cuda_check_error(status, __FILE__, __LINE__, __func__, message) 17 | #endif 18 | 19 | inline void cuda_check_error(cudaError_t error, const std::string filename, 20 | const std::size_t line, const std::string funcname, 21 | const std::string message = "") { 22 | if (error != cudaSuccess) { 23 | std::stringstream ss; 24 | ss << cudaGetErrorString(error); 25 | if (message.length() != 0) { 26 | ss << " : " << message; 27 | } 28 | ss << " [" << filename << ":" << line << " in " << funcname << "]"; 29 | throw std::runtime_error(ss.str()); 30 | } 31 | } 32 | 33 | namespace mtk { 34 | namespace test_utils { 35 | 36 | template std::string to_string(); 37 | template <> std::string to_string() { return "acc"; } 38 | template <> std::string to_string() { 39 | return "matrix_a"; 40 | } 41 | template <> std::string to_string() { 42 | return "matrix_b"; 43 | } 44 | template <> std::string to_string() { 45 | return "col_major"; 46 | } 47 | template <> std::string to_string() { 48 | return "row_major"; 49 | } 50 | template <> std::string to_string() { return "float"; } 51 | template <> std::string to_string() { return "half"; } 52 | template <> std::string to_string() { 53 | return "tf32"; 54 | } 55 | template <> std::string to_string() { 56 | return "op_wmma"; 57 | } 58 | template <> std::string to_string() { 59 | return "op_mma"; 60 | } 61 | #ifdef TEST_SIMT 62 | template <> std::string to_string() { 63 | return "op_simt"; 64 | } 65 | #endif 66 | 67 | constexpr unsigned warp_size = 32; 68 | 69 | template 70 | __device__ void copy_matrix(T *const dst, const unsigned ldd, 71 | const T *const src, const unsigned lds, 72 | const unsigned m, const unsigned n) { 73 | for (unsigned i = 0; i < m * n; i += warp_size) { 74 | const auto j = i + threadIdx.x; 75 | if (j >= m * n) 76 | return; 77 | const auto mm = j % m; 78 | const auto mn = j / m; 79 | dst[mm + mn * ldd] = src[mm + mn * lds]; 80 | } 81 | } 82 | 83 | __device__ void fill_zero(float *const dst, const unsigned size) { 84 | for (unsigned i = 0; i < size; i += warp_size) { 85 | const auto j = i + threadIdx.x; 86 | if (j >= size) 87 | return; 88 | dst[j] = 0.0f; 89 | } 90 | } 91 | 92 | __device__ void fill_zero(cuComplex *const dst, const unsigned size) { 93 | for (unsigned i = 0; i < size; i += warp_size) { 94 | const auto j = i + threadIdx.x; 95 | if (j >= size) 96 | return; 97 | dst[j].x = 0.0f; 98 | dst[j].y = 0.0f; 99 | } 100 | } 101 | } // namespace test_utils 102 | } // namespace mtk 103 | #endif 104 | -------------------------------------------------------------------------------- /test/utils/Makefile: -------------------------------------------------------------------------------- 1 | TEST_ARCH=80 2 | ROOT_DIR=../../include 3 | NVCC=nvcc 4 | NVCCFLAGS=-std=c++17 -I$(ROOT_DIR) -arch=sm_$(TEST_ARCH) -DTEST_ARCH=$(TEST_ARCH) --extended-lambda 5 | HEADERS=$(shell find ../../include -name '*.hpp') 6 | 7 | TARGET= 8 | TARGET+=cast.test cp_async.test 9 | 10 | all: $(TARGET) 11 | 12 | %.test : %.cu Makefile $(HEADERS) 13 | $(NVCC) $(NVCCFLAGS) -o $@ $< 14 | 15 | clean: 16 | rm -f *.test 17 | -------------------------------------------------------------------------------- /test/utils/cast.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int main() { 4 | const auto a = mtk::wmma::utils::cast(1.0f); 5 | const auto b = mtk::wmma::utils::cast(a); 6 | } 7 | -------------------------------------------------------------------------------- /test/utils/cp_async.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | template __host__ __device__ constexpr unsigned get_size_in_byte(); 5 | template <> __host__ __device__ constexpr unsigned get_size_in_byte() { 6 | return 4; 7 | }; 8 | template <> __host__ __device__ constexpr unsigned get_size_in_byte() { 9 | return 8; 10 | }; 11 | template <> __host__ __device__ constexpr unsigned get_size_in_byte() { 12 | return 16; 13 | }; 14 | 15 | template 16 | __global__ void cp_async_test_kernel(T *const dst_ptr, const T *const src_ptr) { 17 | __shared__ T smem[block_size]; 18 | 19 | mtk::wmma::utils::cp_async::cp_async()>( 20 | smem + threadIdx.x, src_ptr + threadIdx.x); 21 | mtk::wmma::utils::cp_async::commit(); 22 | 23 | mtk::wmma::utils::cp_async::wait_all(); 24 | dst_ptr[threadIdx.x] = smem[threadIdx.x]; 25 | } 26 | 27 | template void cp_async_test() { 28 | T *d_input; 29 | T *d_output; 30 | T *h_input; 31 | T *h_output; 32 | 33 | cudaMalloc(&d_input, sizeof(T) * block_size); 34 | cudaMalloc(&d_output, sizeof(T) * block_size); 35 | cudaMallocHost(&h_input, sizeof(T) * block_size); 36 | cudaMallocHost(&h_output, sizeof(T) * block_size); 37 | 38 | for (unsigned i = 0; i < block_size * get_size_in_byte() / 4; i++) { 39 | reinterpret_cast(h_input)[i] = i; 40 | } 41 | 42 | cudaMemcpy(d_input, h_input, block_size * sizeof(T), cudaMemcpyDefault); 43 | 44 | cp_async_test_kernel<<<1, block_size>>>(d_output, d_input); 45 | 46 | cudaMemcpy(h_output, d_output, block_size * sizeof(T), cudaMemcpyDefault); 47 | 48 | double max_error = 0; 49 | for (unsigned i = 0; i < block_size * get_size_in_byte() / 4; i++) { 50 | const double diff = reinterpret_cast(h_output)[i] - 51 | reinterpret_cast(h_input)[i]; 52 | max_error = std::max(std::abs(diff), max_error); 53 | } 54 | 55 | std::printf("%s[%2u Byte] error = %e\n", __func__, get_size_in_byte(), 56 | max_error); 57 | 58 | cudaFree(d_input); 59 | cudaFree(d_output); 60 | cudaFreeHost(h_input); 61 | cudaFreeHost(h_output); 62 | } 63 | 64 | int main() { 65 | cp_async_test(); 66 | cp_async_test(); 67 | cp_async_test(); 68 | } 69 | --------------------------------------------------------------------------------