├── .gitignore
├── .gitlab-ci.yml
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs
├── .gitkeep
├── load_vector-en.svg
├── m8n8k4.md
├── make_direct_product_fragments-en.svg
├── make_eye-en.svg
├── make_identity-en.svg
├── mma.md
├── mma_f32.md
├── ops.md
├── store_vector-en.svg
├── utils.md
└── wmmae.svg
├── include
└── wmma_extension
│ ├── detail
│ ├── common.hpp
│ ├── m16n8k16.hpp
│ ├── m16n8k16_imma.hpp
│ ├── m16n8k32_imma.hpp
│ ├── m16n8k8.hpp
│ ├── m16n8k8_tf32.hpp
│ ├── m8n8k4.hpp
│ ├── operators.hpp
│ ├── sm_70.hpp
│ ├── sm_75.hpp
│ ├── sm_80.hpp
│ └── sm_80_tf32.hpp
│ ├── operators.hpp
│ ├── tcec
│ ├── complex.hpp
│ ├── detail
│ │ ├── common.hpp
│ │ ├── functions.hpp
│ │ ├── functions_simt.hpp
│ │ ├── no_cor.hpp
│ │ ├── notc.hpp
│ │ ├── policy.hpp
│ │ ├── policy_simt.hpp
│ │ ├── print.hpp
│ │ ├── scale.hpp
│ │ ├── simt
│ │ │ ├── detail
│ │ │ │ ├── common.hpp
│ │ │ │ ├── fma.hpp
│ │ │ │ └── m16n16k16.hpp
│ │ │ └── mma_simt.hpp
│ │ ├── wmma_extension_include.hpp
│ │ └── wmma_extension_simt_include.hpp
│ └── tcec.hpp
│ ├── utils.hpp
│ ├── wmma_extension.hpp
│ └── wmma_mma.hpp
├── research
├── bank-conflict
│ ├── Makefile
│ ├── README.md
│ └── main.cu
├── common
│ └── utils.hpp
├── fragment_analysis
│ ├── Makefile
│ └── main.cu
├── fragment_analysis_ij
│ ├── Makefile
│ └── main.cu
└── fragment_analysis_map
│ ├── Makefile
│ └── main.cu
└── test
├── performance
├── Makefile.common
├── batched_m8n8k4
│ ├── Makefile
│ └── batched_m8n8k4.cu
├── foreach
│ ├── Makefile
│ └── matmul.cu
├── givens_rotation
│ ├── Makefile
│ └── givens.cu
├── householder
│ ├── Makefile
│ └── householder.cu
├── load_matrix_with_op_sync
│ ├── Makefile
│ └── matmul.cu
├── load_vector_sync
│ ├── Makefile
│ ├── batched_direct_product.cu
│ └── direct_product.cu
└── make_identity
│ ├── Makefile
│ └── batched_householder.cu
├── primitive
├── Makefile
├── add_eye.cu
├── common.hpp
├── direct_product.cu
├── fill.cu
├── foreach.cu
├── foreach_ij.cu
├── foreach_v.cu
├── foreach_v_acc.cu
├── gevm.cu
├── map.cu
├── mma.cu
├── operators.cu
├── print_fragment.cu
├── vector.cu
├── wmma.load_vector.cu
└── wmma.store_vector.cu
├── tcec
├── Makefile
├── batch_gemm.cu
├── elementwise.cu
├── matvec.cu
├── mma.cu
├── mma_complex.cu
├── utils.hpp
└── vector.cu
└── utils
├── Makefile
├── cast.cu
└── cp_async.cu
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/c++,cuda,vim,python
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=c++,cuda,vim,python
3 |
4 | ### C++ ###
5 | # Prerequisites
6 | *.d
7 |
8 | # Compiled Object files
9 | *.slo
10 | *.lo
11 | *.o
12 | *.obj
13 |
14 | # Precompiled Headers
15 | *.gch
16 | *.pch
17 |
18 | # Compiled Dynamic libraries
19 | *.so
20 | *.dylib
21 | *.dll
22 |
23 | # Fortran module files
24 | *.mod
25 | *.smod
26 |
27 | # Compiled Static libraries
28 | *.lai
29 | *.la
30 | *.a
31 | *.lib
32 |
33 | # Executables
34 | *.exe
35 | *.out
36 | *.app
37 |
38 | ### CUDA ###
39 | *.i
40 | *.ii
41 | *.gpu
42 | *.ptx
43 | *.cubin
44 | *.fatbin
45 |
46 | ### Python ###
47 | # Byte-compiled / optimized / DLL files
48 | __pycache__/
49 | *.py[cod]
50 | *$py.class
51 |
52 | # C extensions
53 |
54 | # Distribution / packaging
55 | .Python
56 | build/
57 | develop-eggs/
58 | dist/
59 | downloads/
60 | eggs/
61 | .eggs/
62 | lib/
63 | lib64/
64 | parts/
65 | sdist/
66 | var/
67 | wheels/
68 | share/python-wheels/
69 | *.egg-info/
70 | .installed.cfg
71 | *.egg
72 | MANIFEST
73 |
74 | # PyInstaller
75 | # Usually these files are written by a python script from a template
76 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
77 | *.manifest
78 | *.spec
79 |
80 | # Installer logs
81 | pip-log.txt
82 | pip-delete-this-directory.txt
83 |
84 | # Unit test / coverage reports
85 | htmlcov/
86 | .tox/
87 | .nox/
88 | .coverage
89 | .coverage.*
90 | .cache
91 | nosetests.xml
92 | coverage.xml
93 | *.cover
94 | *.py,cover
95 | .hypothesis/
96 | .pytest_cache/
97 | cover/
98 |
99 | # Translations
100 | *.mo
101 | *.pot
102 |
103 | # Django stuff:
104 | *.log
105 | local_settings.py
106 | db.sqlite3
107 | db.sqlite3-journal
108 |
109 | # Flask stuff:
110 | instance/
111 | .webassets-cache
112 |
113 | # Scrapy stuff:
114 | .scrapy
115 |
116 | # Sphinx documentation
117 | docs/_build/
118 |
119 | # PyBuilder
120 | .pybuilder/
121 | target/
122 |
123 | # Jupyter Notebook
124 | .ipynb_checkpoints
125 |
126 | # IPython
127 | profile_default/
128 | ipython_config.py
129 |
130 | # pyenv
131 | # For a library or package, you might want to ignore these files since the code is
132 | # intended to run in multiple environments; otherwise, check them in:
133 | # .python-version
134 |
135 | # pipenv
136 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
137 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
138 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
139 | # install all needed dependencies.
140 | #Pipfile.lock
141 |
142 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
143 | __pypackages__/
144 |
145 | # Celery stuff
146 | celerybeat-schedule
147 | celerybeat.pid
148 |
149 | # SageMath parsed files
150 | *.sage.py
151 |
152 | # Environments
153 | .env
154 | .venv
155 | env/
156 | venv/
157 | ENV/
158 | env.bak/
159 | venv.bak/
160 |
161 | # Spyder project settings
162 | .spyderproject
163 | .spyproject
164 |
165 | # Rope project settings
166 | .ropeproject
167 |
168 | # mkdocs documentation
169 | /site
170 |
171 | # mypy
172 | .mypy_cache/
173 | .dmypy.json
174 | dmypy.json
175 |
176 | # Pyre type checker
177 | .pyre/
178 |
179 | # pytype static type analyzer
180 | .pytype/
181 |
182 | # Cython debug symbols
183 | cython_debug/
184 |
185 | ### Vim ###
186 | # Swap
187 | [._]*.s[a-v][a-z]
188 | !*.svg # comment out if you don't need vector files
189 | [._]*.sw[a-p]
190 | [._]s[a-rt-v][a-z]
191 | [._]ss[a-gi-z]
192 | [._]sw[a-p]
193 |
194 | # Session
195 | Session.vim
196 | Sessionx.vim
197 |
198 | # Temporary
199 | .netrwhist
200 | *~
201 | # Auto-generated tag files
202 | tags
203 | # Persistent undo
204 | [._]*.un~
205 |
206 | # End of https://www.toptal.com/developers/gitignore/api/c++,cuda,vim,python
207 |
--------------------------------------------------------------------------------
/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 | precommit:
2 | image: python:3.10.2-slim-bullseye
3 | before_script:
4 | - apt update && apt install -y --no-install-recommends git
5 | - pip install pre-commit
6 | script:
7 | - pre-commit run --all-files
8 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | files: ^(.*\.(cpp|hpp|cu|md))$
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v3.2.0
7 | hooks:
8 | - id: trailing-whitespace
9 | - id: end-of-file-fixer
10 | - id: check-yaml
11 | - id: check-added-large-files
12 | - repo: https://github.com/codespell-project/codespell
13 | rev: v2.3.0
14 | hooks:
15 | - id: codespell
16 | - repo: https://github.com/pre-commit/mirrors-clang-format
17 | rev: v18.1.8
18 | hooks:
19 | - id: clang-format
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 mutsuki
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # WMMA API Extension
4 |
5 | This extension provides features for
6 | - mapping between memory and fragment (**primitive functions**)
7 | - operationf for vectors
8 | - loading a vector as a fragment
9 | - storing a fragment as a vector
10 | - C++ interface for `mma` instructions [[detail](./docs/mma.md)]
11 | - Error Correction (TCEC) for SGEMM emulation [[detail](./docs/mma_f32.md)]
12 | - arithmetic operators for fragments (`+, -, *, /, fma`) [[detail](./docs/ops.md)]
13 | - utils [[detail](./docs/utils.md)]
14 | - etc
15 |
16 | without using extra shared memory.
17 |
18 | > [!IMPORTANT]
19 | > Please specify an appropriate virtual architecture for real GPU.
20 | > For instance, a program which is compiled with `-arch=sm_70` will not work correctly on Ampere GPUs.
21 |
22 | ## Requirements
23 | - CUDA (10.2 or later)
24 | - C++ (17 or later)
25 |
26 | ## Supported architectures / fragment
27 | - [x] sm_70: ((16, 16, 16), fp16/fp32)
28 | - [x] sm_75: ((16, 16, 16), fp16/fp32)
29 | - [x] sm_80: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
30 | - [x] sm_89: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32)
31 | - [x] sm_90: ((16, 16, 16), fp16/fp32), ((16, 16, 8), tf32/fp32) (`wgmma` instruction is not supported yet)
32 |
33 | # Functions
34 | ## Primitive functions
35 | ### foreach
36 | This function calculates the mapping of the memory and fragment elements.
37 | ```cuda
38 | nvcuda::wmma::fragment frag_b;
39 | __shared__ compute_t matrix[16 * 16];
40 | mtk::wmma::foreach(
41 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
42 | const auto m = mem_index % 16;
43 | const auto n = mem_index / 16;
44 | for (unsigned i = 0; i < fragment_index_count; i++)
45 | frag_b.x[frag_index_list[i]] = convert_to(matrix[n * 16 + m]);
46 | });
47 | ```
48 |
49 | ### foreach_ij
50 | This function calculates the mapping of the matrix element position (i,j) and fragment elements.
51 | ```cuda
52 | nvcuda::wmma::fragment frag_b;
53 | __shared__ compute_t matrix[16 * 16];
54 | mtk::wmma::foreach_ij(
55 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned i, const unsigned j) {
56 | for (unsigned f = 0; f < fragment_index_count; f++)
57 | frag_b.x[frag_index_list[f]] = convert_to(matrix[j * 16 + i]);
58 | });
59 | ```
60 |
61 | ### foreach_v
62 | #### For matrix A/B
63 | This function calculates the mapping of a given vector and fragment elements.
64 | ```cuda
65 | nvcuda::wmma::fragment frag_b;
66 | __shared__ compute_t vector[16];
67 | mtk::wmma::foreach_v(
68 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
69 | for (unsigned i = 0; i < fragment_index_count; i++)
70 | frag_b.x[frag_index_list[i]] = convert_to(vector[mem_index]);
71 | });
72 | // is equivalent to `load_vector`
73 | ```
74 |
75 | #### For accumulator
76 | ```cuda
77 | nvcuda::wmma::fragment frag_c;
78 | __shared__ compute_t vector[16];
79 | mtk::wmma::foreach_v(nvcuda::wmma::mem_col_major,
80 | [&](const unsigned* frag_index_list, const unsigned fragment_index_count, const unsigned mem_index) {
81 | for (unsigned i = 0; i < fragment_index_count; i++)
82 | vector[mem_index] = convert_to(frag_c.x[frag_index_list[i]]);
83 | });
84 | // is equivalent to `store_vector`
85 | ```
86 |
87 | ### map
88 | This function returns the mapping of matrix element (i, j) and fragment element (tid, fid)
89 | ```cuda
90 | nvcuda::wmma::fragment frag_b;
91 | unsigned tid_list[2];
92 | unsigned fid_list[2];
93 | unsigned list_size;
94 | mtk::wmma::map(tid_list, fid_list, list_size, i, j);
95 | for (unsigned k = 0; k < list_size; k++) {
96 | if ((threadIdx.x & 0x1f) == tid_list[k]) {
97 | frag_b.x[fid_list[k]] = 3.0f;
98 | }
99 | }
100 | ```
101 |
102 |
103 | ## Functions for vector
104 | ## Sample
105 | ```cuda
106 | #include
107 | #include
108 |
109 | __global__ void kernel() {
110 | nvcuda::wmma::fragment frag_a;
111 | nvcuda::wmma::fragment frag_b;
112 | nvcuda::wmma::fragment frag_c;
113 |
114 | __shared__ float vec16[16];
115 |
116 | mtk::wmma::load_vector(frag_a, vec16);
117 | mtk::wmma::load_vector(frag_b, vec16);
118 |
119 | nvcuda::wmma::fill_fragment(frag_c, 0.0f);
120 | nvcuda::wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);
121 |
122 | mtk::wmma::store_vector(vec16, frag_c, nvcuda::wmma::mem_col_major);
123 | }
124 | ```
125 |
126 | ## Other functions
127 | ### make_identity_matrix / add_eye
128 | 
129 | - Arguments
130 | - dst_fragment : Destination fragment (`accumulator`)
131 | - alpha : diagonal element
132 |
133 | ### fill_zero
134 | - Argument
135 | - dst_fragment : Destination fragment
136 |
137 | ## Debugging functions
138 |
139 | #### print_fragment
140 | This function output the elements of a fragment.
141 | - Arguments
142 | - frag : Target fragment
143 | - name : printing name of fragment (`char*`, optional)
144 |
145 | # Publication
146 | ```bibtex
147 | @inproceedings{ootomo_wmmae_2023,
148 | author = {Ootomo, Hiroyuki and Yokota, Rio},
149 | title = {Reducing Shared Memory Footprint to Leverage High Throughput on Tensor Cores and Its Flexible API Extension Library},
150 | year = {2023},
151 | series = {HPC Asia '23}
152 | }
153 | ```
154 |
155 | # LICENSE
156 | MIT
157 |
--------------------------------------------------------------------------------
/docs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wmmae/wmma_extension/a9e6b34e40a884cbdd3436f9d5f870bb6085bf6c/docs/.gitkeep
--------------------------------------------------------------------------------
/docs/load_vector-en.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/docs/m8n8k4.md:
--------------------------------------------------------------------------------
1 | # m8n8k4
2 |
3 | CUDA provides an experimental PTX instruction `mma.sync.aligned.m8n8k4` which computes `(m, n, k) = (8, 8, 4)` matrix FMA.
4 | This library provides its C++ interface.
5 |
6 | ## Sample
7 | ```cpp
8 | constexpr unsigned M = 8;
9 | constexpr unsigned M = 8;
10 | constexpr unsigned M = 4;
11 |
12 | __global__ void m8n8k4_test_kernel(float* const d, const half* const a, const half* const b, const float* const c) {
13 | mtk::wmma::fragment frag_a;
14 | mtk::wmma::fragment frag_b;
15 | mtk::wmma::fragment frag_c;
16 | mtk::wmma::fragment frag_d;
17 |
18 | mtk::wmma::load_matrix_sync(frag_a, a, M);
19 | mtk::wmma::load_matrix_sync(frag_b, b, K);
20 | mtk::wmma::load_matrix_sync(frag_c, c, M, nvcuda::wmma::mem_col_major);
21 |
22 | mtk::wmma::mma_sync(frag_d, frag_a, frag_b, frag_c);
23 |
24 | mtk::wmma::store_matrix_sync(d, frag_d, M, nvcuda::wmma::mem_col_major);
25 | }
26 | ```
27 |
--------------------------------------------------------------------------------
/docs/mma.md:
--------------------------------------------------------------------------------
1 | # C++ interface of `mma` instructions
2 |
3 | ```cpp
4 | #include
5 |
6 | __global__ void kernel(float* const d, const half* const a, const half* const b, const float* const c) {
7 | mtk::wmma::mma::fragment frag_a;
8 | mtk::wmma::mma::fragment frag_b;
9 | mtk::wmma::mma::fragment frag_c;
10 | mtk::wmma::mma::fragment frag_d;
11 |
12 | mtk::wmma::mma::load_matrix_sync(frag_a, a, 16);
13 | mtk::wmma::mma::load_matrix_sync(frag_b, b, 8);
14 | mtk::wmma::mma::load_matrix_sync(frag_c, c, 16, nvcuda::wmma::mem_col_major);
15 |
16 | mtk::wmma::mma::mma_sync(frag_d, frag_a, frag_b, frag_c);
17 |
18 | mtk::wmma::mma::store_matrix_sync(d, frag_d, 16, nvcuda::wmma::mem_col_major);
19 | }
20 | ```
21 |
22 | ## Supported fragments
23 |
24 | | shape | A,B type | C, D type | arch |
25 | |:-------- |:-------------------- |:-------------------- |:--------------- |
26 | | m16n8k16 | `half` | `float` / `half` | sm_80 or higher |
27 | | m16n8k8 | `half` | `float` / `half` | sm_75 or higher |
28 | | m16n8k8 | `nvcuda::wmma::tf32` | `float` | sm_80 or higher |
29 | | m8n8k4 | `half` | `float` / `half` | sm_70, sm_75 |
30 | | m16n8k16 | `int8` / `uint8` | `int32` | sm_80 or higher |
31 | | m16n8k32 | `int8` / `uint8` | `int32` | sm_80 or higher |
32 |
33 | ## Supported functions
34 | - `foreach`
35 | - `foreach_ij`
36 | - `load_matrix_sync`
37 | - `store_matrix_sync`
38 | - `fill_fragment`
39 | - `fill_zero`
40 |
--------------------------------------------------------------------------------
/docs/mma_f32.md:
--------------------------------------------------------------------------------
1 | # WMMA Extension for single precision matmul using Tensor Cores and error correction technique (TCEC)
2 |
3 | ## Error correction technique
4 | See [our paper](https://arxiv.org/abs/2203.03341).
5 |
6 | ## Requirements
7 | - CUDA
8 | - CUDA >= 10.0 for HMMA-FP16
9 | - CUDA >= 11.1 for HMMA-TF32
10 |
11 | - C++ >= 14
12 |
13 | ## Installation
14 | 1. Clone [wmma_extension](https://github.com/wmmae/wmma_extension)
15 | ```bash
16 | git clone https://github.com/wmmae/wmma_extension
17 | ```
18 |
19 | ## Sample code
20 | ```cuda
21 | // sample.cu
22 | // nvcc -I./path/to/wmma_extension/include/ -std=c++17 sample.cu ...
23 | //
24 | #include
25 |
26 | template
27 | __global__ void mma_kernel(float* const d_ptr, const float* const a_ptr, const float* const b_ptr, const float* const c_ptr) {
28 | __shared__ float smem[N * N];
29 | fill_zero(smem, N * N);
30 |
31 | mtk::wmma::tcec::fragment frag_a;
32 | mtk::wmma::tcec::fragment frag_b;
33 | mtk::wmma::tcec::fragment frag_c, frag_d;
34 |
35 | // Load A
36 | // copy_matrix(smem, N, a_ptr, N, N, N);
37 | mtk::wmma::tcec::load_matrix_sync(frag_a, smem, N);
38 |
39 | // Load B
40 | // copy_matrix(smem, N, b_ptr, N, N, N);
41 | mtk::wmma::tcec::load_matrix_sync(frag_b, smem, N);
42 |
43 | // Load C
44 | // copy_matrix(smem, N, c_ptr, N, N, N);
45 | mtk::wmma::tcec::load_matrix_sync(frag_c, smem, N, nvcuda::wmma::mem_col_major);
46 |
47 | // Fill D
48 | mtk::wmma::tcec::fill_fragment(frag_d, 0.0f);
49 |
50 | // mma
51 | mtk::wmma::tcec::mma_sync(frag_d, frag_a, frag_b, frag_c);
52 |
53 | // Store D
54 | mtk::wmma::tcec::store_matrix_sync(smem, frag_d, N, nvcuda::wmma::mem_col_major);
55 | //copy_matrix(d_ptr, N, smem, N, N, N);
56 | }
57 | ```
58 |
59 | ## Fragment
60 | ```cpp
61 | template ::type>
62 | struct fragment;
63 | ```
64 |
65 | ### Template arguments
66 | `mtk::wmma::tcec::fragment` is a fragment for this computation.
67 | It contains arrays of `nvcuda::wmma::fragment`.
68 | - `m`, `n` and `k` have to be a multiple of `Policy::m`, `Policy::n` and `Policy::k` respectively.
69 | You can get a default policy using `mtk::wmma::tcec::default_policy::type`.
70 | - `k` has to be a multiple of 16 when `T` is `half` and 8 when `T` is `nvcuda::wmma::precision::tf32`.
71 | - `T` is `half` or `nvcuda::wmma::precision::tf32`. Unlike `nvcuda::wmma::fragment`, even if `Use` is `nvcuda::wmma::accumulator`, the same is true.
72 | - `Policy` is a concept of `mtk::wmma::tcec::Policy`.
73 | - `Op` : `mtk::wmma::tcec::op_mma` / `mtk::wmma::tcec::op_wmma`
74 | - `ErrorCorrection` : `mtk::wmma::tcec::with_ec` / `mtk::wmma::tcec::without_ec`
75 | - `fm`, `fn`, `fk` is a size of internal fragments.
76 |
77 | ### Policy
78 | `default_policy` can make `Policy` easily.
79 | ```cuda
80 | using policy = mtk::wmma::tcec::default_policy::type;
81 | ```
82 |
83 | ## Supported fragment
84 |
85 | | fm | fn | fk | LayoutA | LayoutB | Type | Operation | Supported arch |
86 | | -- | -- | -- | ------- | ------- | ----- | -------------- | ---------------|
87 | | 16 | 16 | 16 | col/row | col/row | half | Arch dependent | sm_70 or later |
88 | | 16 | 16 | 16 | col/row | col/row | tf32 | wmma | sm_80 or later |
89 | | 16 | 8 | 8 | row | col | tf32 | mma | sm_80 or later |
90 | | 16 | 8 | 8 | row | col | half | mma | sm_75 or later |
91 | | 16 | 8 | 16 | row | col | half | mma | sm_80 or later |
92 |
93 | ### Note
94 | To get default policy for `sm_75` and `op_mma`, specify the architecture as follows:
95 | ```cuda
96 | using policy = mtk::wmma::tcec::default_policy::type;
97 | ```
98 |
99 | ### Member variables/functions
100 | - Member variable `element_type` is `float`
101 | - Member function `x(index)` and `dx(index)` return the reference of a elements.
102 |
103 | ## Functions
104 | - `mtk::wmma::tcec::fill_fragment`
105 | - `mtk::wmma::tcec::load_matrix_sync`
106 | - `mtk::wmma::tcec::store_matrix_sync`
107 | - `mtk::wmma::tcec::mma_sync`
108 |
109 | - `mtk::wmma::tcec::mma_rz_sync`
110 | - `mtk::wmma::tcec::load_vector`
111 | - `mtk::wmma::tcec::store_vector`
112 | - `mtk::wmma::tcec::fill_zero`
113 |
114 | ### Note
115 | While some `fragment` only supports either `row` or `col`, `load_matrix_sync` function can load both memory layout matrices using an additional template parameter.
116 |
117 | ```cpp
118 | // e.g.
119 | using policy = mtk::wmma::tcec::default_policy::type;
120 | mtk::wmma::tcec::fragment frag_a;
121 |
122 | mtk::wmma::tcec::load_matrix_sync(frag_a, matrix_ptr, ldm);
123 | ```
124 |
125 |
126 | ## Rounding mode
127 | To specify the rounding mode in `+C` operation, use functions as follows.
128 | - `mtk::wmma::tcec::mma_rn_sync`
129 | - `mtk::wmma::tcec::mma_rz_sync`
130 |
131 | ### Default rounding mode
132 | | op | rounding mode |
133 | | ---------- | ------------- |
134 | | with_ec | RN |
135 | | without_ec | RZ |
136 |
137 | Read [our paper](https://arxiv.org/abs/2203.03341) for detail.
138 |
139 | ## SIMT Core computation
140 |
141 | This library provides fragments and functionf for mma operations using CUDA SIMT Core with the same API as WMMA API.
142 |
143 | | fm | fn | fk | LayoutA | LayoutB | Type | Operation | Supported arch |
144 | | -- | -- | -- | ------- | ------- | ----- | -------------- | ---------------|
145 | | 16 | 16 | 16 | col/row | col/row | float | simt | sm_70 or later |
146 |
147 | ### Policy
148 | ```cuda
149 | using simt_policy = typename mtk::wmma::tcec::default_policy::type;
150 |
151 | mtk::wmma::tcec::fragment frag_a;
152 | ```
153 |
154 | ## Complex type
155 | ```cuda
156 | mtk::wmma::tcec::fragment_complex frag_a;
157 | // or
158 | using policy = typename mtk::wmma::tcec::default_policy::type;
159 | mtk::wmma::tcec::fragment_complex frag_a;
160 | ```
161 |
162 | ### Supported functions
163 | - `mtk::wmma::tcec::fill_fragment`
164 | - `mtk::wmma::tcec::load_matrix_sync`
165 | - `mtk::wmma::tcec::store_matrix_sync`
166 | - `mtk::wmma::tcec::mma_sync`
167 |
168 | - `mtk::wmma::tcec::mma_rz_sync`
169 | - `mtk::wmma::tcec::fill_zero`
170 |
171 | See [test code](../test/tcec/mma_complex.cu) for more detail.
172 |
--------------------------------------------------------------------------------
/docs/ops.md:
--------------------------------------------------------------------------------
1 | # Arithmetic operators for fragments
2 |
3 | ## Supported operators
4 |
5 | | op | A type | B type | C type |
6 | |:----:|:------:|:------:|:------:|
7 | | `+` | `fragment` | `fragment` ||
8 | | `-` | `fragment` | `fragment` ||
9 | | `*` | `fragment` | `fragment::storage_element_t` ||
10 | | `/` | `fragment` | `fragment::storage_element_t` ||
11 | | `mtk::wmma::fma` | `fragment` | `fragment::storage_element_t` | `fragment` |
12 | | `mtk::wmma::fma` | `fragment::storage_element_t` | `fragment` | `fragment` |
13 |
14 | ## Example
15 |
16 | ```cpp
17 | #include
18 |
19 | nvcuda::wmma::fragment frag_a0, frag_a1;
20 |
21 | const auto frag_a0 = frag_a0 + frag_a1 * __float2half(2.0f);
22 | ```
23 |
--------------------------------------------------------------------------------
/docs/store_vector-en.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/docs/utils.md:
--------------------------------------------------------------------------------
1 | # Utils
2 |
3 | ## Type conversion
4 | ```cpp
5 | const auto dst_val = mtk::wmma::utils::cast(src_val);
6 | ```
7 |
8 | ## Asynchronous D2S data copy
9 | ```cpp
10 | mtk::wmma::utils::cp_async::cp_async(dst_ptr, src_ptr);
11 | mtk::wmma::utils::cp_async::commit();
12 | mtk::wmma::utils::cp_async::wait_group();
13 | mtk::wmma::utils::cp_async::wait_all();
14 | ```
15 |
16 | - `N` is data size in byte. (4, 8, 16)
17 |
--------------------------------------------------------------------------------
/docs/wmmae.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/include/wmma_extension/detail/common.hpp:
--------------------------------------------------------------------------------
1 | #ifndef __WMMAE_DETAIL_COMMON__
2 | #define __WMMAE_DETAIL_COMMON__
3 | #include
4 | #include
5 | #include
6 |
7 | #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
8 | namespace nvcuda {
9 | namespace wmma {
10 | namespace precision {
11 | class tf32;
12 | } // namespace precision
13 | } // namespace wmma
14 | } // namespace nvcuda
15 | #endif
16 |
17 | namespace mtk {
18 | namespace wmma {
19 |
20 | namespace detail {
21 | namespace common {
22 | template struct storage_t {
23 | using type = T;
24 | };
25 | template
26 | inline __device__ __host__ typename storage_t::type cast(const float v) {
27 | return static_cast::type>(v);
28 | }
29 | template
30 | inline __device__ __host__ typename storage_t::type cast(const half v) {
31 | return static_cast::type>(v);
32 | }
33 |
34 | template <> struct storage_t {
35 | using type = float;
36 | };
37 | __device__ __host__ inline float to_tf32(const float a) {
38 | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
39 | float ret;
40 | asm("{.reg .b32 %mr;\n"
41 | "cvt.rna.tf32.f32 %mr, %1;\n"
42 | "mov.b32 %0, %mr;}\n"
43 | : "=f"(ret)
44 | : "f"(a));
45 | return ret;
46 | #else
47 | return a;
48 | #endif
49 | }
50 | template <>
51 | inline __device__ __host__
52 | typename storage_t::type
53 | cast(const float v) {
54 | return to_tf32(v);
55 | }
56 | template <>
57 | inline __device__ __host__
58 | typename storage_t::type
59 | cast(const half v) {
60 | return to_tf32(__half2float(v));
61 | }
62 |
63 | inline __device__ unsigned get_lane_id() {
64 | unsigned lane_id;
65 | asm(R"({mov.s32 %0, %laneid;})" : "=r"(lane_id));
66 | return lane_id;
67 | }
68 |
69 | template struct get_M;
70 | template struct get_M {
71 | static const int value = M;
72 | };
73 | template struct get_M {
74 | static const int value = K;
75 | };
76 | template
77 | struct get_M {
78 | static const int value = M;
79 | };
80 |
81 | template struct get_N;
82 | template struct get_N {
83 | static const int value = K;
84 | };
85 | template struct get_N {
86 | static const int value = N;
87 | };
88 | template
89 | struct get_N {
90 | static const int value = N;
91 | };
92 |
93 | template struct layout_switch;
94 | template
95 | struct layout_switch {
96 | static const int value = col_value;
97 | };
98 | template
99 | struct layout_switch {
100 | static const int value = row_value;
101 | };
102 |
103 | } // namespace common
104 |
105 | template struct fill_zero_core;
106 |
107 | template struct fill_zero_core<2, T> {
108 | __device__ void operator()(T *const ptr) {
109 | *reinterpret_cast(ptr) = 0;
110 | }
111 | };
112 |
113 | template struct fill_zero_core<4, T> {
114 | __device__ void operator()(T *const ptr) {
115 | *reinterpret_cast(ptr) = 0;
116 | }
117 | };
118 |
119 | template struct fill_zero_core<8, T> {
120 | __device__ void operator()(T *const ptr) {
121 | *reinterpret_cast(ptr) = make_int2(0, 0);
122 | }
123 | };
124 |
125 | template struct fill_zero_core<16, T> {
126 | __device__ void operator()(T *const ptr) {
127 | *reinterpret_cast(ptr) = make_int4(0, 0, 0, 0);
128 | }
129 | };
130 |
131 | template struct fill_zero_core<32, T> {
132 | __device__ void operator()(T *const ptr) {
133 | *(reinterpret_cast(ptr) + 0) = make_int4(0, 0, 0, 0);
134 | *(reinterpret_cast(ptr) + 1) = make_int4(0, 0, 0, 0);
135 | }
136 | };
137 |
138 | template struct fill_zero_core<64, T> {
139 | __device__ void operator()(T *const ptr) {
140 | *(reinterpret_cast(ptr) + 0) = make_int4(0, 0, 0, 0);
141 | *(reinterpret_cast(ptr) + 1) = make_int4(0, 0, 0, 0);
142 | *(reinterpret_cast(ptr) + 2) = make_int4(0, 0, 0, 0);
143 | *(reinterpret_cast(ptr) + 3) = make_int4(0, 0, 0, 0);
144 | }
145 | };
146 |
147 | template struct size_of {
148 | static constexpr unsigned value = 0;
149 | };
150 | template <> struct size_of {
151 | static constexpr unsigned value = 1;
152 | };
153 | template <> struct size_of {
154 | static constexpr unsigned value = 1;
155 | };
156 | template <> struct size_of {
157 | static constexpr unsigned value = 4;
158 | };
159 | template <> struct size_of {
160 | static constexpr unsigned value = 2;
161 | };
162 | template <> struct size_of {
163 | static constexpr unsigned value = 4;
164 | };
165 | } // namespace detail
166 |
167 | namespace mma {
168 | template struct __align__(4) __frag_base {
169 | T x[size];
170 | enum { num_elements = size };
171 | };
172 |
173 | template
174 | __device__ inline void fill_fragment(__frag_base &f, const T v) {
175 | #pragma unroll
176 | for (unsigned i = 0; i < f.num_elements; i++)
177 | f.x[i] = v;
178 | }
179 | template
180 | __device__ inline void fill_fragment(__frag_base &f, const T v) {
181 | #pragma unroll
182 | for (unsigned i = 0; i < f.num_elements; i++)
183 | f.x[i] = v;
184 | }
185 | template
186 | __device__ inline void fill_fragment(__frag_base &f, const T v) {
187 | #pragma unroll
188 | for (unsigned i = 0; i < f.num_elements; i++)
189 | f.x[i] = v;
190 | }
191 | template
192 | __device__ inline void fill_fragment(__frag_base &f, const T v) {
193 | #pragma unroll
194 | for (unsigned i = 0; i < f.num_elements; i++)
195 | f.x[i] = v;
196 | }
197 | template
198 | __device__ inline void fill_fragment(__frag_base &f, const T v) {
199 | #pragma unroll
200 | for (unsigned i = 0; i < f.num_elements; i++)
201 | f.x[i] = v;
202 | }
203 | template
204 | __device__ inline void fill_fragment(__frag_base &f,
205 | const T v) {
206 | #pragma unroll
207 | for (unsigned i = 0; i < f.num_elements; i++)
208 | f.x[i] = v;
209 | }
210 | template
211 | __device__ inline void fill_fragment(__frag_base &f,
212 | const T v) {
213 | #pragma unroll
214 | for (unsigned i = 0; i < f.num_elements; i++)
215 | f.x[i] = v;
216 | }
217 | template
218 | __device__ inline void fill_fragment(__frag_base &f,
219 | const T v) {
220 | #pragma unroll
221 | for (unsigned i = 0; i < f.num_elements; i++)
222 | f.x[i] = v;
223 | }
224 | template
225 | __device__ inline void fill_fragment(__frag_base &f,
226 | const T v) {
227 | #pragma unroll
228 | for (unsigned i = 0; i < f.num_elements; i++)
229 | f.x[i] = v;
230 | }
231 | template
232 | __device__ inline void fill_fragment(__frag_base &f,
233 | const T v) {
234 | #pragma unroll
235 | for (unsigned i = 0; i < f.num_elements; i++)
236 | f.x[i] = v;
237 | }
238 |
239 | template
240 | class fragment;
241 |
242 | template
243 | __device__ inline void
244 | fill_zero(mtk::wmma::mma::fragment