├── .gitignore
├── BUILD
├── LICENSE
├── Makefile
├── README.md
├── WORKSPACE
├── benchmark.py
├── gen_kernels.py
├── include
├── c_interface.h
├── kernel_headers.h
├── kernels
│ ├── hgemm_128x128x8_NN_sm_50.h
│ ├── hgemm_128x128x8_NN_sm_60.h
│ ├── hgemm_128x128x8_NN_vec_sm_50.h
│ ├── hgemm_128x128x8_NN_vec_sm_60.h
│ ├── hgemm_128x128x8_NT_sm_50.h
│ ├── hgemm_128x128x8_NT_sm_60.h
│ ├── hgemm_128x128x8_NT_vec_sm_50.h
│ ├── hgemm_128x128x8_NT_vec_sm_60.h
│ ├── hgemm_128x128x8_TN_sm_50.h
│ ├── hgemm_128x128x8_TN_sm_60.h
│ ├── hgemm_128x128x8_TN_vec_sm_50.h
│ ├── hgemm_128x128x8_TN_vec_sm_60.h
│ ├── hgemm_128x128x8_TT_sm_50.h
│ ├── hgemm_128x128x8_TT_sm_60.h
│ ├── hgemm_128x128x8_TT_vec_sm_50.h
│ ├── hgemm_128x128x8_TT_vec_sm_60.h
│ ├── hgemm_16x64x64_NN_sm_50.h
│ ├── hgemm_16x64x64_NN_sm_60.h
│ ├── hgemm_16x64x64_NN_vec_sm_50.h
│ ├── hgemm_16x64x64_NN_vec_sm_60.h
│ ├── hgemm_16x64x64_NT_sm_50.h
│ ├── hgemm_16x64x64_NT_sm_60.h
│ ├── hgemm_16x64x64_NT_vec_sm_50.h
│ ├── hgemm_16x64x64_NT_vec_sm_60.h
│ ├── hgemm_32x32x32_NN_sm_50.h
│ ├── hgemm_32x32x32_NN_sm_60.h
│ ├── hgemm_32x32x32_NN_vec_sm_50.h
│ ├── hgemm_32x32x32_NN_vec_sm_60.h
│ ├── hgemm_32x32x32_NT_sm_50.h
│ ├── hgemm_32x32x32_NT_sm_60.h
│ ├── hgemm_32x32x32_NT_vec_sm_50.h
│ ├── hgemm_32x32x32_NT_vec_sm_60.h
│ ├── hgemm_32x32x32_TN_sm_50.h
│ ├── hgemm_32x32x32_TN_sm_60.h
│ ├── hgemm_32x32x32_TN_vec_sm_50.h
│ ├── hgemm_32x32x32_TN_vec_sm_60.h
│ ├── hgemm_32x32x32_TT_sm_50.h
│ ├── hgemm_32x32x32_TT_sm_60.h
│ ├── hgemm_32x32x32_TT_vec_sm_50.h
│ ├── hgemm_32x32x32_TT_vec_sm_60.h
│ ├── hgemm_32x32x64_NT_sm_50.h
│ ├── hgemm_32x32x64_NT_sm_60.h
│ ├── hgemm_32x32x64_NT_vec_sm_50.h
│ ├── hgemm_32x32x64_NT_vec_sm_60.h
│ ├── hgemm_32x64x32_NN_sm_50.h
│ ├── hgemm_32x64x32_NN_sm_60.h
│ ├── hgemm_32x64x32_NN_vec_sm_50.h
│ ├── hgemm_32x64x32_NN_vec_sm_60.h
│ ├── sgemm_128x128x8_NN_sm_50.h
│ ├── sgemm_128x128x8_NN_sm_60.h
│ ├── sgemm_128x128x8_NN_vec_sm_50.h
│ ├── sgemm_128x128x8_NN_vec_sm_60.h
│ ├── sgemm_128x128x8_NT_sm_50.h
│ ├── sgemm_128x128x8_NT_sm_60.h
│ ├── sgemm_128x128x8_NT_vec_sm_50.h
│ ├── sgemm_128x128x8_NT_vec_sm_60.h
│ ├── sgemm_128x128x8_TN_sm_50.h
│ ├── sgemm_128x128x8_TN_sm_60.h
│ ├── sgemm_128x128x8_TN_vec_sm_50.h
│ ├── sgemm_128x128x8_TN_vec_sm_60.h
│ ├── sgemm_128x128x8_TT_sm_50.h
│ ├── sgemm_128x128x8_TT_sm_60.h
│ ├── sgemm_128x128x8_TT_vec_sm_50.h
│ ├── sgemm_128x128x8_TT_vec_sm_60.h
│ ├── sgemm_32x32x32_NN_sm_50.h
│ ├── sgemm_32x32x32_NN_sm_60.h
│ ├── sgemm_32x32x32_NN_vec_sm_50.h
│ ├── sgemm_32x32x32_NN_vec_sm_60.h
│ ├── sgemm_32x32x32_NT_sm_50.h
│ ├── sgemm_32x32x32_NT_sm_60.h
│ ├── sgemm_32x32x32_NT_vec_sm_50.h
│ ├── sgemm_32x32x32_NT_vec_sm_60.h
│ ├── sgemm_32x32x32_TN_sm_50.h
│ ├── sgemm_32x32x32_TN_sm_60.h
│ ├── sgemm_32x32x32_TN_vec_sm_50.h
│ ├── sgemm_32x32x32_TN_vec_sm_60.h
│ ├── sgemm_32x32x32_TT_sm_50.h
│ ├── sgemm_32x32x32_TT_sm_60.h
│ ├── sgemm_32x32x32_TT_vec_sm_50.h
│ └── sgemm_32x32x32_TT_vec_sm_60.h
└── static_kernel_information.h
├── lib
└── .gitignore
├── maxas
├── MaxAs
│ ├── Cubin.pm
│ ├── MaxAs.pm
│ └── MaxAsGrammar.pm
└── maxas.pl
├── openai_gemm.py
├── sass
├── hgemm_16x64x64_NN.sass
├── hgemm_16x64x64_NT.sass
├── hgemm_32x32x64_NT.sass
├── hgemm_32x64x32_NN.sass
├── xgemm_128x128x8.sass
└── xgemm_32x32x32.sass
├── src
├── c_interface.cpp
└── test.cu
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | temp/
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *,cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # IPython Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # dotenv
80 | .env
81 |
82 | # virtualenv
83 | venv/
84 | ENV/
85 |
86 | # Spyder project settings
87 | .spyderproject
88 |
89 | # Rope project settings
90 | .ropeproject
91 |
92 | #test executable
93 | test
94 |
95 | bazel-*
96 |
--------------------------------------------------------------------------------
/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//visibility:public"])
2 |
3 | cc_library(
4 | name = "openai_gemm",
5 | srcs = ["src/c_interface.cpp",
6 | "include/kernel_headers.h",
7 | "include/static_kernel_information.h",
8 | "include/kernels/hgemm_128x128x8_NN_sm_50.h",
9 | "include/kernels/hgemm_128x128x8_NN_sm_60.h",
10 | "include/kernels/hgemm_128x128x8_NN_vec_sm_50.h",
11 | "include/kernels/hgemm_128x128x8_NN_vec_sm_60.h",
12 | "include/kernels/hgemm_128x128x8_NT_sm_50.h",
13 | "include/kernels/hgemm_128x128x8_NT_vec_sm_50.h",
14 | "include/kernels/hgemm_128x128x8_NT_sm_60.h",
15 | "include/kernels/hgemm_128x128x8_NT_vec_sm_60.h",
16 | "include/kernels/hgemm_128x128x8_TN_sm_50.h",
17 | "include/kernels/hgemm_128x128x8_TN_vec_sm_50.h",
18 | "include/kernels/hgemm_128x128x8_TN_sm_60.h",
19 | "include/kernels/hgemm_128x128x8_TN_vec_sm_60.h",
20 | "include/kernels/hgemm_128x128x8_TT_sm_50.h",
21 | "include/kernels/hgemm_128x128x8_TT_vec_sm_50.h",
22 | "include/kernels/hgemm_128x128x8_TT_sm_60.h",
23 | "include/kernels/hgemm_128x128x8_TT_vec_sm_60.h",
24 | "include/kernels/hgemm_16x64x64_NN_sm_50.h",
25 | "include/kernels/hgemm_16x64x64_NN_vec_sm_50.h",
26 | "include/kernels/hgemm_16x64x64_NN_sm_60.h",
27 | "include/kernels/hgemm_16x64x64_NN_vec_sm_60.h",
28 | "include/kernels/hgemm_16x64x64_NT_sm_50.h",
29 | "include/kernels/hgemm_16x64x64_NT_vec_sm_50.h",
30 | "include/kernels/hgemm_16x64x64_NT_sm_60.h",
31 | "include/kernels/hgemm_16x64x64_NT_vec_sm_60.h",
32 | "include/kernels/hgemm_32x32x32_NN_sm_50.h",
33 | "include/kernels/hgemm_32x32x32_NN_sm_60.h",
34 | "include/kernels/hgemm_32x32x32_NN_vec_sm_50.h",
35 | "include/kernels/hgemm_32x32x32_NN_vec_sm_60.h",
36 | "include/kernels/hgemm_32x32x32_NT_sm_50.h",
37 | "include/kernels/hgemm_32x32x32_NT_vec_sm_50.h",
38 | "include/kernels/hgemm_32x32x32_NT_sm_60.h",
39 | "include/kernels/hgemm_32x32x32_NT_vec_sm_60.h",
40 | "include/kernels/hgemm_32x32x32_TN_sm_50.h",
41 | "include/kernels/hgemm_32x32x32_TN_vec_sm_50.h",
42 | "include/kernels/hgemm_32x32x32_TN_sm_60.h",
43 | "include/kernels/hgemm_32x32x32_TN_vec_sm_60.h",
44 | "include/kernels/hgemm_32x32x32_TT_sm_50.h",
45 | "include/kernels/hgemm_32x32x32_TT_vec_sm_50.h",
46 | "include/kernels/hgemm_32x32x32_TT_sm_60.h",
47 | "include/kernels/hgemm_32x32x32_TT_vec_sm_60.h",
48 | "include/kernels/hgemm_32x32x64_NT_sm_50.h",
49 | "include/kernels/hgemm_32x32x64_NT_vec_sm_50.h",
50 | "include/kernels/hgemm_32x32x64_NT_sm_60.h",
51 | "include/kernels/hgemm_32x32x64_NT_vec_sm_60.h",
52 | "include/kernels/hgemm_32x64x32_NN_sm_50.h",
53 | "include/kernels/hgemm_32x64x32_NN_vec_sm_50.h",
54 | "include/kernels/hgemm_32x64x32_NN_sm_60.h",
55 | "include/kernels/hgemm_32x64x32_NN_vec_sm_60.h",
56 | "include/kernels/sgemm_128x128x8_NN_sm_50.h",
57 | "include/kernels/sgemm_128x128x8_NN_sm_60.h",
58 | "include/kernels/sgemm_128x128x8_NN_vec_sm_50.h",
59 | "include/kernels/sgemm_128x128x8_NN_vec_sm_60.h",
60 | "include/kernels/sgemm_128x128x8_NT_sm_50.h",
61 | "include/kernels/sgemm_128x128x8_NT_vec_sm_50.h",
62 | "include/kernels/sgemm_128x128x8_NT_sm_60.h",
63 | "include/kernels/sgemm_128x128x8_NT_vec_sm_60.h",
64 | "include/kernels/sgemm_128x128x8_TN_sm_50.h",
65 | "include/kernels/sgemm_128x128x8_TN_vec_sm_50.h",
66 | "include/kernels/sgemm_128x128x8_TN_sm_60.h",
67 | "include/kernels/sgemm_128x128x8_TN_vec_sm_60.h",
68 | "include/kernels/sgemm_128x128x8_TT_sm_50.h",
69 | "include/kernels/sgemm_128x128x8_TT_vec_sm_50.h",
70 | "include/kernels/sgemm_128x128x8_TT_sm_60.h",
71 | "include/kernels/sgemm_128x128x8_TT_vec_sm_60.h",
72 | "include/kernels/sgemm_32x32x32_NN_sm_50.h",
73 | "include/kernels/sgemm_32x32x32_NN_sm_60.h",
74 | "include/kernels/sgemm_32x32x32_NN_vec_sm_50.h",
75 | "include/kernels/sgemm_32x32x32_NN_vec_sm_60.h",
76 | "include/kernels/sgemm_32x32x32_NT_sm_50.h",
77 | "include/kernels/sgemm_32x32x32_NT_vec_sm_50.h",
78 | "include/kernels/sgemm_32x32x32_NT_sm_60.h",
79 | "include/kernels/sgemm_32x32x32_NT_vec_sm_60.h",
80 | "include/kernels/sgemm_32x32x32_TN_sm_50.h",
81 | "include/kernels/sgemm_32x32x32_TN_vec_sm_50.h",
82 | "include/kernels/sgemm_32x32x32_TN_sm_60.h",
83 | "include/kernels/sgemm_32x32x32_TN_vec_sm_60.h",
84 | "include/kernels/sgemm_32x32x32_TT_sm_50.h",
85 | "include/kernels/sgemm_32x32x32_TT_vec_sm_50.h",
86 | "include/kernels/sgemm_32x32x32_TT_sm_60.h",
87 | "include/kernels/sgemm_32x32x32_TT_vec_sm_60.h",
88 | ],
89 | hdrs = ["include/c_interface.h"],
90 | )
91 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2016 OpenAI (http://openai.com), 2016 Google Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | gen_kernels: gen_kernels.py include/kernel_headers.h $(wildcard sass/*.sass)
2 | python gen_kernels.py
3 |
4 | lib/c_interface.o: gen_kernels src/c_interface.cpp include/kernel_headers.h include/static_kernel_information.h
5 | nvcc -c src/c_interface.cpp -o lib/c_interface.o -std=c++11 -I .
6 |
7 | test: src/test.cu lib/c_interface.o
8 | nvcc -o test src/test.cu lib/c_interface.o -std=c++11 -I . -lcuda
9 |
10 | clean:
11 | rm -f include/kernels/*
12 | rm -rf temp/
13 | rm -f lib/c_interface.o
14 | rm -f test
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Status:** Archive (code is provided as-is, no updates expected)
2 |
3 | # openai-gemm
4 | Open single and half precision gemm implementations. The main speedups over cublas are with small minibatch and in fp16 data formats.
5 |
6 | ## Quick Install
7 |
8 | The demonstration code currently depends on [Nervana neon](https://github.com/NervanaSystems/neon):
9 |
10 | git clone git@github.com:NervanaSystems/neon.git
11 | cd neon
12 | make
13 | . .venv/bin/activate
14 |
15 | Clone and run this repo:
16 |
17 | git clone git@github.com:openai/openai-gemm.git
18 |
19 | Run the benchmark:
20 | ./benchmark.py
21 |
22 | Run the unit test:
23 | ./test.py
24 |
25 |
26 | ### DeepBench on Pascal TITAN X
27 | ( https://github.com/baidu-research/DeepBench )
28 |
29 |
30 | | M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16|
31 | |------|------|------|---|---------|---------|--------|---------|---------|--------|
32 | | 16| 1760| 1760| NN| 2557| 2195| 1.2| 3507| 346| 10.1|
33 | | 32| 1760| 1760| NN| 5010| 1128| 4.4| 6814| 526| 13.0|
34 | | 64| 1760| 1760| NN| 6486| 4112| 1.6| 8235| 2801| 2.9|
35 | | 128| 1760| 1760| NN| 7068| 6931| 1.0| 9400| 5307| 1.8|
36 | | 7000| 1760| 1760| NN| 9968| 9584| 1.0| 10515| 9807| 1.1|
37 | | 16| 2048| 2048| NN| 2569| 1516| 1.7| 3619| 242| 15.0|
38 | | 32| 2048| 2048| NN| 5034| 1356| 3.7| 6576| 606| 10.8|
39 | | 64| 2048| 2048| NN| 6636| 2815| 2.4| 8285| 3241| 2.6|
40 | | 128| 2048| 2048| NN| 7316| 6373| 1.1| 9066| 5334| 1.7|
41 | | 7000| 2048| 2048| NN| 10081| 9900| 1.0| 11275| 9948| 1.1|
42 | | 16| 2560| 2560| NN| 2718| 1312| 2.1| 4312| 251| 17.2|
43 | | 32| 2560| 2560| NN| 5370| 1660| 3.2| 7525| 749| 10.0|
44 | | 64| 2560| 2560| NN| 7331| 2687| 2.7| 8436| 951| 8.9|
45 | | 128| 2560| 2560| NN| 8007| 5238| 1.5| 9277| 6123| 1.5|
46 | | 7000| 2560| 2560| NN| 10282| 10131| 1.0| 11027| 9974| 1.1|
47 | | 16| 4096| 4096| NN| 2695| 1110| 2.4| 4442| 266| 16.7|
48 | | 32| 4096| 4096| NN| 5266| 2264| 2.3| 7723| 758| 10.2|
49 | | 64| 4096| 4096| NN| 6942| 3922| 1.8| 8904| 1055| 8.4|
50 | | 128| 4096| 4096| NN| 8127| 5686| 1.4| 9711| 5681| 1.7|
51 | | 7000| 4096| 4096| NN| 10462| 10082| 1.0| 11152| 9991| 1.1|
52 | | 16| 1760| 1760| NT| 1719| 1095| 1.6| 2692| 290| 9.3|
53 | | 32| 1760| 1760| NT| 3316| 1312| 2.5| 5068| 447| 11.3|
54 | | 64| 1760| 1760| NT| 5247| 1955| 2.7| 7621| 1797| 4.2|
55 | | 128| 1760| 1760| NT| 6720| 3393| 2.0| 8886| 3342| 2.7|
56 | | 7000| 1760| 1760| NT| 9341| 8513| 1.1| 10085| 9635| 1.0|
57 | | 16| 2048| 2048| NT| 2442| 1231| 2.0| 3641| 299| 12.2|
58 | | 32| 2048| 2048| NT| 4801| 1251| 3.8| 5849| 468| 12.5|
59 | | 64| 2048| 2048| NT| 6317| 1967| 3.2| 7825| 3128| 2.5|
60 | | 128| 2048| 2048| NT| 7176| 5041| 1.4| 8616| 4843| 1.8|
61 | | 7000| 2048| 2048| NT| 9975| 9173| 1.1| 10741| 9560| 1.1|
62 | | 16| 2560| 2560| NT| 1834| 1208| 1.5| 3154| 297| 10.6|
63 | | 32| 2560| 2560| NT| 3610| 1436| 2.5| 5418| 584| 9.3|
64 | | 64| 2560| 2560| NT| 6083| 2815| 2.2| 8331| 1042| 8.0|
65 | | 128| 2560| 2560| NT| 7702| 3246| 2.4| 8857| 5259| 1.7|
66 | | 7000| 2560| 2560| NT| 9257| 7829| 1.2| 10659| 9548| 1.1|
67 | | 16| 4096| 4096| NT| 2546| 1297| 2.0| 4164| 309| 13.5|
68 | | 32| 4096| 4096| NT| 4992| 2290| 2.2| 8156| 775| 10.5|
69 | | 64| 4096| 4096| NT| 6746| 4157| 1.6| 8429| 1381| 6.1|
70 | | 128| 4096| 4096| NT| 7843| 5425| 1.4| 9298| 5527| 1.7|
71 | | 7000| 4096| 4096| NT| 9925| 6879| 1.4| 10630| 9784| 1.1|
72 | | 7133| 1760| 1760| TN| 9752| 10186| 1.0| 10517| 8912| 1.2|
73 | | 7133| 2048| 2048| TN| 10485| 10319| 1.0| 10674| 9608| 1.1|
74 | | 7133| 2560| 2560| TN| 10743| 11057| 1.0| 11195| 10059| 1.1|
75 | | 7133| 4096| 4096| TN| 10384| 10290| 1.0| 10980| 10558| 1.0|
76 | | 9124| 5124| 1760| NN| 9920| 9480| 1.0| 10580| 9743| 1.1|
77 | | 9124| 5124| 2048| NN| 10008| 9415| 1.1| 10602| 9796| 1.1|
78 | | 9124| 5124| 2560| NN| 9925| 9426| 1.1| 10586| 9850| 1.1|
79 | | 9124| 5124| 4096| NN| 9982| 9489| 1.1| 10580| 9472| 1.1|
80 | | 9124| 5124| 1760| NT| 9093| 3497| 2.6| 9302| 8692| 1.1|
81 | | 9124| 5124| 2048| NT| 9506| 6512| 1.5| 9506| 8883| 1.1|
82 | | 9124| 5124| 2560| NT| 8704| 3364| 2.6| 9855| 7733| 1.3|
83 | | 9124| 5124| 4096| NT| 9733| 6109| 1.6| 10278| 8760| 1.2|
84 | | 8457| 35| 1760| NN| 3343| 1020| 3.3| 3841| 736| 5.2|
85 | | 8457| 35| 2048| NN| 3419| 1996| 1.7| 4782| 803| 6.0|
86 | | 8457| 35| 2560| NN| 3415| 1072| 3.2| 3868| 789| 4.9|
87 | | 8457| 35| 4096| NN| 3743| 2009| 1.9| 4741| 804| 5.9|
88 | | 8457| 35| 1760| NT| 3574| 1970| 1.8| 4176| 1243| 3.4|
89 | | 8457| 35| 2048| NT| 4564| 3069| 1.5| 4818| 1255| 3.8|
90 | | 8457| 35| 2560| NT| 3598| 2062| 1.7| 3597| 1135| 3.2|
91 | | 8457| 35| 4096| NT| 4311| 2990| 1.4| 4927| 1303| 3.8|
92 | | 16| 7680| 2560| NN| 2683| 718| 3.7| 4449| 289| 15.4|
93 | | 32| 7680| 2560| NN| 5304| 3660| 1.4| 7837| 979| 8.0|
94 | | 64| 7680| 2560| NN| 7311| 4979| 1.5| 9310| 1274| 7.3|
95 | | 128| 7680| 2560| NN| 7931| 6109| 1.3| 9390| 6591| 1.4|
96 | | 16| 7680| 2560| NT| 1885| 1191| 1.6| 3401| 290| 11.7|
97 | | 32| 7680| 2560| NT| 3731| 1808| 2.1| 6373| 1004| 6.3|
98 | | 64| 7680| 2560| NT| 6274| 3509| 1.8| 8809| 1655| 5.3|
99 | | 128| 7680| 2560| NT| 7957| 2988| 2.7| 9246| 4695| 2.0|
100 | | 16| 3072| 1024| NN| 2277| 1295| 1.8| 3373| 282| 12.0|
101 | | 32| 3072| 1024| NN| 4494| 1798| 2.5| 6011| 807| 7.4|
102 | | 64| 3072| 1024| NN| 6272| 3046| 2.1| 6790| 917| 7.4|
103 | | 128| 3072| 1024| NN| 7364| 5436| 1.4| 7768| 5749| 1.4|
104 | | 16| 3072| 1024| NT| 2285| 1077| 2.1| 3439| 244| 14.1|
105 | | 32| 3072| 1024| NT| 4597| 1540| 3.0| 5645| 677| 8.3|
106 | | 64| 3072| 1024| NT| 6392| 2969| 2.2| 7555| 1204| 6.3|
107 | | 128| 3072| 1024| NT| 7460| 5058| 1.5| 8586| 5535| 1.6|
108 | | 7435| 3072| 1024| TN| 9829| 8804| 1.1| 10123| 9365| 1.1|
109 | | 5481| 7680| 2560| TN| 9448| 9309| 1.0| 9466| 9394| 1.0|
110 |
111 |
112 | ### DeepBench on DGX1 (P100)
113 | Note that the OpenAI kernels do not yet implement fp16x2 instructions. Even still it seems the current cublas hgemm implentation is only good for large dimensions. There are also accuracy considerations when accumulating large reductions in fp16.
114 |
115 | | M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16|
116 | |------|------|------|---|---------|---------|--------|---------|---------|--------|
117 | | 16| 1760| 1760| NN| 2595| 2048| 1.3| 2935| 463| 6.3|
118 | | 32| 1760| 1760| NN| 4963| 864| 5.7| 5766| 895| 6.4|
119 | | 64| 1760| 1760| NN| 7565| 3909| 1.9| 7760| 1711| 4.5|
120 | | 128| 1760| 1760| NN| 8140| 6053| 1.3| 8422| 4089| 2.1|
121 | | 7000| 1760| 1760| NN| 9653| 8722| 1.1| 9617| 16143| 0.6|
122 | | 16| 2048| 2048| NN| 2255| 1746| 1.3| 3211| 546| 5.9|
123 | | 32| 2048| 2048| NN| 4467| 1012| 4.4| 4533| 1019| 4.4|
124 | | 64| 2048| 2048| NN| 6618| 4198| 1.6| 6591| 2018| 3.3|
125 | | 128| 2048| 2048| NN| 8059| 5921| 1.4| 7936| 4667| 1.7|
126 | | 7000| 2048| 2048| NN| 9761| 9346| 1.0| 9910| 18715| 0.5|
127 | | 16| 2560| 2560| NN| 2883| 2108| 1.4| 4210| 685| 6.1|
128 | | 32| 2560| 2560| NN| 5701| 1279| 4.5| 5820| 1297| 4.5|
129 | | 64| 2560| 2560| NN| 8100| 6054| 1.3| 8099| 2558| 3.2|
130 | | 128| 2560| 2560| NN| 8308| 6799| 1.2| 8790| 5901| 1.5|
131 | | 7000| 2560| 2560| NN| 9740| 9538| 1.0| 9845| 18499| 0.5|
132 | | 16| 4096| 4096| NN| 3449| 1342| 2.6| 4299| 1069| 4.0|
133 | | 32| 4096| 4096| NN| 6863| 2045| 3.4| 6907| 2103| 3.3|
134 | | 64| 4096| 4096| NN| 8404| 4059| 2.1| 8248| 4183| 2.0|
135 | | 128| 4096| 4096| NN| 8224| 8039| 1.0| 8853| 8669| 1.0|
136 | | 7000| 4096| 4096| NN| 9818| 9519| 1.0| 10011| 18588| 0.5|
137 | | 16| 1760| 1760| NT| 2579| 1324| 1.9| 2763| 428| 6.4|
138 | | 32| 1760| 1760| NT| 5089| 878| 5.8| 5382| 857| 6.3|
139 | | 64| 1760| 1760| NT| 7501| 3017| 2.5| 7695| 1695| 4.5|
140 | | 128| 1760| 1760| NT| 8043| 5494| 1.5| 8192| 3426| 2.4|
141 | | 7000| 1760| 1760| NT| 9477| 7571| 1.3| 9355| 16113| 0.6|
142 | | 16| 2048| 2048| NT| 2267| 1276| 1.8| 3171| 504| 6.3|
143 | | 32| 2048| 2048| NT| 4484| 1026| 4.4| 4489| 1009| 4.4|
144 | | 64| 2048| 2048| NT| 6567| 3986| 1.6| 6551| 2018| 3.2|
145 | | 128| 2048| 2048| NT| 8019| 5825| 1.4| 7968| 4496| 1.8|
146 | | 7000| 2048| 2048| NT| 9625| 9373| 1.0| 9713| 17878| 0.5|
147 | | 16| 2560| 2560| NT| 2870| 1460| 2.0| 4256| 638| 6.7|
148 | | 32| 2560| 2560| NT| 5614| 1299| 4.3| 5705| 1271| 4.5|
149 | | 64| 2560| 2560| NT| 8014| 4402| 1.8| 8085| 2521| 3.2|
150 | | 128| 2560| 2560| NT| 8219| 5640| 1.5| 8240| 5137| 1.6|
151 | | 7000| 2560| 2560| NT| 9534| 9091| 1.0| 9735| 18025| 0.5|
152 | | 16| 4096| 4096| NT| 3366| 1547| 2.2| 4354| 1047| 4.2|
153 | | 32| 4096| 4096| NT| 6714| 2055| 3.3| 6859| 2093| 3.3|
154 | | 64| 4096| 4096| NT| 8297| 3445| 2.4| 8289| 4178| 2.0|
155 | | 128| 4096| 4096| NT| 8335| 7450| 1.1| 7911| 7973| 1.0|
156 | | 7000| 4096| 4096| NT| 9578| 9214| 1.0| 9877| 18073| 0.5|
157 | | 7133| 1760| 1760| TN| 9704| 9267| 1.0| 9506| 15605| 0.6|
158 | | 7133| 2048| 2048| TN| 9747| 9836| 1.0| 10012| 19110| 0.5|
159 | | 7133| 2560| 2560| TN| 9742| 9748| 1.0| 9805| 19107| 0.5|
160 | | 7133| 4096| 4096| TN| 9807| 9733| 1.0| 10122| 19559| 0.5|
161 | | 9124| 5124| 1760| NN| 9326| 9076| 1.0| 9631| 17496| 0.6|
162 | | 9124| 5124| 2048| NN| 9414| 9054| 1.0| 9602| 17523| 0.5|
163 | | 9124| 5124| 2560| NN| 9353| 9041| 1.0| 9698| 17380| 0.6|
164 | | 9124| 5124| 4096| NN| 9370| 9051| 1.0| 9689| 17617| 0.5|
165 | | 9124| 5124| 1760| NT| 9124| 8746| 1.0| 9524| 16777| 0.6|
166 | | 9124| 5124| 2048| NT| 9294| 8817| 1.1| 9641| 16935| 0.6|
167 | | 9124| 5124| 2560| NT| 9221| 8499| 1.1| 9637| 16820| 0.6|
168 | | 9124| 5124| 4096| NT| 9270| 8961| 1.0| 9568| 17080| 0.6|
169 | | 8457| 35| 1760| NN| 3301| 2233| 1.5| 4505| 3154| 1.4|
170 | | 8457| 35| 2048| NN| 3265| 3066| 1.1| 4501| 3335| 1.3|
171 | | 8457| 35| 2560| NN| 3127| 2300| 1.4| 4516| 3135| 1.4|
172 | | 8457| 35| 4096| NN| 3257| 3272| 1.0| 4729| 3485| 1.4|
173 | | 8457| 35| 1760| NT| 4563| 3142| 1.5| 4612| 2998| 1.5|
174 | | 8457| 35| 2048| NT| 4554| 3202| 1.4| 4601| 3109| 1.5|
175 | | 8457| 35| 2560| NT| 4567| 3144| 1.5| 4654| 3039| 1.5|
176 | | 8457| 35| 4096| NT| 4353| 3415| 1.3| 4457| 3257| 1.4|
177 | | 16| 7680| 2560| NN| 3668| 1200| 3.1| 5020| 1236| 4.1|
178 | | 32| 7680| 2560| NN| 7245| 3385| 2.1| 7519| 2465| 3.1|
179 | | 64| 7680| 2560| NN| 8440| 5210| 1.6| 8349| 4910| 1.7|
180 | | 128| 7680| 2560| NN| 8765| 4872| 1.8| 9131| 11349| 0.8|
181 | | 16| 7680| 2560| NT| 3229| 1515| 2.1| 5032| 1157| 4.3|
182 | | 32| 7680| 2560| NT| 6640| 2721| 2.4| 6810| 2307| 3.0|
183 | | 64| 7680| 2560| NT| 8282| 5113| 1.6| 8362| 4494| 1.9|
184 | | 128| 7680| 2560| NT| 8763| 4646| 1.9| 8617| 9159| 0.9|
185 | | 16| 3072| 1024| NN| 2929| 1717| 1.7| 3335| 750| 4.4|
186 | | 32| 3072| 1024| NN| 5801| 1399| 4.1| 6116| 1420| 4.3|
187 | | 64| 3072| 1024| NN| 6958| 4340| 1.6| 6923| 2814| 2.5|
188 | | 128| 3072| 1024| NN| 8047| 6492| 1.2| 7769| 6302| 1.2|
189 | | 16| 3072| 1024| NT| 2990| 1068| 2.8| 3384| 705| 4.8|
190 | | 32| 3072| 1024| NT| 5834| 1429| 4.1| 6021| 1411| 4.3|
191 | | 64| 3072| 1024| NT| 6921| 3500| 2.0| 6893| 2819| 2.4|
192 | | 128| 3072| 1024| NT| 7918| 6034| 1.3| 7876| 5760| 1.4|
193 | | 7435| 3072| 1024| TN| 9367| 9391| 1.0| 9559| 17234| 0.6|
194 | | 5481| 7680| 2560| TN| 9672| 9520| 1.0| 9967| 18832| 0.5|
195 |
--------------------------------------------------------------------------------
/WORKSPACE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/openai-gemm/db5da9e6656f6dfa55a18df77cee2e2f95d7ee9c/WORKSPACE
--------------------------------------------------------------------------------
/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import numpy as np
4 | import pycuda.driver as drv
5 | from neon.backends.nervanagpu import NervanaGPU
6 |
7 | from openai_gemm import matmul
8 |
9 |
10 | ng = NervanaGPU()
11 | print drv.Context.get_current().get_device().name()
12 |
13 | config = (
14 | # m, n, k, AT, BT (row order)
15 | ( 16, 1760, 1760, False, False),
16 | ( 32, 1760, 1760, False, False),
17 | ( 64, 1760, 1760, False, False),
18 | ( 128, 1760, 1760, False, False),
19 | ( 7000, 1760, 1760, False, False),
20 | ( 16, 2048, 2048, False, False),
21 | ( 32, 2048, 2048, False, False),
22 | ( 64, 2048, 2048, False, False),
23 | ( 128, 2048, 2048, False, False),
24 | ( 7000, 2048, 2048, False, False),
25 | ( 16, 2560, 2560, False, False),
26 | ( 32, 2560, 2560, False, False),
27 | ( 64, 2560, 2560, False, False),
28 | ( 128, 2560, 2560, False, False),
29 | ( 7000, 2560, 2560, False, False),
30 | ( 16, 4096, 4096, False, False),
31 | ( 32, 4096, 4096, False, False),
32 | ( 64, 4096, 4096, False, False),
33 | ( 128, 4096, 4096, False, False),
34 | ( 7000, 4096, 4096, False, False),
35 | ( 16, 1760, 1760, False, True),
36 | ( 32, 1760, 1760, False, True),
37 | ( 64, 1760, 1760, False, True),
38 | ( 128, 1760, 1760, False, True),
39 | ( 7000, 1760, 1760, False, True),
40 | ( 16, 2048, 2048, False, True),
41 | ( 32, 2048, 2048, False, True),
42 | ( 64, 2048, 2048, False, True),
43 | ( 128, 2048, 2048, False, True),
44 | ( 7000, 2048, 2048, False, True),
45 | ( 16, 2560, 2560, False, True),
46 | ( 32, 2560, 2560, False, True),
47 | ( 64, 2560, 2560, False, True),
48 | ( 128, 2560, 2560, False, True),
49 | ( 7000, 2560, 2560, False, True),
50 | ( 16, 4096, 4096, False, True),
51 | ( 32, 4096, 4096, False, True),
52 | ( 64, 4096, 4096, False, True),
53 | ( 128, 4096, 4096, False, True),
54 | ( 7000, 4096, 4096, False, True),
55 | ( 7133, 1760, 1760, True , False),
56 | ( 7133, 2048, 2048, True , False),
57 | ( 7133, 2560, 2560, True , False),
58 | ( 7133, 4096, 4096, True , False),
59 | ( 9124, 5124, 1760, False, False),
60 | ( 9124, 5124, 2048, False, False),
61 | ( 9124, 5124, 2560, False, False),
62 | ( 9124, 5124, 4096, False, False),
63 | ( 9124, 5124, 1760, False, True),
64 | ( 9124, 5124, 2048, False, True),
65 | ( 9124, 5124, 2560, False, True),
66 | ( 9124, 5124, 4096, False, True),
67 | ( 8457, 35, 1760, False, False),
68 | ( 8457, 35, 2048, False, False),
69 | ( 8457, 35, 2560, False, False),
70 | ( 8457, 35, 4096, False, False),
71 | ( 8457, 35, 1760, False, True),
72 | ( 8457, 35, 2048, False, True),
73 | ( 8457, 35, 2560, False, True),
74 | ( 8457, 35, 4096, False, True),
75 | ( 16, 7680, 2560, False, False),
76 | ( 32, 7680, 2560, False, False),
77 | ( 64, 7680, 2560, False, False),
78 | ( 128, 7680, 2560, False, False),
79 | ( 16, 7680, 2560, False, True),
80 | ( 32, 7680, 2560, False, True),
81 | ( 64, 7680, 2560, False, True),
82 | ( 128, 7680, 2560, False, True),
83 | ( 16, 3072, 1024, False, False),
84 | ( 32, 3072, 1024, False, False),
85 | ( 64, 3072, 1024, False, False),
86 | ( 128, 3072, 1024, False, False),
87 | ( 16, 3072, 1024, False, True),
88 | ( 32, 3072, 1024, False, True),
89 | ( 64, 3072, 1024, False, True),
90 | ( 128, 3072, 1024, False, True),
91 | ( 7435, 3072, 1024, True , False),
92 | ( 5481, 7680, 2560, True , False),
93 |
94 | # (60000, 32, 32, True , False),
95 | # (60000, 256, 256, True , False),
96 |
97 | # ( 4096, 4096, 32, True , False),
98 | # ( 3456, 3456, 32, True , False),
99 | # ( 896, 896, 32, True , False),
100 | )
101 |
102 | print "| M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16|"
103 | print "|------|------|------|---|---------|---------|--------|---------|---------|--------|"
104 |
105 | for m, n, k, at, bt in config:
106 |
107 | dimA = (k,m) if at else (m,k)
108 | dimB = (n,k) if bt else (k,n)
109 | dimC = (m,n)
110 |
111 | opA = 'T' if at else 'N'
112 | opB = 'T' if bt else 'N'
113 | op = opA + opB
114 |
115 | dtype_data = list()
116 |
117 | for dtype in ( np.float32, np.float16 ): #np.float32, np.float16,
118 |
119 | A = ng.empty(dimA, dtype=dtype)
120 | B = ng.empty(dimB, dtype=dtype)
121 | C = ng.empty(dimC, dtype=dtype)
122 |
123 | if at: A = A.T
124 | if bt: B = B.T
125 |
126 | data = matmul(A, B, C, bench=True)
127 |
128 | # if dtype is np.float16:
129 | # print ""
130 | # for d in sorted(data):
131 | # print "%7.3f %5.0f %22s %5d" % d
132 |
133 | cublas = data.pop()
134 | openai = sorted(data)[0]
135 |
136 | text = "%9.0f|%9.0f|%8.1f" % (openai[1], cublas[1], openai[1] / cublas[1])
137 |
138 | dtype_data.append(text)
139 |
140 |
141 | print "|%6d|%6d|%6d|%3s|%s|" % (m, n, k, op, "|".join(dtype_data))
142 |
--------------------------------------------------------------------------------
/gen_kernels.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os.path
3 | import re
4 | import subprocess
5 | import sys
6 | import time
7 |
8 | base_dir = os.path.dirname(__file__)
9 | maxas_dir = os.path.join(base_dir, "maxas")
10 | sass_dir = os.path.join(base_dir, "sass")
11 |
12 | # Tile sizes: m, n, k, vA,vB,vC div, op (dynamic shared options)
13 | k128x128x8 = (128, 128, 8, 4, 4, 1, 2, 0, (0,))
14 | k32x32x32 = ( 32, 32, 32, 4, 4, 1, 4, 0, (0, 2**14))
15 | k32x64x32_NN = ( 32, 64, 32, 8, 4, 4, 4, 1, (0, 2**13))
16 | k32x32x64_NT = ( 32, 32, 64, 8, 8, 4, 4, 1, (0,))
17 | k16x64x64_NN = ( 16, 64, 64, 8, 4, 4, 4, 1, (0,))
18 | k16x64x64_NT = ( 16, 64, 64, 8, 8, 4, 4, 1, (0,))
19 |
20 | selections = {
21 | "s" : {
22 | "TN" : (k128x128x8, k32x32x32),
23 | "NN" : (k128x128x8, k32x32x32),
24 | "NT" : (k128x128x8, k32x32x32),
25 | "TT" : (k128x128x8, k32x32x32),
26 | },
27 | "h" : {
28 | "TN" : (k128x128x8, k32x32x32),
29 | "NN" : (k128x128x8, k32x32x32, k32x64x32_NN, k16x64x64_NN),
30 | "NT" : (k128x128x8, k32x32x32, k32x32x64_NT, k16x64x64_NT),
31 | "TT" : (k128x128x8, k32x32x32),
32 | },
33 | }
34 |
35 | kernels = {
36 | # Generic gemm tiles
37 | "sgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "s"} },
38 | "hgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "h"} },
39 | "sgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "s"} },
40 | "hgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "h"} },
41 |
42 | # Custom hgemm tiles designed for small minibatch RNNs
43 | "hgemm_32x64x32_NN": {"threads": 128, "sass": "hgemm_32x64x32_NN", "params": "xgemm", "share": "32*33*2 + 64*32*2 + 4" },
44 | "hgemm_32x32x64_NT": {"threads": 128, "sass": "hgemm_32x32x64_NT", "params": "xgemm", "share": "32*65*4 + 4" },
45 | "hgemm_16x64x64_NN": {"threads": 128, "sass": "hgemm_16x64x64_NN", "params": "xgemm", "share": "(16*64 + 32)*2 + 64*64*2 + 4" },
46 | "hgemm_16x64x64_NT": {"threads": 128, "sass": "hgemm_16x64x64_NT", "params": "xgemm", "share": "(16*64 + 32)*2 + (64*64 + 32)*2 + 4" },
47 | }
48 |
49 | _params = {
50 | "xgemm": [
51 | "float* param_C",
52 | "float* param_A",
53 | "float* param_B",
54 | "float param_alpha",
55 | "float param_beta",
56 | "unsigned param_cda",
57 | "unsigned param_cdb",
58 | "unsigned param_cdc",
59 | "unsigned param_m",
60 | "unsigned param_n",
61 | "unsigned param_k",
62 | "unsigned param_blk_a",
63 | "unsigned param_blk_b",
64 | ],
65 | }
66 |
67 | _space_re = re.compile(r"\s+")
68 |
69 | _share_template = r"""
70 | .shared .align 4 .b32 share[{0}];
71 | """
72 |
73 | _kernel_template = r"""
74 | .version {6}
75 | .target {0}
76 | .address_size 64
77 |
78 | // args: {5}
79 |
80 | .visible .entry {1}(
81 | {2}
82 | )
83 | .reqntid {3}
84 | {{
85 | {4}
86 | ret;
87 | }}
88 | """
89 |
90 |
91 | def _get_cache_dir(subdir=None):
92 | cache_dir = 'temp/'
93 |
94 | if subdir:
95 | subdir = subdir if isinstance(subdir, list) else [subdir]
96 | cache_dir = os.path.join(cache_dir, *subdir)
97 |
98 | if not os.path.exists(cache_dir):
99 | os.makedirs(cache_dir)
100 |
101 | return cache_dir
102 |
103 |
104 | def get_ptx_file(kernel_spec, kernel_name, arch, ptx_ver):
105 | ptx_dir = _get_cache_dir([arch, 'ptx'])
106 |
107 | thread_spec = kernel_spec["threads"]
108 | args_spec = str(kernel_spec.get("args",""))
109 | param_spec = _params[kernel_spec["params"]]
110 |
111 | kernel_params = []
112 | for p in param_spec:
113 | ptype, pname = _space_re.split(p)
114 |
115 | if ptype[-1] == '*':
116 | ptype = '.u64'
117 | elif ptype == 'float':
118 | ptype = '.f32'
119 | else:
120 | ptype = '.u32'
121 |
122 | kernel_params.append(" .param %s %s" % (ptype, pname))
123 |
124 | kernel_params = ",\n".join(kernel_params)
125 |
126 | if "share" in kernel_spec:
127 | share = _share_template.format(eval(kernel_spec["share"]))
128 | else:
129 | share = ""
130 |
131 | kernel_text = _kernel_template.format(arch, kernel_name, kernel_params, thread_spec, share, args_spec, ptx_ver)
132 | kernel_ptx = os.path.join(ptx_dir, kernel_name + ".ptx")
133 |
134 | current_text = ""
135 | if os.path.exists(kernel_ptx):
136 | f = open(kernel_ptx, "r")
137 | current_text = f.read()
138 | f.close()
139 | # only write out the kernel if text has changed.
140 | if kernel_text != current_text:
141 | f = open(kernel_ptx, "w")
142 | f.write(kernel_text)
143 | f.close()
144 |
145 | return kernel_ptx
146 |
147 |
148 | include_re = re.compile(r'^')
149 |
150 |
151 | def run_command(cmdlist):
152 | cmd = " ".join(cmdlist)
153 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
154 | out, err = proc.communicate()
155 | if proc.returncode:
156 | raise RuntimeError("Error(%d):\n%s\n%s" % (proc.returncode, cmd, err))
157 |
158 |
159 | def get_kernel(base_name, major, minor, options=None):
160 | if major < 5:
161 | raise RuntimeError("sass kernels require Maxwell or greater class hardware")
162 | elif major >= 7:
163 | raise RuntimeError("sm version 7 or greater is not supported")
164 |
165 | arch = "sm_%d%d" % (major, minor)
166 |
167 | libprefix = "PERL5LIB=%s" % maxas_dir
168 | maxas_i = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -i -w"]
169 | maxas_p = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -p"]
170 |
171 | kernel_spec = kernels[base_name]
172 | kernel_name = base_name
173 |
174 | # static options
175 | if "args" in kernel_spec:
176 | for pair in kernel_spec["args"].items():
177 | maxas_i.append("-D%s %s" % pair)
178 | maxas_p.append("-D%s %s" % pair)
179 |
180 | # dynamic options
181 | if options is not None:
182 | for opt in options:
183 | if type(opt) is tuple:
184 | maxas_i.append("-D%s %s" % opt)
185 | maxas_p.append("-D%s %s" % opt)
186 | kernel_name += "_%s%s" % opt
187 | else:
188 | maxas_i.append("-D%s 1" % opt)
189 | maxas_p.append("-D%s 1" % opt)
190 | kernel_name += "_%s" % opt
191 |
192 | maxas_i.insert(2, "-k " + kernel_name)
193 |
194 | sass_name = kernel_spec["sass"] + ".sass"
195 | cubin_name = kernel_name + ".cubin"
196 | cubin_dir = _get_cache_dir([arch, 'cubin'])
197 | header_dir = os.path.join(base_dir, "include/kernels")
198 |
199 | ptx_version = "4.2" if major < 6 else "5.0"
200 | ptx_file = get_ptx_file(kernel_spec, kernel_name, arch, ptx_version)
201 | cubin_file = os.path.join(cubin_dir, cubin_name)
202 | sass_file = os.path.join(sass_dir, sass_name)
203 | header_file = os.path.join(header_dir, kernel_name + "_" + arch + ".h")
204 |
205 | if not os.path.exists(sass_file):
206 | raise RuntimeError("Missing: %s for kernel: %s" % (sass_name, kernel_name))
207 |
208 | # build the cubin and run maxas in the same command
209 | # we don't want the chance of a generated cubin not processed by maxas (in case user hits ^C in between these steps)
210 | command_string = [ "ptxas -v -arch", arch, "-o", cubin_file, ptx_file, ";" ] + maxas_i + [sass_file, cubin_file]
211 | run_command(command_string)
212 | cubin_mtime = time.time()
213 |
214 | # now also generate the associated header file containing the cubin
215 | with open(cubin_file, 'rb') as input_file:
216 | with open(header_file, 'wb') as output_file:
217 | output_file.write('const uint8_t %s[] = {' % (kernel_name + "_" + arch))
218 | byte = input_file.read(1)
219 | count = 0
220 | while byte:
221 | if count % 12 == 0:
222 | output_file.write('\n ')
223 | output_file.write(' 0x' + byte.encode('hex') + ',')
224 | byte = input_file.read(1)
225 | count += 1
226 | output_file.write('\n};')
227 |
228 |
229 | def gen_kernels():
230 | for prefix in ['s', 'h']:
231 | for op in ['NN', 'NT', 'TN', 'TT']:
232 | for tileM, tileN, tileK, vecA, vecB, vecC, div, base_op, dyn_shared in selections[prefix][op]:
233 | for vec in [False, True]:
234 | for major, minor in [(5, 0), (6, 0)]:
235 | if base_op:
236 | # The op is part of the base kernel name
237 | base = "%sgemm_%dx%dx%d_%s" % (prefix, tileM, tileN, tileK, op)
238 | opts = ( "vec", ) if vec else ()
239 | else:
240 | # The op is an option passed to a more generic kernel
241 | base = "%sgemm_%dx%dx%d" % (prefix, tileM, tileN, tileK)
242 | opts = ( op, "vec" ) if vec else (op,)
243 |
244 | get_kernel(base, major, minor, opts)
245 |
246 |
247 | def main():
248 | gen_kernels()
249 |
250 |
251 | if __name__ == "__main__":
252 | main()
253 |
--------------------------------------------------------------------------------
/include/c_interface.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #else
7 | #include
8 | #endif
9 |
10 | #include
11 |
12 | typedef struct CUstream_st *CUstream;
13 |
14 | bool openai_sgemm(float *A, float *B, float *C,
15 | bool a_t, bool b_t,
16 | int m, int n, int k,
17 | int lda, int ldb, int ldc,
18 | float alpha, float beta,
19 | CUstream stream, unsigned int grid, unsigned int shared);
20 |
21 | bool openai_hgemm(uint16_t *A, uint16_t *B, uint16_t *C,
22 | bool a_t, bool b_t,
23 | int m, int n, int k,
24 | int lda, int ldb, int ldc,
25 | float alpha, float beta,
26 | CUstream stream, unsigned int grid, unsigned int shared);
27 |
28 | bool get_grid_limits(char precision, bool a_t, bool b_t, unsigned int *grid);
29 | bool get_shared_limits(char precision, bool a_t, bool b_t, unsigned int grid, unsigned int *shared);
30 |
31 | #ifdef __cplusplus
32 | }
33 | #endif
34 |
--------------------------------------------------------------------------------
/include/kernel_headers.h:
--------------------------------------------------------------------------------
1 | #include "include/kernels/hgemm_16x64x64_NN_sm_50.h"
2 | #include "include/kernels/hgemm_16x64x64_NN_vec_sm_50.h"
3 | #include "include/kernels/hgemm_16x64x64_NT_sm_50.h"
4 | #include "include/kernels/hgemm_16x64x64_NT_vec_sm_50.h"
5 | #include "include/kernels/hgemm_32x32x64_NT_sm_50.h"
6 | #include "include/kernels/hgemm_32x32x64_NT_vec_sm_50.h"
7 | #include "include/kernels/hgemm_32x64x32_NN_sm_50.h"
8 | #include "include/kernels/hgemm_32x64x32_NN_vec_sm_50.h"
9 | #include "include/kernels/hgemm_128x128x8_NN_sm_50.h"
10 | #include "include/kernels/hgemm_128x128x8_NT_sm_50.h"
11 | #include "include/kernels/hgemm_128x128x8_TN_sm_50.h"
12 | #include "include/kernels/hgemm_128x128x8_TT_sm_50.h"
13 | #include "include/kernels/hgemm_128x128x8_NN_vec_sm_50.h"
14 | #include "include/kernels/hgemm_128x128x8_NT_vec_sm_50.h"
15 | #include "include/kernels/hgemm_128x128x8_TN_vec_sm_50.h"
16 | #include "include/kernels/hgemm_128x128x8_TT_vec_sm_50.h"
17 | #include "include/kernels/hgemm_32x32x32_NN_sm_50.h"
18 | #include "include/kernels/hgemm_32x32x32_NT_sm_50.h"
19 | #include "include/kernels/hgemm_32x32x32_TN_sm_50.h"
20 | #include "include/kernels/hgemm_32x32x32_TT_sm_50.h"
21 | #include "include/kernels/hgemm_32x32x32_NN_vec_sm_50.h"
22 | #include "include/kernels/hgemm_32x32x32_NT_vec_sm_50.h"
23 | #include "include/kernels/hgemm_32x32x32_TN_vec_sm_50.h"
24 | #include "include/kernels/hgemm_32x32x32_TT_vec_sm_50.h"
25 | #include "include/kernels/sgemm_128x128x8_NN_sm_50.h"
26 | #include "include/kernels/sgemm_128x128x8_NT_sm_50.h"
27 | #include "include/kernels/sgemm_128x128x8_TN_sm_50.h"
28 | #include "include/kernels/sgemm_128x128x8_TT_sm_50.h"
29 | #include "include/kernels/sgemm_128x128x8_NN_vec_sm_50.h"
30 | #include "include/kernels/sgemm_128x128x8_NT_vec_sm_50.h"
31 | #include "include/kernels/sgemm_128x128x8_TN_vec_sm_50.h"
32 | #include "include/kernels/sgemm_128x128x8_TT_vec_sm_50.h"
33 | #include "include/kernels/sgemm_32x32x32_NN_sm_50.h"
34 | #include "include/kernels/sgemm_32x32x32_NT_sm_50.h"
35 | #include "include/kernels/sgemm_32x32x32_TN_sm_50.h"
36 | #include "include/kernels/sgemm_32x32x32_TT_sm_50.h"
37 | #include "include/kernels/sgemm_32x32x32_NN_vec_sm_50.h"
38 | #include "include/kernels/sgemm_32x32x32_NT_vec_sm_50.h"
39 | #include "include/kernels/sgemm_32x32x32_TN_vec_sm_50.h"
40 | #include "include/kernels/sgemm_32x32x32_TT_vec_sm_50.h"
41 |
42 | #include "include/kernels/hgemm_16x64x64_NN_sm_60.h"
43 | #include "include/kernels/hgemm_16x64x64_NN_vec_sm_60.h"
44 | #include "include/kernels/hgemm_16x64x64_NT_sm_60.h"
45 | #include "include/kernels/hgemm_16x64x64_NT_vec_sm_60.h"
46 | #include "include/kernels/hgemm_32x32x64_NT_sm_60.h"
47 | #include "include/kernels/hgemm_32x32x64_NT_vec_sm_60.h"
48 | #include "include/kernels/hgemm_32x64x32_NN_sm_60.h"
49 | #include "include/kernels/hgemm_32x64x32_NN_vec_sm_60.h"
50 | #include "include/kernels/hgemm_128x128x8_NN_sm_60.h"
51 | #include "include/kernels/hgemm_128x128x8_NT_sm_60.h"
52 | #include "include/kernels/hgemm_128x128x8_TN_sm_60.h"
53 | #include "include/kernels/hgemm_128x128x8_TT_sm_60.h"
54 | #include "include/kernels/hgemm_128x128x8_NN_vec_sm_60.h"
55 | #include "include/kernels/hgemm_128x128x8_NT_vec_sm_60.h"
56 | #include "include/kernels/hgemm_128x128x8_TN_vec_sm_60.h"
57 | #include "include/kernels/hgemm_128x128x8_TT_vec_sm_60.h"
58 | #include "include/kernels/hgemm_32x32x32_NN_sm_60.h"
59 | #include "include/kernels/hgemm_32x32x32_NT_sm_60.h"
60 | #include "include/kernels/hgemm_32x32x32_TN_sm_60.h"
61 | #include "include/kernels/hgemm_32x32x32_TT_sm_60.h"
62 | #include "include/kernels/hgemm_32x32x32_NN_vec_sm_60.h"
63 | #include "include/kernels/hgemm_32x32x32_NT_vec_sm_60.h"
64 | #include "include/kernels/hgemm_32x32x32_TN_vec_sm_60.h"
65 | #include "include/kernels/hgemm_32x32x32_TT_vec_sm_60.h"
66 | #include "include/kernels/sgemm_128x128x8_NN_sm_60.h"
67 | #include "include/kernels/sgemm_128x128x8_NT_sm_60.h"
68 | #include "include/kernels/sgemm_128x128x8_TN_sm_60.h"
69 | #include "include/kernels/sgemm_128x128x8_TT_sm_60.h"
70 | #include "include/kernels/sgemm_128x128x8_NN_vec_sm_60.h"
71 | #include "include/kernels/sgemm_128x128x8_NT_vec_sm_60.h"
72 | #include "include/kernels/sgemm_128x128x8_TN_vec_sm_60.h"
73 | #include "include/kernels/sgemm_128x128x8_TT_vec_sm_60.h"
74 | #include "include/kernels/sgemm_32x32x32_NN_sm_60.h"
75 | #include "include/kernels/sgemm_32x32x32_NT_sm_60.h"
76 | #include "include/kernels/sgemm_32x32x32_TN_sm_60.h"
77 | #include "include/kernels/sgemm_32x32x32_TT_sm_60.h"
78 | #include "include/kernels/sgemm_32x32x32_NN_vec_sm_60.h"
79 | #include "include/kernels/sgemm_32x32x32_NT_vec_sm_60.h"
80 | #include "include/kernels/sgemm_32x32x32_TN_vec_sm_60.h"
81 | #include "include/kernels/sgemm_32x32x32_TT_vec_sm_60.h"
82 |
--------------------------------------------------------------------------------
/include/static_kernel_information.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | struct kernel_properties {
4 | int tile_m;
5 | int tile_n;
6 | int tile_k;
7 | int vA;
8 | int vB;
9 | int vC;
10 | int div;
11 | bool op;
12 | int threads;
13 | std::vector shared_sizes;
14 | std::string tile_string;
15 | };
16 |
17 | kernel_properties kernel_properties_[6] = {
18 | {128, 128, 8, 4, 4, 1, 2, false, 256, {0}, "128x128x8"},
19 | {32, 32, 32, 4, 4, 1, 4, false, 128, {0, 1 << 14}, "32x32x32"},
20 | {32, 64, 32, 8, 4, 4, 4, true, 128, {0, 1 << 13}, "32x64x32"},
21 | {32, 32, 64, 8, 8, 4, 4, true, 128, {0}, "32x32x64"},
22 | {16, 64, 64, 8, 4, 4, 4, true, 128, {0}, "16x64x64"},
23 | {16, 64, 64, 8, 8, 4, 4, true, 128, {0}, "16x64x64"},
24 | };
25 |
26 | std::unordered_map>>
27 | selections = {
28 | {"s", {
29 | {"TN", {kernel_properties_[0], kernel_properties_[1]}},
30 | {"NN", {kernel_properties_[0], kernel_properties_[1]}},
31 | {"NT", {kernel_properties_[0], kernel_properties_[1]}},
32 | {"TT", {kernel_properties_[0], kernel_properties_[1]}}
33 | }
34 | },
35 | {"h", {
36 | {"TN", {kernel_properties_[0], kernel_properties_[1]}},
37 | {"NN", {kernel_properties_[0], kernel_properties_[1], kernel_properties_[2], kernel_properties_[4]}},
38 | {"NT", {kernel_properties_[0], kernel_properties_[1], kernel_properties_[3], kernel_properties_[5]}},
39 | {"TT", {kernel_properties_[0], kernel_properties_[1]}}
40 | }
41 | }
42 | };
43 |
44 | std::unordered_map kernels_60 = {
45 | {"hgemm_16x64x64_NN_sm_60", hgemm_16x64x64_NN_sm_60},
46 | {"hgemm_16x64x64_NN_vec_sm_60", hgemm_16x64x64_NN_vec_sm_60},
47 | {"hgemm_16x64x64_NT_sm_60", hgemm_16x64x64_NT_sm_60},
48 | {"hgemm_16x64x64_NT_vec_sm_60", hgemm_16x64x64_NT_vec_sm_60},
49 | {"hgemm_32x32x64_NT_sm_60", hgemm_32x32x64_NT_sm_60},
50 | {"hgemm_32x32x64_NT_vec_sm_60", hgemm_32x32x64_NT_vec_sm_60},
51 | {"hgemm_32x64x32_NN_sm_60", hgemm_32x64x32_NN_sm_60},
52 | {"hgemm_32x64x32_NN_vec_sm_60", hgemm_32x64x32_NN_vec_sm_60},
53 | {"hgemm_128x128x8_NN_sm_60", hgemm_128x128x8_NN_sm_60},
54 | {"hgemm_128x128x8_TN_sm_60", hgemm_128x128x8_TN_sm_60},
55 | {"hgemm_128x128x8_NT_sm_60", hgemm_128x128x8_NT_sm_60},
56 | {"hgemm_128x128x8_TT_sm_60", hgemm_128x128x8_TT_sm_60},
57 | {"hgemm_128x128x8_NN_vec_sm_60", hgemm_128x128x8_NN_vec_sm_60},
58 | {"hgemm_128x128x8_TN_vec_sm_60", hgemm_128x128x8_TN_vec_sm_60},
59 | {"hgemm_128x128x8_NT_vec_sm_60", hgemm_128x128x8_NT_vec_sm_60},
60 | {"hgemm_128x128x8_TT_vec_sm_60", hgemm_128x128x8_TT_vec_sm_60},
61 | {"hgemm_32x32x32_NN_sm_60", hgemm_32x32x32_NN_sm_60},
62 | {"hgemm_32x32x32_TN_sm_60", hgemm_32x32x32_TN_sm_60},
63 | {"hgemm_32x32x32_NT_sm_60", hgemm_32x32x32_NT_sm_60},
64 | {"hgemm_32x32x32_TT_sm_60", hgemm_32x32x32_TT_sm_60},
65 | {"hgemm_32x32x32_NN_vec_sm_60", hgemm_32x32x32_NN_vec_sm_60},
66 | {"hgemm_32x32x32_TN_vec_sm_60", hgemm_32x32x32_TN_vec_sm_60},
67 | {"hgemm_32x32x32_NT_vec_sm_60", hgemm_32x32x32_NT_vec_sm_60},
68 | {"hgemm_32x32x32_TT_vec_sm_60", hgemm_32x32x32_TT_vec_sm_60},
69 | {"sgemm_128x128x8_NN_sm_60", sgemm_128x128x8_NN_sm_60},
70 | {"sgemm_128x128x8_TN_sm_60", sgemm_128x128x8_TN_sm_60},
71 | {"sgemm_128x128x8_NT_sm_60", sgemm_128x128x8_NT_sm_60},
72 | {"sgemm_128x128x8_TT_sm_60", sgemm_128x128x8_TT_sm_60},
73 | {"sgemm_128x128x8_NN_vec_sm_60", sgemm_128x128x8_NN_vec_sm_60},
74 | {"sgemm_128x128x8_TN_vec_sm_60", sgemm_128x128x8_TN_vec_sm_60},
75 | {"sgemm_128x128x8_NT_vec_sm_60", sgemm_128x128x8_NT_vec_sm_60},
76 | {"sgemm_128x128x8_TT_vec_sm_60", sgemm_128x128x8_TT_vec_sm_60},
77 | {"sgemm_32x32x32_NN_sm_60", sgemm_32x32x32_NN_sm_60},
78 | {"sgemm_32x32x32_TN_sm_60", sgemm_32x32x32_TN_sm_60},
79 | {"sgemm_32x32x32_NT_sm_60", sgemm_32x32x32_NT_sm_60},
80 | {"sgemm_32x32x32_TT_sm_60", sgemm_32x32x32_TT_sm_60},
81 | {"sgemm_32x32x32_NN_vec_sm_60", sgemm_32x32x32_NN_vec_sm_60},
82 | {"sgemm_32x32x32_TN_vec_sm_60", sgemm_32x32x32_TN_vec_sm_60},
83 | {"sgemm_32x32x32_NT_vec_sm_60", sgemm_32x32x32_NT_vec_sm_60},
84 | {"sgemm_32x32x32_TT_vec_sm_60", sgemm_32x32x32_TT_vec_sm_60},
85 | };
86 |
87 | std::unordered_map kernels_50 = {
88 | {"hgemm_16x64x64_NN_sm_50", hgemm_16x64x64_NN_sm_50},
89 | {"hgemm_16x64x64_NN_vec_sm_50", hgemm_16x64x64_NN_vec_sm_50},
90 | {"hgemm_16x64x64_NT_sm_50", hgemm_16x64x64_NT_sm_50},
91 | {"hgemm_16x64x64_NT_vec_sm_50", hgemm_16x64x64_NT_vec_sm_50},
92 | {"hgemm_32x32x64_NT_sm_50", hgemm_32x32x64_NT_sm_50},
93 | {"hgemm_32x32x64_NT_vec_sm_50", hgemm_32x32x64_NT_vec_sm_50},
94 | {"hgemm_32x64x32_NN_sm_50", hgemm_32x64x32_NN_sm_50},
95 | {"hgemm_32x64x32_NN_vec_sm_50", hgemm_32x64x32_NN_vec_sm_50},
96 | {"hgemm_128x128x8_NN_sm_50", hgemm_128x128x8_NN_sm_50},
97 | {"hgemm_128x128x8_TN_sm_50", hgemm_128x128x8_TN_sm_50},
98 | {"hgemm_128x128x8_NT_sm_50", hgemm_128x128x8_NT_sm_50},
99 | {"hgemm_128x128x8_TT_sm_50", hgemm_128x128x8_TT_sm_50},
100 | {"hgemm_128x128x8_NN_vec_sm_50", hgemm_128x128x8_NN_vec_sm_50},
101 | {"hgemm_128x128x8_TN_vec_sm_50", hgemm_128x128x8_TN_vec_sm_50},
102 | {"hgemm_128x128x8_NT_vec_sm_50", hgemm_128x128x8_NT_vec_sm_50},
103 | {"hgemm_128x128x8_TT_vec_sm_50", hgemm_128x128x8_TT_vec_sm_50},
104 | {"hgemm_32x32x32_NN_sm_50", hgemm_32x32x32_NN_sm_50},
105 | {"hgemm_32x32x32_TN_sm_50", hgemm_32x32x32_TN_sm_50},
106 | {"hgemm_32x32x32_NT_sm_50", hgemm_32x32x32_NT_sm_50},
107 | {"hgemm_32x32x32_TT_sm_50", hgemm_32x32x32_TT_sm_50},
108 | {"hgemm_32x32x32_NN_vec_sm_50", hgemm_32x32x32_NN_vec_sm_50},
109 | {"hgemm_32x32x32_TN_vec_sm_50", hgemm_32x32x32_TN_vec_sm_50},
110 | {"hgemm_32x32x32_NT_vec_sm_50", hgemm_32x32x32_NT_vec_sm_50},
111 | {"hgemm_32x32x32_TT_vec_sm_50", hgemm_32x32x32_TT_vec_sm_50},
112 | {"sgemm_128x128x8_NN_sm_50", sgemm_128x128x8_NN_sm_50},
113 | {"sgemm_128x128x8_TN_sm_50", sgemm_128x128x8_TN_sm_50},
114 | {"sgemm_128x128x8_NT_sm_50", sgemm_128x128x8_NT_sm_50},
115 | {"sgemm_128x128x8_TT_sm_50", sgemm_128x128x8_TT_sm_50},
116 | {"sgemm_128x128x8_NN_vec_sm_50", sgemm_128x128x8_NN_vec_sm_50},
117 | {"sgemm_128x128x8_TN_vec_sm_50", sgemm_128x128x8_TN_vec_sm_50},
118 | {"sgemm_128x128x8_NT_vec_sm_50", sgemm_128x128x8_NT_vec_sm_50},
119 | {"sgemm_128x128x8_TT_vec_sm_50", sgemm_128x128x8_TT_vec_sm_50},
120 | {"sgemm_32x32x32_NN_sm_50", sgemm_32x32x32_NN_sm_50},
121 | {"sgemm_32x32x32_TN_sm_50", sgemm_32x32x32_TN_sm_50},
122 | {"sgemm_32x32x32_NT_sm_50", sgemm_32x32x32_NT_sm_50},
123 | {"sgemm_32x32x32_TT_sm_50", sgemm_32x32x32_TT_sm_50},
124 | {"sgemm_32x32x32_NN_vec_sm_50", sgemm_32x32x32_NN_vec_sm_50},
125 | {"sgemm_32x32x32_TN_vec_sm_50", sgemm_32x32x32_TN_vec_sm_50},
126 | {"sgemm_32x32x32_NT_vec_sm_50", sgemm_32x32x32_NT_vec_sm_50},
127 | {"sgemm_32x32x32_TT_vec_sm_50", sgemm_32x32x32_TT_vec_sm_50},
128 | };
129 |
--------------------------------------------------------------------------------
/lib/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/maxas/MaxAs/Cubin.pm:
--------------------------------------------------------------------------------
1 | package MaxAs::Cubin;
2 |
3 | use strict;
4 | use Data::Dumper;
5 |
6 | my @Elf32_Hdr = qw(
7 | H8 magic
8 | C fileClass
9 | C encoding
10 | C fileVersion
11 | H18 padding
12 | S type
13 | S machine
14 | L version
15 | L entry
16 | L phOffset
17 | L shOffset
18 | L flags
19 | S ehSize
20 | S phEntSize
21 | S phNum
22 | S shEntSize
23 | S shNum
24 | S shStrIndx
25 | );
26 | my @Elf64_Hdr = qw(
27 | H8 magic
28 | C fileClass
29 | C encoding
30 | C fileVersion
31 | H18 padding
32 | S type
33 | S machine
34 | L version
35 | Q entry
36 | Q phOffset
37 | Q shOffset
38 | L flags
39 | S ehSize
40 | S phEntSize
41 | S phNum
42 | S shEntSize
43 | S shNum
44 | S shStrIndx
45 | );
46 | my @Elf32_PrgHdr = qw(
47 | L type
48 | L offset
49 | L vaddr
50 | L paddr
51 | L fileSize
52 | L memSize
53 | L flags
54 | L align
55 | );
56 | my @Elf64_PrgHdr = qw(
57 | L type
58 | L flags
59 | Q offset
60 | Q vaddr
61 | Q paddr
62 | Q fileSize
63 | Q memSize
64 | Q align
65 | );
66 | my @Elf32_SecHdr = qw(
67 | L name
68 | L type
69 | L flags
70 | L addr
71 | L offset
72 | L size
73 | L link
74 | L info
75 | L align
76 | L entSize
77 | );
78 | my @Elf64_SecHdr = qw(
79 | L name
80 | L type
81 | Q flags
82 | Q addr
83 | Q offset
84 | Q size
85 | L link
86 | L info
87 | Q align
88 | Q entSize
89 | );
90 | my @Elf32_SymEnt = qw(
91 | L name
92 | L value
93 | L size
94 | C info
95 | C other
96 | S shIndx
97 | );
98 | my @Elf64_SymEnt = qw(
99 | L name
100 | C info
101 | C other
102 | S shIndx
103 | Q value
104 | Q size
105 | );
106 | my @symBind = qw(LOCAL GLOBAL WEAK);
107 |
108 | # Split the Elf Header defs into template strings (T) and corresponding hash keys columns (C)
109 | my (@elfHdrT, @prgHdrT, @secHdrT, @symHdrT, @elfHdrC, @prgHdrC, @secHdrC, @symHdrC);
110 |
111 | $elfHdrT[1] = join '', grep { length($_) <= 3} @Elf32_Hdr;
112 | $prgHdrT[1] = join '', grep { length($_) <= 3} @Elf32_PrgHdr;
113 | $secHdrT[1] = join '', grep { length($_) <= 3} @Elf32_SecHdr;
114 | $symHdrT[1] = join '', grep { length($_) <= 3} @Elf32_SymEnt;
115 |
116 | $elfHdrT[2] = join '', grep { length($_) <= 3} @Elf64_Hdr;
117 | $prgHdrT[2] = join '', grep { length($_) <= 3} @Elf64_PrgHdr;
118 | $secHdrT[2] = join '', grep { length($_) <= 3} @Elf64_SecHdr;
119 | $symHdrT[2] = join '', grep { length($_) <= 3} @Elf64_SymEnt;
120 |
121 | $elfHdrC[1] = [ grep { length($_) > 3} @Elf32_Hdr ];
122 | $prgHdrC[1] = [ grep { length($_) > 3} @Elf32_PrgHdr ];
123 | $secHdrC[1] = [ grep { length($_) > 3} @Elf32_SecHdr ];
124 | $symHdrC[1] = [ grep { length($_) > 3} @Elf32_SymEnt ];
125 |
126 | $elfHdrC[2] = [ grep { length($_) > 3} @Elf64_Hdr ];
127 | $prgHdrC[2] = [ grep { length($_) > 3} @Elf64_PrgHdr ];
128 | $secHdrC[2] = [ grep { length($_) > 3} @Elf64_SecHdr ];
129 | $symHdrC[2] = [ grep { length($_) > 3} @Elf64_SymEnt ];
130 |
131 | # Load a cubin ELF file
132 | sub new
133 | {
134 | my ($package, $file) = @_;
135 |
136 | my $cubin = bless { fileName => $file }, $package;
137 |
138 | open my $fh, $file or die "$file: $!";
139 | binmode($fh);
140 |
141 | # Read in assuming 32 bit header
142 | my $data;
143 | read $fh, $data, 0x34;
144 | my $elfHdr = $cubin->{elfHdr} = {};
145 | @{$elfHdr}{@{$elfHdrC[1]}} = unpack $elfHdrT[1], $data;
146 |
147 | # 1: 32bit, 2: 64bit
148 | my $class = $elfHdr->{fileClass};
149 |
150 | # re-read in with 64 bit header if needed
151 | if ($class == 2)
152 | {
153 | seek $fh, 0, 0;
154 | read $fh, $data, 0x46;
155 | @{$elfHdr}{@{$elfHdrC[$class]}} = unpack $elfHdrT[$class], $data;
156 |
157 | $cubin->{Class} = 64;
158 | }
159 | else
160 | {
161 | $cubin->{Class} = 32;
162 | }
163 |
164 | # verify sm_50 cubin
165 | $cubin->{Arch} = $elfHdr->{flags} & 0xFF;
166 | die "Cubin not in sm_50 or greater format. Found: sm_$cubin->{Arch}\n" if $cubin->{Arch} < 50;
167 |
168 | $cubin->{AddressSize} = $elfHdr->{flags} & 0x400 ? 64 : 32;
169 |
170 | # Read in Program Headers
171 | seek $fh, $elfHdr->{phOffset}, 0;
172 | foreach (1 .. $elfHdr->{phNum})
173 | {
174 | read $fh, $data, $elfHdr->{phEntSize};
175 |
176 | my %prgHdr = (Indx => $_ - 1);
177 | @prgHdr{@{$prgHdrC[$class]}} = unpack $prgHdrT[$class], $data;
178 | push @{$cubin->{prgHdrs}}, \%prgHdr;
179 | }
180 |
181 | # Read in Section Headers
182 | seek $fh, $elfHdr->{shOffset}, 0;
183 | foreach (1 .. $elfHdr->{shNum})
184 | {
185 | read $fh, $data, $elfHdr->{shEntSize};
186 |
187 | my %secHdr = (Indx => $_ - 1);
188 | @secHdr{@{$secHdrC[$class]}} = unpack $secHdrT[$class], $data;
189 | push @{$cubin->{secHdrs}}, \%secHdr;
190 | }
191 |
192 | # Read in Section data
193 | foreach my $secHdr (@{$cubin->{secHdrs}})
194 | {
195 | $data = '';
196 | # Skip sections with no data (type NULL or NOBITS)
197 | if ($secHdr->{size} && $secHdr->{type} != 8)
198 | {
199 | seek $fh, $secHdr->{offset}, 0;
200 | read $fh, $data, $secHdr->{size};
201 | }
202 | # Convert string tables to maps
203 | if ($secHdr->{type} == 3) # STRTAB
204 | {
205 | my $strTab = $secHdr->{StrTab} = {};
206 | my $indx = 0;
207 | foreach my $str (split "\0", $data)
208 | {
209 | $strTab->{$indx} = $str;
210 | $indx += 1 + length($str);
211 | }
212 | }
213 | # Read in Symbol data
214 | if ($secHdr->{type} == 2) # SYMTAB
215 | {
216 | my $offset = 0;
217 | while ($offset < $secHdr->{size})
218 | {
219 | my $symEnt = {};
220 | @{$symEnt}{@{$symHdrC[$class]}} = unpack $symHdrT[$class], substr($data, $offset, $secHdr->{entSize});
221 | $offset += $secHdr->{entSize};
222 |
223 | push @{$secHdr->{SymTab}}, $symEnt;
224 | }
225 | }
226 | # Cache raw data for further processing and writing
227 | $secHdr->{Data} = unpack 'H*', $data;
228 | }
229 | close $fh;
230 |
231 | # Update section headers with their names. Map names directly to headers.
232 | my $shStrTab = $cubin->{secHdrs}[$elfHdr->{shStrIndx}]{StrTab};
233 | foreach my $secHdr (@{$cubin->{secHdrs}})
234 | {
235 | $secHdr->{Name} = $shStrTab->{$secHdr->{name}};
236 | $cubin->{$secHdr->{Name}} = $secHdr;
237 | }
238 |
239 | # Update symbols with their names
240 | # For the Global functions, extract kernel meta data
241 | # Populate the kernel hash
242 | my $strTab = $cubin->{'.strtab'}{StrTab};
243 | foreach my $symEnt (@{$cubin->{'.symtab'}{SymTab}})
244 | {
245 | $symEnt->{Name} = $strTab->{$symEnt->{name}};
246 |
247 | # Attach symbol to section
248 | my $secHdr = $cubin->{secHdrs}[$symEnt->{shIndx}];
249 | $secHdr->{SymbolEnt} = $symEnt;
250 |
251 | # Look for symbols tagged FUNC
252 | if (($symEnt->{info} & 0x0f) == 0x02)
253 | {
254 | # Create a hash of kernels for output
255 | my $kernelSec = $cubin->{Kernels}{$symEnt->{Name}} = $secHdr;
256 |
257 | # Extract local/global/weak binding info
258 | $kernelSec->{Linkage} = $symBind[($symEnt->{info} & 0xf0) >> 4];
259 |
260 | # Extract the kernel instructions
261 | $kernelSec->{KernelData} = [ unpack "Q*", pack "H*", $kernelSec->{Data} ];
262 |
263 | # Extract the max barrier resource identifier used and add 1. Should be 0-16.
264 | # If a register is used as a barrier resource id, then this value is the max of 16.
265 | $kernelSec->{BarCnt} = ($kernelSec->{flags} & 0x01f00000) >> 20;
266 |
267 | # Extract the number of allocated registers for this kernel.
268 | $kernelSec->{RegCnt} = ($kernelSec->{info} & 0xff000000) >> 24;
269 |
270 | # Extract the size of shared memory this kernel uses.
271 | my $sharedSec = $kernelSec->{SharedSec} = $cubin->{".nv.shared.$symEnt->{Name}"};
272 | $kernelSec->{SharedSize} = $sharedSec ? $sharedSec->{size} : 0;
273 |
274 | # Attach constant0 section
275 | $kernelSec->{ConstantSec} = $cubin->{".nv.constant0.$symEnt->{Name}"};
276 |
277 | # Extract the kernel parameter data.
278 | my $paramSec = $kernelSec->{ParamSec} = $cubin->{".nv.info.$symEnt->{Name}"};
279 | if ($paramSec)
280 | {
281 | # Extract raw param data
282 | my @data = unpack "L*", pack "H*", $paramSec->{Data};
283 |
284 | $paramSec->{ParamData} = \@data;
285 | $paramSec->{ParamHex} = [ map { sprintf '0x%08x', $_ } @data ];
286 |
287 | # Find the first param delimiter
288 | my $idx = 0;
289 | $idx++ while $idx < @data && $data[$idx] != 0x00080a04;
290 |
291 | my $first = $data[$idx+2] & 0xFFFF;
292 | #my $size = $data[$idx+2] >> 16;
293 | $idx += 4;
294 |
295 | my @params;
296 | while ($idx < @data && $data[$idx] == 0x000c1704)
297 | {
298 | # Get the ordinal, offset, size and pointer alignment for each param
299 | my $ord = $data[$idx+2] & 0xFFFF;
300 | my $offset = sprintf '0x%02x', $first + ($data[$idx+2] >> 16);
301 | my $psize = $data[$idx+3] >> 18;
302 | my $align = $data[$idx+3] & 0x400 ? 1 << ($data[$idx+3] & 0x3ff) : 0;
303 | unshift @params, "$ord:$offset:$psize:$align";
304 | $idx += 4;
305 | }
306 | my @staticParams = @data[0 .. ($idx-1)];
307 |
308 | my ($maxregCount, @exitOffsets, @ctaidOffsets, $ctaidzUsed, @reqntid, @maxntid, @stackSize);
309 | while ($idx < @data)
310 | {
311 | my $code = $data[$idx] & 0xffff;
312 | my $size = $data[$idx] >> 16;
313 | $idx++;
314 |
315 | # EIATTR_MAXREG_COUNT
316 | if ($code == 0x1b03)
317 | {
318 | $maxregCount = $size;
319 | }
320 | # EIATTR_S2RCTAID_INSTR_OFFSETS
321 | elsif ($code == 0x1d04)
322 | {
323 | while ($size > 0)
324 | {
325 | push @ctaidOffsets, $data[$idx++];
326 | $size -= 4;
327 | }
328 | }
329 | # EIATTR_EXIT_INSTR_OFFSETS
330 | elsif ($code == 0x1c04)
331 | {
332 | while ($size > 0)
333 | {
334 | push @exitOffsets, $data[$idx++];
335 | $size -= 4;
336 | }
337 | }
338 | # EIATTR_CTAIDZ_USED
339 | elsif ($code == 0x0401)
340 | {
341 | $ctaidzUsed = 1;
342 | }
343 | # EIATTR_REQNTID
344 | elsif ($code == 0x1004)
345 | {
346 | while ($size > 0)
347 | {
348 | push @reqntid, $data[$idx++];
349 | $size -= 4;
350 | }
351 | }
352 | # EIATTR_MAX_THREADS
353 | elsif ($code == 0x0504)
354 | {
355 | while ($size > 0)
356 | {
357 | push @maxntid, $data[$idx++];
358 | $size -= 4;
359 | }
360 | }
361 | # EIATTR_CRS_STACK_SIZE
362 | elsif ($code == 0x1e04)
363 | {
364 | while ($size > 0)
365 | {
366 | push @stackSize, $data[$idx++];
367 | $size -= 4;
368 | }
369 | }
370 | else
371 | {
372 | printf STDERR "Unknown Code 0x%02x (size:%d)\n", $code, $size;
373 | }
374 | }
375 | $kernelSec->{Params} = \@params;
376 | $kernelSec->{ParamCnt} = scalar @params;
377 |
378 | $paramSec->{StaticParams} = \@staticParams;
379 | $paramSec->{MAXREG_COUNT} = $maxregCount;
380 | $paramSec->{ExitOffsets} = \@exitOffsets;
381 | $paramSec->{CTAIDOffsets} = \@ctaidOffsets;
382 | $paramSec->{CTAIDZUsed} = $ctaidzUsed;
383 | $paramSec->{REQNTID} = \@reqntid;
384 | $paramSec->{MAXNTID} = \@maxntid;
385 | $paramSec->{STACKSIZE} = \@stackSize;
386 | }
387 | # print Dumper($paramSec);
388 | # exit();
389 | }
390 | # Note GLOBALs found in this cubin
391 | elsif (($symEnt->{info} & 0x10) == 0x10)
392 | {
393 | $cubin->{Symbols}{$symEnt->{Name}} = $symEnt;
394 | }
395 | }
396 |
397 | # print "phOffset: $elfHdr->{phOffset}\n";
398 | # print "shOffset: $elfHdr->{shOffset}\n";
399 | # foreach my $secHdr (@{$cubin->{secHdrs}})
400 | # {
401 | # print "secHdr($secHdr->{Indx}): $secHdr->{offset}, $secHdr->{size}, $secHdr->{align} ($secHdr->{Name})\n";
402 | # }
403 | # my $p = 0;
404 | # foreach my $prgHdr (@{$cubin->{prgHdrs}})
405 | # {
406 | # print "prgHdr($p): type: $prgHdr->{type}, offset: $prgHdr->{offset}, fileSize: $prgHdr->{fileSize}, memSize: $prgHdr->{memSize}, align: $prgHdr->{align}\n";
407 | # $p++;
408 | # }
409 | # exit();
410 |
411 | #print map { sprintf "%016x\n", $_ } @{$cubin->{Kernels}{microbench}{KernelData}};
412 |
413 | #print Dumper($cubin->{Kernels}{test}{KernelData});
414 | #exit();
415 | return $cubin;
416 | }
417 | sub class
418 | {
419 | return shift()->{Class};
420 | }
421 | sub arch
422 | {
423 | return shift()->{Arch};
424 | }
425 | sub address_size
426 | {
427 | return shift()->{AddressSize};
428 | }
429 | sub listKernels
430 | {
431 | return shift()->{Kernels};
432 | }
433 | sub listSymbols
434 | {
435 | return shift()->{Symbols};
436 | }
437 | sub getKernel
438 | {
439 | my ($cubin, $kernel) = @_;
440 | return $cubin->{Kernels}{$kernel};
441 | }
442 |
443 | sub modifyKernel
444 | {
445 | my ($cubin, %params) = @_;
446 |
447 | my $kernelSec = $params{Kernel};
448 | my $newReg = $params{RegCnt};
449 | my $newBar = $params{BarCnt};
450 | my $exitOffsets = $params{ExitOffsets};
451 | my $ctaidOffsets = $params{CTAIDOffsets};
452 | my $ctaidzUsed = $params{CTAIDZUsed};
453 | my $newData = $params{KernelData};
454 | my $newSize = @$newData * 8;
455 |
456 | die "255 register max" if $newReg > 255;
457 | die "new kernel size must be multiple of 8 instructions (64 bytes)" if $newSize & 63;
458 | die "16 is max barrier count" if $newBar > 16;
459 |
460 | my $paramSec = $kernelSec->{ParamSec};
461 | my $kernelName = $kernelSec->{SymbolEnt}{Name};
462 | my $maxregCount = $paramSec->{MAXREG_COUNT};
463 | my $stackSize = $paramSec->{STACKSIZE};
464 |
465 | # update the kernel
466 | $kernelSec->{KernelData} = $newData;
467 | $kernelSec->{Data} = unpack "H*", pack "Q*", @$newData;
468 |
469 | if ($newReg != $kernelSec->{RegCnt})
470 | {
471 | print "Modified $kernelName RegCnt: $kernelSec->{RegCnt} => $newReg\n";
472 | $kernelSec->{RegCnt} = $newReg;
473 | $kernelSec->{info} &= ~0xff000000;
474 | $kernelSec->{info} |= $newReg << 24;
475 | }
476 | if ($newBar != $kernelSec->{BarCnt})
477 | {
478 | print "Modified $kernelName BarCnt: $kernelSec->{BarCnt} => $newBar\n";
479 | $kernelSec->{BarCnt} = $newBar;
480 | $kernelSec->{flags} &= ~0x01f00000;
481 | $kernelSec->{flags} |= $newBar << 20;
482 | }
483 |
484 | my @paramData = @{$paramSec->{StaticParams}};
485 |
486 | if (defined $maxregCount)
487 | {
488 | push @paramData, ($maxregCount << 16) | 0x1b03;
489 | }
490 |
491 | my $newCTAIDs = join ',', map { sprintf '%04x', $_ } @$ctaidOffsets;
492 | my $oldCTAIDs = join ',', map { sprintf '%04x', $_ } @{$paramSec->{CTAIDOffsets}};
493 |
494 | if ($newCTAIDs ne $oldCTAIDs)
495 | {
496 | print "Modified $kernelName CTAID Offsets: '$oldCTAIDs' => '$newCTAIDs'\n";
497 | }
498 | if (@$ctaidOffsets)
499 | {
500 | push @paramData, (scalar(@$ctaidOffsets) << 18) | 0x1d04;
501 | push @paramData, @$ctaidOffsets;
502 | }
503 |
504 | my $newExits = join ',', map { sprintf '%04x', $_ } @$exitOffsets;
505 | my $oldExits = join ',', map { sprintf '%04x', $_ } @{$paramSec->{ExitOffsets}};
506 |
507 | if ($newExits ne $oldExits)
508 | {
509 | print "Modified $kernelName Exit Offsets: '$oldExits' => '$newExits'\n";
510 | }
511 | if (@$exitOffsets)
512 | {
513 | push @paramData, (scalar(@$exitOffsets) << 18) | 0x1c04;
514 | push @paramData, @$exitOffsets;
515 | }
516 |
517 | if ($ctaidzUsed != $paramSec->{CTAIDZUsed})
518 | {
519 | print "Modified $kernelName CTAID.Z Used: '$paramSec->{CTAIDZUsed}' => '$ctaidzUsed'\n";
520 | }
521 | if ($ctaidzUsed)
522 | {
523 | push @paramData, 0x0401;
524 | }
525 |
526 | if (@{$paramSec->{REQNTID}})
527 | {
528 | push @paramData, (scalar(@{$paramSec->{REQNTID}}) << 18) | 0x1004;
529 | push @paramData, @{$paramSec->{REQNTID}};
530 | }
531 | if (@{$paramSec->{MAXNTID}})
532 | {
533 | push @paramData, (scalar(@{$paramSec->{MAXNTID}}) << 18) | 0x0504;
534 | push @paramData, @{$paramSec->{MAXNTID}};
535 | }
536 |
537 | if (@$stackSize)
538 | {
539 | push @paramData, (scalar(@$stackSize) << 18) | 0x1e04;
540 | push @paramData, @$stackSize;
541 | }
542 |
543 | my $newParamSize = scalar(@paramData)*4;
544 | $paramSec->{Data} = unpack "H*", pack "L*", @paramData;
545 | if ($newParamSize != $paramSec->{size})
546 | {
547 | print "Modified $kernelName ParamSecSize: $paramSec->{size} => $newParamSize\n";
548 | $cubin->updateSize($paramSec, $newParamSize);
549 | }
550 |
551 | if ($newSize != $kernelSec->{size})
552 | {
553 | print "Modified $kernelName KernelSize: $kernelSec->{size} => $newSize\n";
554 | $cubin->updateSize($kernelSec, $newSize, 1);
555 | }
556 | }
557 |
558 | sub updateSize
559 | {
560 | my ($cubin, $sec, $newSize, $updatePrgSize) = @_;
561 |
562 | my $elfHdr = $cubin->{elfHdr};
563 | my $class = $elfHdr->{fileClass};
564 |
565 | # update section header
566 | my $delta = $newSize - $sec->{size};
567 | $sec->{size} = $newSize;
568 |
569 | # update symtab section
570 | if ($sec->{SymbolEnt})
571 | {
572 | $sec->{SymbolEnt}{size} = $newSize;
573 | my $symSection = $cubin->{'.symtab'};
574 | $symSection->{Data} = '';
575 | foreach my $symEnt (@{$symSection->{SymTab}})
576 | {
577 | $symSection->{Data} .= unpack "H*", pack $symHdrT[$class], @{$symEnt}{@{$symHdrC[$class]}};
578 | }
579 | }
580 |
581 | my $pos = $elfHdr->{ehSize};
582 | my %sizeMap;
583 |
584 | # update section header offsets
585 | foreach my $secHdr (@{$cubin->{secHdrs}})
586 | {
587 | # skip first header
588 | next if $secHdr->{align} == 0;
589 |
590 | # NOBITS data sections are size 0
591 | my $size = $secHdr->{type} == 8 ? 0 : $secHdr->{size};
592 |
593 | # Add any needed padding between sections
594 | my $pad = $pos % $secHdr->{align};
595 | if ($pad > 0)
596 | {
597 | $pos += $secHdr->{align} - $pad;
598 | }
599 | # map old offset to new
600 | $sizeMap{$secHdr->{offset}} = $pos;
601 |
602 | # update offset
603 | $secHdr->{offset} = $pos;
604 |
605 | # advance position by size
606 | $pos += $size;
607 | }
608 |
609 | # compute total section header size
610 | my $shSize = $elfHdr->{phOffset} - $elfHdr->{shOffset};
611 |
612 | # map old offset to new
613 | $sizeMap{$elfHdr->{shOffset}} = $pos;
614 | $sizeMap{$elfHdr->{phOffset}} = $pos + $shSize;
615 |
616 | $elfHdr->{shOffset} = $pos;
617 | $elfHdr->{phOffset} = $pos + $shSize;
618 |
619 | # update program header offsets and sizes
620 | foreach my $prgHdr (@{$cubin->{prgHdrs}})
621 | {
622 | # Not sure how best to adjust these so just assume they'll track other offsets.
623 | $prgHdr->{offset} = $sizeMap{$prgHdr->{offset}};
624 |
625 | # If the kernel sizes changes, also update the associated ProgramHeader.
626 | # Note that this size is the kernel size plus any constant section sizes.
627 | if ($updatePrgSize && $prgHdr->{type} == 1 &&
628 | $sec->{offset} >= $prgHdr->{offset} &&
629 | $sec->{offset} < $prgHdr->{offset} + $prgHdr->{fileSize} + $delta)
630 | {
631 | $prgHdr->{fileSize} += $delta;
632 | $prgHdr->{memSize} += $delta;
633 | }
634 | }
635 | }
636 |
637 | # Write out the cubin after modifying it.
638 | sub write
639 | {
640 | my ($cubin, $file) = @_;
641 |
642 | open my $fh, ">$file" or die "Error: could not open $file for writing: $!";
643 | binmode($fh);
644 |
645 | my $elfHdr = $cubin->{elfHdr};
646 | my $class = $elfHdr->{fileClass};
647 |
648 | # write elf header
649 | print $fh pack $elfHdrT[$class], @{$elfHdr}{@{$elfHdrC[$class]}};
650 | my $pos = $elfHdr->{ehSize};
651 |
652 | # write section data
653 | foreach my $secHdr (@{$cubin->{secHdrs}})
654 | {
655 | # Skip NULL and NOBITS data sections
656 | next if $secHdr->{size} == 0 || $secHdr->{type} == 8;
657 |
658 | # Add any needed padding between sections
659 | my $pad = $pos % $secHdr->{align};
660 | if ($pad > 0)
661 | {
662 | $pad = $secHdr->{align} - $pad;
663 | print $fh join '', "\0" x $pad;
664 | $pos += $pad;
665 | }
666 |
667 | print $fh pack 'H*', $secHdr->{Data};
668 | $pos += $secHdr->{size};
669 | }
670 |
671 | # write section headers
672 | foreach my $secHdr (@{$cubin->{secHdrs}})
673 | {
674 | print $fh pack $secHdrT[$class], @{$secHdr}{@{$secHdrC[$class]}};
675 | }
676 |
677 | #write program headers
678 | foreach my $prgHdr (@{$cubin->{prgHdrs}})
679 | {
680 | print $fh pack $prgHdrT[$class], @{$prgHdr}{@{$prgHdrC[$class]}};
681 | }
682 | close $fh;
683 | }
684 |
685 | __END__
686 |
687 |
--------------------------------------------------------------------------------
/maxas/maxas.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl
2 | use strict;
3 | use MaxAs::Cubin;
4 | use MaxAs::MaxAs;
5 | use Data::Dumper;
6 | use File::Spec;
7 |
8 | require 5.10.0;
9 |
10 | $Data::Dumper::Sortkeys = 1;
11 |
12 | my $mode = shift;
13 |
14 | # List cubin contents
15 | if ($mode =~ /^\-?\-l/i)
16 | {
17 | my $cubinFile = shift or usage();
18 |
19 | my $cubin = MaxAs::Cubin->new($cubinFile);
20 |
21 | my $arch = $cubin->arch;
22 | my $class = $cubin->class;
23 | my $asize = $cubin->address_size;
24 | my $kernels = $cubin->listKernels;
25 | my $symbols = $cubin->listSymbols;
26 |
27 | printf "%s: arch:sm_%d machine:%dbit address_size:%dbit\n", $cubinFile, $arch, $class, $asize;
28 |
29 | foreach my $ker (sort keys %$kernels)
30 | {
31 | printf "Kernel: %s (Linkage: %s, Params: %d, Size: %d, Registers: %d, SharedMem: %d, Barriers: %d)\n", $ker, @{$kernels->{$ker}}{qw(Linkage ParamCnt size RegCnt SharedSize BarCnt)};
32 | }
33 | foreach my $sym (sort keys %$symbols)
34 | {
35 | printf "Symbol: %s\n", $sym;
36 | }
37 | }
38 | # Test that the assembler can reproduce the op codes this cubin or sass contains
39 | elsif ($mode =~ /^\-?\-t/i)
40 | {
41 | my $reg = shift if $ARGV[0] =~ /^\-?\-r/i;
42 | my $all = shift if $ARGV[0] =~ /^\-?\-a/i;
43 | my $file = shift or usage();
44 | my $fh;
45 | # sass file
46 | if (-T $file)
47 | {
48 | open $fh, $file or die "$file: $!";
49 | }
50 | # cubin file
51 | else
52 | {
53 | my $cubin = MaxAs::Cubin->new($file);
54 | my $arch = $cubin->arch;
55 |
56 | open $fh, "cuobjdump -arch sm_$arch -sass $file |" or die "cuobjdump -arch sm_$arch -sass $file: $!";
57 | my $first = <$fh>;
58 | if ($first =~ /cuobjdump fatal/)
59 | {
60 | print $first;
61 | exit(1);
62 | }
63 | }
64 | exit(MaxAs::MaxAs::Test($fh, $reg, $all) ? 1 : 0);
65 | }
66 | # Extract an asm file containing the desired kernel
67 | elsif ($mode =~ /^\-?\-e/i)
68 | {
69 | my $kernelName;
70 | if ($ARGV[0] =~ /^\-?\-k/i)
71 | {
72 | shift;
73 | $kernelName = shift or usage();
74 | }
75 | my $cubinFile = shift or usage();
76 | my $asmFile = shift;
77 | my $cubin = MaxAs::Cubin->new($cubinFile);
78 | my $arch = $cubin->arch;
79 | my $kernels = $cubin->listKernels;
80 |
81 | #default the kernel name if not specified.
82 | $kernelName ||= (sort keys %$kernels)[0];
83 |
84 | my $kernel = $kernels->{$kernelName} or die "bad kernel: $kernelName";
85 |
86 | open my $in, "cuobjdump -arch sm_$arch -sass -fun $kernelName $cubinFile |" or die "cuobjdump -arch sm_50 -sass -fun $kernelName $cubinFile: $!";
87 | my $first = <$in>;
88 | if ($first =~ /cuobjdump fatal/)
89 | {
90 | print $first;
91 | exit(1);
92 | }
93 | my $out;
94 | if ($asmFile)
95 | {
96 | open $out, ">$asmFile" or die "$asmFile: $!";
97 | }
98 | else
99 | {
100 | $out = \*STDOUT;
101 | }
102 |
103 | print $out "# Kernel: $kernelName\n# Arch: sm_$arch\n";
104 |
105 | print $out "# $_: $kernel->{$_}\n" foreach (qw(InsCnt RegCnt SharedSize BarCnt));
106 |
107 | print $out "# Params($kernel->{ParamCnt}):\n#\tord:addr:size:align\n";
108 |
109 | print $out join('', map "#\t$_\n", @{$kernel->{Params}}) if $kernel->{Params};
110 |
111 | print $out "#\n# Instructions:\n\n";
112 |
113 | MaxAs::MaxAs::Extract($in, $out, $kernel->{Params});
114 |
115 | close $out if $asmFile;
116 | close $in;
117 | }
118 | # Extract a kernel from a sass dump
119 | elsif ($mode =~ /^\-?\-s/i)
120 | {
121 | my $sassFile = shift or usage();
122 | my $asmFile = shift;
123 |
124 | open my $in, $sassFile or die "$sassFile: $!";
125 |
126 | my $out;
127 | if ($asmFile)
128 | {
129 | open $out, ">$asmFile" or die "$asmFile: $!";
130 | }
131 | else
132 | {
133 | $out = \*STDOUT;
134 | }
135 |
136 | MaxAs::MaxAs::Extract($in, $out, []);
137 |
138 | close $out if $asmFile;
139 | close $in;
140 | }
141 | # Insert the kernel asm back into the cubin:
142 | elsif ($mode =~ /^\-?\-i/i)
143 | {
144 | my $nowarn;
145 | if ($ARGV[0] =~ /^\-?\-w/i)
146 | {
147 | $nowarn = shift;
148 | }
149 | my $kernelName;
150 | if ($ARGV[0] =~ /^\-?\-k/i)
151 | {
152 | shift;
153 | $kernelName = shift or usage();
154 | }
155 | my $noReuse = shift if $ARGV[0] =~ /^\-?\-n/i;
156 | while ($ARGV[0] =~ /^\-?\-D(\w+)/)
157 | {
158 | shift;
159 | my $name = $1;
160 | my $value = shift;
161 | eval "package MaxAs::MaxAs::CODE; our \$$name = '$value';"
162 | }
163 |
164 | my $asmFile = shift or usage();
165 | my $cubinFile = shift or usage();
166 | my $newCubin = shift || $cubinFile;
167 |
168 | my $file;
169 | if (open my $fh, $asmFile)
170 | {
171 | local $/;
172 | $file = <$fh>;
173 | close $fh;
174 | }
175 | else { die "$asmFile: $!" }
176 |
177 | my ($vol,$dir) = File::Spec->splitpath($asmFile);
178 | my $include = [$vol, $dir];
179 |
180 | # extract the kernel name from the file
181 | ($kernelName) = $file =~ /^# Kernel: (\w+)/ unless $kernelName;
182 | die "asm file missing kernel name or is badly formatted" unless $kernelName;
183 |
184 | my $kernel = MaxAs::MaxAs::Assemble($file, $include, !$noReuse, $nowarn);
185 |
186 | my $cubin = MaxAs::Cubin->new($cubinFile);
187 | $kernel->{Kernel} = $cubin->getKernel($kernelName) or die "cubin does not contain kernel: $kernelName";
188 |
189 | $cubin->modifyKernel(%$kernel);
190 |
191 | $cubin->write($newCubin);
192 |
193 | printf "Kernel: $kernelName, Instructions: %d, Register Count: %d, Bank Conflicts: %d, Reuse: %.1f% (%d/%d)\n",
194 | @{$kernel}{qw(InsCnt RegCnt ConflictCnt ReusePct ReuseCnt ReuseTot)};
195 |
196 | }
197 | # Preprocessing:
198 | elsif ($mode =~ /^\-?\-p/i)
199 | {
200 | while ($ARGV[0] =~ /^\-?\-D(\w+)/)
201 | {
202 | shift;
203 | my $name = $1;
204 | my $value = shift;
205 | eval "package MaxAs::MaxAs::CODE; our \$$name = '$value';";
206 | }
207 | my $debug = shift if $ARGV[0] =~ /^\-?\-d/i;
208 | my $asmFile = shift or usage();
209 | my $asmFile2 = shift;
210 |
211 | die "source and destination probably shouldn't be the same file\n" if $asmFile eq $asmFile2;
212 |
213 | open my $fh, $asmFile or die "$asmFile: $!";
214 | local $/;
215 | my $file = <$fh>;
216 | close $fh;
217 |
218 | my ($vol,$dir) = File::Spec->splitpath($asmFile);
219 | my $include = [$vol, $dir];
220 |
221 | if ($asmFile2)
222 | {
223 | open $fh, ">$asmFile2" or die "$asmFile2: $!";
224 | }
225 | else
226 | {
227 | $fh = \*STDOUT;
228 | }
229 | print $fh MaxAs::MaxAs::Preprocess($file, $include, $debug);
230 | close $fh;
231 | }
232 | # get version information
233 | elsif ($mode =~ /^\-?\-v/i)
234 | {
235 | print "$MaxAs::MaxAs::VERSION\n";
236 | }
237 | else
238 | {
239 | print "$mode\n";
240 | usage();
241 | }
242 |
243 | exit(0);
244 |
245 |
246 |
247 | sub usage
248 | {
249 | print <
255 |
256 | Test a cubin or sass file to to see if the assembler can reproduce all of the contained opcodes.
257 | Also useful for extending the missing grammar rules. Defaults to only showing failures without --all.
258 | With the --reg flag it will show register bank conflicts not hidden by reuse flags.
259 |
260 | maxas.pl --test|-t [--reg|-r] [--all|-a]
261 |
262 | Extract a single kernel into an asm file from a cubin.
263 | Works much like cuobjdump but outputs in a format that can be re-assembled back into the cubin.
264 |
265 | maxas.pl --extract|-e [--kernel|-k kernel_name] [asm_file]
266 |
267 | Preprocess the asm: expand CODE sections, perform scheduling. Mainly used for debugging purposes.
268 | Include the debug flag to print out detailed scheduler info.
269 |
270 | maxas.pl --pre|-p [--debug|-d] [new_asm_file]
271 |
272 | Insert the kernel asm back into the cubin. Overwrite existing or create new cubin.
273 | Optionally you can skip register reuse flag auto insertion. This allows you to observe
274 | performance without any reuse or you can use it to set the flags manually in your sass.
275 |
276 | maxas.pl --insert|-i [--noreuse|-n] [new_cubin_file]
277 |
278 | Display version information and exit:
279 |
280 | maxas.pl --version|-v
281 |
282 | EOF
283 | exit(1);
284 | }
285 |
286 | __END__
287 |
--------------------------------------------------------------------------------
/openai_gemm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | import ctypes
4 | import appdirs
5 | import os.path
6 | import subprocess
7 | import numpy as np
8 | import pycuda.driver as drv
9 | from operator import mul
10 | from struct import unpack_from
11 | from pycuda.tools import context_dependent_memoize
12 | from scikits.cuda import cublas
13 |
14 | def matmul(A, B, C, alpha=1.0, beta=0.0, stream=None, bench=False):
15 | """
16 | C = alpha * A . B + beta * C
17 | C = alpha * A.T . B + beta * C
18 | C = alpha * A . B.T + beta * C
19 | C = alpha * A.T . B.T + beta * C
20 |
21 | bench: return benchmark data for all available tiles + cublas
22 | """
23 |
24 | # this could be relaxed, kernels are capable of mixed precision (with minor tweaks)
25 | # the s/h prefix would then go away and each type would be specified with kernel build option
26 | assert A.dtype.type == B.dtype.type == C.dtype.type
27 |
28 | if C.dtype.type is np.float32:
29 | prefix = "s"
30 | elif C.dtype.type is np.float16:
31 | prefix = "h"
32 | else:
33 | raise TypeError("Only floating point dot currently supported.")
34 |
35 | # (m,n) = (m,k) . (k,n)
36 | m = A.shape[0]
37 | n = B.shape[1]
38 | k = A.shape[1]
39 | assert m == C.shape[0]
40 | assert n == C.shape[1]
41 | assert k == B.shape[0]
42 |
43 | # Extract the operations and contiguous dimension sizes (cda, cdb, cdc).
44 | # Note that these can be the same as from the shape unless the non-contiguous dimension is sliced.
45 | # One dimension must be contiguous (DRAM efficiency demands this).
46 | # Note that the strides here do not include the datatype size as they would in numpy.
47 | # A transpose op (.T) on a GPUTensor reverses the shape and strides then flags the tensor as transposed (is_trans=True) -
48 | # The underlying data is unchanged.
49 | if A.is_trans:
50 | opA = 'T'
51 | cda = A.strides[1]
52 | assert A.strides[0] == 1
53 | else:
54 | opA = 'N'
55 | cda = A.strides[0]
56 | assert A.strides[1] == 1
57 |
58 | if B.is_trans:
59 | opB = 'T'
60 | cdb = B.strides[1]
61 | assert B.strides[0] == 1
62 | else:
63 | opB = 'N'
64 | cdb = B.strides[0]
65 | assert B.strides[1] == 1
66 |
67 | cdc = C.strides[0]
68 | assert C.strides[1] == 1
69 |
70 | op = opA + opB
71 |
72 | # get and autotune the kernel selection
73 | kernel, params, dynamic_shared = _get_gemm_kernel(prefix, op, cda, cdb, cdc, m, n, k)
74 |
75 | # bind dynamic params
76 | params[2:8] = (stream, C.gpudata, A.gpudata, B.gpudata, alpha, beta)
77 |
78 | # call the kernel
79 | kernel.prepared_async_call(*params, shared_size=dynamic_shared)
80 |
81 | # unbind dynamic params
82 | params[2:8] = (None,) * 6
83 |
84 | # return benchmark data if requested
85 | if bench:
86 | return _get_bench_data()[(prefix, op, cda, cdb, cdc, m, n, k)]
87 |
88 | return C
89 |
90 |
91 |
92 | ####################################################################################################
93 |
94 |
95 | # scikits.cuda doesn't expose cublasSgemmEx or cublasHgemm
96 | cublas._libcublas.cublasSgemmEx.restype = int
97 | cublas._libcublas.cublasSgemmEx.argtypes = [
98 | cublas._types.handle,
99 | ctypes.c_int, ctypes.c_int,
100 | ctypes.c_int, ctypes.c_int, ctypes.c_int,
101 | ctypes.c_void_p,
102 | ctypes.c_void_p, ctypes.c_int, ctypes.c_int,
103 | ctypes.c_void_p, ctypes.c_int, ctypes.c_int,
104 | ctypes.c_void_p,
105 | ctypes.c_void_p, ctypes.c_int, ctypes.c_int ]
106 |
107 | def cublasSgemmEx(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc):
108 | status = cublas._libcublas.cublasSgemmEx(handle,
109 | cublas._CUBLAS_OP[transa], cublas._CUBLAS_OP[transb],
110 | m, n, k,
111 | ctypes.byref(ctypes.c_float(alpha)),
112 | int(A), 2, lda,
113 | int(B), 2, ldb,
114 | ctypes.byref(ctypes.c_float(beta)),
115 | int(C), 2, ldc)
116 | cublas.cublasCheckStatus(status)
117 |
118 | cublas._libcublas.cublasHgemm.restype = int
119 | cublas._libcublas.cublasHgemm.argtypes = [
120 | cublas._types.handle,
121 | ctypes.c_int, ctypes.c_int,
122 | ctypes.c_int, ctypes.c_int, ctypes.c_int,
123 | ctypes.c_void_p,
124 | ctypes.c_void_p, ctypes.c_int,
125 | ctypes.c_void_p, ctypes.c_int,
126 | ctypes.c_void_p,
127 | ctypes.c_void_p, ctypes.c_int ]
128 |
129 | h_dtype = np.dtype(np.float16)
130 |
131 | def cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc):
132 |
133 | alpha = unpack_from('H', h_dtype.type(alpha))[0]
134 | beta = unpack_from('H', h_dtype.type(beta))[0]
135 |
136 | status = cublas._libcublas.cublasHgemm(handle,
137 | cublas._CUBLAS_OP[transa], cublas._CUBLAS_OP[transb],
138 | m, n, k,
139 | ctypes.byref(ctypes.c_uint16(alpha)),
140 | int(A), lda,
141 | int(B), ldb,
142 | ctypes.byref(ctypes.c_uint16(beta)),
143 | int(C), ldc)
144 | cublas.cublasCheckStatus(status)
145 |
146 | cublasXgemm = {
147 | "s" : cublas.cublasSgemm,
148 | "h" : cublasSgemmEx,
149 | "h2" : cublasHgemm,
150 | }
151 |
152 |
153 | @context_dependent_memoize
154 | def _get_sm_count():
155 | attributes = drv.Context.get_device().get_attributes()
156 | return attributes[drv.device_attribute.MULTIPROCESSOR_COUNT]
157 |
158 | @context_dependent_memoize
159 | def _get_compute_capability():
160 |
161 | attributes = drv.Context.get_device().get_attributes()
162 | major = attributes[drv.device_attribute.COMPUTE_CAPABILITY_MAJOR]
163 | minor = attributes[drv.device_attribute.COMPUTE_CAPABILITY_MINOR]
164 | return major, minor
165 |
166 | @context_dependent_memoize
167 | def _get_events():
168 | return (drv.Event(), drv.Event())
169 |
170 | @context_dependent_memoize
171 | def _get_cublas():
172 | return cublas.cublasCreate()
173 |
174 | @context_dependent_memoize
175 | def _get_bench_data():
176 | return dict()
177 |
178 | def _ceil_div(x, y):
179 | return -(-x // y)
180 |
181 | def _closest_divisor(val, div):
182 | divisors = sorted([(abs(i - div), i) for i in range(2, 8) if val % i == 0])
183 | if len(divisors):
184 | return (divisors[0][1], val // divisors[0][1])
185 | else:
186 | return (1, val)
187 |
188 |
189 | # Tile sizes: m, n, k, vA,vB,vC div, op (dynamic shared options)
190 | k128x128x8 = (128, 128, 8, 4, 4, 1, 2, 0, (0,))
191 | k32x32x32 = ( 32, 32, 32, 4, 4, 1, 4, 0, (0, 2**14))
192 | k32x64x32_NN = ( 32, 64, 32, 8, 4, 4, 4, 1, (0, 2**13))
193 | k32x32x64_NT = ( 32, 32, 64, 8, 8, 4, 4, 1, (0,))
194 | k16x64x64_NN = ( 16, 64, 64, 8, 4, 4, 4, 1, (0,))
195 | k16x64x64_NT = ( 16, 64, 64, 8, 8, 4, 4, 1, (0,))
196 |
197 | selections = {
198 | "s" : {
199 | "TN" : (k128x128x8, k32x32x32),
200 | "NN" : (k128x128x8, k32x32x32),
201 | "NT" : (k128x128x8, k32x32x32),
202 | "TT" : (k128x128x8, k32x32x32),
203 | },
204 | "h" : {
205 | "TN" : (k128x128x8, k32x32x32),
206 | "NN" : (k128x128x8, k32x32x32, k32x64x32_NN, k16x64x64_NN),
207 | "NT" : (k128x128x8, k32x32x32, k32x32x64_NT, k16x64x64_NT),
208 | "TT" : (k128x128x8, k32x32x32),
209 | },
210 | }
211 |
212 | # Autotune kernel selection
213 | @context_dependent_memoize
214 | def _get_gemm_kernel(prefix, op, cda, cdb, cdc, m, n, k):
215 |
216 | if op[0] == 'T':
217 | vec4A = (cda & 3) == 0 and (m & 3) == 0
218 | vec8A = (cda & 7) == 0 and (m & 7) == 0
219 | dimA = (k,cda)
220 | else:
221 | vec4A = (cda & 3) == 0 and (k & 3) == 0
222 | vec8A = (cda & 7) == 0 and (k & 7) == 0
223 | dimA = (m,cda)
224 |
225 | if op[1] == 'T':
226 | vec4B = (cdb & 3) == 0 and (k & 3) == 0
227 | vec8B = (cdb & 7) == 0 and (k & 7) == 0
228 | dimB = (n,cdb)
229 | else:
230 | vec4B = (cdb & 3) == 0 and (n & 3) == 0
231 | vec8B = (cdb & 7) == 0 and (n & 7) == 0
232 | dimB = (k,cdb)
233 |
234 | vec4C = (cdc & 3) == 0 and (n & 3) == 0
235 | dimC = (m,cdc)
236 |
237 | dtype = np.dtype(np.float32 if prefix == 's' else np.float16)
238 |
239 | A = drv.mem_alloc(mul(*dimA) * dtype.itemsize)
240 | B = drv.mem_alloc(mul(*dimB) * dtype.itemsize)
241 | C = drv.mem_alloc(mul(*dimC) * dtype.itemsize)
242 |
243 | # TODO: use curand
244 | dataA = np.random.uniform(-1.0, 1.0, dimA).astype(dtype)
245 | dataB = np.random.uniform(-1.0, 1.0, dimB).astype(dtype)
246 | drv.memcpy_htod(int(A), dataA)
247 | drv.memcpy_htod(int(B), dataB)
248 |
249 | # Using random data gets you more accurate autotune results
250 | # drv.memset_d8(int(A), 0, mul(*dimA) * dtype.itemsize)
251 | # drv.memset_d8(int(B), 0, mul(*dimB) * dtype.itemsize)
252 |
253 | timings = []
254 | cache = []
255 |
256 | # scale the repeat count to amount of work
257 | repeat = min(max(int(5e11 * 28 / (m*n*k * 2.0 * _get_sm_count()) ), 10), 5000)
258 | warmup = repeat
259 | #print repeat
260 |
261 | start, end = _get_events()
262 | flops = m * n * k * 2.0
263 |
264 | for tileM, tileN, tileK, vecA, vecB, vecC, div, base_op, dyn_shared in selections[prefix][op]:
265 |
266 | vecA = (vecA == 4 and vec4A) or (vecA == 8 and vec8A)
267 | vecB = (vecB == 4 and vec4B) or (vecB == 8 and vec8B)
268 | vecC = vecC == 1 or vec4C
269 | vec = vecA and vecB and vecC
270 |
271 | if base_op:
272 | # The op is part of the base kernel name
273 | base = "%sgemm_%dx%dx%d_%s" % (prefix, tileM, tileN, tileK, op)
274 | opts = ( "vec", ) if vec else ()
275 | else:
276 | # The op is an option passed to a more generic kernel
277 | base = "%sgemm_%dx%dx%d" % (prefix, tileM, tileN, tileK)
278 | opts = ( op, "vec" ) if vec else (op,)
279 |
280 | kernel = get_kernel(base, opts)
281 |
282 | blk_A = _ceil_div(m, tileM)
283 | blk_B = _ceil_div(n, tileN)
284 |
285 | # TODO: perhaps autotune all possible small divisors
286 | blk_a, blk_A = _closest_divisor(blk_A, div)
287 | blk_b, blk_B = _closest_divisor(blk_B, div)
288 | if blk_a == 1:
289 | blk_a, blk_A = (blk_A, 1)
290 |
291 | for dynamic_shared in dyn_shared:
292 |
293 | params = [
294 | (blk_a * blk_b, blk_B, blk_A), (kernel.threads, 1, 1), None,
295 | C, A, B, 1.0, 0.0,
296 | cda, cdb, cdc, m, n, k, blk_a, blk_b ]
297 |
298 | #print kernel.name, params, dynamic_shared
299 |
300 | # Warmup (once per config)
301 | for r in range(warmup):
302 | kernel.prepared_async_call(*params)
303 | warmup = 0
304 |
305 | # Benchmark
306 | start.record()
307 | for r in range(repeat):
308 | kernel.prepared_async_call(*params, shared_size=dynamic_shared)
309 | end.record()
310 | end.synchronize()
311 | msecs = end.time_since(start) / float(repeat)
312 | gflops = flops / (msecs * 1000000.0)
313 |
314 | params[3:8] = (None,) * 5
315 |
316 | timings.append((msecs, gflops, kernel, params, dynamic_shared))
317 | cache.append((msecs, gflops, kernel.name, dynamic_shared))
318 |
319 | major, minor = _get_compute_capability()
320 | if prefix == "h" and major == 6 and minor == 0:
321 | cublas_gemm = cublasXgemm["h2"]
322 | else:
323 | cublas_gemm = cublasXgemm[prefix]
324 |
325 | # record a cublas time for reference
326 | cublas_handle = _get_cublas()
327 | start.record()
328 | for r in range(repeat):
329 | # convert row order to col order
330 | cublas_gemm(cublas_handle, op[1], op[0], n, m, k, 1.0, B, cdb, A, cda, 0.0, C, cdc)
331 | end.record()
332 | end.synchronize()
333 | msecs = end.time_since(start) / float(repeat)
334 | gflops = flops / (msecs * 1000000.0)
335 | cache.append( (msecs, gflops, "cuBLAS", 0) )
336 |
337 | # cache complete timing data for benchmark comparisons
338 | # this data could be cached to disk for quicker autotuning on future runs
339 | _get_bench_data()[(prefix, op, cda, cdb, cdc, m, n, k)] = cache
340 |
341 | # return the fastest kernel
342 | return tuple(sorted(timings)[0][2:5])
343 |
344 |
345 |
346 | # Utility function to test all tiles for the given dimensions and dtype
347 | def matmul_test(ng, dtype, op, m, n, k, ones=False, out=False):
348 |
349 | prefix = "s" if dtype is np.float32 else "h"
350 |
351 | if op[0] == 'T':
352 | vec4A = (m & 3) == 0
353 | vec8A = (m & 7) == 0
354 | dimA = (k,m)
355 | cda = m
356 | else:
357 | vec4A = (k & 3) == 0
358 | vec8A = (k & 7) == 0
359 | dimA = (m,k)
360 | cda = k
361 |
362 | if op[1] == 'T':
363 | vec4B = (k & 3) == 0
364 | vec8B = (k & 7) == 0
365 | dimB = (n,k)
366 | cdb = k
367 | else:
368 | vec4B = (n & 3) == 0
369 | vec8B = (n & 7) == 0
370 | dimB = (k,n)
371 | cdb = n
372 |
373 | vec4C = (n & 3) == 0
374 | dimC = (m,n)
375 | cdc = n
376 |
377 | A1 = ng.empty(dimA, dtype=dtype)
378 | B1 = ng.empty(dimB, dtype=dtype)
379 | C1 = ng.empty(dimC, dtype=dtype)
380 | C2 = ng.empty(dimC, dtype=dtype)
381 |
382 | if ones:
383 | A1[:] = 1.0
384 | B1[:] = 1.0
385 | else:
386 | # fill with uniform randoms from -1 to 1
387 | A1[:] = 2 * (.5 - ng.rand())
388 | B1[:] = 2 * (.5 - ng.rand())
389 |
390 | # for reducing outputs
391 | partial1 = ng.empty((C1.shape[0],1), dtype=np.float32)
392 | partial2 = partial1[0:1,0:1]
393 |
394 | cublas_handle = _get_cublas()
395 |
396 | for tileM, tileN, tileK, vecA, vecB, vecC, div, base_op, dyn_shared in selections[prefix][op]:
397 |
398 | vecA = (vecA == 4 and vec4A) or (vecA == 8 and vec8A)
399 | vecB = (vecB == 4 and vec4B) or (vecB == 8 and vec8B)
400 | vecC = vecC == 1 or vec4C
401 | vec = vecA and vecB and vecC
402 |
403 | if base_op:
404 | # The op is part of the base kernel name
405 | base = "%sgemm_%dx%dx%d_%s" % (prefix, tileM, tileN, tileK, op)
406 | opts = ( "vec", ) if vec else ()
407 | else:
408 | # The op is an option passed to a more generic kernel
409 | base = "%sgemm_%dx%dx%d" % (prefix, tileM, tileN, tileK)
410 | opts = ( op, "vec" ) if vec else (op,)
411 |
412 | kernel = get_kernel(base, opts)
413 |
414 | blk_A = _ceil_div(m, tileM)
415 | blk_B = _ceil_div(n, tileN)
416 |
417 | blk_a, blk_A = _closest_divisor(blk_A, div)
418 | blk_b, blk_B = _closest_divisor(blk_B, div)
419 | if blk_a == 1:
420 | blk_a, blk_A = (blk_A, 1)
421 |
422 | for alpha, beta in ( (1.0,0.0), (0.5,0.5) ):
423 |
424 | try:
425 | if ones:
426 | C1[:] = 1.0
427 | else:
428 | C1[:] = 2 * (.5 - ng.rand())
429 | C2[:] = C1
430 |
431 | params = [
432 | (blk_a * blk_b, blk_B, blk_A), (kernel.threads, 1, 1), None,
433 | C1.gpudata, A1.gpudata, B1.gpudata, alpha, beta,
434 | cda, cdb, cdc, m, n, k, blk_a, blk_b ]
435 |
436 | kernel.prepared_async_call(*params)
437 |
438 | # convert row order to col order
439 | cublasXgemm[prefix](cublas_handle, op[1], op[0], n, m, k, alpha, B1.gpudata, cdb, A1.gpudata, cda, beta, C2.gpudata, cdc)
440 |
441 | # Check for NaNs
442 | partial1[:] = ng.min(ng.finite(C1), axis=1)
443 | partial2[:] = ng.min(partial1, axis=0)
444 | if partial2.get()[0,0] == 0.0:
445 | print "Error: NaN kernel: %s mnk: (%d,%d,%d) ab: (%f,%f)" % (kernel.name, m,n,k, alpha,beta)
446 | exit()
447 |
448 | # Get Max Diff
449 | partial1[:] = ng.max(abs(C2 - C1), axis=1)
450 | partial2[:] = ng.max(partial1, axis=0)
451 | diff = partial2.get()[0,0]
452 |
453 | # Get Mean
454 | partial1[:] = ng.sum(abs(C2), axis=1)
455 | partial2[:] = ng.sum(partial1, axis=0)
456 | mean = partial2.get()[0,0] / C2.size
457 |
458 | # Scale diff by the mean
459 | pctErr = 100 * diff / mean
460 |
461 | #print "Error: %.3f %s" % (pctErr, kernel.name)
462 |
463 | maxerr = .005 if dtype is np.float32 else 0.7
464 |
465 | if pctErr > maxerr:
466 | print "Error: %.3f%% diff: %.5f mean %.5f kernel: %s mnk: (%d,%d,%d) ab: (%f,%f)" % (pctErr, diff, mean, kernel.name, m,n,k, alpha,beta)
467 | print params
468 | if out:
469 | C1 = C1.get()
470 | C2 = C2.get()
471 | D = abs(C2 - C1)
472 | np.savetxt("out_diff.txt", D, fmt='%3.1f')
473 | np.savetxt("out_correct.txt", C2, fmt='%5.1f')
474 | np.savetxt("out_error", C1, fmt='%5.1f')
475 | exit()
476 |
477 | except drv.Error as e:
478 | print "kernel: %s mnk: (%d,%d,%d) ab: (%f,%f)" % (kernel.name, m,n,k, alpha,beta)
479 | print e
480 | exit()
481 |
482 | ### below code adapted from Nervana Neon: kernel_specs.py
483 |
484 | def _get_cache_dir(subdir=None):
485 |
486 | cache_dir = appdirs.user_cache_dir("openai-gemm")
487 |
488 | if subdir:
489 | subdir = subdir if isinstance(subdir, list) else [subdir]
490 | cache_dir = os.path.join(cache_dir, *subdir)
491 |
492 | if not os.path.exists(cache_dir):
493 | os.makedirs(cache_dir)
494 |
495 | return cache_dir
496 |
497 | # helpful for kernel development
498 | debug = 0
499 |
500 | base_dir = os.path.dirname(__file__)
501 | maxas_dir = os.path.join(base_dir, "maxas")
502 | sass_dir = os.path.join(base_dir, "sass")
503 |
504 | kernels = {
505 | # Generic gemm tiles
506 | "sgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "s"} },
507 | "hgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "h"} },
508 | "sgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "s"} },
509 | "hgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "h"} },
510 |
511 | # Custom hgemm tiles designed for small minibatch RNNs
512 | "hgemm_32x64x32_NN": {"threads": 128, "sass": "hgemm_32x64x32_NN", "params": "xgemm", "share": "32*33*2 + 64*32*2 + 4" },
513 | "hgemm_32x32x64_NT": {"threads": 128, "sass": "hgemm_32x32x64_NT", "params": "xgemm", "share": "32*65*4 + 4" },
514 | "hgemm_16x64x64_NN": {"threads": 128, "sass": "hgemm_16x64x64_NN", "params": "xgemm", "share": "(16*64 + 32)*2 + 64*64*2 + 4" },
515 | "hgemm_16x64x64_NT": {"threads": 128, "sass": "hgemm_16x64x64_NT", "params": "xgemm", "share": "(16*64 + 32)*2 + (64*64 + 32)*2 + 4" },
516 | }
517 |
518 | _params = {
519 | "xgemm": [
520 | "float* param_C",
521 | "float* param_A",
522 | "float* param_B",
523 | "float param_alpha",
524 | "float param_beta",
525 | "unsigned param_cda",
526 | "unsigned param_cdb",
527 | "unsigned param_cdc",
528 | "unsigned param_m",
529 | "unsigned param_n",
530 | "unsigned param_k",
531 | "unsigned param_blk_a",
532 | "unsigned param_blk_b",
533 | ],
534 | }
535 |
536 | _space_re = re.compile(r"\s+")
537 |
538 | _share_template = r"""
539 | .shared .align 4 .b32 share[{0}];
540 | """
541 |
542 | _kernel_template = r"""
543 | .version {6}
544 | .target {0}
545 | .address_size 64
546 |
547 | // args: {5}
548 |
549 | .visible .entry {1}(
550 | {2}
551 | )
552 | .reqntid {3}
553 | {{
554 | {4}
555 | ret;
556 | }}
557 | """
558 |
559 | def get_ptx_file(kernel_spec, kernel_name, arch, ptx_ver):
560 |
561 | ptx_dir = _get_cache_dir([arch, 'ptx'])
562 |
563 | thread_spec = kernel_spec["threads"]
564 | args_spec = str(kernel_spec.get("args",""))
565 | param_spec = _params[kernel_spec["params"]]
566 |
567 | kernel_params = []
568 | for p in param_spec:
569 | ptype, pname = _space_re.split(p)
570 |
571 | if ptype[-1] == '*':
572 | ptype = '.u64'
573 | elif ptype == 'float':
574 | ptype = '.f32'
575 | else:
576 | ptype = '.u32'
577 |
578 | kernel_params.append(" .param %s %s" % (ptype, pname))
579 |
580 | kernel_params = ",\n".join(kernel_params)
581 |
582 | if "share" in kernel_spec:
583 | share = _share_template.format(eval(kernel_spec["share"]))
584 | else:
585 | share = ""
586 |
587 | kernel_text = _kernel_template.format(arch, kernel_name, kernel_params, thread_spec, share, args_spec, ptx_ver)
588 | kernel_ptx = os.path.join(ptx_dir, kernel_name + ".ptx")
589 |
590 | current_text = ""
591 | if os.path.exists(kernel_ptx):
592 | f = open(kernel_ptx, "r")
593 | current_text = f.read()
594 | f.close()
595 | # only write out the kernel if text has changed.
596 | if kernel_text != current_text:
597 | f = open(kernel_ptx, "w")
598 | f.write(kernel_text)
599 | f.close()
600 |
601 | return kernel_ptx
602 |
603 |
604 | include_re = re.compile(r'^')
605 |
606 | def extract_includes(name, includes=None):
607 | if not includes:
608 | includes = list()
609 | sass_file = os.path.join(sass_dir, name)
610 | includes.append((sass_file, os.path.getmtime(sass_file)))
611 | for line in open(sass_file, "r"):
612 | match = include_re.search(line)
613 | if match:
614 | extract_includes(match.group(1), includes)
615 | return includes
616 |
617 | def run_command(cmdlist):
618 | cmd = " ".join(cmdlist)
619 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
620 | out, err = proc.communicate()
621 | if proc.returncode:
622 | raise RuntimeError("Error(%d):\n%s\n%s" % (proc.returncode, cmd, err))
623 | if debug:
624 | print cmd
625 | if out: print out
626 | if err: print err
627 |
628 |
629 | @context_dependent_memoize
630 | def get_kernel(base_name, options=None):
631 |
632 | major, minor = _get_compute_capability()
633 | if major < 5:
634 | raise RuntimeError("sass kernels require Maxwell or greater class hardware")
635 |
636 | arch = "sm_%d%d" % (major, minor)
637 |
638 | libprefix = "PERL5LIB=%s" % maxas_dir
639 | maxas_i = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -i -w"]
640 | maxas_p = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -p"]
641 |
642 | kernel_spec = kernels[base_name]
643 | kernel_name = base_name
644 |
645 | # static options
646 | if "args" in kernel_spec:
647 | for pair in kernel_spec["args"].items():
648 | maxas_i.append("-D%s %s" % pair)
649 | maxas_p.append("-D%s %s" % pair)
650 |
651 | # dynamic options
652 | if options is not None:
653 | for opt in options:
654 | if type(opt) is tuple:
655 | maxas_i.append("-D%s %s" % opt)
656 | maxas_p.append("-D%s %s" % opt)
657 | kernel_name += "_%s%s" % opt
658 | else:
659 | maxas_i.append("-D%s 1" % opt)
660 | maxas_p.append("-D%s 1" % opt)
661 | kernel_name += "_%s" % opt
662 |
663 | maxas_i.insert(2, "-k " + kernel_name)
664 |
665 | sass_name = kernel_spec["sass"] + ".sass"
666 | cubin_name = kernel_name + ".cubin"
667 | cubin_dir = _get_cache_dir([arch, 'cubin'])
668 |
669 | ptx_version = "4.2" if major < 6 else "5.0"
670 | ptx_file = get_ptx_file(kernel_spec, kernel_name, arch, ptx_version)
671 | sass_file = os.path.join(sass_dir, sass_name)
672 | cubin_file = os.path.join(cubin_dir, cubin_name)
673 |
674 | if not os.path.exists(sass_file):
675 | raise RuntimeError("Missing: %s for kernel: %s" % (sass_name, kernel_name))
676 |
677 | ptx_mtime = os.path.getmtime(ptx_file)
678 | cubin_mtime = os.path.getmtime(cubin_file) if os.path.exists(cubin_file) else 0
679 |
680 | build_cubin = False
681 | if ptx_mtime > cubin_mtime:
682 | build_cubin = True
683 |
684 | includes = extract_includes(sass_name)
685 | for include, include_mtime in includes:
686 | if include_mtime > cubin_mtime:
687 | build_cubin = True
688 | break
689 |
690 | if build_cubin:
691 | # build the cubin and run maxas in the same command
692 | # we don't want the chance of a generated cubin not processed by maxas (in case user hits ^C in between these steps)
693 | run_command([ "ptxas -v -arch", arch, "-o", cubin_file, ptx_file, ";" ] + maxas_i + [sass_file, cubin_file])
694 | cubin_mtime = time.time()
695 |
696 | # output preprocessed and disassembled versions in debug mode
697 | if debug:
698 | pre_dir = _get_cache_dir([arch, 'pre'])
699 | dump_dir = _get_cache_dir([arch, 'dump'])
700 |
701 | pre_file = os.path.join(pre_dir, kernel_name + "_pre.sass")
702 | dump_file = os.path.join(dump_dir, kernel_name + "_dump.sass")
703 | pre_mtime = os.path.getmtime(pre_file) if os.path.exists(pre_file) else 0
704 | dump_mtime = os.path.getmtime(dump_file) if os.path.exists(dump_file) else 0
705 |
706 | for include, include_mtime in includes:
707 | if include_mtime > pre_mtime:
708 | run_command(maxas_p + [sass_file, pre_file])
709 | break
710 |
711 | if cubin_mtime > dump_mtime:
712 | run_command(["nvdisasm -c", cubin_file, ">", dump_file])
713 |
714 | # generate the function signature for pycuda
715 | params = _params[kernel_spec["params"]]
716 | sig = ""
717 | for p in params:
718 | ptype, pname = _space_re.split(p)
719 | if ptype[-1] == '*':
720 | sig += "Q"
721 | elif ptype == 'float':
722 | sig += "f"
723 | elif ptype == 'unsigned':
724 | sig += "I"
725 | else:
726 | sig += "i"
727 |
728 | module = drv.module_from_file(cubin_file)
729 | func = module.get_function(kernel_name)
730 | func.prepare(sig)
731 | func.threads = kernel_spec["threads"]
732 | func.name = kernel_name
733 | func.static_shared = eval(kernel_spec["share"])
734 |
735 | return func
--------------------------------------------------------------------------------
/sass/hgemm_32x64x32_NN.sass:
--------------------------------------------------------------------------------
1 | # Kernel: hgemm_32x64x32_NN
2 |
3 | [-
4 | our $vec;
5 | sub vector { return $vec; }
6 | -]
7 |
8 |
9 | addr_zero : 4x<32*33*2 + 64*32*2>
10 | szShareA : (32*33)
11 | szShareB : (64*32)
12 |
13 | param_C[0] : c[0x0][0x140]
14 | param_C[1] : c[0x0][0x144]
15 | param_A[0] : c[0x0][0x148]
16 | param_A[1] : c[0x0][0x14c]
17 | param_B[0] : c[0x0][0x150]
18 | param_B[1] : c[0x0][0x154]
19 | param_alpha : c[0x0][0x158]
20 | param_beta : c[0x0][0x15c]
21 | param_cda : c[0x0][0x160]
22 | param_cdb : c[0x0][0x164]
23 | param_cdc : c[0x0][0x168]
24 | param_m : c[0x0][0x16c]
25 | param_n : c[0x0][0x170]
26 | param_k : c[0x0][0x174]
27 | param_blk_a : c[0x0][0x178]
28 | param_blk_b : c[0x0][0x17c]
29 |
30 |
31 |
32 |
33 | 3, 2,11,10,19,18,27,26 : cx<0-7>y0
34 | 7, 6,15,14,23,22,31,30 : cx<0-7>y1
35 | 1, 0, 9, 8,17,16,25,24 : cx<0-7>y2
36 | 5, 4,13,12,21,20,29,28 : cx<0-7>y3
37 | 35,34,43,42,51,50,59,58 : cx<0-7>y4
38 | 39,38,47,46,55,54,63,62 : cx<0-7>y5
39 | 33,32,41,40,49,48,57,56 : cx<0-7>y6
40 | 37,36,45,44,53,52,61,60 : cx<0-7>y7
41 |
42 | 0-63 : czero<00-63>
43 | 64-79 : j0Ay<0-7>, j0Bx<0-7>
44 | 80-95 : j1Ay<0-7>, j1Bx<0-7>
45 |
46 | 64-95 ~ cda, cdb, cdb8, tidAX, tidAY, tidBX, tidBY, tidAY<1-3>, tidBY<8|16|24>, tid1, tid32, tb, shiftAX, partialK, partialB, ta, txa, txb, txb<1-3>, xmad_ta, xmad_tb
47 | 96-119 ~ idx_ab, idx_ab_f, idx_a, neg_blk_b, rcp_blk_b, idx_b
48 |
49 |
50 | 96-119 : load0A<0-7>, load0B<0-3>, load1B<0-3>, load2B<0-3>, load3B<0-3>
51 | 120-129 : track0A<0-1>, track0B<0-1>, track1B<0-1>, track2B<0-1>, track3B<0-1>
52 |
53 | 130-137 ~ swapBuf, readAs, readBs, writeAs, writeBs, k, cdb32
54 | 138-144 ~ tid, idx_A, idx_B, writeCs, preds
55 |
56 | 0-15 : part0C<0-3>, part1C<0-3>, part2C<0-3>, part3C<0-3>
57 | 64-95 : shuffle_x<0-7>y0, shuffle_x<0-7>y1, shuffle_x<0-7>y2, shuffle_x<0-7>y3
58 | 64-95 : shuffle_x<0-7>y4, shuffle_x<0-7>y5, shuffle_x<0-7>y6, shuffle_x<0-7>y7
59 | 96-99 : loadC<0-3>
60 | 100-103 : b<0-3>
61 | 104-107 : c<0-3>
62 | 108-109 : C<0-1>
63 | 110-137 ~ cdc, cx, cx<1-3>, cy, ci, xmad_c, cdc8, readCs, alpha, beta, tid15, tid16
64 |
65 |
66 |
67 | --:-:5:-:1 I2F.F32.S32 rcp_blk_b, param_blk_b;
68 | --:-:1:-:1 S2R tid, SR_TID.X;
69 | --:-:2:-:1 S2R idx_ab, SR_CTAID.X;
70 | --:-:3:-:1 S2R idx_A, SR_CTAID.Z;
71 | --:-:4:-:1 S2R idx_B, SR_CTAID.Y;
72 |
73 |
74 | 10:-:5:-:1 MUFU.RCP rcp_blk_b, rcp_blk_b;
75 |
76 | --:-:-:-:1 MOV k, param_k;
77 | --:-:-:-:1 MOV cda, param_cda;
78 | --:-:-:-:1 MOV cdb, param_cdb;
79 | --:-:-:-:1 SHL cdb8, cdb, 3;
80 | --:-:-:-:1 SHL cdb32, cdb, 6;
81 |
82 | // If k is not a multiple of 32 we want to grab the partial amount on the first fetch.
83 | // If it is a multiple of 32 then make a full 32 line fetch.
84 | --:-:-:-:1 LOP.AND.Z P0, partialK, k, 31;
85 | --:-:-:-:1 @P0 MOV partialK, 32;
86 | --:-:-:-:1 IADD k, k, -partialK;
87 |
88 | # idx_a = idx_ab // blk_b
89 | 02:-:2:-:1 I2F.F32.S32 idx_ab_f, idx_ab;
90 | 12:-:-:-:1 FMUL idx_a, idx_ab_f, rcp_blk_b;
91 | --:-:-:-:1 FFMA idx_a, idx_a, 5.9604644775390625e-08, idx_a;
92 | --:-:2:-:1 F2I.S32.F32.TRUNC idx_a, idx_a;
93 | # idx_b = idx_AB % blk_b
94 | --:-:-:-:1 IADD neg_blk_b, RZ, -param_blk_b;
95 | 02:-:-:-:1 XMAD.S16.U16 idx_b, neg_blk_b, idx_a, idx_ab;
96 |
97 | # idx_A = idx_A * blk_a + idx_a
98 | # idx_B = idx_B * blk_b + idx_b
99 | 06:-:-:-:1 XMAD.U16.U16 idx_A, idx_A, param_blk_a, idx_a;
100 | 08:-:-:-:1 XMAD.U16.U16 idx_B, idx_B, param_blk_b, idx_b;
101 |
102 |
103 | --:-:-:-:1 STS.128 [addr_zero], RZ;
104 | [+ join '', map sprintf("--:-:-:-:1 LDS.U.128 czero%02d, [addr_zero];\n", $_ * 4), 0..15; +]
105 |
106 | // tidAX = tid >> 2
107 | // tidAY = (tid & 3) << 3
108 | // shiftAX = (tid & 3) << 3
109 | 01:-:-:-:1 SHR.U32 tidAX, tid, 2;
110 | --:-:-:-:1 LOP.AND tidAY, tid, 3;
111 | --:-:-:-:1 SHL shiftAX, tidAY, 3;
112 | --:-:-:-:1 SHL tidAY, tidAY, 3;
113 |
114 | // tidBX = (tid & 15) << 2
115 | // tidBY = tid >> 4
116 | 01:-:-:-:1 LOP.AND tidBX, tid, 15;
117 | --:-:-:-:1 SHL tidBX, tidBX, 2;
118 | --:-:-:-:1 SHR.U32 tidBY, tid, 4;
119 |
120 | --:-:-:-:1 IADD tidBY8, tidBY, 8;
121 | --:-:-:-:1 IADD tidBY16, tidBY, 16;
122 | --:-:-:-:1 IADD tidBY24, tidBY, 24;
123 |
124 | // trackA += ((idx_A*32 + tidAX) * cda + tidAY) * 2
125 | --:-:-:-:1 ISCADD txa, idx_A, tidAX, 5;
126 | --:-:-:-:1 XMAD.LO ta, cda, txa, tidAY, xmad_ta;
127 | --:-:-:-:1 LEA track0A0.CC, ta, param_A[0], 1;
128 | --:-:-:-:1 LEA.HI.X track0A1, ta, param_A[1], RZ, 1;
129 |
130 | --:-:-:-:1 ISETP.LT.AND P2, PT, txa, param_m, PT;
131 |
132 | // trackB += (idx_B*64 + tidBX + cdb*tidBY) * 2
133 | --:-:-:-:1 ISCADD txb, idx_B, tidBX, 6;
134 | --:-:-:-:1 XMAD.LO2 tb, cdb, tidBY, txb;
135 | --:-:-:-:1 LEA track0B0.CC, tb, param_B[0], 1;
136 | --:-:-:-:1 LEA.HI.X track0B1, tb, param_B[1], RZ, 1;
137 | --:-:-:-:1 IADD tb, tb, cdb8;
138 | --:-:-:-:1 LEA track1B0.CC, tb, param_B[0], 1;
139 | --:-:-:-:1 LEA.HI.X track1B1, tb, param_B[1], RZ, 1;
140 | --:-:-:-:1 IADD tb, tb, cdb8;
141 | --:-:-:-:1 LEA track2B0.CC, tb, param_B[0], 1;
142 | --:-:-:-:1 LEA.HI.X track2B1, tb, param_B[1], RZ, 1;
143 | --:-:-:-:1 IADD tb, tb, cdb8;
144 | --:-:-:-:1 LEA track3B0.CC, tb, param_B[0], 1;
145 | --:-:-:-:1 LEA.HI.X track3B1, tb, param_B[1], RZ, 1;
146 |
147 | --:-:-:-:1 ISETP.LT.AND P3, PT, txb, param_n, PT;
148 | [+
149 | return vector() ? '' : q{
150 | --:-:-:-:1 IADD txb1, txb, 1;
151 | --:-:-:-:1 IADD txb2, txb, 2;
152 | --:-:-:-:1 IADD txb3, txb, 3;
153 | --:-:-:-:1 ISETP.LT.AND P4, PT, txb1, param_n, PT;
154 | --:-:-:-:1 ISETP.LT.AND P5, PT, txb2, param_n, PT;
155 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb3, param_n, PT;
156 | };
157 | +]
158 | --:-:-:-:1 P2R preds, PR, RZ, 0x7c;
159 |
160 | // writeAs = (tidAY*32 + tidAX + shiftAX) * 4
161 | --:-:-:-:1 ISCADD writeAs, tidAY, tidAX, 5;
162 | --:-:-:-:1 IADD writeAs, writeAs, shiftAX;
163 | --:-:-:-:1 SHL writeAs, writeAs, 2;
164 |
165 | // writeBs = (tidBY*64 + tidBX) * 4
166 | --:-:-:-:1 ISCADD writeBs, tidBY, tidBX, 6;
167 | --:-:-:-:1 ISCADD writeBs, writeBs, 4x, 2;
168 |
169 | // readAs = (((tid & 16) >> 2) | (tid & 1)) << 4
170 | --:-:-:-:1 LOP.AND tid1, tid, 1;
171 | --:-:-:-:1 LOP.AND readAs, tid, 16;
172 | --:-:-:-:1 SHR.U32 readAs, readAs, 3;
173 | --:-:-:-:1 LOP.OR readAs, readAs, tid1;
174 | --:-:-:-:1 SHL readAs, readAs, 4;
175 |
176 | // readBs = (((tid >> 1) & 7) << 4
177 | --:-:-:-:1 BFE.U32 readBs, tid, 0x301; // 2 bits at position 1
178 | --:-:-:-:1 SHL readBs, readBs, 4;
179 |
180 | // Each tile has 32 threads so this is an index into the 4 tiles (at bit position 5)
181 | // tid32 = tid & -32
182 | --:-:-:-:1 LOP.AND tid32, tid, -32;
183 |
184 | // Write out the 4 groups of 32 rows 16 at a time
185 | // writeCs = (readAs + tid32/2*4) * 64 + readBs
186 | --:-:-:-:1 ISCADD writeCs, tid32, readAs, 1;
187 | --:-:-:-:1 ISCADD writeCs, writeCs, readBs, 6;
188 |
189 | // Each block of 32 threads works on 8 lines,
190 | // readAs is also shifted over by 8 for each group of 32 threads
191 | // readAs += tid32/4 * 32 * 4 + tid32/4 * 4
192 | // readBs += tid32/4 * 64 * 4 + 4x
193 | --:-:-:-:1 ISCADD readAs, tid32, readAs, 5;
194 | --:-:-:-:1 ISCADD readBs, tid32, readBs, 6;
195 | --:-:-:-:1 IADD readAs, tid32, readAs;
196 | --:-:-:-:1 IADD readBs, readBs, 4x;
197 |
198 | --:-:-:-:1 MOV32I swapBuf, 4x;
199 |
200 | [+
201 | return vector() ? q{
202 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidAY, partialK, P2;
203 | --:-:-:-:1 ISETP.LT.AND P4, PT, tidBY8, partialK, P3;
204 | --:-:-:-:1 ISETP.LT.AND P5, PT, tidBY16, partialK, P3;
205 | --:-:-:-:1 ISETP.LT.AND P6, PT, tidBY24, partialK, P3;
206 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidBY, partialK, P3;
207 |
208 | --:-:2:-:1 @P2 LDG.E.CI.128 load0A, [track0A];
209 | --:-:3:-:1 @P3 LDG.E.CI.64 load0B, [track0B];
210 | --:-:4:-:1 @P4 LDG.E.CI.64 load1B, [track1B];
211 | --:-:5:-:1 @P5 LDG.E.CI.64 load2B, [track2B];
212 | --:-:6:-:1 @P6 LDG.E.CI.64 load3B, [track3B];
213 |
214 |
215 | --:-:-:-:1 @!P2 LDS.U.128 load0A, [addr_zero];
216 | --:-:-:-:1 @!P3 LDS.U.64 load0B, [addr_zero];
217 | --:-:-:-:1 @!P4 LDS.U.64 load1B, [addr_zero];
218 | --:-:-:-:1 @!P5 LDS.U.64 load2B, [addr_zero];
219 | --:-:1:-:1 @!P6 LDS.U.64 load3B, [addr_zero];
220 |
221 | } : q{
222 |
223 | --:-:-:-:1 IADD tidAY1, tidAY, 1;
224 | --:-:-:-:1 IADD tidAY2, tidAY, 2;
225 | --:-:-:-:1 IADD tidAY3, tidAY, 3;
226 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidAY, partialK, P2;
227 | --:-:-:-:1 ISETP.LT.AND P4, PT, tidAY1, partialK, P2;
228 | --:-:-:-:1 ISETP.LT.AND P5, PT, tidAY2, partialK, P2;
229 | --:-:-:-:1 ISETP.LT.AND P6, PT, tidAY3, partialK, P2;
230 |
231 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0A0, [track0A + 2x<0>];
232 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0A1, [track0A + 2x<1>];
233 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0A2, [track0A + 2x<2>];
234 | --:-:2:-:1 @P6 LDG.E.CI.U16 load0A3, [track0A + 2x<3>];
235 |
236 | --:-:-:-:1 @!P3 MOV load0A0, RZ;
237 | --:-:-:-:1 @!P4 MOV load0A1, RZ;
238 | --:-:-:-:1 @!P5 MOV load0A2, RZ;
239 | --:-:-:-:1 @!P6 MOV load0A3, RZ;
240 |
241 | --:-:-:-:1 IADD tidAY, tidAY, 4;
242 | --:-:-:-:1 IADD tidAY1, tidAY1, 4;
243 | --:-:-:-:1 IADD tidAY2, tidAY2, 4;
244 | --:-:-:-:1 IADD tidAY3, tidAY3, 4;
245 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidAY, partialK, P2;
246 | --:-:-:-:1 ISETP.LT.AND P4, PT, tidAY1, partialK, P2;
247 | --:-:-:-:1 ISETP.LT.AND P5, PT, tidAY2, partialK, P2;
248 | --:-:-:-:1 ISETP.LT.AND P6, PT, tidAY3, partialK, P2;
249 |
250 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0A4, [track0A + 2x<4>];
251 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0A5, [track0A + 2x<5>];
252 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0A6, [track0A + 2x<6>];
253 | --:-:2:-:1 @P6 LDG.E.CI.U16 load0A7, [track0A + 2x<7>];
254 |
255 | --:-:-:-:1 @!P3 MOV load0A4, RZ;
256 | --:-:-:-:1 @!P4 MOV load0A5, RZ;
257 | --:-:-:-:1 @!P5 MOV load0A6, RZ;
258 | --:-:-:-:1 @!P6 MOV load0A7, RZ;
259 |
260 |
261 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidBY, partialK, PT;
262 | --:-:-:-:1 @P0 R2P PR, preds, 0x78;
263 | --:-:-:-:1 @!P0 R2P PR, RZ, 0x78;
264 |
265 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0B0, [track0B + 2x<0>];
266 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0B1, [track0B + 2x<1>];
267 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0B2, [track0B + 2x<2>];
268 | --:-:3:-:1 @P6 LDG.E.CI.U16 load0B3, [track0B + 2x<3>];
269 |
270 | --:-:-:-:1 @!P3 MOV load0B0, RZ;
271 | --:-:-:-:1 @!P4 MOV load0B1, RZ;
272 | --:-:-:-:1 @!P5 MOV load0B2, RZ;
273 | --:-:-:-:1 @!P6 MOV load0B3, RZ;
274 |
275 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY8, partialK, PT;
276 | --:-:-:-:1 @P1 R2P PR, preds, 0x78;
277 | --:-:-:-:1 @!P1 R2P PR, RZ, 0x78;
278 |
279 | --:-:-:-:1 @P3 LDG.E.CI.U16 load1B0, [track1B + 2x<0>];
280 | --:-:-:-:1 @P4 LDG.E.CI.U16 load1B1, [track1B + 2x<1>];
281 | --:-:-:-:1 @P5 LDG.E.CI.U16 load1B2, [track1B + 2x<2>];
282 | --:-:4:-:1 @P6 LDG.E.CI.U16 load1B3, [track1B + 2x<3>];
283 |
284 | --:-:-:-:1 @!P3 MOV load1B0, RZ;
285 | --:-:-:-:1 @!P4 MOV load1B1, RZ;
286 | --:-:-:-:1 @!P5 MOV load1B2, RZ;
287 | --:-:-:-:1 @!P6 MOV load1B3, RZ;
288 |
289 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidBY16, partialK, PT;
290 | --:-:-:-:1 @P2 R2P PR, preds, 0x78;
291 | --:-:-:-:1 @!P2 R2P PR, RZ, 0x78;
292 |
293 | --:-:-:-:1 @P3 LDG.E.CI.U16 load2B0, [track2B + 2x<0>];
294 | --:-:-:-:1 @P4 LDG.E.CI.U16 load2B1, [track2B + 2x<1>];
295 | --:-:-:-:1 @P5 LDG.E.CI.U16 load2B2, [track2B + 2x<2>];
296 | --:-:5:-:1 @P6 LDG.E.CI.U16 load2B3, [track2B + 2x<3>];
297 |
298 | --:-:-:-:1 @!P3 MOV load2B0, RZ;
299 | --:-:-:-:1 @!P4 MOV load2B1, RZ;
300 | --:-:-:-:1 @!P5 MOV load2B2, RZ;
301 | --:-:-:-:1 @!P6 MOV load2B3, RZ;
302 |
303 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidBY24, partialK, PT;
304 | --:-:-:-:1 @P0 R2P PR, preds, 0x78;
305 | --:-:-:-:1 @!P0 R2P PR, RZ, 0x78;
306 |
307 | --:-:-:-:1 @P3 LDG.E.CI.U16 load3B0, [track3B + 2x<0>];
308 | --:-:-:-:1 @P4 LDG.E.CI.U16 load3B1, [track3B + 2x<1>];
309 | --:-:-:-:1 @P5 LDG.E.CI.U16 load3B2, [track3B + 2x<2>];
310 | --:-:6:-:1 @P6 LDG.E.CI.U16 load3B3, [track3B + 2x<3>];
311 |
312 | --:-:-:-:1 @!P3 MOV load3B0, RZ;
313 | --:-:-:-:1 @!P4 MOV load3B1, RZ;
314 | --:-:-:-:1 @!P5 MOV load3B2, RZ;
315 | --:-:-:-:1 @!P6 MOV load3B3, RZ;
316 |
317 | };
318 | +]
319 | // partialB = partialK * cdb
320 | --:-:-:-:1 XMAD.LO2 partialB, cdb, partialK, RZ;
321 |
322 | --:-:-:-:1 ISETP.GE.AND P1, PT, k, 32, PT;
323 | --:-:-:-:1 IADD k, k, -32;
324 | --:-:-:-:1 @P1 R2P PR, preds, 0x7c;
325 | --:-:-:-:1 @!P1 R2P PR, RZ, 0x7c;
326 |
327 |
328 | [+
329 | return vector() ? q{
330 | 03:-:-:-:1 F2F.F32.F16 load0A7, load0A3.H1;
331 | --:-:-:-:1 F2F.F32.F16 load0A6, load0A3.H0;
332 | --:-:-:-:1 F2F.F32.F16 load0A5, load0A2.H1;
333 | --:-:1:-:1 F2F.F32.F16 load0A4, load0A2.H0;
334 | --:-:-:-:1 F2F.F32.F16 load0A3, load0A1.H1;
335 | --:-:-:-:1 F2F.F32.F16 load0A2, load0A1.H0;
336 | --:-:-:-:1 F2F.F32.F16 load0A1, load0A0.H1;
337 | --:-:2:-:1 F2F.F32.F16 load0A0, load0A0.H0;
338 | } : q{
339 | 02:-:-:-:1 F2F.F32.F16 load0A7, load0A7;
340 | --:-:-:-:1 F2F.F32.F16 load0A6, load0A6;
341 | --:-:-:-:1 F2F.F32.F16 load0A5, load0A5;
342 | --:-:1:-:1 F2F.F32.F16 load0A4, load0A4;
343 | --:-:-:-:1 F2F.F32.F16 load0A3, load0A3;
344 | --:-:-:-:1 F2F.F32.F16 load0A2, load0A2;
345 | --:-:-:-:1 F2F.F32.F16 load0A1, load0A1;
346 | --:-:2:-:1 F2F.F32.F16 load0A0, load0A0;
347 | };
348 | +]
349 | --:-:-:-:0 LEA track0A0.CC, partialK, track0A0, 1;
350 | 01:-:-:-:1 STS [writeAs + 4x<7*32>], load0A7;
351 | --:-:-:-:1 STS [writeAs + 4x<6*32>], load0A6;
352 | --:-:-:-:1 STS [writeAs + 4x<5*32>], load0A5;
353 | --:-:-:-:1 STS [writeAs + 4x<4*32>], load0A4;
354 | 02:-:-:-:1 STS [writeAs + 4x<3*32>], load0A3;
355 | --:-:-:-:1 STS [writeAs + 4x<2*32>], load0A2;
356 | --:-:-:-:1 STS [writeAs + 4x<1*32>], load0A1;
357 | --:-:-:-:1 STS [writeAs + 4x<0*32>], load0A0;
358 | --:-:-:-:0 IADD.X track0A1, track0A1, RZ;
359 |
360 | [+
361 | return vector() ? q{
362 | 04:-:-:-:1 F2F.F32.F16 load0B3, load0B1.H1;
363 | --:-:-:-:1 F2F.F32.F16 load0B2, load0B1.H0;
364 | --:-:-:-:1 F2F.F32.F16 load0B1, load0B0.H1;
365 | --:-:3:-:1 F2F.F32.F16 load0B0, load0B0.H0;
366 |
367 | 08:-:-:-:1 F2F.F32.F16 load1B3, load1B1.H1;
368 | --:-:-:-:1 F2F.F32.F16 load1B2, load1B1.H0;
369 | --:-:-:-:1 F2F.F32.F16 load1B1, load1B0.H1;
370 | --:-:4:-:1 F2F.F32.F16 load1B0, load1B0.H0;
371 |
372 | 10:-:-:-:1 F2F.F32.F16 load2B3, load2B1.H1;
373 | --:-:-:-:1 F2F.F32.F16 load2B2, load2B1.H0;
374 | --:-:-:-:1 F2F.F32.F16 load2B1, load2B0.H1;
375 | --:-:5:-:1 F2F.F32.F16 load2B0, load2B0.H0;
376 |
377 | 20:-:-:-:1 F2F.F32.F16 load3B3, load3B1.H1;
378 | --:-:-:-:1 F2F.F32.F16 load3B2, load3B1.H0;
379 | --:-:-:-:1 F2F.F32.F16 load3B1, load3B0.H1;
380 | --:-:6:-:1 F2F.F32.F16 load3B0, load3B0.H0;
381 | } : q{
382 | 04:-:-:-:1 F2F.F32.F16 load0B0, load0B0;
383 | --:-:-:-:1 F2F.F32.F16 load0B1, load0B1;
384 | --:-:-:-:1 F2F.F32.F16 load0B2, load0B2;
385 | --:-:3:-:1 F2F.F32.F16 load0B3, load0B3;
386 |
387 | 08:-:-:-:1 F2F.F32.F16 load1B0, load1B0;
388 | --:-:-:-:1 F2F.F32.F16 load1B1, load1B1;
389 | --:-:-:-:1 F2F.F32.F16 load1B2, load1B2;
390 | --:-:4:-:1 F2F.F32.F16 load1B3, load1B3;
391 |
392 | 10:-:-:-:1 F2F.F32.F16 load2B0, load2B0;
393 | --:-:-:-:1 F2F.F32.F16 load2B1, load2B1;
394 | --:-:-:-:1 F2F.F32.F16 load2B2, load2B2;
395 | --:-:5:-:1 F2F.F32.F16 load2B3, load2B3;
396 |
397 | 20:-:-:-:1 F2F.F32.F16 load3B0, load3B0;
398 | --:-:-:-:1 F2F.F32.F16 load3B1, load3B1;
399 | --:-:-:-:1 F2F.F32.F16 load3B2, load3B2;
400 | --:-:6:-:1 F2F.F32.F16 load3B3, load3B3;
401 | };
402 | +]
403 |
404 | --:-:-:-:0 LEA track0B0.CC, partialB, track0B0, 1;
405 | 04:-:-:-:6 STS.128 [writeBs + 4x<0*64>], load0B;
406 | --:-:-:-:1 IADD.X track0B1, track0B1, RZ;
407 |
408 | --:-:-:-:0 LEA track1B0.CC, partialB, track1B0, 1;
409 | 08:-:-:-:6 STS.128 [writeBs + 4x<8*64>], load1B;
410 | --:-:-:-:1 IADD.X track1B1, track1B1, RZ;
411 |
412 | --:-:-:-:0 LEA track2B0.CC, partialB, track2B0, 1;
413 | 10:-:-:-:6 STS.128 [writeBs + 4x<16*64>], load2B;
414 | --:-:-:-:1 IADD.X track2B1, track2B1, RZ;
415 |
416 | --:-:-:-:0 LEA track3B0.CC, partialB, track3B0, 1;
417 | 20:-:-:-:6 STS.128 [writeBs + 4x<24*64>], load3B;
418 | --:-:-:-:0 IADD.X track3B1, track3B1, RZ;
419 |
420 | --:-:-:-:5 BAR.SYNC 0;
421 | --:-:-:-:1 IADD writeBs, writeBs, swapBuf;
422 | --:-:-:-:1 IADD writeAs, writeAs, swapBuf;
423 | --:-:-:-:0 IADD swapBuf, RZ, -swapBuf;
424 |
425 | --:-:-:-:1 LDS.U.128 j0Ay0, [readAs + 4x<0*32 + 00>];
426 | --:-:-:-:1 LDS.U.128 j0Bx0, [readBs + 4x<0*64 + 00>];
427 | --:-:-:-:1 LDS.U.128 j0Ay4, [readAs + 4x<0*32 + 16>];
428 | --:-:1:-:1 LDS.U.128 j0Bx4, [readBs + 4x<0*64 + 32>];
429 |
430 | [+
431 | return vector() ? q{
432 | --:-:2:-:1 @P2 LDG.E.CI.128 load0A, [track0A];
433 | --:-:3:-:1 @P3 LDG.E.CI.64 load0B, [track0B];
434 | --:-:4:-:1 @P3 LDG.E.CI.64 load1B, [track1B];
435 | --:-:5:-:1 @P3 LDG.E.CI.64 load2B, [track2B];
436 | --:-:6:-:1 @P3 LDG.E.CI.64 load3B, [track3B];
437 | } : q{
438 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A0, [track0A + 2x<0>];
439 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A1, [track0A + 2x<1>];
440 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A2, [track0A + 2x<2>];
441 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A3, [track0A + 2x<3>];
442 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A4, [track0A + 2x<4>];
443 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A5, [track0A + 2x<5>];
444 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A6, [track0A + 2x<6>];
445 | --:-:2:-:1 @P2 LDG.E.CI.U16 load0A7, [track0A + 2x<7>];
446 |
447 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0B0, [track0B + 2x<0>];
448 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0B1, [track0B + 2x<1>];
449 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0B2, [track0B + 2x<2>];
450 | --:-:3:-:1 @P6 LDG.E.CI.U16 load0B3, [track0B + 2x<3>];
451 |
452 | --:-:-:-:1 @P3 LDG.E.CI.U16 load1B0, [track1B + 2x<0>];
453 | --:-:-:-:1 @P4 LDG.E.CI.U16 load1B1, [track1B + 2x<1>];
454 | --:-:-:-:1 @P5 LDG.E.CI.U16 load1B2, [track1B + 2x<2>];
455 | --:-:4:-:1 @P6 LDG.E.CI.U16 load1B3, [track1B + 2x<3>];
456 |
457 | --:-:-:-:1 @P3 LDG.E.CI.U16 load2B0, [track2B + 2x<0>];
458 | --:-:-:-:1 @P4 LDG.E.CI.U16 load2B1, [track2B + 2x<1>];
459 | --:-:-:-:1 @P5 LDG.E.CI.U16 load2B2, [track2B + 2x<2>];
460 | --:-:5:-:1 @P6 LDG.E.CI.U16 load2B3, [track2B + 2x<3>];
461 |
462 | --:-:-:-:1 @P3 LDG.E.CI.U16 load3B0, [track3B + 2x<0>];
463 | --:-:-:-:1 @P4 LDG.E.CI.U16 load3B1, [track3B + 2x<1>];
464 | --:-:-:-:1 @P5 LDG.E.CI.U16 load3B2, [track3B + 2x<2>];
465 | --:-:6:-:1 @P6 LDG.E.CI.U16 load3B3, [track3B + 2x<3>];
466 | };
467 | +]
468 |
469 | LOOP:
470 |
471 | [+
472 | our %insert =
473 | (
474 | j0c8 => "--:-:-:-:1 ISETP.GE.AND P0, PT, k, RZ, PT;\n",
475 | j0c10 => "--:-:-:-:1 ISETP.GE.AND P1, PT, k, 32, PT;\n" .
476 | "--:-:-:-:1 IADD k, k, -32;\n",
477 |
478 | j0c23 => "--:-:-:-:1 \@P1 R2P PR, preds, 0x7c;\n",
479 | j0c24 => "--:-:-:-:1 \@!P1 R2P PR, RZ, 0x7c;\n",
480 |
481 | j2c32 => "--:-:-:-:1 \@P2 IADD track0A0.CC, track0A0, 2x<32>;\n",
482 | j2c37 => "--:-:-:-:1 \@P2 IADD.X track0A1, track0A1, RZ;\n",
483 | j3c32 => "--:-:-:-:1 \@P3 IADD track0B0.CC, track0B0, cdb32;\n",
484 | j3c37 => "--:-:-:-:1 \@P3 IADD.X track0B1, track0B1, RZ;\n",
485 | j4c32 => "--:-:-:-:1 \@P3 IADD track1B0.CC, track1B0, cdb32;\n",
486 | j4c37 => "--:-:-:-:1 \@P3 IADD.X track1B1, track1B1, RZ;\n",
487 | j5c32 => "--:-:-:-:1 \@P3 IADD track2B0.CC, track2B0, cdb32;\n",
488 | j5c37 => "--:-:-:-:1 \@P3 IADD.X track2B1, track2B1, RZ;\n",
489 | j6c32 => "--:-:-:-:1 \@P3 IADD track3B0.CC, track3B0, cdb32;\n",
490 | j6c37 => "--:-:-:-:1 \@P3 IADD.X track3B1, track3B1, RZ;\n",
491 |
492 | j6c63 => "--:-:-:-:5 BAR.SYNC 0;\n" .
493 | "--:-:-:-:1 \@P0 IADD readAs, readAs, -swapBuf;\n" .
494 | "--:-:-:-:1 \@P0 IADD readBs, readBs, -swapBuf;\n" .
495 | "--:-:-:-:1 \@P0 IADD writeAs, writeAs, swapBuf;\n" .
496 | "--:-:-:-:1 \@P0 IADD writeBs, writeBs, swapBuf;\n" .
497 | "--:-:-:-:1 \@P0 IADD swapBuf, RZ, -swapBuf;\n",
498 |
499 | j2c16 => "02:-:-:-:1 \@P0 STS [writeAs + 4x<7*32>], load0A7;\n",
500 | j2c18 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<6*32>], load0A6;\n",
501 | j2c20 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<5*32>], load0A5;\n",
502 | j2c22 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<4*32>], load0A4;\n",
503 | j2c24 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<3*32>], load0A3;\n",
504 | j2c26 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<2*32>], load0A2;\n",
505 | j2c28 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<1*32>], load0A1;\n",
506 | j2c30 => "--:2:-:-:1 \@P0 STS [writeAs + 4x<0*32>], load0A0;\n",
507 |
508 | j3c16 => "04:3:-:-:1 \@P0 STS.128 [writeBs + 4x< 0*64>], load0B;\n",
509 | j4c16 => "08:4:-:-:1 \@P0 STS.128 [writeBs + 4x< 8*64>], load1B;\n",
510 | j5c16 => "10:5:-:-:1 \@P0 STS.128 [writeBs + 4x<16*64>], load2B;\n",
511 | j6c16 => "20:6:-:-:1 \@P0 STS.128 [writeBs + 4x<24*64>], load3B;\n",
512 |
513 | (vector() ?
514 | (
515 | j1c35 => "02:-:-:-:1 \@P0 F2F.F32.F16 load0A7, load0A3.H1;\n",
516 | j1c39 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A6, load0A3.H0;\n",
517 | j1c43 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A5, load0A2.H1;\n",
518 | j1c47 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A4, load0A2.H0;\n",
519 | j1c51 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A3, load0A1.H1;\n",
520 | j1c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A2, load0A1.H0;\n",
521 | j1c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A1, load0A0.H1;\n",
522 | j1c63 => "--:-:2:-:1 \@P0 F2F.F32.F16 load0A0, load0A0.H0;\n",
523 |
524 | j2c51 => "04:-:-:-:1 \@P0 F2F.F32.F16 load0B3, load0B1.H1;\n",
525 | j2c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B2, load0B1.H0;\n",
526 | j2c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B1, load0B0.H1;\n",
527 | j2c63 => "--:-:3:-:1 \@P0 F2F.F32.F16 load0B0, load0B0.H0;\n",
528 |
529 | j3c51 => "08:-:-:-:1 \@P0 F2F.F32.F16 load1B3, load1B1.H1;\n",
530 | j3c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B2, load1B1.H0;\n",
531 | j3c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B1, load1B0.H1;\n",
532 | j3c63 => "--:-:4:-:1 \@P0 F2F.F32.F16 load1B0, load1B0.H0;\n",
533 |
534 | j4c51 => "10:-:-:-:1 \@P0 F2F.F32.F16 load2B3, load2B1.H1;\n",
535 | j4c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B2, load2B1.H0;\n",
536 | j4c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B1, load2B0.H1;\n",
537 | j4c63 => "--:-:5:-:1 \@P0 F2F.F32.F16 load2B0, load2B0.H0;\n",
538 |
539 | j5c51 => "20:-:-:-:1 \@P0 F2F.F32.F16 load3B3, load3B1.H1;\n",
540 | j5c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B2, load3B1.H0;\n",
541 | j5c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B1, load3B0.H1;\n",
542 | j5c63 => "--:-:6:-:1 \@P0 F2F.F32.F16 load3B0, load3B0.H0;\n",
543 |
544 | j2c61 => "02:-:2:-:1 \@P2 LDG.E.CI.128 load0A, [track0A];\n",
545 | j3c61 => "04:-:3:-:1 \@P3 LDG.E.CI.64 load0B, [track0B];\n",
546 | j4c61 => "08:-:4:-:1 \@P3 LDG.E.CI.64 load1B, [track1B];\n",
547 | j5c61 => "10:-:5:-:1 \@P3 LDG.E.CI.64 load2B, [track2B];\n",
548 | j6c61 => "20:-:6:-:1 \@P3 LDG.E.CI.64 load3B, [track3B];\n",
549 | ) :
550 | (
551 | j1c35 => "02:-:-:-:1 \@P0 F2F.F32.F16 load0A0, load0A0;\n",
552 | j1c39 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A1, load0A1;\n",
553 | j1c43 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A2, load0A2;\n",
554 | j1c47 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A3, load0A3;\n",
555 | j1c51 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A4, load0A4;\n",
556 | j1c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A5, load0A5;\n",
557 | j1c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A6, load0A6;\n",
558 | j1c63 => "--:2:-:-:1 \@P0 F2F.F32.F16 load0A7, load0A7;\n",
559 |
560 | j2c51 => "04:-:-:-:1 \@P0 F2F.F32.F16 load0B0, load0B0;\n",
561 | j2c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B1, load0B1;\n",
562 | j2c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B2, load0B2;\n",
563 | j2c63 => "--:-:3:-:1 \@P0 F2F.F32.F16 load0B3, load0B3;\n",
564 |
565 | j3c51 => "08:-:-:-:1 \@P0 F2F.F32.F16 load1B0, load1B0;\n",
566 | j3c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B1, load1B1;\n",
567 | j3c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B2, load1B2;\n",
568 | j3c63 => "--:-:4:-:1 \@P0 F2F.F32.F16 load1B3, load1B3;\n",
569 |
570 | j4c51 => "10:-:-:-:1 \@P0 F2F.F32.F16 load2B0, load2B0;\n",
571 | j4c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B1, load2B1;\n",
572 | j4c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B2, load2B2;\n",
573 | j4c63 => "--:-:5:-:1 \@P0 F2F.F32.F16 load2B3, load2B3;\n",
574 |
575 | j5c51 => "20:-:-:-:1 \@P0 F2F.F32.F16 load3B0, load3B0;\n",
576 | j5c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B1, load3B1;\n",
577 | j5c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B2, load3B2;\n",
578 | j5c63 => "--:-:6:-:1 \@P0 F2F.F32.F16 load3B3, load3B3;\n",
579 |
580 | j2c48 => "02:-:-:-:1 \@P2 LDG.E.CI.U16 load0A0, [track0A + 2x<0>];\n",
581 | j2c50 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A1, [track0A + 2x<1>];\n",
582 | j2c52 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A2, [track0A + 2x<2>];\n",
583 | j2c54 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A3, [track0A + 2x<3>];\n",
584 | j2c56 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A4, [track0A + 2x<4>];\n",
585 | j2c58 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A5, [track0A + 2x<5>];\n",
586 | j2c60 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A6, [track0A + 2x<6>];\n",
587 | j2c62 => "--:-:2:-:1 \@P2 LDG.E.CI.U16 load0A7, [track0A + 2x<7>];\n",
588 |
589 | j3c56 => "04:-:-:-:1 \@P3 LDG.E.CI.U16 load0B0, [track0B + 2x<0>];\n",
590 | j3c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load0B1, [track0B + 2x<1>];\n",
591 | j3c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load0B2, [track0B + 2x<2>];\n",
592 | j3c62 => "--:-:3:-:1 \@P6 LDG.E.CI.U16 load0B3, [track0B + 2x<3>];\n",
593 |
594 | j4c56 => "08:-:-:-:1 \@P3 LDG.E.CI.U16 load1B0, [track1B + 2x<0>];\n",
595 | j4c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load1B1, [track1B + 2x<1>];\n",
596 | j4c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load1B2, [track1B + 2x<2>];\n",
597 | j4c62 => "--:-:4:-:1 \@P6 LDG.E.CI.U16 load1B3, [track1B + 2x<3>];\n",
598 |
599 | j5c56 => "10:-:-:-:1 \@P3 LDG.E.CI.U16 load2B0, [track2B + 2x<0>];\n",
600 | j5c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load2B1, [track2B + 2x<1>];\n",
601 | j5c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load2B2, [track2B + 2x<2>];\n",
602 | j5c62 => "--:-:5:-:1 \@P6 LDG.E.CI.U16 load2B3, [track2B + 2x<3>];\n",
603 |
604 | j6c56 => "20:-:-:-:1 \@P3 LDG.E.CI.U16 load3B0, [track3B + 2x<0>];\n",
605 | j6c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load3B1, [track3B + 2x<1>];\n",
606 | j6c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load3B2, [track3B + 2x<2>];\n",
607 | j6c62 => "--:-:6:-:1 \@P6 LDG.E.CI.U16 load3B3, [track3B + 2x<3>];\n",
608 | )
609 | ),
610 | j7c63 => "--:-:-:Y:5 \@P0 BRA.U LOOP;\n",
611 | );
612 | my @cOrder;
613 | my @swirl = ([0,2],[1,2],[1,0],[0,0]);
614 | my @y = (0,1,4,5);
615 | foreach my $x (0,2,4,6)
616 | {
617 | foreach my $y (@y)
618 | {
619 | push @cOrder, [$x + $_->[0], $y + $_->[1]] foreach @swirl;
620 | }
621 | @y = reverse @y;
622 | }
623 | my $out = '';
624 | foreach my $j (0 .. 7)
625 | {
626 | my $odd = $j & 1;
627 | my $nOdd = !$odd + 0;
628 | my $rsOffset = ($j + 1) % 8;
629 | my $rsPred = $j == 7 ? '@P0' : ' ';
630 |
631 | $insert{"j${j}c0"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy0, [readAs + 4x<%d*32 + 00>];\n", $rsPred, $nOdd, $rsOffset;
632 | $insert{"j${j}c2"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dBx0, [readBs + 4x<%d*64 + 00>];\n", $rsPred, $nOdd, $rsOffset;
633 | $insert{"j${j}c4"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy4, [readAs + 4x<%d*32 + 16>];\n", $rsPred, $nOdd, $rsOffset;
634 | $insert{"j${j}c6"} = sprintf "--:-:1:-:1 %s LDS.U.128 j%dBx4, [readBs + 4x<%d*64 + 32>];\n", $rsPred, $nOdd, $rsOffset;
635 |
636 | foreach my $c (0 .. 63)
637 | {
638 | my ($x,$y) = @{$cOrder[$c]};
639 |
640 | my $ins = $insert{"j${j}c$c"} || '';
641 |
642 | my $stall = $ins =~ /LDS|I2I|I2F|F2I|F2F|LDG|STS|BAR|BRA/ ? 0 : 1;
643 |
644 | my $yield = $c == 32 && $stall ? 'Y' : '-';
645 |
646 | my $wait = $c == 0 ? '01' : '--';
647 |
648 | my $ctrl = "$wait:-:-:$yield:$stall";
649 |
650 | $out .= sprintf "%s FFMA cx%dy%d, j%dBx%d, j%dAy%d, cx%dy%d;\n%s", $ctrl, $x,$y, $odd,$x, $odd,$y, $x,$y, $ins;
651 | }
652 | }
653 | return $out;
654 | +]
655 |
656 |
657 | --:-:-:-:1 MOV alpha, param_alpha;
658 | --:-:-:-:1 MOV beta, param_beta;
659 |
660 | // readCs = ((tid & 15) * 4 + (tid / 16) * 64) * 4
661 | --:-:-:-:1 LOP.AND tid15, tid, 15;
662 | --:-:-:-:1 SHR.U32 tid16, tid, 4;
663 | --:-:-:-:1 SHL tid15, tid15, 2;
664 | --:-:-:-:1 ISCADD readCs, tid16, tid15, 6;
665 | --:-:-:-:1 SHL readCs, readCs, 2;
666 |
667 | // cx = idx_B*64 + tid15;
668 | --:-:-:-:1 ISCADD cx, idx_B, tid15, 6;
669 | --:-:-:-:1 IADD cx1, cx, 1;
670 | --:-:-:-:1 IADD cx2, cx, 2;
671 | --:-:-:-:1 IADD cx3, cx, 3;
672 |
673 | // cy = idx_A*32 + tid16
674 | --:-:-:-:1 ISCADD cy, idx_A, tid16, 5;
675 |
676 | // C += (cy*cdc + cx) * 2;
677 | --:-:-:-:1 MOV cdc, param_cdc;
678 | --:-:-:-:1 SHL cdc8, cdc, 4;
679 |
680 | --:-:-:-:1 XMAD.LO ci, cy, cdc, cx, xmad_c;
681 | --:-:-:-:1 LEA C0.CC, ci, param_C[0], 1;
682 | --:-:-:-:1 LEA.HI.X C1, ci, param_C[1], RZ, 1;
683 |
684 | // P0 = cx < n
685 | --:-:-:-:1 ISETP.LT.AND P0, PT, cx, param_n, PT;
686 | --:-:-:-:1 ISETP.LT.AND P1, PT, cx1, param_n, PT;
687 | --:-:-:-:1 ISETP.LT.AND P2, PT, cx2, param_n, PT;
688 | --:-:-:-:1 ISETP.LT.AND P3, PT, cx3, param_n, PT;
689 | --:-:-:-:1 P2R preds, PR, RZ, 0x0f;
690 |
691 | // P4 = cy < m
692 | --:-:-:-:1 ISETP.LT.AND P4, PT, cy, param_m, PT;
693 |
694 | // P5 = beta != 0 && P4
695 | --:-:-:-:1 ISETP.NE.AND P5, PT, beta, RZ, P4;
696 |
697 | // Init beta preds
698 | --:-:-:-:1 @P5 R2P PR, preds, 0x0f;
699 | --:-:-:-:1 @!P5 R2P PR, RZ, 0x0f;
700 |
701 |
702 |
703 | --:-:-:-:1 FMUL shuffle_x0y0, cx0y0, alpha;
704 | --:-:-:-:1 FMUL shuffle_x1y0, cx1y0, alpha;
705 | --:-:-:-:1 FMUL shuffle_x2y0, cx2y0, alpha;
706 | --:-:-:-:1 FMUL shuffle_x3y0, cx3y0, alpha;
707 | --:-:-:-:1 FMUL shuffle_x4y0, cx4y0, alpha;
708 | --:-:-:-:1 FMUL shuffle_x5y0, cx5y0, alpha;
709 | --:-:-:-:1 FMUL shuffle_x6y0, cx6y0, alpha;
710 | --:-:-:-:0 FMUL shuffle_x7y0, cx7y0, alpha;
711 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 00>], shuffle_x0y0;
712 | --:-:-:-:1 FMUL shuffle_x0y1, cx0y1, alpha;
713 | --:-:-:-:1 FMUL shuffle_x1y1, cx1y1, alpha;
714 | --:-:-:-:1 FMUL shuffle_x2y1, cx2y1, alpha;
715 | --:-:-:-:0 FMUL shuffle_x3y1, cx3y1, alpha;
716 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 32>], shuffle_x4y0;
717 | --:-:-:-:1 FMUL shuffle_x4y1, cx4y1, alpha;
718 | --:-:-:-:1 FMUL shuffle_x5y1, cx5y1, alpha;
719 | --:-:-:-:1 FMUL shuffle_x6y1, cx6y1, alpha;
720 | --:-:-:-:0 FMUL shuffle_x7y1, cx7y1, alpha;
721 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 00>], shuffle_x0y1;
722 | --:-:-:-:1 FMUL shuffle_x0y2, cx0y2, alpha;
723 | --:-:-:-:1 FMUL shuffle_x1y2, cx1y2, alpha;
724 | --:-:-:-:1 FMUL shuffle_x2y2, cx2y2, alpha;
725 | --:-:-:-:0 FMUL shuffle_x3y2, cx3y2, alpha;
726 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 32>], shuffle_x4y1;
727 | --:-:-:-:1 FMUL shuffle_x4y2, cx4y2, alpha;
728 | --:-:-:-:1 FMUL shuffle_x5y2, cx5y2, alpha;
729 | --:-:-:-:1 FMUL shuffle_x6y2, cx6y2, alpha;
730 | --:-:-:-:0 FMUL shuffle_x7y2, cx7y2, alpha;
731 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 00>], shuffle_x0y2;
732 | --:-:-:-:1 FMUL shuffle_x0y3, cx0y3, alpha;
733 | --:-:-:-:1 FMUL shuffle_x1y3, cx1y3, alpha;
734 | --:-:-:-:1 FMUL shuffle_x2y3, cx2y3, alpha;
735 | --:-:-:-:0 FMUL shuffle_x3y3, cx3y3, alpha;
736 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 32>], shuffle_x4y2;
737 | --:-:-:-:1 FMUL shuffle_x4y3, cx4y3, alpha;
738 | --:-:-:-:1 FMUL shuffle_x5y3, cx5y3, alpha;
739 | --:-:-:-:1 FMUL shuffle_x6y3, cx6y3, alpha;
740 | --:-:-:-:0 FMUL shuffle_x7y3, cx7y3, alpha;
741 | --:-:-:-:4 STS.128 [writeCs+4x<3*64 + 00>], shuffle_x0y3;
742 | --:-:-:-:1 STS.128 [writeCs+4x<3*64 + 32>], shuffle_x4y3;
743 | --:-:-:-:5 BAR.SYNC 0;
744 |
745 | --:-:-:-:5 CAL STORE_C;
746 | --:-:-:-:5 CAL STORE_C;
747 |
748 | --:-:-:-:1 FMUL shuffle_x0y4, cx0y4, alpha;
749 | --:-:-:-:1 FMUL shuffle_x1y4, cx1y4, alpha;
750 | --:-:-:-:1 FMUL shuffle_x2y4, cx2y4, alpha;
751 | --:-:-:-:1 FMUL shuffle_x3y4, cx3y4, alpha;
752 | --:-:-:-:1 FMUL shuffle_x4y4, cx4y4, alpha;
753 | --:-:-:-:1 FMUL shuffle_x5y4, cx5y4, alpha;
754 | --:-:-:-:0 FMUL shuffle_x6y4, cx6y4, alpha;
755 | --:-:-:-:5 BAR.SYNC 0;
756 | --:-:-:-:0 FMUL shuffle_x7y4, cx7y4, alpha;
757 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 00>], shuffle_x0y4;
758 | --:-:-:-:1 FMUL shuffle_x0y5, cx0y5, alpha;
759 | --:-:-:-:1 FMUL shuffle_x1y5, cx1y5, alpha;
760 | --:-:-:-:1 FMUL shuffle_x2y5, cx2y5, alpha;
761 | --:-:-:-:0 FMUL shuffle_x3y5, cx3y5, alpha;
762 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 32>], shuffle_x4y4;
763 | --:-:-:-:1 FMUL shuffle_x4y5, cx4y5, alpha;
764 | --:-:-:-:1 FMUL shuffle_x5y5, cx5y5, alpha;
765 | --:-:-:-:1 FMUL shuffle_x6y5, cx6y5, alpha;
766 | --:-:-:-:0 FMUL shuffle_x7y5, cx7y5, alpha;
767 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 00>], shuffle_x0y5;
768 | --:-:-:-:1 FMUL shuffle_x0y6, cx0y6, alpha;
769 | --:-:-:-:1 FMUL shuffle_x1y6, cx1y6, alpha;
770 | --:-:-:-:1 FMUL shuffle_x2y6, cx2y6, alpha;
771 | --:-:-:-:0 FMUL shuffle_x3y6, cx3y6, alpha;
772 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 32>], shuffle_x4y5;
773 | --:-:-:-:1 FMUL shuffle_x4y6, cx4y6, alpha;
774 | --:-:-:-:1 FMUL shuffle_x5y6, cx5y6, alpha;
775 | --:-:-:-:1 FMUL shuffle_x6y6, cx6y6, alpha;
776 | --:-:-:-:0 FMUL shuffle_x7y6, cx7y6, alpha;
777 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 00>], shuffle_x0y6;
778 | --:-:-:-:1 FMUL shuffle_x0y7, cx0y7, alpha;
779 | --:-:-:-:1 FMUL shuffle_x1y7, cx1y7, alpha;
780 | --:-:-:-:1 FMUL shuffle_x2y7, cx2y7, alpha;
781 | --:-:-:-:0 FMUL shuffle_x3y7, cx3y7, alpha;
782 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 32>], shuffle_x4y6;
783 | --:-:-:-:1 FMUL shuffle_x4y7, cx4y7, alpha;
784 | --:-:-:-:1 FMUL shuffle_x5y7, cx5y7, alpha;
785 | --:-:-:-:1 FMUL shuffle_x6y7, cx6y7, alpha;
786 | --:-:-:-:0 FMUL shuffle_x7y7, cx7y7, alpha;
787 | --:-:-:-:4 STS.128 [writeCs+4x<3*64 + 00>], shuffle_x0y7;
788 | --:-:-:-:1 STS.128 [writeCs+4x<3*64 + 32>], shuffle_x4y7;
789 | --:-:-:-:5 BAR.SYNC 0;
790 |
791 | --:-:-:-:5 CAL STORE_C;
792 | --:-:-:-:5 CAL STORE_C;
793 |
794 | --:-:-:-:5 EXIT;
795 |
796 | STORE_C:
797 |
798 | [+
799 | return vector() ? q{
800 | --:-:1:-:1 @P0 LDG.E.64 loadC, [C];
801 | } : q{
802 | --:-:-:-:0 @!P0 MOV loadC0, RZ;
803 | --:-:-:-:1 @P0 LDG.E.CI.U16 loadC0, [C + 2x<0>];
804 | --:-:-:-:0 @!P1 MOV loadC1, RZ;
805 | --:-:-:-:1 @P1 LDG.E.CI.U16 loadC1, [C + 2x<1>];
806 | --:-:-:-:0 @!P2 MOV loadC2, RZ;
807 | --:-:-:-:1 @P2 LDG.E.CI.U16 loadC2, [C + 2x<2>];
808 | --:-:-:-:0 @!P3 MOV loadC3, RZ;
809 | --:-:1:-:1 @P3 LDG.E.CI.U16 loadC3, [C + 2x<3>];
810 | };
811 | +]
812 |
813 | // Restore output preds
814 | --:-:-:-:1 @P4 R2P PR, preds, 0x0f;
815 | --:-:-:-:1 @!P4 R2P PR, RZ, 0x0f;
816 |
817 | --:-:-:-:1 LDS.U.128 part0C, [readCs + 4x< 0*64>];
818 | --:-:2:-:1 LDS.U.128 part1C, [readCs + 4x<16*64>];
819 | --:-:-:-:1 LDS.U.128 part2C, [readCs + 4x<32*64>];
820 | --:-:3:-:1 LDS.U.128 part3C, [readCs + 4x<48*64>];
821 |
822 |
823 | 02:-:-:-:1 @P0 FADD part0C0, part0C0, part1C0;
824 | --:-:-:-:1 @P1 FADD part0C1, part0C1, part1C1;
825 | --:-:-:-:1 @P2 FADD part0C2, part0C2, part1C2;
826 | --:-:-:-:1 @P3 FADD part0C3, part0C3, part1C3;
827 |
828 | 04:-:-:-:1 @P0 FADD part2C0, part2C0, part3C0;
829 | --:-:-:-:1 @P1 FADD part2C1, part2C1, part3C1;
830 | --:-:-:-:1 @P2 FADD part2C2, part2C2, part3C2;
831 | --:-:-:-:1 @P3 FADD part2C3, part2C3, part3C3;
832 |
833 | --:-:-:-:1 @P0 FADD c0, part0C0, part2C0;
834 | --:-:-:-:1 @P1 FADD c1, part0C1, part2C1;
835 | --:-:-:-:1 @P2 FADD c2, part0C2, part2C2;
836 | --:-:-:-:1 @P3 FADD c3, part0C3, part2C3;
837 |
838 |
839 | --:-:-:-:0 IADD cy, cy, 8;
840 |
841 | [+
842 | return vector() ? q{
843 | 01:-:1:-:1 @P5 F2F.F32.F16 b0, loadC0.H0;
844 | --:-:2:-:1 @P5 F2F.F32.F16 b1, loadC0.H1;
845 | --:-:3:-:1 @P5 F2F.F32.F16 b2, loadC1.H0;
846 | --:-:4:-:1 @P5 F2F.F32.F16 b3, loadC1.H1;
847 | } : q{
848 | 01:-:1:-:1 @P5 F2F.F32.F16 b0, loadC0;
849 | --:-:2:-:1 @P5 F2F.F32.F16 b1, loadC1;
850 | --:-:3:-:1 @P5 F2F.F32.F16 b2, loadC2;
851 | --:-:4:-:1 @P5 F2F.F32.F16 b3, loadC3;
852 | };
853 | +]
854 |
855 | 01:-:-:-:1 @P5 FFMA c0, b0, beta, c0;
856 | 02:-:-:-:1 @P5 FFMA c1, b1, beta, c1;
857 | 04:-:-:-:1 @P5 FFMA c2, b2, beta, c2;
858 | 08:-:-:-:1 @P5 FFMA c3, b3, beta, c3;
859 |
860 | --:-:-:-:0 ISETP.LT.AND P5, PT, cy, param_m, P5;
861 |
862 | --:-:1:-:1 @P0 F2F.F16.F32 c0, c0;
863 | --:-:2:-:1 @P1 F2F.F16.F32 c1, c1;
864 |
865 | --:-:-:-:0 ISETP.LT.AND P4, PT, cy, param_m, PT;
866 |
867 | --:-:3:-:1 @P2 F2F.F16.F32 c2, c2;
868 |
869 | --:-:-:-:0 LOP.XOR readCs, readCs, 4x<8*64>;
870 |
871 | --:-:4:-:1 @P3 F2F.F16.F32 c3, c3;
872 |
873 | [+
874 | return vector() ? q{
875 | 03:-:-:-:2 @P0 BFI c0, c1, 0x1010, c0;
876 | 0c:-:-:-:2 @P0 BFI c1, c3, 0x1010, c2;
877 |
878 | --:1:-:-:1 @P0 STG.E.CG.64 [C], c;
879 | } : q{
880 | 01:-:-:-:1 @P0 STG.E.U16 [C + 2x<0>], c0;
881 | 02:-:-:-:1 @P1 STG.E.U16 [C + 2x<1>], c1;
882 | 04:-:-:-:1 @P2 STG.E.U16 [C + 2x<2>], c2;
883 | 08:1:-:-:1 @P3 STG.E.U16 [C + 2x<3>], c3;
884 | };
885 | +]
886 |
887 | // Restore beta preds
888 | --:-:-:-:1 @P5 R2P PR, preds, 0x0f;
889 | --:-:-:-:1 @!P5 R2P PR, RZ, 0x0f;
890 |
891 | 01:-:-:-:6 IADD C0.CC, C0, cdc8;
892 | --:-:-:-:0 IADD.X C1, C1, RZ;
893 |
894 | --:-:-:-:5 RET;
895 |
--------------------------------------------------------------------------------
/sass/xgemm_128x128x8.sass:
--------------------------------------------------------------------------------
1 | # Kernel: xgemm_128x128x8
2 |
3 | [-
4 | our ($type, $A16, $B16, $C16);
5 | sub A16 { return $type eq 'h' || $A16 }
6 | sub B16 { return $type eq 'h' || $B16 }
7 | sub C16 { return $type eq 'h' || $C16 }
8 |
9 | our $dtypeA = A16() ? 'U16' : '32';
10 | our $dtypeB = B16() ? 'U16' : '32';
11 | our $dtypeC = C16() ? 'U16' : '32';
12 |
13 | our $dshiftA = A16() ? '1' : '2';
14 | our $dshiftB = B16() ? '1' : '2';
15 | our $dshiftC = C16() ? '1' : '2';
16 |
17 | our $dsizeA = A16() ? '2' : '4';
18 | our $dsizeB = B16() ? '2' : '4';
19 | our $dsizeC = C16() ? '2' : '4';
20 |
21 | our $vsizeA = A16() ? '64' : '128';
22 | our $vsizeB = B16() ? '64' : '128';
23 | our $vsizeC = C16() ? '64' : '128';
24 |
25 | sub dtypeA { return $dtypeA; }
26 | sub dtypeB { return $dtypeB; }
27 | sub dtypeC { return $dtypeC; }
28 |
29 | sub dsizeA { return $dsizeA; }
30 | sub dsizeB { return $dsizeB; }
31 | sub dsizeC { return $dsizeC; }
32 |
33 | sub dshiftA { return $dshiftA; }
34 | sub dshiftB { return $dshiftB; }
35 | sub dshiftC { return $dshiftC; }
36 |
37 | our ($outerContigA, $outerContigB, $NN, $NT, $TN, $TT);
38 | if ($NN) { $outerContigA = 0; $outerContigB = 1 }
39 | elsif ($NT) { $outerContigA = 0; $outerContigB = 0 }
40 | elsif ($TN) { $outerContigA = 1; $outerContigB = 1 }
41 | elsif ($TT) { $outerContigA = 1; $outerContigB = 0 }
42 |
43 | sub outerContigA { return $outerContigA; }
44 | sub outerContigB { return $outerContigB; }
45 |
46 | our ($vecA, $vecB, $vec);
47 | sub vecA { return $vecA || $vec; }
48 | sub vecB { return $vecB || $vec; }
49 | -]
50 |
51 |
52 | [+
53 | return outerContigA() && outerContigB() ? q{
54 | addr_zero : 4x<128*8*4>
55 | szShareA : (128*8)
56 | szShareB : (128*8)
57 | } : q{
58 | addr_zero : 4x<(128*8 + 32)*4>
59 | szShareA : (128*8 + 32)
60 | szShareB : (128*8 + 32)
61 | };
62 | +]
63 |
64 | param_C[0] : c[0x0][0x140]
65 | param_C[1] : c[0x0][0x144]
66 | param_A[0] : c[0x0][0x148]
67 | param_A[1] : c[0x0][0x14c]
68 | param_B[0] : c[0x0][0x150]
69 | param_B[1] : c[0x0][0x154]
70 | param_alpha : c[0x0][0x158]
71 | param_beta : c[0x0][0x15c]
72 | param_cda : c[0x0][0x160]
73 | param_cdb : c[0x0][0x164]
74 | param_cdc : c[0x0][0x168]
75 | param_m : c[0x0][0x16c]
76 | param_n : c[0x0][0x170]
77 | param_k : c[0x0][0x174]
78 | param_blk_a : c[0x0][0x178]
79 | param_blk_b : c[0x0][0x17c]
80 |
81 |
82 |
83 |
84 | 3, 2,11,10,19,18,27,26 : cx<0-7>y0
85 | 7, 6,15,14,23,22,31,30 : cx<0-7>y1
86 | 1, 0, 9, 8,17,16,25,24 : cx<0-7>y2
87 | 5, 4,13,12,21,20,29,28 : cx<0-7>y3
88 | 35,34,43,42,51,50,59,58 : cx<0-7>y4
89 | 39,38,47,46,55,54,63,62 : cx<0-7>y5
90 | 33,32,41,40,49,48,57,56 : cx<0-7>y6
91 | 37,36,45,44,53,52,61,60 : cx<0-7>y7
92 |
93 | 0-63 : czero<00-63>
94 | 64-79 : j0Ay<0-7>, j0Bx<0-7>
95 | 80-95 : j1Ay<0-7>, j1Bx<0-7>
96 |
97 | 64-95 ~ idx_ab, cda, cdb, idx_ab_f, idx_a, neg_blk_b, rcp_blk_b, idx_b, tidAX, txa, txa<1-3>, ta, xmad_ta, tidBX, txb, txb<1-3>, tb, xmad_tb, tid1, tid32_2, tid96, tid128
98 |
99 | 96-99 : tidAY, tidAY<1-3>
100 | 104-107 : tidBY, tidBY<1-3>
101 | 100-103 ~ predsA, partialK, partialA, partialB
102 | 108-111 ~ predsB
103 |
104 | 96-99 : loadA<0-3>
105 | 100-103 : loadA<4-7>
106 | 104-107 : loadB<0-3>
107 | 108-111 : loadB<4-7>
108 |
109 | 112-115 : trackA<0-1>, trackB<0-1>
110 |
111 | 116-123 ~ k, swapBuf, readAs, readBs, writeAs, writeBs, cda8, cdb8
112 | 124-127 ~ tid, idx_A, idx_B, writeCs
113 |
114 | 64-71 : track00C<0-1>, track04C<0-1>, track08C<0-1>, track12C<0-1>
115 | 72-79 : c<0-7>
116 | 80-87 ~ c<00|04|08|12>_00, c<00|04|08|12>_32
117 | 88-95 ~ b<00|04|08|12>_00, b<00|04|08|12>_32
118 | 96-123 ~ readCs, cx<00|32>, cy<00|04|08|12>, cdc, cdc1, cdc4, cdc13, tc, xmad_tc, tid_31, tid_32, tid_96, tid_128, alpha, beta
119 |
120 |
121 |
122 | --:-:5:-:1 I2F.F32.S32 rcp_blk_b, param_blk_b;
123 | --:-:1:-:1 S2R tid, SR_TID.X;
124 | --:-:2:-:1 S2R idx_ab, SR_CTAID.X;
125 | --:-:3:-:1 S2R idx_A, SR_CTAID.Z;
126 | --:-:4:-:1 S2R idx_B, SR_CTAID.Y;
127 |
128 |
129 | 10:-:5:-:1 MUFU.RCP rcp_blk_b, rcp_blk_b;
130 |
131 | --:-:-:-:1 MOV k, param_k;
132 | --:-:-:-:1 MOV cda, param_cda;
133 | --:-:-:-:1 MOV cdb, param_cdb;
134 |
135 | // If k is not a multiple of 8 we want to grab the partial amount on the first fetch.
136 | // If it is a multiple of 8 then make a full 8 line fetch.
137 | --:-:-:-:1 LOP.AND.Z P0, partialK, k, 7;
138 | --:-:-:-:1 @P0 MOV partialK, 8;
139 | --:-:-:-:1 IADD k, k, -partialK;
140 |
141 | # idx_a = idx_ab // blk_b
142 | 02:-:2:-:1 I2F.F32.S32 idx_ab_f, idx_ab;
143 | 12:-:-:-:1 FMUL idx_a, idx_ab_f, rcp_blk_b;
144 | --:-:-:-:1 FFMA idx_a, idx_a, 5.9604644775390625e-08, idx_a;
145 | --:-:2:-:1 F2I.S32.F32.TRUNC idx_a, idx_a;
146 | # idx_b = idx_AB % blk_b
147 | --:-:-:-:1 IADD neg_blk_b, RZ, -param_blk_b;
148 | 02:-:-:-:1 XMAD.S16.U16 idx_b, neg_blk_b, idx_a, idx_ab;
149 |
150 | # idx_A = idx_A * blk_a + idx_a
151 | # idx_B = idx_B * blk_b + idx_b
152 | 06:-:-:-:1 XMAD.U16.U16 idx_A, idx_A, param_blk_a, idx_a;
153 | 08:-:-:-:1 XMAD.U16.U16 idx_B, idx_B, param_blk_b, idx_b;
154 |
155 | --:-:-:-:1 STS.128 [addr_zero], RZ;
156 | [+ join '', map sprintf("--:-:-:-:1 LDS.U.128 czero%02d, [addr_zero];\n", $_ * 4), 0..15; +]
157 |
158 | [+
159 | our $dshiftA;
160 | my $predsA = vecA() ? q{
161 | --:-:-:-:1 ISETP.LT.AND P5, PT, txa, param_m, PT;
162 |
163 | } : q{
164 | --:-:-:-:1 IADD txa1, txa, 1;
165 | --:-:-:-:1 IADD txa2, txa, 2;
166 | --:-:-:-:1 IADD txa3, txa, 3;
167 | --:-:-:-:1 ISETP.LT.AND P2, PT, txa, param_m, PT;
168 | --:-:-:-:1 ISETP.LT.AND P3, PT, txa1, param_m, PT;
169 | --:-:-:-:1 ISETP.LT.AND P4, PT, txa2, param_m, PT;
170 | --:-:-:-:1 ISETP.LT.AND P5, PT, txa3, param_m, PT;
171 | --:-:-:-:1 P2R predsA, PR, RZ, 0x3c;
172 | };
173 | return outerContigA() ? qq{
174 | // tidAX = (tid & 31) << 2
175 | // tidAY = tid >> 5
176 | 01:-:-:-:1 LOP.AND tidAX, tid, 31;
177 | --:-:-:-:1 SHL tidAX, tidAX, 2;
178 | 01:-:-:-:1 SHR.U32 tidAY, tid, 5;
179 |
180 |
181 | // trackA += (idx_A*128 + tidAX + cda*tidAY) * dsize
182 | --:-:-:-:1 ISCADD txa, idx_A, tidAX, 7;
183 | --:-:-:-:1 XMAD.LO2 ta, cda, tidAY, txa;
184 | --:-:-:-:1 SHL cda8, cda, 1x<$dshiftA + 3>;$predsA
185 |
186 | // writeAs = (tidAY*128 + tidAX) * 4
187 | --:-:-:-:1 ISCADD writeAs, tidAY, tidAX, 7;
188 | --:-:-:-:1 SHL writeAs, writeAs, 2;
189 |
190 | // partialA = partialK * cda
191 | --:-:-:-:1 XMAD.LO2 partialA, cda, partialK, RZ;
192 |
193 | } : q{
194 | // tidAX = tid >> 1
195 | // tidAY = (tid & 1) << 2
196 | 01:-:-:-:1 SHR.U32 tidAX, tid, 1;
197 | 01:-:-:-:1 LOP.AND tidAY, tid, 1;
198 | --:-:-:-:1 SHL tidAY, tidAY, 2;
199 |
200 | // trackA += ((idx_A*128 + tidAX) * cda + tidAY) * dsize
201 | --:-:-:-:1 ISCADD txa, idx_A, tidAX, 7;
202 | --:-:-:-:1 XMAD.LO ta, cda, txa, tidAY, xmad_ta;
203 |
204 | --:-:-:-:1 ISETP.LT.AND P5, PT, txa, param_m, PT;
205 |
206 | // The extra shiftAX here is to avoid bank conflicts on write
207 | // shiftAX = tidAY * 4
208 | // writeAs = (tidAY*128 + tidAX + shiftAX) * 4
209 | --:-:-:-:1 ISCADD writeAs, tidAY, tidAX, 7;
210 | --:-:-:-:1 ISCADD writeAs, tidAY, writeAs, 2;
211 | --:-:-:-:1 SHL writeAs, writeAs, 2;
212 |
213 | };
214 | +]
215 | --:-:-:-:1 LEA trackA0.CC, ta, param_A[0], [+ dshiftA() +];
216 | --:-:-:-:1 LEA.HI.X trackA1, ta, param_A[1], RZ, [+ dshiftA() +];
217 |
218 |
219 |
220 | [+
221 | our $dshiftB;
222 | my $predsB = vecB() ? q{
223 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb, param_n, PT;
224 |
225 | } : q{
226 | --:-:-:-:1 IADD txb1, txb, 1;
227 | --:-:-:-:1 IADD txb2, txb, 2;
228 | --:-:-:-:1 IADD txb3, txb, 3;
229 | --:-:-:-:1 ISETP.LT.AND P2, PT, txb, param_n, PT;
230 | --:-:-:-:1 ISETP.LT.AND P3, PT, txb1, param_n, PT;
231 | --:-:-:-:1 ISETP.LT.AND P4, PT, txb2, param_n, PT;
232 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb3, param_n, PT;
233 | --:-:-:-:1 P2R predsB, PR, RZ, 0x5c;
234 | };
235 | return outerContigB() ? qq{
236 | // tidBX = (tid & 31) << 2
237 | // tidBY = tid >> 5
238 | 01:-:-:-:1 LOP.AND tidBX, tid, 31;
239 | --:-:-:-:1 SHL tidBX, tidBX, 2;
240 | 01:-:-:-:1 SHR.U32 tidBY, tid, 5;
241 |
242 |
243 | // trackB += (idx_B*128 + tidBX + cdb*tidBY) * dsize
244 | --:-:-:-:1 ISCADD txb, idx_B, tidBX, 7;
245 | --:-:-:-:1 XMAD.LO2 tb, cdb, tidBY, txb;
246 | --:-:-:-:1 SHL cdb8, cdb, 1x<$dshiftB + 3>;$predsB
247 |
248 |
249 | // writeBs = (tidBY*128 + tidBX) * 4
250 | --:-:-:-:1 ISCADD writeBs, tidBY, tidBX, 7;
251 | --:-:-:-:1 ISCADD writeBs, writeBs, 4x, 2;
252 |
253 | // partialB = partialK * cdb
254 | --:-:-:-:1 XMAD.LO2 partialB, cdb, partialK, RZ;
255 |
256 | } : q{
257 | // tidBX = tid >> 1
258 | // tidBY = (tid & 1) << 2
259 | 01:-:-:-:1 SHR.U32 tidBX, tid, 1;
260 | 01:-:-:-:1 LOP.AND tidBY, tid, 1;
261 | --:-:-:-:1 SHL tidBY, tidBY, 2;
262 |
263 | // trackB += ((idx_B*128 + tidBX) * cdb + tidBY) * dsize
264 | --:-:-:-:1 ISCADD txb, idx_B, tidBX, 7;
265 | --:-:-:-:1 XMAD.LO tb, cdb, txb, tidBY, xmad_tb;
266 |
267 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb, param_n, PT;
268 |
269 |
270 | // The extra shiftBX here is to avoid bank conflicts on write
271 | // shiftBX = tidBY * 4
272 | // writeBs = (tidBY*128 + tidBX + shiftBX) * 4
273 | --:-:-:-:1 ISCADD writeBs, tidBY, tidBX, 7;
274 | --:-:-:-:1 ISCADD writeBs, tidBY, writeBs, 2;
275 | --:-:-:-:1 ISCADD writeBs, writeBs, 4x, 2;
276 | };
277 | +]
278 | --:-:-:-:1 LEA trackB0.CC, tb, param_B[0], [+ dshiftB() +];
279 | --:-:-:-:1 LEA.HI.X trackB1, tb, param_B[1], RZ, [+ dshiftB() +];
280 |
281 |
282 | // readAs = ((tid & 16) >> 3) | (tid & 1)
283 | 01:-:-:-:1 LOP.AND tid1, tid, 1;
284 | 01:-:-:-:1 LOP.AND readAs, tid, 16;
285 | --:-:-:-:1 SHR.U32 readAs, readAs, 3;
286 | --:-:-:-:1 LOP.OR readAs, readAs, tid1;
287 |
288 | // readBs = (tid >> 1) & 7
289 | 01:-:-:-:1 BFE.U32 readBs, tid, 0x301; // 3 bits at position 1
290 |
291 | // writeCs = (readAs*64*8*4 + readAs*16*4 + readBs*4*4 + (tid & -32)*2*4
292 | 01:-:-:-:1 LOP.AND tid32_2, tid, -32;
293 | --:-:-:-:1 SHL tid32_2, tid32_2, 3;
294 | --:-:-:-:1 ISCADD writeCs, readAs, tid32_2, 11;
295 | --:-:-:-:1 ISCADD writeCs, readAs, writeCs, 6;
296 | --:-:-:-:1 ISCADD writeCs, readBs, writeCs, 4;
297 |
298 | // readAs = (readAs + ((tid & 96) >> 2)) * 16
299 | // readAs = readAs*16 + ((tid & 96)*4)
300 | 01:-:-:-:1 LOP.AND tid96, tid, 96;
301 | --:-:-:-:1 SHL tid96, tid96, 2;
302 | --:-:-:-:1 ISCADD readAs, readAs, tid96, 4;
303 |
304 | // readBs = (readBs + ((tid & 128) >> 3)) * 16 + 4x
305 | // readBs = readBs*16 + (tid & 128)*2 + 4x
306 | 01:-:-:-:1 LOP.AND tid128, tid, 128;
307 | --:-:-:-:1 SHL tid128, tid128, 1;
308 | --:-:-:-:1 ISCADD readBs, readBs, tid128, 4;
309 | --:-:-:-:1 IADD readBs, readBs, 4x;
310 |
311 | [+
312 | return outerContigA() && outerContigB() ? '' : q{
313 | --:-:-:-:1 MOV32I swapBuf, 4x;
314 | };
315 | +]
316 |
317 | [+
318 | our ($dsizeA, $vsizeA, $dtypeA);
319 | return
320 | outerContigA() ?
321 | vecA() ? qq{
322 |
323 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, P5;
324 | --:-:2:-:1 \@P0 LDG.E.CI.$vsizeA loadA, [trackA];
325 | --:-:2:-:1 \@!P0 LDS.U.$vsizeA loadA, [addr_zero];
326 | } : qq{
327 |
328 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, PT;
329 | --:-:-:-:1 \@P0 R2P PR, predsA, 0x3c;
330 | --:-:-:-:1 \@!P0 R2P PR, RZ, 0x3c;
331 | --:-:2:-:1 \@P2 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>];
332 | --:-:2:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>];
333 | --:-:2:-:1 \@P4 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>];
334 | --:-:2:-:1 \@P5 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>];
335 | --:-:-:-:1 \@!P2 MOV loadA0, RZ;
336 | --:-:-:-:1 \@!P3 MOV loadA1, RZ;
337 | --:-:-:-:1 \@!P4 MOV loadA2, RZ;
338 | --:-:-:-:1 \@!P5 MOV loadA3, RZ; }
339 | :
340 | # not outerContigA
341 | vecA() ? qq{
342 |
343 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, P5;
344 | --:-:2:-:1 \@P0 LDG.E.CI.$vsizeA loadA, [trackA];
345 | --:-:2:-:1 \@!P0 LDS.U.$vsizeA loadA, [addr_zero];
346 |
347 | } : qq{
348 |
349 | --:-:-:-:1 IADD tidAY1, tidAY, 1;
350 | --:-:-:-:1 IADD tidAY2, tidAY, 2;
351 | --:-:-:-:1 IADD tidAY3, tidAY, 3;
352 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, P5;
353 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidAY1, partialK, P5;
354 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidAY2, partialK, P5;
355 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidAY3, partialK, P5;
356 | --:-:2:-:1 \@P0 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>];
357 | --:-:2:-:1 \@P1 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>];
358 | --:-:2:-:1 \@P2 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>];
359 | --:-:2:-:1 \@P3 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>];
360 | --:-:-:-:1 \@!P0 MOV loadA0, RZ;
361 | --:-:-:-:1 \@!P1 MOV loadA1, RZ;
362 | --:-:-:-:1 \@!P2 MOV loadA2, RZ;
363 | --:-:-:-:1 \@!P3 MOV loadA3, RZ;
364 | };
365 | +]
366 |
367 | [+
368 | our ($dsizeB, $vsizeB, $dtypeB);
369 | return
370 | outerContigB() ?
371 | vecB() ? qq{
372 |
373 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY, partialK, P6;
374 | --:-:4:-:1 \@P1 LDG.E.CI.$vsizeB loadB, [trackB];
375 | --:-:4:-:1 \@!P1 LDS.U.$vsizeB loadB, [addr_zero];
376 | } : qq{
377 |
378 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY, partialK, PT;
379 | --:-:-:-:1 \@P1 R2P PR, predsB, 0x5c;
380 | --:-:-:-:1 \@!P1 R2P PR, RZ, 0x5c;
381 | --:-:4:-:1 \@P2 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>];
382 | --:-:4:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>];
383 | --:-:4:-:1 \@P4 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>];
384 | --:-:4:-:1 \@P6 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>];
385 | --:-:-:-:1 \@!P2 MOV loadB0, RZ;
386 | --:-:-:-:1 \@!P3 MOV loadB1, RZ;
387 | --:-:-:-:1 \@!P4 MOV loadB2, RZ;
388 | --:-:-:-:1 \@!P6 MOV loadB3, RZ;
389 | }
390 | :
391 | # not outerContigB
392 | vecB() ? qq{
393 |
394 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY, partialK, P6;
395 | --:-:4:-:1 \@P1 LDG.E.CI.$vsizeB loadB, [trackB];
396 | --:-:4:-:1 \@!P1 LDS.U.$vsizeB loadB, [addr_zero];
397 |
398 | } : qq{
399 |
400 | --:-:-:-:1 IADD tidBY1, tidBY, 1;
401 | --:-:-:-:1 IADD tidBY2, tidBY, 2;
402 | --:-:-:-:1 IADD tidBY3, tidBY, 3;
403 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidBY, partialK, P6;
404 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY1, partialK, P6;
405 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidBY2, partialK, P6;
406 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidBY3, partialK, P6;
407 | --:-:4:-:1 \@P0 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>];
408 | --:-:4:-:1 \@P1 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>];
409 | --:-:4:-:1 \@P2 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>];
410 | --:-:4:-:1 \@P3 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>];
411 | --:-:-:-:1 \@!P0 MOV loadB0, RZ;
412 | --:-:-:-:1 \@!P1 MOV loadB1, RZ;
413 | --:-:-:-:1 \@!P2 MOV loadB2, RZ;
414 | --:-:-:-:1 \@!P3 MOV loadB3, RZ;
415 | };
416 | +]
417 |
418 | [+
419 | return outerContigA() && !vecA() ? q{
420 | --:-:-:-:1 ISETP.GE.AND P0, PT, k, 8, PT;
421 | } : q{
422 | --:-:-:-:1 ISETP.GE.AND P0, PT, k, 8, P5;
423 | };
424 | +]
425 | [+
426 | return outerContigB() && !vecB() ? q{
427 | --:-:-:-:1 ISETP.GE.AND P1, PT, k, 8, PT;
428 | } : q{
429 | --:-:-:-:1 ISETP.GE.AND P1, PT, k, 8, P6;
430 | };
431 | +]
432 |
433 | --:-:-:-:0 IADD k, k, -8;
434 | [+
435 | return A16() ?
436 | vecA() ? q{
437 | 02:-:-:-:1 F2F.F32.F16 loadA3, loadA1.H1;
438 | --:-:-:-:1 F2F.F32.F16 loadA2, loadA1.H0;
439 | --:-:-:-:1 F2F.F32.F16 loadA1, loadA0.H1;
440 | --:-:2:-:1 F2F.F32.F16 loadA0, loadA0.H0;
441 | } : q{
442 | 02:-:-:-:1 F2F.F32.F16 loadA3, loadA3;
443 | --:-:-:-:1 F2F.F32.F16 loadA2, loadA2;
444 | --:-:-:-:1 F2F.F32.F16 loadA1, loadA1;
445 | --:-:2:-:1 F2F.F32.F16 loadA0, loadA0;
446 | }
447 | : '';
448 | +]
449 | [+
450 | return B16() ?
451 | vecB() ? q{
452 | 08:-:-:-:1 F2F.F32.F16 loadB3, loadB1.H1;
453 | --:-:-:-:1 F2F.F32.F16 loadB2, loadB1.H0;
454 | --:-:-:-:1 F2F.F32.F16 loadB1, loadB0.H1;
455 | --:-:4:-:1 F2F.F32.F16 loadB0, loadB0.H0;
456 | } : q{
457 | 08:-:-:-:1 F2F.F32.F16 loadB3, loadB3;
458 | --:-:-:-:1 F2F.F32.F16 loadB2, loadB2;
459 | --:-:-:-:1 F2F.F32.F16 loadB1, loadB1;
460 | --:-:4:-:1 F2F.F32.F16 loadB0, loadB0;
461 | }
462 | : '';
463 | +]
464 | [+
465 | our $dshiftA;
466 | return outerContigA() ? qq{
467 |
468 | 02:-:-:-:1 STS.128 [writeAs], loadA;
469 |
470 | --:-:-:-:6 LEA trackA0.CC, partialA, trackA0, $dshiftA;
471 | --:-:-:-:0 IADD.X trackA1, trackA1, RZ;
472 |
473 | } : qq{
474 | 02:-:-:-:1 STS [writeAs + 4x<3*128>], loadA3;
475 | --:-:-:-:1 STS [writeAs + 4x<2*128>], loadA2;
476 | --:-:-:-:1 STS [writeAs + 4x<1*128>], loadA1;
477 | --:-:-:-:1 STS [writeAs + 4x<0*128>], loadA0;
478 |
479 | --:-:-:-:6 LEA trackA0.CC, partialK, trackA0, $dshiftA;
480 | --:-:-:-:0 IADD.X trackA1, trackA1, RZ;
481 | };
482 | +]
483 | [+
484 | our $dshiftB;
485 | return outerContigB() ? qq{
486 |
487 | 08:-:-:-:1 STS.128 [writeBs], loadB;
488 |
489 | --:-:-:-:6 LEA trackB0.CC, partialB, trackB0, $dshiftB;
490 | --:-:-:-:0 IADD.X trackB1, trackB1, RZ;
491 |
492 | } : qq{
493 | 08:-:-:-:1 STS [writeBs + 4x<3*128>], loadB3;
494 | --:-:-:-:1 STS [writeBs + 4x<2*128>], loadB2;
495 | --:-:-:-:1 STS [writeBs + 4x<1*128>], loadB1;
496 | --:-:-:-:1 STS [writeBs + 4x<0*128>], loadB0;
497 |
498 | --:-:-:-:6 LEA trackB0.CC, partialK, trackB0, $dshiftB;
499 | --:-:-:-:0 IADD.X trackB1, trackB1, RZ;
500 | };
501 | +]
502 |
503 | --:-:-:-:5 BAR.SYNC 0;
504 | [+
505 | return outerContigA() && outerContigB() ? q{
506 | --:-:-:-:1 LOP.XOR writeBs, writeBs, 4x;
507 | --:-:-:-:0 LOP.XOR writeAs, writeAs, 4x;
508 | } : q{
509 | --:-:-:-:1 IADD writeBs, writeBs, swapBuf;
510 | --:-:-:-:1 IADD writeAs, writeAs, swapBuf;
511 | --:-:-:-:0 IADD swapBuf, RZ, -swapBuf;
512 | };
513 | +]
514 | --:-:-:-:1 LDS.U.128 j0Ay0, [readAs + 4x<0*128 + 00>];
515 | --:-:-:-:1 LDS.U.128 j0Bx0, [readBs + 4x<0*128 + 00>];
516 | --:-:-:-:1 LDS.U.128 j0Ay4, [readAs + 4x<0*128 + 16>];
517 | --:-:1:-:1 LDS.U.128 j0Bx4, [readBs + 4x<0*128 + 32>];
518 | [+
519 | our ($dsizeA, $vsizeA, $dtypeA);
520 | return
521 | outerContigA() ?
522 | vecA() ? qq{
523 |
524 | --:-:2:-:1 \@P0 LDG.E.CI.$vsizeA loadA, [trackA];
525 |
526 | } : qq{
527 |
528 | --:-:-:-:2 \@P0 R2P PR, predsA, 0x3c;
529 | --:-:-:Y:d \@!P0 R2P PR, RZ, 0x3c;
530 | --:-:-:-:1 \@P2 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>];
531 | --:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>];
532 | --:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>];
533 | --:-:2:-:1 \@P5 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>];
534 | }
535 | :
536 | vecA() ? qq{
537 |
538 | --:-:3:-:1 \@P0 LDG.E.CI.$vsizeA loadA4, [trackA];
539 | --:-:3:-:1 \@!P0 LDS.U.$vsizeA loadA4, [addr_zero];
540 |
541 | } : qq{
542 |
543 | --:-:-:-:1 \@P0 LDG.E.CI.$dtypeA loadA4, [trackA + ${dsizeA}x<0>];
544 | --:-:-:-:1 \@P0 LDG.E.CI.$dtypeA loadA5, [trackA + ${dsizeA}x<1>];
545 | --:-:-:-:1 \@P0 LDG.E.CI.$dtypeA loadA6, [trackA + ${dsizeA}x<2>];
546 | --:-:3:-:1 \@P0 LDG.E.CI.$dtypeA loadA7, [trackA + ${dsizeA}x<3>];
547 | --:-:3:-:1 \@!P0 LDS.U.128 loadA4, [addr_zero];
548 |
549 | };
550 | +]
551 | [+
552 | our ($dsizeB, $vsizeB, $dtypeB);
553 | return
554 | outerContigB() ?
555 | vecB() ? qq{
556 |
557 | --:-:4:-:1 \@P1 LDG.E.CI.$vsizeB loadB, [trackB];
558 |
559 | } : qq{
560 |
561 | --:-:-:-:2 \@P1 R2P PR, predsB, 0x5c;
562 | --:-:-:Y:d \@!P1 R2P PR, RZ, 0x5c;
563 | --:-:-:-:1 \@P2 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>];
564 | --:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>];
565 | --:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>];
566 | --:-:4:-:1 \@P6 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>];
567 | }
568 | :
569 | vecB() ? qq{
570 |
571 | --:-:5:-:1 \@P1 LDG.E.CI.$vsizeB loadB4, [trackB];
572 | --:-:5:-:1 \@!P1 LDS.U.$vsizeB loadB4, [addr_zero];
573 |
574 | } : qq{
575 |
576 | --:-:-:-:1 \@P1 LDG.E.CI.$dtypeB loadB4, [trackB + ${dsizeB}x<0>];
577 | --:-:-:-:1 \@P1 LDG.E.CI.$dtypeB loadB5, [trackB + ${dsizeB}x<1>];
578 | --:-:-:-:1 \@P1 LDG.E.CI.$dtypeB loadB6, [trackB + ${dsizeB}x<2>];
579 | --:-:5:-:1 \@P1 LDG.E.CI.$dtypeB loadB7, [trackB + ${dsizeB}x<3>];
580 | --:-:5:-:1 \@!P1 LDS.U.128 loadB4, [addr_zero];
581 |
582 | };
583 | +]
584 | --:-:-:-:2 PSETP.AND.AND P1, PT, !PT, !PT, !PT;
585 |
586 | LOOP:
587 | --:-:-:-:1 ISETP.GE.AND P0, PT, k, RZ, PT;
588 | [+
589 | our ($dsizeA, $vsizeA, $dtypeA, $dsizeB, $vsizeB, $dtypeB);
590 | our %insert =
591 | (
592 | ( outerContigA() && outerContigB() ?
593 | () : (
594 | j0c10 => "--:-:-:-:1 PSETP.AND.AND P1, PT, !P1, PT, PT;\n",
595 | )
596 | ),
597 |
598 | ( outerContigA() ?
599 | (
600 | ( vecA() ?
601 | (
602 | j2c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, P5;\n",
603 | ) : (
604 | j2c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, PT;\n",
605 | j2c26 => "--:-:-:-:1 \@!P2 R2P PR, RZ, 0x3c;\n",
606 | j2c27 => "--:-:-:-:1 \@P2 R2P PR, predsA, 0x3c;\n",
607 | )
608 | ),
609 |
610 | ( A16() ?
611 | ( vecA() ?
612 | (
613 | j2c44 => "02:-:-:-:1 \@P0 F2F.F32.F16 loadA3, loadA1.H1;\n",
614 | j2c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA2, loadA1.H0;\n",
615 | j2c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA1, loadA0.H1;\n",
616 | j2c56 => "--:-:2:-:1 \@P0 F2F.F32.F16 loadA0, loadA0.H0;\n",
617 | ) : (
618 | j2c44 => "02:-:-:-:1 \@P0 F2F.F32.F16 loadA3, loadA3;\n",
619 | j2c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA2, loadA2;\n",
620 | j2c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA1, loadA1;\n",
621 | j2c56 => "--:-:2:-:1 \@P0 F2F.F32.F16 loadA0, loadA0;\n",
622 | )
623 | ) : ()
624 | ),
625 |
626 | j3c8 => "02:2:-:-:1 \@P0 STS.128 [writeAs], loadA;\n",
627 |
628 | j3c10 => "--:-:-:-:1 \@P2 IADD trackA0.CC, trackA0, cda8;\n",
629 | j3c15 => "--:-:-:-:1 \@P2 IADD.X trackA1, trackA1, RZ;\n",
630 |
631 | ( vecA() ?
632 | (
633 | j3c56 => "02:-:2:-:1 \@P2 LDG.E.CI.$vsizeA loadA, [trackA];\n",
634 | ) : (
635 | j3c56 => "02:-:-:-:1 \@P2 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>];\n",
636 | j3c58 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>];\n",
637 | j3c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>];\n",
638 | j3c62 => "--:-:2:-:1 \@P5 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>];\n",
639 | )
640 | ),
641 |
642 | # Not outerContigA
643 | ) : (
644 | j2c13 => "--:-:-:-:1 PSETP.AND.AND P2, PT, PT, P1, P5;\n",
645 | j2c26 => "--:-:-:-:1 ISETP.GE.AND P3, PT, k, 8, P2;\n",
646 | j2c27 => "--:-:-:-:1 ISETP.GE.AND P4, PT, k, 16, P2;\n",
647 |
648 | ( A16() ?
649 | (
650 | ( vecA() ?
651 | (
652 | j2c28 => "02:-:-:-:1 \@!P1 F2F.F32.F16 loadA3, loadA1.H1;\n",
653 | j2c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA2, loadA1.H0;\n",
654 | j2c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA1, loadA0.H1;\n",
655 | j2c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA0, loadA0.H0;\n",
656 |
657 | j2c44 => "04:-:-:-:1 \@P1 F2F.F32.F16 loadA3, loadA5.H1;\n",
658 | j2c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA2, loadA5.H0;\n",
659 | j2c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA1, loadA4.H1;\n",
660 | j2c56 => "--:-:2:-:1 \@P1 F2F.F32.F16 loadA0, loadA4.H0;\n",
661 | ) : (
662 | j2c28 => "02:-:-:-:1 \@!P1 F2F.F32.F16 loadA3, loadA3;\n",
663 | j2c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA2, loadA2;\n",
664 | j2c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA1, loadA1;\n",
665 | j2c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA0, loadA0;\n",
666 |
667 | j2c44 => "04:-:-:-:1 \@P1 F2F.F32.F16 loadA3, loadA7;\n",
668 | j2c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA2, loadA6;\n",
669 | j2c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA1, loadA5;\n",
670 | j2c56 => "--:-:2:-:1 \@P1 F2F.F32.F16 loadA0, loadA4;\n",
671 | ),
672 | ),
673 | j3c8 => "02:-:-:-:1 \@P0 STS [writeAs + 4x<0*128>], loadA0;\n",
674 | j3c10 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<1*128>], loadA1;\n",
675 | j3c12 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<2*128>], loadA2;\n",
676 | j3c14 => "--:2:-:-:1 \@P0 STS [writeAs + 4x<3*128>], loadA3;\n",
677 |
678 | ) : (
679 | j2c54 => "02:-:-:-:1 \@!P1 STS [writeAs + 4x<0*128>], loadA0;\n",
680 | j2c56 => "--:-:-:-:1 \@!P1 STS [writeAs + 4x<1*128>], loadA1;\n",
681 | j2c58 => "--:-:-:-:1 \@!P1 STS [writeAs + 4x<2*128>], loadA2;\n",
682 | j2c60 => "--:2:-:-:1 \@!P1 STS [writeAs + 4x<3*128>], loadA3;\n",
683 |
684 | j3c8 => "04:-:-:-:1 \@P1 STS [writeAs + 4x<0*128>], loadA4;\n",
685 | j3c10 => "--:-:-:-:1 \@P1 STS [writeAs + 4x<1*128>], loadA5;\n",
686 | j3c12 => "--:-:-:-:1 \@P1 STS [writeAs + 4x<2*128>], loadA6;\n",
687 | j3c14 => "--:3:-:-:1 \@P1 STS [writeAs + 4x<3*128>], loadA7;\n",
688 | )
689 | ),
690 |
691 | j3c15 => "--:-:-:-:1 \@P5 IADD trackA0.CC, trackA0, ${dsizeA}x<8>;\n",
692 | j3c20 => "--:-:-:-:1 \@P5 IADD.X trackA1, trackA1, RZ;\n",
693 |
694 | ( vecA() ?
695 | (
696 | j3c60 => "02:-:2:-:1 \@P3 LDG.E.CI.$vsizeA loadA0, [trackA + ${dsizeA}x<0>];\n",
697 | j3c62 => "04:-:3:-:1 \@P4 LDG.E.CI.$vsizeA loadA4, [trackA + ${dsizeA}x<8>];\n",
698 |
699 | ) : (
700 | j3c48 => "02:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x< 0>];\n",
701 | j3c50 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x< 1>];\n",
702 | j3c52 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x< 2>];\n",
703 | j3c54 => "--:-:2:-:1 \@P3 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x< 3>];\n",
704 |
705 | j3c56 => "04:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA4, [trackA + ${dsizeA}x< 8>];\n",
706 | j3c58 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA5, [trackA + ${dsizeA}x< 9>];\n",
707 | j3c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA6, [trackA + ${dsizeA}x<10>];\n",
708 | j3c62 => "--:-:3:-:1 \@P4 LDG.E.CI.$dtypeA loadA7, [trackA + ${dsizeA}x<11>];\n",
709 | )
710 | ),
711 | ),
712 | ),
713 |
714 | ( outerContigB() ?
715 | (
716 | ( vecB() ?
717 | (
718 | j5c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, P6;\n",
719 | ) : (
720 | j5c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, PT;\n",
721 | j5c26 => "--:-:-:-:1 \@!P2 R2P PR, RZ, 0x5c;\n",
722 | j5c27 => "--:-:-:-:1 \@P2 R2P PR, predsB, 0x5c;\n",
723 | )
724 | ),
725 |
726 | ( B16() ?
727 | ( vecB() ?
728 | (
729 | j5c44 => "08:-:-:-:1 \@P0 F2F.F32.F16 loadB3, loadB1.H1;\n",
730 | j5c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB2, loadB1.H0;\n",
731 | j5c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB1, loadB0.H1;\n",
732 | j5c56 => "--:-:4:-:1 \@P0 F2F.F32.F16 loadB0, loadB0.H0;\n",
733 | ) : (
734 | j5c44 => "08:-:-:-:1 \@P0 F2F.F32.F16 loadB3, loadB3;\n",
735 | j5c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB2, loadB2;\n",
736 | j5c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB1, loadB1;\n",
737 | j5c56 => "--:-:4:-:1 \@P0 F2F.F32.F16 loadB0, loadB0;\n",
738 | )
739 | ) : ()
740 | ),
741 |
742 | j6c8 => "08:4:-:-:1 \@P0 STS.128 [writeBs], loadB;\n",
743 |
744 | j6c10 => "--:-:-:-:1 \@P2 IADD trackB0.CC, trackB0, cdb8;\n",
745 | j6c15 => "--:-:-:-:1 \@P2 IADD.X trackB1, trackB1, RZ;\n",
746 |
747 | ( vecB() ?
748 | (
749 | j6c56 => "08:-:4:-:1 \@P2 LDG.E.CI.$vsizeB loadB, [trackB];\n",
750 | ) : (
751 | j6c56 => "08:-:-:-:1 \@P2 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>];\n",
752 | j6c58 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>];\n",
753 | j6c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>];\n",
754 | j6c62 => "--:-:4:-:1 \@P6 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>];\n",
755 | )
756 | ),
757 |
758 | # Not outerContigB
759 | ) : (
760 | j5c13 => "--:-:-:-:1 PSETP.AND.AND P2, PT, PT, P1, P6;\n",
761 | j5c26 => "--:-:-:-:1 ISETP.GE.AND P3, PT, k, 8, P2;\n",
762 | j5c27 => "--:-:-:-:1 ISETP.GE.AND P4, PT, k, 16, P2;\n",
763 |
764 | ( B16() ?
765 | (
766 | ( vecB() ?
767 | (
768 | j5c28 => "08:-:-:-:1 \@!P1 F2F.F32.F16 loadB3, loadB1.H1;\n",
769 | j5c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB2, loadB1.H0;\n",
770 | j5c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB1, loadB0.H1;\n",
771 | j5c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB0, loadB0.H0;\n",
772 |
773 | j5c44 => "10:-:-:-:1 \@P1 F2F.F32.F16 loadB3, loadB5.H1;\n",
774 | j5c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB2, loadB5.H0;\n",
775 | j5c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB1, loadB4.H1;\n",
776 | j5c56 => "--:-:4:-:1 \@P1 F2F.F32.F16 loadB0, loadB4.H0;\n",
777 | ) : (
778 | j5c28 => "08:-:-:-:1 \@!P1 F2F.F32.F16 loadB3, loadB3;\n",
779 | j5c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB2, loadB2;\n",
780 | j5c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB1, loadB1;\n",
781 | j5c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB0, loadB0;\n",
782 |
783 | j5c44 => "10:-:-:-:1 \@P1 F2F.F32.F16 loadB3, loadB7;\n",
784 | j5c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB2, loadB6;\n",
785 | j5c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB1, loadB5;\n",
786 | j5c56 => "--:-:4:-:1 \@P1 F2F.F32.F16 loadB0, loadB4;\n",
787 | ),
788 | ),
789 | j6c8 => "08:-:-:-:1 \@P0 STS [writeBs + 4x<0*128>], loadB0;\n",
790 | j6c10 => "--:-:-:-:1 \@P0 STS [writeBs + 4x<1*128>], loadB1;\n",
791 | j6c12 => "--:-:-:-:1 \@P0 STS [writeBs + 4x<2*128>], loadB2;\n",
792 | j6c14 => "--:4:-:-:1 \@P0 STS [writeBs + 4x<3*128>], loadB3;\n",
793 |
794 | ) : (
795 | j5c54 => "08:-:-:-:1 \@!P1 STS [writeBs + 4x<0*128>], loadB0;\n",
796 | j5c56 => "--:-:-:-:1 \@!P1 STS [writeBs + 4x<1*128>], loadB1;\n",
797 | j5c58 => "--:-:-:-:1 \@!P1 STS [writeBs + 4x<2*128>], loadB2;\n",
798 | j5c60 => "--:4:-:-:1 \@!P1 STS [writeBs + 4x<3*128>], loadB3;\n",
799 |
800 | j6c8 => "10:-:-:-:1 \@P1 STS [writeBs + 4x<0*128>], loadB4;\n",
801 | j6c10 => "--:-:-:-:1 \@P1 STS [writeBs + 4x<1*128>], loadB5;\n",
802 | j6c12 => "--:-:-:-:1 \@P1 STS [writeBs + 4x<2*128>], loadB6;\n",
803 | j6c14 => "--:5:-:-:1 \@P1 STS [writeBs + 4x<3*128>], loadB7;\n",
804 | )
805 | ),
806 |
807 | j6c15 => "--:-:-:-:1 \@P6 IADD trackB0.CC, trackB0, ${dsizeB}x<8>;\n",
808 | j6c20 => "--:-:-:-:1 \@P6 IADD.X trackB1, trackB1, RZ;\n",
809 |
810 | ( vecB() ?
811 | (
812 | j6c60 => "08:-:4:-:1 \@P3 LDG.E.CI.$vsizeB loadB0, [trackB + ${dsizeB}x<0>];\n",
813 | j6c62 => "10:-:5:-:1 \@P4 LDG.E.CI.$vsizeB loadB4, [trackB + ${dsizeB}x<8>];\n",
814 |
815 | ) : (
816 | j6c48 => "08:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x< 0>];\n",
817 | j6c50 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x< 1>];\n",
818 | j6c52 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x< 2>];\n",
819 | j6c54 => "--:-:4:-:1 \@P3 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x< 3>];\n",
820 |
821 | j6c56 => "10:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB4, [trackB + ${dsizeB}x< 8>];\n",
822 | j6c58 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB5, [trackB + ${dsizeB}x< 9>];\n",
823 | j6c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB6, [trackB + ${dsizeB}x<10>];\n",
824 | j6c62 => "--:-:5:-:1 \@P4 LDG.E.CI.$dtypeB loadB7, [trackB + ${dsizeB}x<11>];\n",
825 | )
826 | ),
827 | ),
828 | ),
829 |
830 | ( outerContigA() && outerContigB() ?
831 | (
832 | j6c63 => "--:-:-:-:5 BAR.SYNC 0;\n" .
833 | "--:-:-:-:1 \@P0 LOP.XOR readAs, readAs, 4x;\n" .
834 | "--:-:-:-:1 \@P0 LOP.XOR readBs, readBs, 4x;\n" .
835 | "--:-:-:-:1 \@P0 LOP.XOR writeAs, writeAs, 4x;\n" .
836 | "--:-:-:-:1 \@P0 LOP.XOR writeBs, writeBs, 4x;\n" .
837 | "--:-:-:-:1 IADD k, k, -8;\n",
838 | ) : (
839 | j6c63 => "--:-:-:-:5 BAR.SYNC 0;\n" .
840 | "--:-:-:-:1 \@P0 IADD readAs, readAs, -swapBuf;\n" .
841 | "--:-:-:-:1 \@P0 IADD readBs, readBs, -swapBuf;\n" .
842 | "--:-:-:-:1 \@P0 IADD writeAs, writeAs, swapBuf;\n" .
843 | "--:-:-:-:1 \@P0 IADD writeBs, writeBs, swapBuf;\n" .
844 | "--:-:-:-:1 \@P0 IADD swapBuf, RZ, -swapBuf;\n" .
845 | "--:-:-:-:1 IADD k, k, -8;\n",
846 | )
847 | ),
848 |
849 | j7c63 => "--:-:-:Y:5 \@P0 BRA.U LOOP;\n",
850 | );
851 |
852 | my @cOrder;
853 | my @swirl = ([0,2],[1,2],[1,0],[0,0]);
854 | my @y = (0,1,4,5);
855 | foreach my $x (0,2,4,6)
856 | {
857 | foreach my $y (@y)
858 | {
859 | push @cOrder, [$x + $_->[0], $y + $_->[1]] foreach @swirl;
860 | }
861 | @y = reverse @y;
862 | }
863 | my $out = '';
864 | foreach my $j (0 .. 7)
865 | {
866 | my $odd = $j & 1;
867 | my $nOdd = !$odd + 0;
868 | my $rsOffset = ($j + 1) % 8;
869 | my $rsPred = $j == 7 ? '@P0' : ' ';
870 | my $shiftA = outerContigA() || $rsOffset < 4 ? 0 : 16;
871 | my $shiftB = outerContigB() || $rsOffset < 4 ? 0 : 16;
872 |
873 | $insert{"j${j}c0"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy0, [readAs + 4x<%d*128 + 00 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftA;
874 | $insert{"j${j}c2"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dBx0, [readBs + 4x<%d*128 + 00 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftB;
875 | $insert{"j${j}c4"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy4, [readAs + 4x<%d*128 + 16 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftA;
876 | $insert{"j${j}c6"} = sprintf "--:-:1:-:1 %s LDS.U.128 j%dBx4, [readBs + 4x<%d*128 + 32 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftB;
877 |
878 | foreach my $c (0 .. 63)
879 | {
880 | my ($x,$y) = @{$cOrder[$c]};
881 |
882 | my $ins = $insert{"j${j}c$c"} || '';
883 |
884 | my $stall = $ins =~ /LDS|I2I|I2F|F2I|F2F|LDG|STS|BAR|BRA/ ? 0 : 1;
885 |
886 | my $yield = $c == 32 && $stall ? 'Y' : '-';
887 |
888 | my $wait = $c == 0 ? '01' : '--';
889 |
890 | my $ctrl = "$wait:-:-:$yield:$stall";
891 |
892 | $out .= sprintf "%s FFMA cx%dy%d, j%dBx%d, j%dAy%d, cx%dy%d;\n%s", $ctrl, $x,$y, $odd,$x, $odd,$y, $x,$y, $ins;
893 | }
894 | }
895 | return $out;
896 | +]
897 |
898 |
899 | --:-:-:-:1 MOV alpha, param_alpha;
900 | --:-:-:-:1 MOV beta, param_beta;
901 |
902 | // P4 = beta != 0
903 | --:-:-:-:1 ISETP.NE.AND P4, PT, RZ, param_beta, PT;
904 |
905 | --:-:-:-:1 SHR.U32 tid_32, tid, 5;
906 | --:-:-:-:1 LOP.AND tid_31, tid, 31;
907 | --:-:-:-:1 LOP.AND tid_96, tid, 96;
908 | --:-:-:-:1 LOP.AND tid_128, tid, 128;
909 |
910 | // readCs = (tid_32*64 + tid_31) * 4
911 | --:-:-:-:1 ISCADD readCs, tid_32, tid_31, 6;
912 | --:-:-:-:1 SHL readCs, readCs, 2;
913 |
914 | // cx = idx_B*128 + tid_128>>1 + tid_31;
915 | --:-:-:-:1 SHR.U32 cx00, tid_128, 1;
916 | --:-:-:-:1 LOP.OR cx00, tid_31, cx00;
917 | --:-:-:-:1 ISCADD cx00, idx_B, cx00, 7;
918 | --:-:-:-:1 IADD cx32, cx00, 32;
919 |
920 | --:-:-:-:1 ISETP.LT.AND P5, PT, cx00, param_n, P4;
921 | --:-:-:-:1 ISETP.LT.AND P6, PT, cx32, param_n, P4;
922 |
923 | // cy = idx_A*128 + tid_96
924 | --:-:-:-:1 ISCADD cy00, idx_A, tid_96, 7;
925 | --:-:-:-:1 IADD cy04, cy00, 4;
926 | --:-:-:-:1 IADD cy08, cy00, 8;
927 | --:-:-:-:1 IADD cy12, cy00, 12;
928 |
929 | --:-:-:-:1 MOV cdc, param_cdc;
930 | --:-:-:-:1 SHL cdc1, cdc, [+ dshiftC() + 0 +];
931 | --:-:-:-:1 SHL cdc4, cdc, [+ dshiftC() + 2 +];
932 | --:-:-:-:1 XMAD.LO2 cdc13, cdc, [+ dsizeC() * 13 +], RZ;
933 |
934 | // trackC += cy*cdc + cx;
935 | --:-:-:-:1 XMAD.LO tc, cy00, cdc, cx00, xmad_tc;
936 |
937 | --:-:-:-:1 LEA track00C0.CC, tc, param_C[0], [+ dshiftC() +];
938 | --:-:-:-:1 LEA.HI.X track00C1, tc, param_C[1], RZ, [+ dshiftC() +];
939 | --:-:-:-:1 IADD track04C0.CC, track00C0, cdc4;
940 | --:-:-:-:1 IADD.X track04C1, track00C1, RZ;
941 | --:-:-:-:1 IADD track08C0.CC, track04C0, cdc4;
942 | --:-:-:-:1 IADD.X track08C1, track04C1, RZ;
943 | --:-:-:-:1 IADD track12C0.CC, track08C0, cdc4;
944 | --:-:-:-:0 IADD.X track12C1, track08C1, RZ;
945 |
946 | --:-:-:-:1 FMUL c0, cx0y0, alpha;
947 | --:-:-:-:1 FMUL c1, cx1y0, alpha;
948 | --:-:-:-:1 FMUL c2, cx2y0, alpha;
949 | --:-:-:-:1 FMUL c3, cx3y0, alpha;
950 | --:-:-:-:1 FMUL c4, cx4y0, alpha;
951 | --:-:-:-:1 FMUL c5, cx5y0, alpha;
952 | --:-:-:-:1 FMUL c6, cx6y0, alpha;
953 | --:-:-:-:1 FMUL c7, cx7y0, alpha;
954 |
955 |
956 |
957 | --:-:-:-:5 CAL STORE_C;
958 |
959 | [+
960 | my $out;
961 | foreach my $y (1..7)
962 | {
963 | my $inc = $y == 4 ? 13 : 1;
964 |
965 | $out .= qq{
966 |
967 | --:-:-:-:1 ISETP.LT.AND P5, PT, cx00, param_n, P4;
968 | --:-:-:-:1 ISETP.LT.AND P6, PT, cx32, param_n, P4;
969 |
970 | 01:-:-:-:1 IADD track00C0.CC, track00C0, cdc$inc;
971 | --:-:-:-:1 IADD.X track00C1, track00C1, RZ;
972 | 02:-:-:-:1 IADD track04C0.CC, track04C0, cdc$inc;
973 | --:-:-:-:1 IADD.X track04C1, track04C1, RZ;
974 | 04:-:-:-:1 IADD track08C0.CC, track08C0, cdc$inc;
975 | --:-:-:-:1 IADD.X track08C1, track08C1, RZ;
976 | 08:-:-:-:1 IADD track12C0.CC, track12C0, cdc$inc;
977 | --:-:-:-:0 IADD.X track12C1, track12C1, RZ;
978 |
979 | --:-:-:-:1 IADD cy00, cy00, $inc;
980 | --:-:-:-:1 IADD cy04, cy04, $inc;
981 | --:-:-:-:1 IADD cy08, cy08, $inc;
982 | --:-:-:-:1 IADD cy12, cy12, $inc;
983 |
984 | --:-:-:-:1 FMUL c0, cx0y$y, alpha;
985 | --:-:-:-:1 FMUL c1, cx1y$y, alpha;
986 | --:-:-:-:1 FMUL c2, cx2y$y, alpha;
987 | --:-:-:-:1 FMUL c3, cx3y$y, alpha;
988 | --:-:-:-:1 FMUL c4, cx4y$y, alpha;
989 | --:-:-:-:1 FMUL c5, cx5y$y, alpha;
990 | --:-:-:-:1 FMUL c6, cx6y$y, alpha;
991 | --:-:-:-:1 FMUL c7, cx7y$y, alpha;
992 |
993 |
994 | --:-:-:-:5 CAL STORE_C;
995 | };
996 | }
997 | return $out;
998 | +]
999 |
1000 | --:-:-:-:5 EXIT;
1001 |
1002 |
1003 | STORE_C:
1004 |
1005 |
1006 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy00, param_m, P5;
1007 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy00, param_m, P6;
1008 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy04, param_m, P5;
1009 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy04, param_m, P6;
1010 |
1011 | --:-:-:-:1 @!P0 MOV b00_00, RZ;
1012 | --:-:-:-:1 @!P1 MOV b00_32, RZ;
1013 | --:-:-:-:1 @!P2 MOV b04_00, RZ;
1014 | --:-:-:-:1 @!P3 MOV b04_32, RZ;
1015 | --:-:-:-:1 @P0 LDG.E.CI.[+ dtypeC() +] b00_00, [track00C + 1x<$dsizeC * 00>];
1016 | --:-:-:-:1 @P1 LDG.E.CI.[+ dtypeC() +] b00_32, [track00C + 1x<$dsizeC * 32>];
1017 | --:-:-:-:1 @P2 LDG.E.CI.[+ dtypeC() +] b04_00, [track04C + 1x<$dsizeC * 00>];
1018 | --:-:5:-:1 @P3 LDG.E.CI.[+ dtypeC() +] b04_32, [track04C + 1x<$dsizeC * 32>];
1019 |
1020 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy08, param_m, P5;
1021 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy08, param_m, P6;
1022 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy12, param_m, P5;
1023 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy12, param_m, P6;
1024 |
1025 | --:-:-:-:1 ISETP.LT.AND P5, PT, cx00, param_n, PT;
1026 | --:-:-:-:1 ISETP.LT.AND P6, PT, cx32, param_n, PT;
1027 |
1028 | --:-:-:-:1 @!P0 MOV b08_00, RZ;
1029 | --:-:-:-:1 @!P1 MOV b08_32, RZ;
1030 | --:-:-:-:1 @!P2 MOV b12_00, RZ;
1031 | --:-:-:-:1 @!P3 MOV b12_32, RZ;
1032 | --:-:-:-:1 @P0 LDG.E.CI.[+ dtypeC() +] b08_00, [track08C + 1x<$dsizeC * 00>];
1033 | --:-:-:-:1 @P1 LDG.E.CI.[+ dtypeC() +] b08_32, [track08C + 1x<$dsizeC * 32>];
1034 | --:-:-:-:1 @P2 LDG.E.CI.[+ dtypeC() +] b12_00, [track12C + 1x<$dsizeC * 00>];
1035 | --:-:6:-:1 @P3 LDG.E.CI.[+ dtypeC() +] b12_32, [track12C + 1x<$dsizeC * 32>];
1036 |
1037 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy00, param_m, P5;
1038 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy00, param_m, P6;
1039 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy04, param_m, P5;
1040 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy04, param_m, P6;
1041 |
1042 |
1043 | --:-:-:-:1 STS.128 [writeCs + 4x<00>], c0;
1044 | --:-:-:-:1 STS.128 [writeCs + 4x<32>], c4;
1045 | --:-:-:-:1 LDS c00_00, [readCs + 4x<0*8*64 + 00 + 0*16>];
1046 | --:-:1:-:1 LDS c00_32, [readCs + 4x<0*8*64 + 32 + 0*16>];
1047 | --:-:-:-:1 LDS c04_00, [readCs + 4x<1*8*64 + 00 + 1*16>];
1048 | --:-:2:-:1 LDS c04_32, [readCs + 4x<1*8*64 + 32 + 1*16>];
1049 | --:-:-:-:1 LDS c08_00, [readCs + 4x<2*8*64 + 00 + 2*16>];
1050 | --:-:3:-:1 LDS c08_32, [readCs + 4x<2*8*64 + 32 + 2*16>];
1051 | --:-:-:-:1 LDS c12_00, [readCs + 4x<3*8*64 + 00 + 3*16>];
1052 | --:-:4:-:1 LDS c12_32, [readCs + 4x<3*8*64 + 32 + 3*16>];
1053 |
1054 | [+
1055 | C16() ? q{
1056 | 10:-:-:-:1 F2F.F32.F16 b00_00, b00_00;
1057 | --:-:-:-:1 F2F.F32.F16 b00_32, b00_32;
1058 | --:-:-:-:1 F2F.F32.F16 b04_00, b04_00;
1059 | --:-:5:-:1 F2F.F32.F16 b04_32, b04_32;
1060 | 20:-:-:-:1 F2F.F32.F16 b08_00, b08_00;
1061 | --:-:-:-:1 F2F.F32.F16 b08_32, b08_32;
1062 | --:-:-:-:1 F2F.F32.F16 b12_00, b12_00;
1063 | --:-:6:-:1 F2F.F32.F16 b12_32, b12_32;
1064 | } : '';
1065 | +]
1066 | 11:-:-:-:1 FFMA c00_00, b00_00, beta, c00_00;
1067 | --:-:-:-:1 FFMA c00_32, b00_32, beta, c00_32;
1068 | 02:-:-:-:1 FFMA c04_00, b04_00, beta, c04_00;
1069 | --:-:-:-:1 FFMA c04_32, b04_32, beta, c04_32;
1070 | 24:-:-:-:1 FFMA c08_00, b08_00, beta, c08_00;
1071 | --:-:-:-:1 FFMA c08_32, b08_32, beta, c08_32;
1072 | 08:-:-:-:1 FFMA c12_00, b12_00, beta, c12_00;
1073 | --:-:-:-:1 FFMA c12_32, b12_32, beta, c12_32;
1074 | [+
1075 | C16() ? q{
1076 | --:-:-:-:1 F2F.F16.F32 c00_00, c00_00;
1077 | --:-:1:-:1 F2F.F16.F32 c00_32, c00_32;
1078 | --:-:-:-:1 F2F.F16.F32 c04_00, c04_00;
1079 | --:-:2:-:1 F2F.F16.F32 c04_32, c04_32;
1080 | --:-:-:-:1 F2F.F16.F32 c08_00, c08_00;
1081 | --:-:3:-:1 F2F.F16.F32 c08_32, c08_32;
1082 | --:-:-:-:1 F2F.F16.F32 c12_00, c12_00;
1083 | --:-:4:-:1 F2F.F16.F32 c12_32, c12_32;
1084 | } : '';
1085 | +]
1086 |
1087 |
1088 | 01:-:-:-:1 @P0 STG.E.CG.[+ dtypeC() +] [track00C + 1x<$dsizeC * 00>], c00_00;
1089 | --:1:-:-:1 @P1 STG.E.CG.[+ dtypeC() +] [track00C + 1x<$dsizeC * 32>], c00_32;
1090 | 02:-:-:-:1 @P2 STG.E.CG.[+ dtypeC() +] [track04C + 1x<$dsizeC * 00>], c04_00;
1091 | --:2:-:-:1 @P3 STG.E.CG.[+ dtypeC() +] [track04C + 1x<$dsizeC * 32>], c04_32;
1092 |
1093 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy08, param_m, P5;
1094 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy08, param_m, P6;
1095 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy12, param_m, P5;
1096 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy12, param_m, P6;
1097 |
1098 | 04:-:-:-:1 @P0 STG.E.CG.[+ dtypeC() +] [track08C + 1x<$dsizeC * 00>], c08_00;
1099 | --:3:-:-:1 @P1 STG.E.CG.[+ dtypeC() +] [track08C + 1x<$dsizeC * 32>], c08_32;
1100 | 08:-:-:-:1 @P2 STG.E.CG.[+ dtypeC() +] [track12C + 1x<$dsizeC * 00>], c12_00;
1101 | --:4:-:-:1 @P3 STG.E.CG.[+ dtypeC() +] [track12C + 1x<$dsizeC * 32>], c12_32;
1102 |
1103 |
1104 | --:-:-:-:5 RET;
1105 |
--------------------------------------------------------------------------------
/src/c_interface.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 |
10 | typedef struct CUfunc_st *CUfunction;
11 | typedef struct CUmod_st *CUmodule;
12 | typedef struct CUstream_st *CUstream;
13 | typedef int CUdevice;
14 |
15 | //include the cuda function prototypes here so that we don't need to know
16 | //anything about cuda to compile this file - makes building with bazel
17 | //a million times easier
18 | extern "C" {
19 |
20 | int
21 | #ifdef _WIN32
22 | __stdcall
23 | #endif
24 | cuModuleLoadData(CUmodule *, const void *);
25 |
26 | int
27 | #ifdef _WIN32
28 | __stdcall
29 | #endif
30 | cuModuleGetFunction(CUfunction *, CUmodule, const char *);
31 |
32 | int
33 | #ifdef _WIN32
34 | __stdcall
35 | #endif
36 | cuDeviceGetAttribute(int *, int, CUdevice);
37 |
38 | int
39 | #ifdef _WIN32
40 | __stdcall
41 | #endif
42 | cuLaunchKernel(CUfunction,
43 | unsigned int,
44 | unsigned int,
45 | unsigned int,
46 | unsigned int,
47 | unsigned int,
48 | unsigned int,
49 | unsigned int,
50 | CUstream,
51 | void **,
52 | void **);
53 |
54 | int
55 | #ifdef _WIN32
56 | __stdcall
57 | #endif
58 | cuCtxGetDevice(CUdevice *);
59 |
60 | };
61 |
62 | #include "include/c_interface.h"
63 | #include "include/kernel_headers.h"
64 |
65 | namespace {
66 | #include "include/static_kernel_information.h"
67 |
68 | std::mutex load_kernel_mutex_;
69 |
70 | std::unordered_map kernels_;
71 |
72 | bool loadKernelsHelper(const std::unordered_map& kernels) {
73 | for (auto kernel : kernels) {
74 | if (kernels_.count(kernel.first) > 0)
75 | continue;
76 |
77 | CUmodule module;
78 |
79 | int res = cuModuleLoadData(&module, kernel.second);
80 | if (res != 0) {
81 | std::cerr << "Failed to load " << kernel.first << " " <<
82 | res << std::endl;
83 | return false;
84 | }
85 |
86 | CUfunction function;
87 |
88 | std::string kernel_name = kernel.first.substr(0, kernel.first.size() - 6);
89 |
90 | res = cuModuleGetFunction(&function, module, kernel_name.c_str());
91 | if (res != 0) {
92 | std::cerr << "Failed to extract " << kernel_name << " " <<
93 | res << std::endl;
94 | return false;
95 | }
96 |
97 | kernels_.insert(std::make_pair(kernel.first, function));
98 | }
99 |
100 | return true;
101 | }
102 |
103 | bool loadKernels(int major) {
104 | std::lock_guard lock(load_kernel_mutex_);
105 |
106 | if (major == 5)
107 | return loadKernelsHelper(kernels_50);
108 | else if (major == 6)
109 | return loadKernelsHelper(kernels_60);
110 | else {
111 | std::cerr << "Arch must be 5 or 6" << std::endl;
112 | return false;
113 | }
114 | }
115 |
116 | std::tuple getDeviceProperties(CUdevice& device) {
117 | int major, minor;
118 | int res = cuDeviceGetAttribute(&major, 75, device);
119 | if (res != 0)
120 | return std::make_tuple(res, -1, -1);
121 |
122 | res = cuDeviceGetAttribute(&minor, 76, device);
123 | if (res != 0)
124 | return std::make_tuple(res, -1, -1);
125 |
126 | return std::make_tuple(0, major, minor);
127 | }
128 |
129 | std::pair closest_divisor(int val, int div) {
130 | if (div == 2) {
131 | if ((val & 1) == 0) { return std::make_pair(2, val >> 1); }
132 | else { return std::make_pair(1, val); }
133 | }
134 | else if (div == 4) {
135 | if ((val & 3) == 0) { return std::make_pair(4, val >> 2); }
136 | else if ((val % 3) == 0) { return std::make_pair(3, val / 3); }
137 | else if ((val % 5) == 0) { return std::make_pair(5, val / 5); }
138 | else if ((val & 1) == 0) { return std::make_pair(2, val >> 1); }
139 | else if ((val % 7) == 0) { return std::make_pair(7, val / 7); }
140 | else { return std::make_pair(1, val); }
141 | }
142 | else {
143 | return std::make_pair(1, val);
144 | }
145 | }
146 |
147 | std::string get_op_string(bool a_t, bool b_t) {
148 | if (!a_t && !b_t) return "NN";
149 | else if (a_t && !b_t) return "TN";
150 | else if (!a_t && b_t) return "NT";
151 | else return "TT";
152 | }
153 |
154 | bool gemm(std::string precision, void *A, void *B, void *C,
155 | bool a_t, bool b_t,
156 | int m, int n, int k,
157 | int lda, int ldb, int ldc,
158 | float alpha, float beta,
159 | CUstream stream, unsigned int grid, unsigned int shared) {
160 | std::string kernel_op = get_op_string(a_t, b_t);
161 |
162 | if (grid >= selections[precision][kernel_op].size())
163 | return false;
164 |
165 | kernel_properties kp = selections[precision][kernel_op][grid];
166 |
167 | if (shared >= kp.shared_sizes.size())
168 | return false;
169 |
170 | bool vec4A, vec8A;
171 | bool vec4B, vec8B;
172 | if (a_t) {
173 | vec4A = (lda & 3) == 0 && (m & 3) == 0; //multiple of 4
174 | vec8A = (lda & 7) == 0 && (m & 7) == 0; //multiple of 8
175 | }
176 | else {
177 | vec4A = (lda & 3) == 0 && (k & 3) == 0; //multiple of 4
178 | vec8A = (lda & 7) == 0 && (k & 7) == 0; //multiple of 8
179 | }
180 |
181 | if (b_t) {
182 | vec4B = (ldb & 3) == 0 && (k & 3) == 0; //multiple of 4
183 | vec8B = (ldb & 7) == 0 && (k & 7) == 0; //multiple of 8
184 | }
185 | else {
186 | vec4B = (ldb & 3) == 0 && (n & 3) == 0; //multiple of 4
187 | vec8B = (ldb & 7) == 0 && (n & 7) == 0; //multiple of 8
188 | }
189 |
190 | bool vec4C = (ldc & 3) == 0 && (n & 3) == 0;
191 |
192 | bool vecA = (kp.vA == 4 && vec4A) || (kp.vA == 8 && vec8A);
193 | bool vecB = (kp.vB == 4 && vec4B) || (kp.vB == 8 && vec8B);
194 | bool vecC = kp.vC == 1 || vec4C;
195 |
196 | bool vec = vecA && vecB && vecC;
197 |
198 | CUdevice device;
199 | int res = cuCtxGetDevice(&device);
200 | if (res != 0)
201 | return false;
202 |
203 | bool success;
204 | int major, minor;
205 |
206 | std::tie(success, major, minor) = getDeviceProperties(device);
207 |
208 | if (success != 0)
209 | return false;
210 |
211 | std::string kernel_string;
212 | kernel_string.reserve(64);
213 |
214 | kernel_string += precision + "gemm_" + kp.tile_string +
215 | "_" + kernel_op;
216 | if (vec)
217 | kernel_string += "_vec";
218 |
219 | if (major == 5)
220 | kernel_string += "_sm_50";
221 | else if (major == 6)
222 | kernel_string += "_sm_60";
223 | else
224 | return false;
225 |
226 | auto kernel = kernels_.find(kernel_string);
227 | if (kernel == kernels_.end()) {
228 | loadKernels(major);
229 | kernel = kernels_.find(kernel_string);
230 | }
231 |
232 | int blk_A = (m + kp.tile_m - 1) / kp.tile_m;
233 | int blk_B = (n + kp.tile_n - 1) / kp.tile_n;
234 |
235 | int blk_a, blk_b;
236 | std::tie(blk_a, blk_A) = closest_divisor(blk_A, kp.div);
237 | std::tie(blk_b, blk_B) = closest_divisor(blk_B, kp.div);
238 |
239 | if (blk_a == 1)
240 | std::tie(blk_a, blk_A) = std::make_pair(blk_A, 1);
241 |
242 | void *args[13] = {&C, &A, &B, &alpha, &beta, &lda, &ldb, &ldc,
243 | &m, &n, &k, &blk_a, &blk_b};
244 |
245 | res = cuLaunchKernel(kernel->second, blk_a * blk_b, blk_B, blk_A,
246 | kp.threads, 1, 1,
247 | kp.shared_sizes[shared], stream, args, NULL);
248 |
249 | if (res != 0) {
250 | std::cerr << "Failed to execute " << kernel_string << " " <<
251 | res << std::endl;
252 | return false;
253 | }
254 |
255 | return true;
256 | }
257 |
258 | };
259 |
260 | bool get_grid_limits(char precision, bool a_t, bool b_t, unsigned int *grid)
261 | {
262 | std::string prec_string(1, precision);
263 |
264 | *grid = selections[prec_string][get_op_string(a_t, b_t)].size();
265 | return true;
266 | }
267 |
268 | bool get_shared_limits(char precision, bool a_t, bool b_t, unsigned int grid, unsigned int *shared) {
269 | std::string prec_string(1, precision);
270 |
271 | if (grid >= selections[prec_string][get_op_string(a_t, b_t)].size())
272 | return false;
273 |
274 | *shared = selections[prec_string][get_op_string(a_t, b_t)][grid].shared_sizes.size();
275 |
276 | return true;
277 | }
278 |
279 | bool openai_sgemm(float *A, float *B, float *C,
280 | bool a_t, bool b_t,
281 | int m, int n, int k,
282 | int lda, int ldb, int ldc,
283 | float alpha, float beta,
284 | CUstream stream, unsigned int grid, unsigned int shared) {
285 | return gemm("s",
286 | static_cast(A),
287 | static_cast(B),
288 | static_cast(C),
289 | a_t, b_t, m, n, k, lda, ldb, ldc,
290 | alpha, beta, stream, grid, shared);
291 | }
292 |
293 | bool openai_hgemm(uint16_t *A, uint16_t *B, uint16_t *C,
294 | bool a_t, bool b_t,
295 | int m, int n, int k,
296 | int lda, int ldb, int ldc,
297 | float alpha, float beta,
298 | CUstream stream, unsigned int grid, unsigned int shared) {
299 | return gemm("h",
300 | static_cast(A),
301 | static_cast(B),
302 | static_cast(C),
303 | a_t, b_t, m, n, k, lda, ldb, ldc,
304 | alpha, beta, stream, grid, shared);
305 | }
306 |
--------------------------------------------------------------------------------
/src/test.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 |
6 | #include "include/c_interface.h"
7 |
8 | /* Simple program to call all possible kernel variants to make sure the are
9 | callable and produce sane results. Not meant to check the correctness of
10 | the underlying routines */
11 |
12 | int main(void) {
13 | cudaFree(0);
14 |
15 | std::vector> ops = { {false, false}, {false, true},
16 | {true, false}, {true, true} };
17 | {
18 | float *A, *B, *C;
19 | const int size = 1024;
20 |
21 | cudaMalloc(&A, size * size * sizeof(float));
22 | cudaMalloc(&B, size * size * sizeof(float));
23 | cudaMalloc(&C, size * size * sizeof(float));
24 | float *C_host = (float *)malloc(size * size * sizeof(float));
25 |
26 | thrust::fill_n(thrust::device, A, size * size, 1.f);
27 | thrust::fill_n(thrust::device, B, size * size, 1.f);
28 |
29 | for (auto op : ops) {
30 | unsigned int grid;
31 | get_grid_limits('s', op.first, op.second, &grid);
32 | for (int g = 0; g < grid; ++g) {
33 | unsigned int shared;
34 | get_shared_limits('s', op.first, op.second, g, &shared);
35 | for (int s = 0; s < shared; ++s) {
36 | bool res = openai_sgemm(A, B, C, op.first, op.second, size, size, size,
37 | size, size, size, 1.0, 0.0, NULL, g, s);
38 | assert(res);
39 | cudaMemcpy(C_host, C, size * size * sizeof(float), cudaMemcpyDeviceToHost);
40 |
41 | assert(C_host[0] == 1024);
42 | }
43 | }
44 | }
45 |
46 | cudaFree(A);
47 | cudaFree(B);
48 | cudaFree(C);
49 | free(C_host);
50 | }
51 |
52 | {
53 | uint16_t *A, *B, *C;
54 | const int size = 1024;
55 |
56 | cudaMalloc(&A, size * size * sizeof(uint16_t));
57 | cudaMalloc(&B, size * size * sizeof(uint16_t));
58 | cudaMalloc(&C, size * size * sizeof(uint16_t));
59 | uint16_t *C_host = (uint16_t *)malloc(size * size * sizeof(uint16_t));
60 |
61 | thrust::fill_n(thrust::device, A, size * size, 0x3c00);
62 | thrust::fill_n(thrust::device, B, size * size, 0x3c00);
63 |
64 | for (auto op : ops) {
65 | unsigned int grid;
66 | get_grid_limits('h', op.first, op.second, &grid);
67 | for (int g = 0; g < grid; ++g) {
68 | unsigned int shared;
69 | get_shared_limits('h', op.first, op.second, g, &shared);
70 | for (int s = 0; s < shared; ++s) {
71 | bool res = openai_hgemm(A, B, C, op.first, op.second, size, size, size,
72 | size, size, size, 1.0, 0.0, NULL, g, s);
73 | assert(res);
74 | cudaMemcpy(C_host, C, size * size * sizeof(uint16_t), cudaMemcpyDeviceToHost);
75 |
76 | assert(C_host[0] == 25600);
77 | }
78 | }
79 | }
80 |
81 | cudaFree(A);
82 | cudaFree(B);
83 | cudaFree(C);
84 | free(C_host);
85 | }
86 |
87 | return 0;
88 | }
89 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import numpy as np
4 | import pycuda.driver as drv
5 | from neon.backends.nervanagpu import NervanaGPU
6 |
7 | from openai_gemm import matmul_test
8 |
9 | ng = NervanaGPU()
10 | print drv.Context.get_current().get_device().name()
11 |
12 | ones = 0
13 | out = 0
14 |
15 | # for i in range(1000): # np.float32, np.float16
16 |
17 | # matmul_test(ng, np.float32, "TN", 4096*4, 4096*4, 33, ones=ones, out=out) # update
18 |
19 | # if i % 100 == 0: print i
20 |
21 | # exit()
22 |
23 | small_1 = (1,2,3,4,5,6,7,8,9,16,32,64,65,72,120,127,128,192)
24 | medium_1 = (32,64,128,192,778,785,786,787,794)
25 | big_1 = (32,64,128,1532,1535,1536,1537,1540,3073,4095)
26 |
27 | small_2 = (8,16,32,64,72,96,120,128,192)
28 | medium_2 = (32,64,128,192,256,768-4,768-8,768,768+16,768+32)
29 | big_2 = (32,64,128,1536-12,1536-24,1536,1536+28,1536+32,3072,4096)
30 |
31 | for dtype in (np.float32, np.float16, ): # np.float32, np.float16
32 | print dtype
33 |
34 | for size in (small_1, small_2, medium_1, medium_2, big_1, big_2,): # small_1, small_2, medium_1, medium_2, big_1, big_2
35 | print size
36 |
37 | for K in size:
38 | print "K:", K
39 |
40 | for C in (size):
41 | print "C:", C
42 |
43 | for N in (size):
44 |
45 | matmul_test(ng, dtype, "NN", N, K, C, ones=ones, out=out) # fprop
46 | matmul_test(ng, dtype, "NT", N, C, K, ones=ones, out=out) # bprop
47 | matmul_test(ng, dtype, "TN", C, K, N, ones=ones, out=out) # update
48 | matmul_test(ng, dtype, "TT", K, N, C, ones=ones, out=out) # ------
49 |
--------------------------------------------------------------------------------