├── .gitignore ├── BUILD ├── LICENSE ├── Makefile ├── README.md ├── WORKSPACE ├── benchmark.py ├── gen_kernels.py ├── include ├── c_interface.h ├── kernel_headers.h ├── kernels │ ├── hgemm_128x128x8_NN_sm_50.h │ ├── hgemm_128x128x8_NN_sm_60.h │ ├── hgemm_128x128x8_NN_vec_sm_50.h │ ├── hgemm_128x128x8_NN_vec_sm_60.h │ ├── hgemm_128x128x8_NT_sm_50.h │ ├── hgemm_128x128x8_NT_sm_60.h │ ├── hgemm_128x128x8_NT_vec_sm_50.h │ ├── hgemm_128x128x8_NT_vec_sm_60.h │ ├── hgemm_128x128x8_TN_sm_50.h │ ├── hgemm_128x128x8_TN_sm_60.h │ ├── hgemm_128x128x8_TN_vec_sm_50.h │ ├── hgemm_128x128x8_TN_vec_sm_60.h │ ├── hgemm_128x128x8_TT_sm_50.h │ ├── hgemm_128x128x8_TT_sm_60.h │ ├── hgemm_128x128x8_TT_vec_sm_50.h │ ├── hgemm_128x128x8_TT_vec_sm_60.h │ ├── hgemm_16x64x64_NN_sm_50.h │ ├── hgemm_16x64x64_NN_sm_60.h │ ├── hgemm_16x64x64_NN_vec_sm_50.h │ ├── hgemm_16x64x64_NN_vec_sm_60.h │ ├── hgemm_16x64x64_NT_sm_50.h │ ├── hgemm_16x64x64_NT_sm_60.h │ ├── hgemm_16x64x64_NT_vec_sm_50.h │ ├── hgemm_16x64x64_NT_vec_sm_60.h │ ├── hgemm_32x32x32_NN_sm_50.h │ ├── hgemm_32x32x32_NN_sm_60.h │ ├── hgemm_32x32x32_NN_vec_sm_50.h │ ├── hgemm_32x32x32_NN_vec_sm_60.h │ ├── hgemm_32x32x32_NT_sm_50.h │ ├── hgemm_32x32x32_NT_sm_60.h │ ├── hgemm_32x32x32_NT_vec_sm_50.h │ ├── hgemm_32x32x32_NT_vec_sm_60.h │ ├── hgemm_32x32x32_TN_sm_50.h │ ├── hgemm_32x32x32_TN_sm_60.h │ ├── hgemm_32x32x32_TN_vec_sm_50.h │ ├── hgemm_32x32x32_TN_vec_sm_60.h │ ├── hgemm_32x32x32_TT_sm_50.h │ ├── hgemm_32x32x32_TT_sm_60.h │ ├── hgemm_32x32x32_TT_vec_sm_50.h │ ├── hgemm_32x32x32_TT_vec_sm_60.h │ ├── hgemm_32x32x64_NT_sm_50.h │ ├── hgemm_32x32x64_NT_sm_60.h │ ├── hgemm_32x32x64_NT_vec_sm_50.h │ ├── hgemm_32x32x64_NT_vec_sm_60.h │ ├── hgemm_32x64x32_NN_sm_50.h │ ├── hgemm_32x64x32_NN_sm_60.h │ ├── hgemm_32x64x32_NN_vec_sm_50.h │ ├── hgemm_32x64x32_NN_vec_sm_60.h │ ├── sgemm_128x128x8_NN_sm_50.h │ ├── sgemm_128x128x8_NN_sm_60.h │ ├── sgemm_128x128x8_NN_vec_sm_50.h │ ├── sgemm_128x128x8_NN_vec_sm_60.h │ ├── sgemm_128x128x8_NT_sm_50.h │ ├── sgemm_128x128x8_NT_sm_60.h │ ├── sgemm_128x128x8_NT_vec_sm_50.h │ ├── sgemm_128x128x8_NT_vec_sm_60.h │ ├── sgemm_128x128x8_TN_sm_50.h │ ├── sgemm_128x128x8_TN_sm_60.h │ ├── sgemm_128x128x8_TN_vec_sm_50.h │ ├── sgemm_128x128x8_TN_vec_sm_60.h │ ├── sgemm_128x128x8_TT_sm_50.h │ ├── sgemm_128x128x8_TT_sm_60.h │ ├── sgemm_128x128x8_TT_vec_sm_50.h │ ├── sgemm_128x128x8_TT_vec_sm_60.h │ ├── sgemm_32x32x32_NN_sm_50.h │ ├── sgemm_32x32x32_NN_sm_60.h │ ├── sgemm_32x32x32_NN_vec_sm_50.h │ ├── sgemm_32x32x32_NN_vec_sm_60.h │ ├── sgemm_32x32x32_NT_sm_50.h │ ├── sgemm_32x32x32_NT_sm_60.h │ ├── sgemm_32x32x32_NT_vec_sm_50.h │ ├── sgemm_32x32x32_NT_vec_sm_60.h │ ├── sgemm_32x32x32_TN_sm_50.h │ ├── sgemm_32x32x32_TN_sm_60.h │ ├── sgemm_32x32x32_TN_vec_sm_50.h │ ├── sgemm_32x32x32_TN_vec_sm_60.h │ ├── sgemm_32x32x32_TT_sm_50.h │ ├── sgemm_32x32x32_TT_sm_60.h │ ├── sgemm_32x32x32_TT_vec_sm_50.h │ └── sgemm_32x32x32_TT_vec_sm_60.h └── static_kernel_information.h ├── lib └── .gitignore ├── maxas ├── MaxAs │ ├── Cubin.pm │ ├── MaxAs.pm │ └── MaxAsGrammar.pm └── maxas.pl ├── openai_gemm.py ├── sass ├── hgemm_16x64x64_NN.sass ├── hgemm_16x64x64_NT.sass ├── hgemm_32x32x64_NT.sass ├── hgemm_32x64x32_NN.sass ├── xgemm_128x128x8.sass └── xgemm_32x32x32.sass ├── src ├── c_interface.cpp └── test.cu └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | temp/ 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | #test executable 93 | test 94 | 95 | bazel-* 96 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "openai_gemm", 5 | srcs = ["src/c_interface.cpp", 6 | "include/kernel_headers.h", 7 | "include/static_kernel_information.h", 8 | "include/kernels/hgemm_128x128x8_NN_sm_50.h", 9 | "include/kernels/hgemm_128x128x8_NN_sm_60.h", 10 | "include/kernels/hgemm_128x128x8_NN_vec_sm_50.h", 11 | "include/kernels/hgemm_128x128x8_NN_vec_sm_60.h", 12 | "include/kernels/hgemm_128x128x8_NT_sm_50.h", 13 | "include/kernels/hgemm_128x128x8_NT_vec_sm_50.h", 14 | "include/kernels/hgemm_128x128x8_NT_sm_60.h", 15 | "include/kernels/hgemm_128x128x8_NT_vec_sm_60.h", 16 | "include/kernels/hgemm_128x128x8_TN_sm_50.h", 17 | "include/kernels/hgemm_128x128x8_TN_vec_sm_50.h", 18 | "include/kernels/hgemm_128x128x8_TN_sm_60.h", 19 | "include/kernels/hgemm_128x128x8_TN_vec_sm_60.h", 20 | "include/kernels/hgemm_128x128x8_TT_sm_50.h", 21 | "include/kernels/hgemm_128x128x8_TT_vec_sm_50.h", 22 | "include/kernels/hgemm_128x128x8_TT_sm_60.h", 23 | "include/kernels/hgemm_128x128x8_TT_vec_sm_60.h", 24 | "include/kernels/hgemm_16x64x64_NN_sm_50.h", 25 | "include/kernels/hgemm_16x64x64_NN_vec_sm_50.h", 26 | "include/kernels/hgemm_16x64x64_NN_sm_60.h", 27 | "include/kernels/hgemm_16x64x64_NN_vec_sm_60.h", 28 | "include/kernels/hgemm_16x64x64_NT_sm_50.h", 29 | "include/kernels/hgemm_16x64x64_NT_vec_sm_50.h", 30 | "include/kernels/hgemm_16x64x64_NT_sm_60.h", 31 | "include/kernels/hgemm_16x64x64_NT_vec_sm_60.h", 32 | "include/kernels/hgemm_32x32x32_NN_sm_50.h", 33 | "include/kernels/hgemm_32x32x32_NN_sm_60.h", 34 | "include/kernels/hgemm_32x32x32_NN_vec_sm_50.h", 35 | "include/kernels/hgemm_32x32x32_NN_vec_sm_60.h", 36 | "include/kernels/hgemm_32x32x32_NT_sm_50.h", 37 | "include/kernels/hgemm_32x32x32_NT_vec_sm_50.h", 38 | "include/kernels/hgemm_32x32x32_NT_sm_60.h", 39 | "include/kernels/hgemm_32x32x32_NT_vec_sm_60.h", 40 | "include/kernels/hgemm_32x32x32_TN_sm_50.h", 41 | "include/kernels/hgemm_32x32x32_TN_vec_sm_50.h", 42 | "include/kernels/hgemm_32x32x32_TN_sm_60.h", 43 | "include/kernels/hgemm_32x32x32_TN_vec_sm_60.h", 44 | "include/kernels/hgemm_32x32x32_TT_sm_50.h", 45 | "include/kernels/hgemm_32x32x32_TT_vec_sm_50.h", 46 | "include/kernels/hgemm_32x32x32_TT_sm_60.h", 47 | "include/kernels/hgemm_32x32x32_TT_vec_sm_60.h", 48 | "include/kernels/hgemm_32x32x64_NT_sm_50.h", 49 | "include/kernels/hgemm_32x32x64_NT_vec_sm_50.h", 50 | "include/kernels/hgemm_32x32x64_NT_sm_60.h", 51 | "include/kernels/hgemm_32x32x64_NT_vec_sm_60.h", 52 | "include/kernels/hgemm_32x64x32_NN_sm_50.h", 53 | "include/kernels/hgemm_32x64x32_NN_vec_sm_50.h", 54 | "include/kernels/hgemm_32x64x32_NN_sm_60.h", 55 | "include/kernels/hgemm_32x64x32_NN_vec_sm_60.h", 56 | "include/kernels/sgemm_128x128x8_NN_sm_50.h", 57 | "include/kernels/sgemm_128x128x8_NN_sm_60.h", 58 | "include/kernels/sgemm_128x128x8_NN_vec_sm_50.h", 59 | "include/kernels/sgemm_128x128x8_NN_vec_sm_60.h", 60 | "include/kernels/sgemm_128x128x8_NT_sm_50.h", 61 | "include/kernels/sgemm_128x128x8_NT_vec_sm_50.h", 62 | "include/kernels/sgemm_128x128x8_NT_sm_60.h", 63 | "include/kernels/sgemm_128x128x8_NT_vec_sm_60.h", 64 | "include/kernels/sgemm_128x128x8_TN_sm_50.h", 65 | "include/kernels/sgemm_128x128x8_TN_vec_sm_50.h", 66 | "include/kernels/sgemm_128x128x8_TN_sm_60.h", 67 | "include/kernels/sgemm_128x128x8_TN_vec_sm_60.h", 68 | "include/kernels/sgemm_128x128x8_TT_sm_50.h", 69 | "include/kernels/sgemm_128x128x8_TT_vec_sm_50.h", 70 | "include/kernels/sgemm_128x128x8_TT_sm_60.h", 71 | "include/kernels/sgemm_128x128x8_TT_vec_sm_60.h", 72 | "include/kernels/sgemm_32x32x32_NN_sm_50.h", 73 | "include/kernels/sgemm_32x32x32_NN_sm_60.h", 74 | "include/kernels/sgemm_32x32x32_NN_vec_sm_50.h", 75 | "include/kernels/sgemm_32x32x32_NN_vec_sm_60.h", 76 | "include/kernels/sgemm_32x32x32_NT_sm_50.h", 77 | "include/kernels/sgemm_32x32x32_NT_vec_sm_50.h", 78 | "include/kernels/sgemm_32x32x32_NT_sm_60.h", 79 | "include/kernels/sgemm_32x32x32_NT_vec_sm_60.h", 80 | "include/kernels/sgemm_32x32x32_TN_sm_50.h", 81 | "include/kernels/sgemm_32x32x32_TN_vec_sm_50.h", 82 | "include/kernels/sgemm_32x32x32_TN_sm_60.h", 83 | "include/kernels/sgemm_32x32x32_TN_vec_sm_60.h", 84 | "include/kernels/sgemm_32x32x32_TT_sm_50.h", 85 | "include/kernels/sgemm_32x32x32_TT_vec_sm_50.h", 86 | "include/kernels/sgemm_32x32x32_TT_sm_60.h", 87 | "include/kernels/sgemm_32x32x32_TT_vec_sm_60.h", 88 | ], 89 | hdrs = ["include/c_interface.h"], 90 | ) 91 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016 OpenAI (http://openai.com), 2016 Google Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | gen_kernels: gen_kernels.py include/kernel_headers.h $(wildcard sass/*.sass) 2 | python gen_kernels.py 3 | 4 | lib/c_interface.o: gen_kernels src/c_interface.cpp include/kernel_headers.h include/static_kernel_information.h 5 | nvcc -c src/c_interface.cpp -o lib/c_interface.o -std=c++11 -I . 6 | 7 | test: src/test.cu lib/c_interface.o 8 | nvcc -o test src/test.cu lib/c_interface.o -std=c++11 -I . -lcuda 9 | 10 | clean: 11 | rm -f include/kernels/* 12 | rm -rf temp/ 13 | rm -f lib/c_interface.o 14 | rm -f test 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # openai-gemm 4 | Open single and half precision gemm implementations. The main speedups over cublas are with small minibatch and in fp16 data formats. 5 | 6 | ## Quick Install 7 | 8 | The demonstration code currently depends on [Nervana neon](https://github.com/NervanaSystems/neon): 9 | 10 | git clone git@github.com:NervanaSystems/neon.git 11 | cd neon 12 | make 13 | . .venv/bin/activate 14 | 15 | Clone and run this repo: 16 | 17 | git clone git@github.com:openai/openai-gemm.git 18 | 19 | Run the benchmark: 20 | ./benchmark.py 21 | 22 | Run the unit test: 23 | ./test.py 24 | 25 | 26 | ### DeepBench on Pascal TITAN X 27 | ( https://github.com/baidu-research/DeepBench ) 28 | 29 | 30 | | M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16| 31 | |------|------|------|---|---------|---------|--------|---------|---------|--------| 32 | | 16| 1760| 1760| NN| 2557| 2195| 1.2| 3507| 346| 10.1| 33 | | 32| 1760| 1760| NN| 5010| 1128| 4.4| 6814| 526| 13.0| 34 | | 64| 1760| 1760| NN| 6486| 4112| 1.6| 8235| 2801| 2.9| 35 | | 128| 1760| 1760| NN| 7068| 6931| 1.0| 9400| 5307| 1.8| 36 | | 7000| 1760| 1760| NN| 9968| 9584| 1.0| 10515| 9807| 1.1| 37 | | 16| 2048| 2048| NN| 2569| 1516| 1.7| 3619| 242| 15.0| 38 | | 32| 2048| 2048| NN| 5034| 1356| 3.7| 6576| 606| 10.8| 39 | | 64| 2048| 2048| NN| 6636| 2815| 2.4| 8285| 3241| 2.6| 40 | | 128| 2048| 2048| NN| 7316| 6373| 1.1| 9066| 5334| 1.7| 41 | | 7000| 2048| 2048| NN| 10081| 9900| 1.0| 11275| 9948| 1.1| 42 | | 16| 2560| 2560| NN| 2718| 1312| 2.1| 4312| 251| 17.2| 43 | | 32| 2560| 2560| NN| 5370| 1660| 3.2| 7525| 749| 10.0| 44 | | 64| 2560| 2560| NN| 7331| 2687| 2.7| 8436| 951| 8.9| 45 | | 128| 2560| 2560| NN| 8007| 5238| 1.5| 9277| 6123| 1.5| 46 | | 7000| 2560| 2560| NN| 10282| 10131| 1.0| 11027| 9974| 1.1| 47 | | 16| 4096| 4096| NN| 2695| 1110| 2.4| 4442| 266| 16.7| 48 | | 32| 4096| 4096| NN| 5266| 2264| 2.3| 7723| 758| 10.2| 49 | | 64| 4096| 4096| NN| 6942| 3922| 1.8| 8904| 1055| 8.4| 50 | | 128| 4096| 4096| NN| 8127| 5686| 1.4| 9711| 5681| 1.7| 51 | | 7000| 4096| 4096| NN| 10462| 10082| 1.0| 11152| 9991| 1.1| 52 | | 16| 1760| 1760| NT| 1719| 1095| 1.6| 2692| 290| 9.3| 53 | | 32| 1760| 1760| NT| 3316| 1312| 2.5| 5068| 447| 11.3| 54 | | 64| 1760| 1760| NT| 5247| 1955| 2.7| 7621| 1797| 4.2| 55 | | 128| 1760| 1760| NT| 6720| 3393| 2.0| 8886| 3342| 2.7| 56 | | 7000| 1760| 1760| NT| 9341| 8513| 1.1| 10085| 9635| 1.0| 57 | | 16| 2048| 2048| NT| 2442| 1231| 2.0| 3641| 299| 12.2| 58 | | 32| 2048| 2048| NT| 4801| 1251| 3.8| 5849| 468| 12.5| 59 | | 64| 2048| 2048| NT| 6317| 1967| 3.2| 7825| 3128| 2.5| 60 | | 128| 2048| 2048| NT| 7176| 5041| 1.4| 8616| 4843| 1.8| 61 | | 7000| 2048| 2048| NT| 9975| 9173| 1.1| 10741| 9560| 1.1| 62 | | 16| 2560| 2560| NT| 1834| 1208| 1.5| 3154| 297| 10.6| 63 | | 32| 2560| 2560| NT| 3610| 1436| 2.5| 5418| 584| 9.3| 64 | | 64| 2560| 2560| NT| 6083| 2815| 2.2| 8331| 1042| 8.0| 65 | | 128| 2560| 2560| NT| 7702| 3246| 2.4| 8857| 5259| 1.7| 66 | | 7000| 2560| 2560| NT| 9257| 7829| 1.2| 10659| 9548| 1.1| 67 | | 16| 4096| 4096| NT| 2546| 1297| 2.0| 4164| 309| 13.5| 68 | | 32| 4096| 4096| NT| 4992| 2290| 2.2| 8156| 775| 10.5| 69 | | 64| 4096| 4096| NT| 6746| 4157| 1.6| 8429| 1381| 6.1| 70 | | 128| 4096| 4096| NT| 7843| 5425| 1.4| 9298| 5527| 1.7| 71 | | 7000| 4096| 4096| NT| 9925| 6879| 1.4| 10630| 9784| 1.1| 72 | | 7133| 1760| 1760| TN| 9752| 10186| 1.0| 10517| 8912| 1.2| 73 | | 7133| 2048| 2048| TN| 10485| 10319| 1.0| 10674| 9608| 1.1| 74 | | 7133| 2560| 2560| TN| 10743| 11057| 1.0| 11195| 10059| 1.1| 75 | | 7133| 4096| 4096| TN| 10384| 10290| 1.0| 10980| 10558| 1.0| 76 | | 9124| 5124| 1760| NN| 9920| 9480| 1.0| 10580| 9743| 1.1| 77 | | 9124| 5124| 2048| NN| 10008| 9415| 1.1| 10602| 9796| 1.1| 78 | | 9124| 5124| 2560| NN| 9925| 9426| 1.1| 10586| 9850| 1.1| 79 | | 9124| 5124| 4096| NN| 9982| 9489| 1.1| 10580| 9472| 1.1| 80 | | 9124| 5124| 1760| NT| 9093| 3497| 2.6| 9302| 8692| 1.1| 81 | | 9124| 5124| 2048| NT| 9506| 6512| 1.5| 9506| 8883| 1.1| 82 | | 9124| 5124| 2560| NT| 8704| 3364| 2.6| 9855| 7733| 1.3| 83 | | 9124| 5124| 4096| NT| 9733| 6109| 1.6| 10278| 8760| 1.2| 84 | | 8457| 35| 1760| NN| 3343| 1020| 3.3| 3841| 736| 5.2| 85 | | 8457| 35| 2048| NN| 3419| 1996| 1.7| 4782| 803| 6.0| 86 | | 8457| 35| 2560| NN| 3415| 1072| 3.2| 3868| 789| 4.9| 87 | | 8457| 35| 4096| NN| 3743| 2009| 1.9| 4741| 804| 5.9| 88 | | 8457| 35| 1760| NT| 3574| 1970| 1.8| 4176| 1243| 3.4| 89 | | 8457| 35| 2048| NT| 4564| 3069| 1.5| 4818| 1255| 3.8| 90 | | 8457| 35| 2560| NT| 3598| 2062| 1.7| 3597| 1135| 3.2| 91 | | 8457| 35| 4096| NT| 4311| 2990| 1.4| 4927| 1303| 3.8| 92 | | 16| 7680| 2560| NN| 2683| 718| 3.7| 4449| 289| 15.4| 93 | | 32| 7680| 2560| NN| 5304| 3660| 1.4| 7837| 979| 8.0| 94 | | 64| 7680| 2560| NN| 7311| 4979| 1.5| 9310| 1274| 7.3| 95 | | 128| 7680| 2560| NN| 7931| 6109| 1.3| 9390| 6591| 1.4| 96 | | 16| 7680| 2560| NT| 1885| 1191| 1.6| 3401| 290| 11.7| 97 | | 32| 7680| 2560| NT| 3731| 1808| 2.1| 6373| 1004| 6.3| 98 | | 64| 7680| 2560| NT| 6274| 3509| 1.8| 8809| 1655| 5.3| 99 | | 128| 7680| 2560| NT| 7957| 2988| 2.7| 9246| 4695| 2.0| 100 | | 16| 3072| 1024| NN| 2277| 1295| 1.8| 3373| 282| 12.0| 101 | | 32| 3072| 1024| NN| 4494| 1798| 2.5| 6011| 807| 7.4| 102 | | 64| 3072| 1024| NN| 6272| 3046| 2.1| 6790| 917| 7.4| 103 | | 128| 3072| 1024| NN| 7364| 5436| 1.4| 7768| 5749| 1.4| 104 | | 16| 3072| 1024| NT| 2285| 1077| 2.1| 3439| 244| 14.1| 105 | | 32| 3072| 1024| NT| 4597| 1540| 3.0| 5645| 677| 8.3| 106 | | 64| 3072| 1024| NT| 6392| 2969| 2.2| 7555| 1204| 6.3| 107 | | 128| 3072| 1024| NT| 7460| 5058| 1.5| 8586| 5535| 1.6| 108 | | 7435| 3072| 1024| TN| 9829| 8804| 1.1| 10123| 9365| 1.1| 109 | | 5481| 7680| 2560| TN| 9448| 9309| 1.0| 9466| 9394| 1.0| 110 | 111 | 112 | ### DeepBench on DGX1 (P100) 113 | Note that the OpenAI kernels do not yet implement fp16x2 instructions. Even still it seems the current cublas hgemm implentation is only good for large dimensions. There are also accuracy considerations when accumulating large reductions in fp16. 114 | 115 | | M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16| 116 | |------|------|------|---|---------|---------|--------|---------|---------|--------| 117 | | 16| 1760| 1760| NN| 2595| 2048| 1.3| 2935| 463| 6.3| 118 | | 32| 1760| 1760| NN| 4963| 864| 5.7| 5766| 895| 6.4| 119 | | 64| 1760| 1760| NN| 7565| 3909| 1.9| 7760| 1711| 4.5| 120 | | 128| 1760| 1760| NN| 8140| 6053| 1.3| 8422| 4089| 2.1| 121 | | 7000| 1760| 1760| NN| 9653| 8722| 1.1| 9617| 16143| 0.6| 122 | | 16| 2048| 2048| NN| 2255| 1746| 1.3| 3211| 546| 5.9| 123 | | 32| 2048| 2048| NN| 4467| 1012| 4.4| 4533| 1019| 4.4| 124 | | 64| 2048| 2048| NN| 6618| 4198| 1.6| 6591| 2018| 3.3| 125 | | 128| 2048| 2048| NN| 8059| 5921| 1.4| 7936| 4667| 1.7| 126 | | 7000| 2048| 2048| NN| 9761| 9346| 1.0| 9910| 18715| 0.5| 127 | | 16| 2560| 2560| NN| 2883| 2108| 1.4| 4210| 685| 6.1| 128 | | 32| 2560| 2560| NN| 5701| 1279| 4.5| 5820| 1297| 4.5| 129 | | 64| 2560| 2560| NN| 8100| 6054| 1.3| 8099| 2558| 3.2| 130 | | 128| 2560| 2560| NN| 8308| 6799| 1.2| 8790| 5901| 1.5| 131 | | 7000| 2560| 2560| NN| 9740| 9538| 1.0| 9845| 18499| 0.5| 132 | | 16| 4096| 4096| NN| 3449| 1342| 2.6| 4299| 1069| 4.0| 133 | | 32| 4096| 4096| NN| 6863| 2045| 3.4| 6907| 2103| 3.3| 134 | | 64| 4096| 4096| NN| 8404| 4059| 2.1| 8248| 4183| 2.0| 135 | | 128| 4096| 4096| NN| 8224| 8039| 1.0| 8853| 8669| 1.0| 136 | | 7000| 4096| 4096| NN| 9818| 9519| 1.0| 10011| 18588| 0.5| 137 | | 16| 1760| 1760| NT| 2579| 1324| 1.9| 2763| 428| 6.4| 138 | | 32| 1760| 1760| NT| 5089| 878| 5.8| 5382| 857| 6.3| 139 | | 64| 1760| 1760| NT| 7501| 3017| 2.5| 7695| 1695| 4.5| 140 | | 128| 1760| 1760| NT| 8043| 5494| 1.5| 8192| 3426| 2.4| 141 | | 7000| 1760| 1760| NT| 9477| 7571| 1.3| 9355| 16113| 0.6| 142 | | 16| 2048| 2048| NT| 2267| 1276| 1.8| 3171| 504| 6.3| 143 | | 32| 2048| 2048| NT| 4484| 1026| 4.4| 4489| 1009| 4.4| 144 | | 64| 2048| 2048| NT| 6567| 3986| 1.6| 6551| 2018| 3.2| 145 | | 128| 2048| 2048| NT| 8019| 5825| 1.4| 7968| 4496| 1.8| 146 | | 7000| 2048| 2048| NT| 9625| 9373| 1.0| 9713| 17878| 0.5| 147 | | 16| 2560| 2560| NT| 2870| 1460| 2.0| 4256| 638| 6.7| 148 | | 32| 2560| 2560| NT| 5614| 1299| 4.3| 5705| 1271| 4.5| 149 | | 64| 2560| 2560| NT| 8014| 4402| 1.8| 8085| 2521| 3.2| 150 | | 128| 2560| 2560| NT| 8219| 5640| 1.5| 8240| 5137| 1.6| 151 | | 7000| 2560| 2560| NT| 9534| 9091| 1.0| 9735| 18025| 0.5| 152 | | 16| 4096| 4096| NT| 3366| 1547| 2.2| 4354| 1047| 4.2| 153 | | 32| 4096| 4096| NT| 6714| 2055| 3.3| 6859| 2093| 3.3| 154 | | 64| 4096| 4096| NT| 8297| 3445| 2.4| 8289| 4178| 2.0| 155 | | 128| 4096| 4096| NT| 8335| 7450| 1.1| 7911| 7973| 1.0| 156 | | 7000| 4096| 4096| NT| 9578| 9214| 1.0| 9877| 18073| 0.5| 157 | | 7133| 1760| 1760| TN| 9704| 9267| 1.0| 9506| 15605| 0.6| 158 | | 7133| 2048| 2048| TN| 9747| 9836| 1.0| 10012| 19110| 0.5| 159 | | 7133| 2560| 2560| TN| 9742| 9748| 1.0| 9805| 19107| 0.5| 160 | | 7133| 4096| 4096| TN| 9807| 9733| 1.0| 10122| 19559| 0.5| 161 | | 9124| 5124| 1760| NN| 9326| 9076| 1.0| 9631| 17496| 0.6| 162 | | 9124| 5124| 2048| NN| 9414| 9054| 1.0| 9602| 17523| 0.5| 163 | | 9124| 5124| 2560| NN| 9353| 9041| 1.0| 9698| 17380| 0.6| 164 | | 9124| 5124| 4096| NN| 9370| 9051| 1.0| 9689| 17617| 0.5| 165 | | 9124| 5124| 1760| NT| 9124| 8746| 1.0| 9524| 16777| 0.6| 166 | | 9124| 5124| 2048| NT| 9294| 8817| 1.1| 9641| 16935| 0.6| 167 | | 9124| 5124| 2560| NT| 9221| 8499| 1.1| 9637| 16820| 0.6| 168 | | 9124| 5124| 4096| NT| 9270| 8961| 1.0| 9568| 17080| 0.6| 169 | | 8457| 35| 1760| NN| 3301| 2233| 1.5| 4505| 3154| 1.4| 170 | | 8457| 35| 2048| NN| 3265| 3066| 1.1| 4501| 3335| 1.3| 171 | | 8457| 35| 2560| NN| 3127| 2300| 1.4| 4516| 3135| 1.4| 172 | | 8457| 35| 4096| NN| 3257| 3272| 1.0| 4729| 3485| 1.4| 173 | | 8457| 35| 1760| NT| 4563| 3142| 1.5| 4612| 2998| 1.5| 174 | | 8457| 35| 2048| NT| 4554| 3202| 1.4| 4601| 3109| 1.5| 175 | | 8457| 35| 2560| NT| 4567| 3144| 1.5| 4654| 3039| 1.5| 176 | | 8457| 35| 4096| NT| 4353| 3415| 1.3| 4457| 3257| 1.4| 177 | | 16| 7680| 2560| NN| 3668| 1200| 3.1| 5020| 1236| 4.1| 178 | | 32| 7680| 2560| NN| 7245| 3385| 2.1| 7519| 2465| 3.1| 179 | | 64| 7680| 2560| NN| 8440| 5210| 1.6| 8349| 4910| 1.7| 180 | | 128| 7680| 2560| NN| 8765| 4872| 1.8| 9131| 11349| 0.8| 181 | | 16| 7680| 2560| NT| 3229| 1515| 2.1| 5032| 1157| 4.3| 182 | | 32| 7680| 2560| NT| 6640| 2721| 2.4| 6810| 2307| 3.0| 183 | | 64| 7680| 2560| NT| 8282| 5113| 1.6| 8362| 4494| 1.9| 184 | | 128| 7680| 2560| NT| 8763| 4646| 1.9| 8617| 9159| 0.9| 185 | | 16| 3072| 1024| NN| 2929| 1717| 1.7| 3335| 750| 4.4| 186 | | 32| 3072| 1024| NN| 5801| 1399| 4.1| 6116| 1420| 4.3| 187 | | 64| 3072| 1024| NN| 6958| 4340| 1.6| 6923| 2814| 2.5| 188 | | 128| 3072| 1024| NN| 8047| 6492| 1.2| 7769| 6302| 1.2| 189 | | 16| 3072| 1024| NT| 2990| 1068| 2.8| 3384| 705| 4.8| 190 | | 32| 3072| 1024| NT| 5834| 1429| 4.1| 6021| 1411| 4.3| 191 | | 64| 3072| 1024| NT| 6921| 3500| 2.0| 6893| 2819| 2.4| 192 | | 128| 3072| 1024| NT| 7918| 6034| 1.3| 7876| 5760| 1.4| 193 | | 7435| 3072| 1024| TN| 9367| 9391| 1.0| 9559| 17234| 0.6| 194 | | 5481| 7680| 2560| TN| 9672| 9520| 1.0| 9967| 18832| 0.5| 195 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/openai-gemm/db5da9e6656f6dfa55a18df77cee2e2f95d7ee9c/WORKSPACE -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import pycuda.driver as drv 5 | from neon.backends.nervanagpu import NervanaGPU 6 | 7 | from openai_gemm import matmul 8 | 9 | 10 | ng = NervanaGPU() 11 | print drv.Context.get_current().get_device().name() 12 | 13 | config = ( 14 | # m, n, k, AT, BT (row order) 15 | ( 16, 1760, 1760, False, False), 16 | ( 32, 1760, 1760, False, False), 17 | ( 64, 1760, 1760, False, False), 18 | ( 128, 1760, 1760, False, False), 19 | ( 7000, 1760, 1760, False, False), 20 | ( 16, 2048, 2048, False, False), 21 | ( 32, 2048, 2048, False, False), 22 | ( 64, 2048, 2048, False, False), 23 | ( 128, 2048, 2048, False, False), 24 | ( 7000, 2048, 2048, False, False), 25 | ( 16, 2560, 2560, False, False), 26 | ( 32, 2560, 2560, False, False), 27 | ( 64, 2560, 2560, False, False), 28 | ( 128, 2560, 2560, False, False), 29 | ( 7000, 2560, 2560, False, False), 30 | ( 16, 4096, 4096, False, False), 31 | ( 32, 4096, 4096, False, False), 32 | ( 64, 4096, 4096, False, False), 33 | ( 128, 4096, 4096, False, False), 34 | ( 7000, 4096, 4096, False, False), 35 | ( 16, 1760, 1760, False, True), 36 | ( 32, 1760, 1760, False, True), 37 | ( 64, 1760, 1760, False, True), 38 | ( 128, 1760, 1760, False, True), 39 | ( 7000, 1760, 1760, False, True), 40 | ( 16, 2048, 2048, False, True), 41 | ( 32, 2048, 2048, False, True), 42 | ( 64, 2048, 2048, False, True), 43 | ( 128, 2048, 2048, False, True), 44 | ( 7000, 2048, 2048, False, True), 45 | ( 16, 2560, 2560, False, True), 46 | ( 32, 2560, 2560, False, True), 47 | ( 64, 2560, 2560, False, True), 48 | ( 128, 2560, 2560, False, True), 49 | ( 7000, 2560, 2560, False, True), 50 | ( 16, 4096, 4096, False, True), 51 | ( 32, 4096, 4096, False, True), 52 | ( 64, 4096, 4096, False, True), 53 | ( 128, 4096, 4096, False, True), 54 | ( 7000, 4096, 4096, False, True), 55 | ( 7133, 1760, 1760, True , False), 56 | ( 7133, 2048, 2048, True , False), 57 | ( 7133, 2560, 2560, True , False), 58 | ( 7133, 4096, 4096, True , False), 59 | ( 9124, 5124, 1760, False, False), 60 | ( 9124, 5124, 2048, False, False), 61 | ( 9124, 5124, 2560, False, False), 62 | ( 9124, 5124, 4096, False, False), 63 | ( 9124, 5124, 1760, False, True), 64 | ( 9124, 5124, 2048, False, True), 65 | ( 9124, 5124, 2560, False, True), 66 | ( 9124, 5124, 4096, False, True), 67 | ( 8457, 35, 1760, False, False), 68 | ( 8457, 35, 2048, False, False), 69 | ( 8457, 35, 2560, False, False), 70 | ( 8457, 35, 4096, False, False), 71 | ( 8457, 35, 1760, False, True), 72 | ( 8457, 35, 2048, False, True), 73 | ( 8457, 35, 2560, False, True), 74 | ( 8457, 35, 4096, False, True), 75 | ( 16, 7680, 2560, False, False), 76 | ( 32, 7680, 2560, False, False), 77 | ( 64, 7680, 2560, False, False), 78 | ( 128, 7680, 2560, False, False), 79 | ( 16, 7680, 2560, False, True), 80 | ( 32, 7680, 2560, False, True), 81 | ( 64, 7680, 2560, False, True), 82 | ( 128, 7680, 2560, False, True), 83 | ( 16, 3072, 1024, False, False), 84 | ( 32, 3072, 1024, False, False), 85 | ( 64, 3072, 1024, False, False), 86 | ( 128, 3072, 1024, False, False), 87 | ( 16, 3072, 1024, False, True), 88 | ( 32, 3072, 1024, False, True), 89 | ( 64, 3072, 1024, False, True), 90 | ( 128, 3072, 1024, False, True), 91 | ( 7435, 3072, 1024, True , False), 92 | ( 5481, 7680, 2560, True , False), 93 | 94 | # (60000, 32, 32, True , False), 95 | # (60000, 256, 256, True , False), 96 | 97 | # ( 4096, 4096, 32, True , False), 98 | # ( 3456, 3456, 32, True , False), 99 | # ( 896, 896, 32, True , False), 100 | ) 101 | 102 | print "| M| N| K| Op|OpenAI_32|cuBLAS_32|ratio_32|OpenAI_16|cuBLAS_16|ratio_16|" 103 | print "|------|------|------|---|---------|---------|--------|---------|---------|--------|" 104 | 105 | for m, n, k, at, bt in config: 106 | 107 | dimA = (k,m) if at else (m,k) 108 | dimB = (n,k) if bt else (k,n) 109 | dimC = (m,n) 110 | 111 | opA = 'T' if at else 'N' 112 | opB = 'T' if bt else 'N' 113 | op = opA + opB 114 | 115 | dtype_data = list() 116 | 117 | for dtype in ( np.float32, np.float16 ): #np.float32, np.float16, 118 | 119 | A = ng.empty(dimA, dtype=dtype) 120 | B = ng.empty(dimB, dtype=dtype) 121 | C = ng.empty(dimC, dtype=dtype) 122 | 123 | if at: A = A.T 124 | if bt: B = B.T 125 | 126 | data = matmul(A, B, C, bench=True) 127 | 128 | # if dtype is np.float16: 129 | # print "" 130 | # for d in sorted(data): 131 | # print "%7.3f %5.0f %22s %5d" % d 132 | 133 | cublas = data.pop() 134 | openai = sorted(data)[0] 135 | 136 | text = "%9.0f|%9.0f|%8.1f" % (openai[1], cublas[1], openai[1] / cublas[1]) 137 | 138 | dtype_data.append(text) 139 | 140 | 141 | print "|%6d|%6d|%6d|%3s|%s|" % (m, n, k, op, "|".join(dtype_data)) 142 | -------------------------------------------------------------------------------- /gen_kernels.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path 3 | import re 4 | import subprocess 5 | import sys 6 | import time 7 | 8 | base_dir = os.path.dirname(__file__) 9 | maxas_dir = os.path.join(base_dir, "maxas") 10 | sass_dir = os.path.join(base_dir, "sass") 11 | 12 | # Tile sizes: m, n, k, vA,vB,vC div, op (dynamic shared options) 13 | k128x128x8 = (128, 128, 8, 4, 4, 1, 2, 0, (0,)) 14 | k32x32x32 = ( 32, 32, 32, 4, 4, 1, 4, 0, (0, 2**14)) 15 | k32x64x32_NN = ( 32, 64, 32, 8, 4, 4, 4, 1, (0, 2**13)) 16 | k32x32x64_NT = ( 32, 32, 64, 8, 8, 4, 4, 1, (0,)) 17 | k16x64x64_NN = ( 16, 64, 64, 8, 4, 4, 4, 1, (0,)) 18 | k16x64x64_NT = ( 16, 64, 64, 8, 8, 4, 4, 1, (0,)) 19 | 20 | selections = { 21 | "s" : { 22 | "TN" : (k128x128x8, k32x32x32), 23 | "NN" : (k128x128x8, k32x32x32), 24 | "NT" : (k128x128x8, k32x32x32), 25 | "TT" : (k128x128x8, k32x32x32), 26 | }, 27 | "h" : { 28 | "TN" : (k128x128x8, k32x32x32), 29 | "NN" : (k128x128x8, k32x32x32, k32x64x32_NN, k16x64x64_NN), 30 | "NT" : (k128x128x8, k32x32x32, k32x32x64_NT, k16x64x64_NT), 31 | "TT" : (k128x128x8, k32x32x32), 32 | }, 33 | } 34 | 35 | kernels = { 36 | # Generic gemm tiles 37 | "sgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "s"} }, 38 | "hgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "h"} }, 39 | "sgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "s"} }, 40 | "hgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "h"} }, 41 | 42 | # Custom hgemm tiles designed for small minibatch RNNs 43 | "hgemm_32x64x32_NN": {"threads": 128, "sass": "hgemm_32x64x32_NN", "params": "xgemm", "share": "32*33*2 + 64*32*2 + 4" }, 44 | "hgemm_32x32x64_NT": {"threads": 128, "sass": "hgemm_32x32x64_NT", "params": "xgemm", "share": "32*65*4 + 4" }, 45 | "hgemm_16x64x64_NN": {"threads": 128, "sass": "hgemm_16x64x64_NN", "params": "xgemm", "share": "(16*64 + 32)*2 + 64*64*2 + 4" }, 46 | "hgemm_16x64x64_NT": {"threads": 128, "sass": "hgemm_16x64x64_NT", "params": "xgemm", "share": "(16*64 + 32)*2 + (64*64 + 32)*2 + 4" }, 47 | } 48 | 49 | _params = { 50 | "xgemm": [ 51 | "float* param_C", 52 | "float* param_A", 53 | "float* param_B", 54 | "float param_alpha", 55 | "float param_beta", 56 | "unsigned param_cda", 57 | "unsigned param_cdb", 58 | "unsigned param_cdc", 59 | "unsigned param_m", 60 | "unsigned param_n", 61 | "unsigned param_k", 62 | "unsigned param_blk_a", 63 | "unsigned param_blk_b", 64 | ], 65 | } 66 | 67 | _space_re = re.compile(r"\s+") 68 | 69 | _share_template = r""" 70 | .shared .align 4 .b32 share[{0}]; 71 | """ 72 | 73 | _kernel_template = r""" 74 | .version {6} 75 | .target {0} 76 | .address_size 64 77 | 78 | // args: {5} 79 | 80 | .visible .entry {1}( 81 | {2} 82 | ) 83 | .reqntid {3} 84 | {{ 85 | {4} 86 | ret; 87 | }} 88 | """ 89 | 90 | 91 | def _get_cache_dir(subdir=None): 92 | cache_dir = 'temp/' 93 | 94 | if subdir: 95 | subdir = subdir if isinstance(subdir, list) else [subdir] 96 | cache_dir = os.path.join(cache_dir, *subdir) 97 | 98 | if not os.path.exists(cache_dir): 99 | os.makedirs(cache_dir) 100 | 101 | return cache_dir 102 | 103 | 104 | def get_ptx_file(kernel_spec, kernel_name, arch, ptx_ver): 105 | ptx_dir = _get_cache_dir([arch, 'ptx']) 106 | 107 | thread_spec = kernel_spec["threads"] 108 | args_spec = str(kernel_spec.get("args","")) 109 | param_spec = _params[kernel_spec["params"]] 110 | 111 | kernel_params = [] 112 | for p in param_spec: 113 | ptype, pname = _space_re.split(p) 114 | 115 | if ptype[-1] == '*': 116 | ptype = '.u64' 117 | elif ptype == 'float': 118 | ptype = '.f32' 119 | else: 120 | ptype = '.u32' 121 | 122 | kernel_params.append(" .param %s %s" % (ptype, pname)) 123 | 124 | kernel_params = ",\n".join(kernel_params) 125 | 126 | if "share" in kernel_spec: 127 | share = _share_template.format(eval(kernel_spec["share"])) 128 | else: 129 | share = "" 130 | 131 | kernel_text = _kernel_template.format(arch, kernel_name, kernel_params, thread_spec, share, args_spec, ptx_ver) 132 | kernel_ptx = os.path.join(ptx_dir, kernel_name + ".ptx") 133 | 134 | current_text = "" 135 | if os.path.exists(kernel_ptx): 136 | f = open(kernel_ptx, "r") 137 | current_text = f.read() 138 | f.close() 139 | # only write out the kernel if text has changed. 140 | if kernel_text != current_text: 141 | f = open(kernel_ptx, "w") 142 | f.write(kernel_text) 143 | f.close() 144 | 145 | return kernel_ptx 146 | 147 | 148 | include_re = re.compile(r'^') 149 | 150 | 151 | def run_command(cmdlist): 152 | cmd = " ".join(cmdlist) 153 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 154 | out, err = proc.communicate() 155 | if proc.returncode: 156 | raise RuntimeError("Error(%d):\n%s\n%s" % (proc.returncode, cmd, err)) 157 | 158 | 159 | def get_kernel(base_name, major, minor, options=None): 160 | if major < 5: 161 | raise RuntimeError("sass kernels require Maxwell or greater class hardware") 162 | elif major >= 7: 163 | raise RuntimeError("sm version 7 or greater is not supported") 164 | 165 | arch = "sm_%d%d" % (major, minor) 166 | 167 | libprefix = "PERL5LIB=%s" % maxas_dir 168 | maxas_i = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -i -w"] 169 | maxas_p = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -p"] 170 | 171 | kernel_spec = kernels[base_name] 172 | kernel_name = base_name 173 | 174 | # static options 175 | if "args" in kernel_spec: 176 | for pair in kernel_spec["args"].items(): 177 | maxas_i.append("-D%s %s" % pair) 178 | maxas_p.append("-D%s %s" % pair) 179 | 180 | # dynamic options 181 | if options is not None: 182 | for opt in options: 183 | if type(opt) is tuple: 184 | maxas_i.append("-D%s %s" % opt) 185 | maxas_p.append("-D%s %s" % opt) 186 | kernel_name += "_%s%s" % opt 187 | else: 188 | maxas_i.append("-D%s 1" % opt) 189 | maxas_p.append("-D%s 1" % opt) 190 | kernel_name += "_%s" % opt 191 | 192 | maxas_i.insert(2, "-k " + kernel_name) 193 | 194 | sass_name = kernel_spec["sass"] + ".sass" 195 | cubin_name = kernel_name + ".cubin" 196 | cubin_dir = _get_cache_dir([arch, 'cubin']) 197 | header_dir = os.path.join(base_dir, "include/kernels") 198 | 199 | ptx_version = "4.2" if major < 6 else "5.0" 200 | ptx_file = get_ptx_file(kernel_spec, kernel_name, arch, ptx_version) 201 | cubin_file = os.path.join(cubin_dir, cubin_name) 202 | sass_file = os.path.join(sass_dir, sass_name) 203 | header_file = os.path.join(header_dir, kernel_name + "_" + arch + ".h") 204 | 205 | if not os.path.exists(sass_file): 206 | raise RuntimeError("Missing: %s for kernel: %s" % (sass_name, kernel_name)) 207 | 208 | # build the cubin and run maxas in the same command 209 | # we don't want the chance of a generated cubin not processed by maxas (in case user hits ^C in between these steps) 210 | command_string = [ "ptxas -v -arch", arch, "-o", cubin_file, ptx_file, ";" ] + maxas_i + [sass_file, cubin_file] 211 | run_command(command_string) 212 | cubin_mtime = time.time() 213 | 214 | # now also generate the associated header file containing the cubin 215 | with open(cubin_file, 'rb') as input_file: 216 | with open(header_file, 'wb') as output_file: 217 | output_file.write('const uint8_t %s[] = {' % (kernel_name + "_" + arch)) 218 | byte = input_file.read(1) 219 | count = 0 220 | while byte: 221 | if count % 12 == 0: 222 | output_file.write('\n ') 223 | output_file.write(' 0x' + byte.encode('hex') + ',') 224 | byte = input_file.read(1) 225 | count += 1 226 | output_file.write('\n};') 227 | 228 | 229 | def gen_kernels(): 230 | for prefix in ['s', 'h']: 231 | for op in ['NN', 'NT', 'TN', 'TT']: 232 | for tileM, tileN, tileK, vecA, vecB, vecC, div, base_op, dyn_shared in selections[prefix][op]: 233 | for vec in [False, True]: 234 | for major, minor in [(5, 0), (6, 0)]: 235 | if base_op: 236 | # The op is part of the base kernel name 237 | base = "%sgemm_%dx%dx%d_%s" % (prefix, tileM, tileN, tileK, op) 238 | opts = ( "vec", ) if vec else () 239 | else: 240 | # The op is an option passed to a more generic kernel 241 | base = "%sgemm_%dx%dx%d" % (prefix, tileM, tileN, tileK) 242 | opts = ( op, "vec" ) if vec else (op,) 243 | 244 | get_kernel(base, major, minor, opts) 245 | 246 | 247 | def main(): 248 | gen_kernels() 249 | 250 | 251 | if __name__ == "__main__": 252 | main() 253 | -------------------------------------------------------------------------------- /include/c_interface.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #else 7 | #include 8 | #endif 9 | 10 | #include 11 | 12 | typedef struct CUstream_st *CUstream; 13 | 14 | bool openai_sgemm(float *A, float *B, float *C, 15 | bool a_t, bool b_t, 16 | int m, int n, int k, 17 | int lda, int ldb, int ldc, 18 | float alpha, float beta, 19 | CUstream stream, unsigned int grid, unsigned int shared); 20 | 21 | bool openai_hgemm(uint16_t *A, uint16_t *B, uint16_t *C, 22 | bool a_t, bool b_t, 23 | int m, int n, int k, 24 | int lda, int ldb, int ldc, 25 | float alpha, float beta, 26 | CUstream stream, unsigned int grid, unsigned int shared); 27 | 28 | bool get_grid_limits(char precision, bool a_t, bool b_t, unsigned int *grid); 29 | bool get_shared_limits(char precision, bool a_t, bool b_t, unsigned int grid, unsigned int *shared); 30 | 31 | #ifdef __cplusplus 32 | } 33 | #endif 34 | -------------------------------------------------------------------------------- /include/kernel_headers.h: -------------------------------------------------------------------------------- 1 | #include "include/kernels/hgemm_16x64x64_NN_sm_50.h" 2 | #include "include/kernels/hgemm_16x64x64_NN_vec_sm_50.h" 3 | #include "include/kernels/hgemm_16x64x64_NT_sm_50.h" 4 | #include "include/kernels/hgemm_16x64x64_NT_vec_sm_50.h" 5 | #include "include/kernels/hgemm_32x32x64_NT_sm_50.h" 6 | #include "include/kernels/hgemm_32x32x64_NT_vec_sm_50.h" 7 | #include "include/kernels/hgemm_32x64x32_NN_sm_50.h" 8 | #include "include/kernels/hgemm_32x64x32_NN_vec_sm_50.h" 9 | #include "include/kernels/hgemm_128x128x8_NN_sm_50.h" 10 | #include "include/kernels/hgemm_128x128x8_NT_sm_50.h" 11 | #include "include/kernels/hgemm_128x128x8_TN_sm_50.h" 12 | #include "include/kernels/hgemm_128x128x8_TT_sm_50.h" 13 | #include "include/kernels/hgemm_128x128x8_NN_vec_sm_50.h" 14 | #include "include/kernels/hgemm_128x128x8_NT_vec_sm_50.h" 15 | #include "include/kernels/hgemm_128x128x8_TN_vec_sm_50.h" 16 | #include "include/kernels/hgemm_128x128x8_TT_vec_sm_50.h" 17 | #include "include/kernels/hgemm_32x32x32_NN_sm_50.h" 18 | #include "include/kernels/hgemm_32x32x32_NT_sm_50.h" 19 | #include "include/kernels/hgemm_32x32x32_TN_sm_50.h" 20 | #include "include/kernels/hgemm_32x32x32_TT_sm_50.h" 21 | #include "include/kernels/hgemm_32x32x32_NN_vec_sm_50.h" 22 | #include "include/kernels/hgemm_32x32x32_NT_vec_sm_50.h" 23 | #include "include/kernels/hgemm_32x32x32_TN_vec_sm_50.h" 24 | #include "include/kernels/hgemm_32x32x32_TT_vec_sm_50.h" 25 | #include "include/kernels/sgemm_128x128x8_NN_sm_50.h" 26 | #include "include/kernels/sgemm_128x128x8_NT_sm_50.h" 27 | #include "include/kernels/sgemm_128x128x8_TN_sm_50.h" 28 | #include "include/kernels/sgemm_128x128x8_TT_sm_50.h" 29 | #include "include/kernels/sgemm_128x128x8_NN_vec_sm_50.h" 30 | #include "include/kernels/sgemm_128x128x8_NT_vec_sm_50.h" 31 | #include "include/kernels/sgemm_128x128x8_TN_vec_sm_50.h" 32 | #include "include/kernels/sgemm_128x128x8_TT_vec_sm_50.h" 33 | #include "include/kernels/sgemm_32x32x32_NN_sm_50.h" 34 | #include "include/kernels/sgemm_32x32x32_NT_sm_50.h" 35 | #include "include/kernels/sgemm_32x32x32_TN_sm_50.h" 36 | #include "include/kernels/sgemm_32x32x32_TT_sm_50.h" 37 | #include "include/kernels/sgemm_32x32x32_NN_vec_sm_50.h" 38 | #include "include/kernels/sgemm_32x32x32_NT_vec_sm_50.h" 39 | #include "include/kernels/sgemm_32x32x32_TN_vec_sm_50.h" 40 | #include "include/kernels/sgemm_32x32x32_TT_vec_sm_50.h" 41 | 42 | #include "include/kernels/hgemm_16x64x64_NN_sm_60.h" 43 | #include "include/kernels/hgemm_16x64x64_NN_vec_sm_60.h" 44 | #include "include/kernels/hgemm_16x64x64_NT_sm_60.h" 45 | #include "include/kernels/hgemm_16x64x64_NT_vec_sm_60.h" 46 | #include "include/kernels/hgemm_32x32x64_NT_sm_60.h" 47 | #include "include/kernels/hgemm_32x32x64_NT_vec_sm_60.h" 48 | #include "include/kernels/hgemm_32x64x32_NN_sm_60.h" 49 | #include "include/kernels/hgemm_32x64x32_NN_vec_sm_60.h" 50 | #include "include/kernels/hgemm_128x128x8_NN_sm_60.h" 51 | #include "include/kernels/hgemm_128x128x8_NT_sm_60.h" 52 | #include "include/kernels/hgemm_128x128x8_TN_sm_60.h" 53 | #include "include/kernels/hgemm_128x128x8_TT_sm_60.h" 54 | #include "include/kernels/hgemm_128x128x8_NN_vec_sm_60.h" 55 | #include "include/kernels/hgemm_128x128x8_NT_vec_sm_60.h" 56 | #include "include/kernels/hgemm_128x128x8_TN_vec_sm_60.h" 57 | #include "include/kernels/hgemm_128x128x8_TT_vec_sm_60.h" 58 | #include "include/kernels/hgemm_32x32x32_NN_sm_60.h" 59 | #include "include/kernels/hgemm_32x32x32_NT_sm_60.h" 60 | #include "include/kernels/hgemm_32x32x32_TN_sm_60.h" 61 | #include "include/kernels/hgemm_32x32x32_TT_sm_60.h" 62 | #include "include/kernels/hgemm_32x32x32_NN_vec_sm_60.h" 63 | #include "include/kernels/hgemm_32x32x32_NT_vec_sm_60.h" 64 | #include "include/kernels/hgemm_32x32x32_TN_vec_sm_60.h" 65 | #include "include/kernels/hgemm_32x32x32_TT_vec_sm_60.h" 66 | #include "include/kernels/sgemm_128x128x8_NN_sm_60.h" 67 | #include "include/kernels/sgemm_128x128x8_NT_sm_60.h" 68 | #include "include/kernels/sgemm_128x128x8_TN_sm_60.h" 69 | #include "include/kernels/sgemm_128x128x8_TT_sm_60.h" 70 | #include "include/kernels/sgemm_128x128x8_NN_vec_sm_60.h" 71 | #include "include/kernels/sgemm_128x128x8_NT_vec_sm_60.h" 72 | #include "include/kernels/sgemm_128x128x8_TN_vec_sm_60.h" 73 | #include "include/kernels/sgemm_128x128x8_TT_vec_sm_60.h" 74 | #include "include/kernels/sgemm_32x32x32_NN_sm_60.h" 75 | #include "include/kernels/sgemm_32x32x32_NT_sm_60.h" 76 | #include "include/kernels/sgemm_32x32x32_TN_sm_60.h" 77 | #include "include/kernels/sgemm_32x32x32_TT_sm_60.h" 78 | #include "include/kernels/sgemm_32x32x32_NN_vec_sm_60.h" 79 | #include "include/kernels/sgemm_32x32x32_NT_vec_sm_60.h" 80 | #include "include/kernels/sgemm_32x32x32_TN_vec_sm_60.h" 81 | #include "include/kernels/sgemm_32x32x32_TT_vec_sm_60.h" 82 | -------------------------------------------------------------------------------- /include/static_kernel_information.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | struct kernel_properties { 4 | int tile_m; 5 | int tile_n; 6 | int tile_k; 7 | int vA; 8 | int vB; 9 | int vC; 10 | int div; 11 | bool op; 12 | int threads; 13 | std::vector shared_sizes; 14 | std::string tile_string; 15 | }; 16 | 17 | kernel_properties kernel_properties_[6] = { 18 | {128, 128, 8, 4, 4, 1, 2, false, 256, {0}, "128x128x8"}, 19 | {32, 32, 32, 4, 4, 1, 4, false, 128, {0, 1 << 14}, "32x32x32"}, 20 | {32, 64, 32, 8, 4, 4, 4, true, 128, {0, 1 << 13}, "32x64x32"}, 21 | {32, 32, 64, 8, 8, 4, 4, true, 128, {0}, "32x32x64"}, 22 | {16, 64, 64, 8, 4, 4, 4, true, 128, {0}, "16x64x64"}, 23 | {16, 64, 64, 8, 8, 4, 4, true, 128, {0}, "16x64x64"}, 24 | }; 25 | 26 | std::unordered_map>> 27 | selections = { 28 | {"s", { 29 | {"TN", {kernel_properties_[0], kernel_properties_[1]}}, 30 | {"NN", {kernel_properties_[0], kernel_properties_[1]}}, 31 | {"NT", {kernel_properties_[0], kernel_properties_[1]}}, 32 | {"TT", {kernel_properties_[0], kernel_properties_[1]}} 33 | } 34 | }, 35 | {"h", { 36 | {"TN", {kernel_properties_[0], kernel_properties_[1]}}, 37 | {"NN", {kernel_properties_[0], kernel_properties_[1], kernel_properties_[2], kernel_properties_[4]}}, 38 | {"NT", {kernel_properties_[0], kernel_properties_[1], kernel_properties_[3], kernel_properties_[5]}}, 39 | {"TT", {kernel_properties_[0], kernel_properties_[1]}} 40 | } 41 | } 42 | }; 43 | 44 | std::unordered_map kernels_60 = { 45 | {"hgemm_16x64x64_NN_sm_60", hgemm_16x64x64_NN_sm_60}, 46 | {"hgemm_16x64x64_NN_vec_sm_60", hgemm_16x64x64_NN_vec_sm_60}, 47 | {"hgemm_16x64x64_NT_sm_60", hgemm_16x64x64_NT_sm_60}, 48 | {"hgemm_16x64x64_NT_vec_sm_60", hgemm_16x64x64_NT_vec_sm_60}, 49 | {"hgemm_32x32x64_NT_sm_60", hgemm_32x32x64_NT_sm_60}, 50 | {"hgemm_32x32x64_NT_vec_sm_60", hgemm_32x32x64_NT_vec_sm_60}, 51 | {"hgemm_32x64x32_NN_sm_60", hgemm_32x64x32_NN_sm_60}, 52 | {"hgemm_32x64x32_NN_vec_sm_60", hgemm_32x64x32_NN_vec_sm_60}, 53 | {"hgemm_128x128x8_NN_sm_60", hgemm_128x128x8_NN_sm_60}, 54 | {"hgemm_128x128x8_TN_sm_60", hgemm_128x128x8_TN_sm_60}, 55 | {"hgemm_128x128x8_NT_sm_60", hgemm_128x128x8_NT_sm_60}, 56 | {"hgemm_128x128x8_TT_sm_60", hgemm_128x128x8_TT_sm_60}, 57 | {"hgemm_128x128x8_NN_vec_sm_60", hgemm_128x128x8_NN_vec_sm_60}, 58 | {"hgemm_128x128x8_TN_vec_sm_60", hgemm_128x128x8_TN_vec_sm_60}, 59 | {"hgemm_128x128x8_NT_vec_sm_60", hgemm_128x128x8_NT_vec_sm_60}, 60 | {"hgemm_128x128x8_TT_vec_sm_60", hgemm_128x128x8_TT_vec_sm_60}, 61 | {"hgemm_32x32x32_NN_sm_60", hgemm_32x32x32_NN_sm_60}, 62 | {"hgemm_32x32x32_TN_sm_60", hgemm_32x32x32_TN_sm_60}, 63 | {"hgemm_32x32x32_NT_sm_60", hgemm_32x32x32_NT_sm_60}, 64 | {"hgemm_32x32x32_TT_sm_60", hgemm_32x32x32_TT_sm_60}, 65 | {"hgemm_32x32x32_NN_vec_sm_60", hgemm_32x32x32_NN_vec_sm_60}, 66 | {"hgemm_32x32x32_TN_vec_sm_60", hgemm_32x32x32_TN_vec_sm_60}, 67 | {"hgemm_32x32x32_NT_vec_sm_60", hgemm_32x32x32_NT_vec_sm_60}, 68 | {"hgemm_32x32x32_TT_vec_sm_60", hgemm_32x32x32_TT_vec_sm_60}, 69 | {"sgemm_128x128x8_NN_sm_60", sgemm_128x128x8_NN_sm_60}, 70 | {"sgemm_128x128x8_TN_sm_60", sgemm_128x128x8_TN_sm_60}, 71 | {"sgemm_128x128x8_NT_sm_60", sgemm_128x128x8_NT_sm_60}, 72 | {"sgemm_128x128x8_TT_sm_60", sgemm_128x128x8_TT_sm_60}, 73 | {"sgemm_128x128x8_NN_vec_sm_60", sgemm_128x128x8_NN_vec_sm_60}, 74 | {"sgemm_128x128x8_TN_vec_sm_60", sgemm_128x128x8_TN_vec_sm_60}, 75 | {"sgemm_128x128x8_NT_vec_sm_60", sgemm_128x128x8_NT_vec_sm_60}, 76 | {"sgemm_128x128x8_TT_vec_sm_60", sgemm_128x128x8_TT_vec_sm_60}, 77 | {"sgemm_32x32x32_NN_sm_60", sgemm_32x32x32_NN_sm_60}, 78 | {"sgemm_32x32x32_TN_sm_60", sgemm_32x32x32_TN_sm_60}, 79 | {"sgemm_32x32x32_NT_sm_60", sgemm_32x32x32_NT_sm_60}, 80 | {"sgemm_32x32x32_TT_sm_60", sgemm_32x32x32_TT_sm_60}, 81 | {"sgemm_32x32x32_NN_vec_sm_60", sgemm_32x32x32_NN_vec_sm_60}, 82 | {"sgemm_32x32x32_TN_vec_sm_60", sgemm_32x32x32_TN_vec_sm_60}, 83 | {"sgemm_32x32x32_NT_vec_sm_60", sgemm_32x32x32_NT_vec_sm_60}, 84 | {"sgemm_32x32x32_TT_vec_sm_60", sgemm_32x32x32_TT_vec_sm_60}, 85 | }; 86 | 87 | std::unordered_map kernels_50 = { 88 | {"hgemm_16x64x64_NN_sm_50", hgemm_16x64x64_NN_sm_50}, 89 | {"hgemm_16x64x64_NN_vec_sm_50", hgemm_16x64x64_NN_vec_sm_50}, 90 | {"hgemm_16x64x64_NT_sm_50", hgemm_16x64x64_NT_sm_50}, 91 | {"hgemm_16x64x64_NT_vec_sm_50", hgemm_16x64x64_NT_vec_sm_50}, 92 | {"hgemm_32x32x64_NT_sm_50", hgemm_32x32x64_NT_sm_50}, 93 | {"hgemm_32x32x64_NT_vec_sm_50", hgemm_32x32x64_NT_vec_sm_50}, 94 | {"hgemm_32x64x32_NN_sm_50", hgemm_32x64x32_NN_sm_50}, 95 | {"hgemm_32x64x32_NN_vec_sm_50", hgemm_32x64x32_NN_vec_sm_50}, 96 | {"hgemm_128x128x8_NN_sm_50", hgemm_128x128x8_NN_sm_50}, 97 | {"hgemm_128x128x8_TN_sm_50", hgemm_128x128x8_TN_sm_50}, 98 | {"hgemm_128x128x8_NT_sm_50", hgemm_128x128x8_NT_sm_50}, 99 | {"hgemm_128x128x8_TT_sm_50", hgemm_128x128x8_TT_sm_50}, 100 | {"hgemm_128x128x8_NN_vec_sm_50", hgemm_128x128x8_NN_vec_sm_50}, 101 | {"hgemm_128x128x8_TN_vec_sm_50", hgemm_128x128x8_TN_vec_sm_50}, 102 | {"hgemm_128x128x8_NT_vec_sm_50", hgemm_128x128x8_NT_vec_sm_50}, 103 | {"hgemm_128x128x8_TT_vec_sm_50", hgemm_128x128x8_TT_vec_sm_50}, 104 | {"hgemm_32x32x32_NN_sm_50", hgemm_32x32x32_NN_sm_50}, 105 | {"hgemm_32x32x32_TN_sm_50", hgemm_32x32x32_TN_sm_50}, 106 | {"hgemm_32x32x32_NT_sm_50", hgemm_32x32x32_NT_sm_50}, 107 | {"hgemm_32x32x32_TT_sm_50", hgemm_32x32x32_TT_sm_50}, 108 | {"hgemm_32x32x32_NN_vec_sm_50", hgemm_32x32x32_NN_vec_sm_50}, 109 | {"hgemm_32x32x32_TN_vec_sm_50", hgemm_32x32x32_TN_vec_sm_50}, 110 | {"hgemm_32x32x32_NT_vec_sm_50", hgemm_32x32x32_NT_vec_sm_50}, 111 | {"hgemm_32x32x32_TT_vec_sm_50", hgemm_32x32x32_TT_vec_sm_50}, 112 | {"sgemm_128x128x8_NN_sm_50", sgemm_128x128x8_NN_sm_50}, 113 | {"sgemm_128x128x8_TN_sm_50", sgemm_128x128x8_TN_sm_50}, 114 | {"sgemm_128x128x8_NT_sm_50", sgemm_128x128x8_NT_sm_50}, 115 | {"sgemm_128x128x8_TT_sm_50", sgemm_128x128x8_TT_sm_50}, 116 | {"sgemm_128x128x8_NN_vec_sm_50", sgemm_128x128x8_NN_vec_sm_50}, 117 | {"sgemm_128x128x8_TN_vec_sm_50", sgemm_128x128x8_TN_vec_sm_50}, 118 | {"sgemm_128x128x8_NT_vec_sm_50", sgemm_128x128x8_NT_vec_sm_50}, 119 | {"sgemm_128x128x8_TT_vec_sm_50", sgemm_128x128x8_TT_vec_sm_50}, 120 | {"sgemm_32x32x32_NN_sm_50", sgemm_32x32x32_NN_sm_50}, 121 | {"sgemm_32x32x32_TN_sm_50", sgemm_32x32x32_TN_sm_50}, 122 | {"sgemm_32x32x32_NT_sm_50", sgemm_32x32x32_NT_sm_50}, 123 | {"sgemm_32x32x32_TT_sm_50", sgemm_32x32x32_TT_sm_50}, 124 | {"sgemm_32x32x32_NN_vec_sm_50", sgemm_32x32x32_NN_vec_sm_50}, 125 | {"sgemm_32x32x32_TN_vec_sm_50", sgemm_32x32x32_TN_vec_sm_50}, 126 | {"sgemm_32x32x32_NT_vec_sm_50", sgemm_32x32x32_NT_vec_sm_50}, 127 | {"sgemm_32x32x32_TT_vec_sm_50", sgemm_32x32x32_TT_vec_sm_50}, 128 | }; 129 | -------------------------------------------------------------------------------- /lib/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /maxas/MaxAs/Cubin.pm: -------------------------------------------------------------------------------- 1 | package MaxAs::Cubin; 2 | 3 | use strict; 4 | use Data::Dumper; 5 | 6 | my @Elf32_Hdr = qw( 7 | H8 magic 8 | C fileClass 9 | C encoding 10 | C fileVersion 11 | H18 padding 12 | S type 13 | S machine 14 | L version 15 | L entry 16 | L phOffset 17 | L shOffset 18 | L flags 19 | S ehSize 20 | S phEntSize 21 | S phNum 22 | S shEntSize 23 | S shNum 24 | S shStrIndx 25 | ); 26 | my @Elf64_Hdr = qw( 27 | H8 magic 28 | C fileClass 29 | C encoding 30 | C fileVersion 31 | H18 padding 32 | S type 33 | S machine 34 | L version 35 | Q entry 36 | Q phOffset 37 | Q shOffset 38 | L flags 39 | S ehSize 40 | S phEntSize 41 | S phNum 42 | S shEntSize 43 | S shNum 44 | S shStrIndx 45 | ); 46 | my @Elf32_PrgHdr = qw( 47 | L type 48 | L offset 49 | L vaddr 50 | L paddr 51 | L fileSize 52 | L memSize 53 | L flags 54 | L align 55 | ); 56 | my @Elf64_PrgHdr = qw( 57 | L type 58 | L flags 59 | Q offset 60 | Q vaddr 61 | Q paddr 62 | Q fileSize 63 | Q memSize 64 | Q align 65 | ); 66 | my @Elf32_SecHdr = qw( 67 | L name 68 | L type 69 | L flags 70 | L addr 71 | L offset 72 | L size 73 | L link 74 | L info 75 | L align 76 | L entSize 77 | ); 78 | my @Elf64_SecHdr = qw( 79 | L name 80 | L type 81 | Q flags 82 | Q addr 83 | Q offset 84 | Q size 85 | L link 86 | L info 87 | Q align 88 | Q entSize 89 | ); 90 | my @Elf32_SymEnt = qw( 91 | L name 92 | L value 93 | L size 94 | C info 95 | C other 96 | S shIndx 97 | ); 98 | my @Elf64_SymEnt = qw( 99 | L name 100 | C info 101 | C other 102 | S shIndx 103 | Q value 104 | Q size 105 | ); 106 | my @symBind = qw(LOCAL GLOBAL WEAK); 107 | 108 | # Split the Elf Header defs into template strings (T) and corresponding hash keys columns (C) 109 | my (@elfHdrT, @prgHdrT, @secHdrT, @symHdrT, @elfHdrC, @prgHdrC, @secHdrC, @symHdrC); 110 | 111 | $elfHdrT[1] = join '', grep { length($_) <= 3} @Elf32_Hdr; 112 | $prgHdrT[1] = join '', grep { length($_) <= 3} @Elf32_PrgHdr; 113 | $secHdrT[1] = join '', grep { length($_) <= 3} @Elf32_SecHdr; 114 | $symHdrT[1] = join '', grep { length($_) <= 3} @Elf32_SymEnt; 115 | 116 | $elfHdrT[2] = join '', grep { length($_) <= 3} @Elf64_Hdr; 117 | $prgHdrT[2] = join '', grep { length($_) <= 3} @Elf64_PrgHdr; 118 | $secHdrT[2] = join '', grep { length($_) <= 3} @Elf64_SecHdr; 119 | $symHdrT[2] = join '', grep { length($_) <= 3} @Elf64_SymEnt; 120 | 121 | $elfHdrC[1] = [ grep { length($_) > 3} @Elf32_Hdr ]; 122 | $prgHdrC[1] = [ grep { length($_) > 3} @Elf32_PrgHdr ]; 123 | $secHdrC[1] = [ grep { length($_) > 3} @Elf32_SecHdr ]; 124 | $symHdrC[1] = [ grep { length($_) > 3} @Elf32_SymEnt ]; 125 | 126 | $elfHdrC[2] = [ grep { length($_) > 3} @Elf64_Hdr ]; 127 | $prgHdrC[2] = [ grep { length($_) > 3} @Elf64_PrgHdr ]; 128 | $secHdrC[2] = [ grep { length($_) > 3} @Elf64_SecHdr ]; 129 | $symHdrC[2] = [ grep { length($_) > 3} @Elf64_SymEnt ]; 130 | 131 | # Load a cubin ELF file 132 | sub new 133 | { 134 | my ($package, $file) = @_; 135 | 136 | my $cubin = bless { fileName => $file }, $package; 137 | 138 | open my $fh, $file or die "$file: $!"; 139 | binmode($fh); 140 | 141 | # Read in assuming 32 bit header 142 | my $data; 143 | read $fh, $data, 0x34; 144 | my $elfHdr = $cubin->{elfHdr} = {}; 145 | @{$elfHdr}{@{$elfHdrC[1]}} = unpack $elfHdrT[1], $data; 146 | 147 | # 1: 32bit, 2: 64bit 148 | my $class = $elfHdr->{fileClass}; 149 | 150 | # re-read in with 64 bit header if needed 151 | if ($class == 2) 152 | { 153 | seek $fh, 0, 0; 154 | read $fh, $data, 0x46; 155 | @{$elfHdr}{@{$elfHdrC[$class]}} = unpack $elfHdrT[$class], $data; 156 | 157 | $cubin->{Class} = 64; 158 | } 159 | else 160 | { 161 | $cubin->{Class} = 32; 162 | } 163 | 164 | # verify sm_50 cubin 165 | $cubin->{Arch} = $elfHdr->{flags} & 0xFF; 166 | die "Cubin not in sm_50 or greater format. Found: sm_$cubin->{Arch}\n" if $cubin->{Arch} < 50; 167 | 168 | $cubin->{AddressSize} = $elfHdr->{flags} & 0x400 ? 64 : 32; 169 | 170 | # Read in Program Headers 171 | seek $fh, $elfHdr->{phOffset}, 0; 172 | foreach (1 .. $elfHdr->{phNum}) 173 | { 174 | read $fh, $data, $elfHdr->{phEntSize}; 175 | 176 | my %prgHdr = (Indx => $_ - 1); 177 | @prgHdr{@{$prgHdrC[$class]}} = unpack $prgHdrT[$class], $data; 178 | push @{$cubin->{prgHdrs}}, \%prgHdr; 179 | } 180 | 181 | # Read in Section Headers 182 | seek $fh, $elfHdr->{shOffset}, 0; 183 | foreach (1 .. $elfHdr->{shNum}) 184 | { 185 | read $fh, $data, $elfHdr->{shEntSize}; 186 | 187 | my %secHdr = (Indx => $_ - 1); 188 | @secHdr{@{$secHdrC[$class]}} = unpack $secHdrT[$class], $data; 189 | push @{$cubin->{secHdrs}}, \%secHdr; 190 | } 191 | 192 | # Read in Section data 193 | foreach my $secHdr (@{$cubin->{secHdrs}}) 194 | { 195 | $data = ''; 196 | # Skip sections with no data (type NULL or NOBITS) 197 | if ($secHdr->{size} && $secHdr->{type} != 8) 198 | { 199 | seek $fh, $secHdr->{offset}, 0; 200 | read $fh, $data, $secHdr->{size}; 201 | } 202 | # Convert string tables to maps 203 | if ($secHdr->{type} == 3) # STRTAB 204 | { 205 | my $strTab = $secHdr->{StrTab} = {}; 206 | my $indx = 0; 207 | foreach my $str (split "\0", $data) 208 | { 209 | $strTab->{$indx} = $str; 210 | $indx += 1 + length($str); 211 | } 212 | } 213 | # Read in Symbol data 214 | if ($secHdr->{type} == 2) # SYMTAB 215 | { 216 | my $offset = 0; 217 | while ($offset < $secHdr->{size}) 218 | { 219 | my $symEnt = {}; 220 | @{$symEnt}{@{$symHdrC[$class]}} = unpack $symHdrT[$class], substr($data, $offset, $secHdr->{entSize}); 221 | $offset += $secHdr->{entSize}; 222 | 223 | push @{$secHdr->{SymTab}}, $symEnt; 224 | } 225 | } 226 | # Cache raw data for further processing and writing 227 | $secHdr->{Data} = unpack 'H*', $data; 228 | } 229 | close $fh; 230 | 231 | # Update section headers with their names. Map names directly to headers. 232 | my $shStrTab = $cubin->{secHdrs}[$elfHdr->{shStrIndx}]{StrTab}; 233 | foreach my $secHdr (@{$cubin->{secHdrs}}) 234 | { 235 | $secHdr->{Name} = $shStrTab->{$secHdr->{name}}; 236 | $cubin->{$secHdr->{Name}} = $secHdr; 237 | } 238 | 239 | # Update symbols with their names 240 | # For the Global functions, extract kernel meta data 241 | # Populate the kernel hash 242 | my $strTab = $cubin->{'.strtab'}{StrTab}; 243 | foreach my $symEnt (@{$cubin->{'.symtab'}{SymTab}}) 244 | { 245 | $symEnt->{Name} = $strTab->{$symEnt->{name}}; 246 | 247 | # Attach symbol to section 248 | my $secHdr = $cubin->{secHdrs}[$symEnt->{shIndx}]; 249 | $secHdr->{SymbolEnt} = $symEnt; 250 | 251 | # Look for symbols tagged FUNC 252 | if (($symEnt->{info} & 0x0f) == 0x02) 253 | { 254 | # Create a hash of kernels for output 255 | my $kernelSec = $cubin->{Kernels}{$symEnt->{Name}} = $secHdr; 256 | 257 | # Extract local/global/weak binding info 258 | $kernelSec->{Linkage} = $symBind[($symEnt->{info} & 0xf0) >> 4]; 259 | 260 | # Extract the kernel instructions 261 | $kernelSec->{KernelData} = [ unpack "Q*", pack "H*", $kernelSec->{Data} ]; 262 | 263 | # Extract the max barrier resource identifier used and add 1. Should be 0-16. 264 | # If a register is used as a barrier resource id, then this value is the max of 16. 265 | $kernelSec->{BarCnt} = ($kernelSec->{flags} & 0x01f00000) >> 20; 266 | 267 | # Extract the number of allocated registers for this kernel. 268 | $kernelSec->{RegCnt} = ($kernelSec->{info} & 0xff000000) >> 24; 269 | 270 | # Extract the size of shared memory this kernel uses. 271 | my $sharedSec = $kernelSec->{SharedSec} = $cubin->{".nv.shared.$symEnt->{Name}"}; 272 | $kernelSec->{SharedSize} = $sharedSec ? $sharedSec->{size} : 0; 273 | 274 | # Attach constant0 section 275 | $kernelSec->{ConstantSec} = $cubin->{".nv.constant0.$symEnt->{Name}"}; 276 | 277 | # Extract the kernel parameter data. 278 | my $paramSec = $kernelSec->{ParamSec} = $cubin->{".nv.info.$symEnt->{Name}"}; 279 | if ($paramSec) 280 | { 281 | # Extract raw param data 282 | my @data = unpack "L*", pack "H*", $paramSec->{Data}; 283 | 284 | $paramSec->{ParamData} = \@data; 285 | $paramSec->{ParamHex} = [ map { sprintf '0x%08x', $_ } @data ]; 286 | 287 | # Find the first param delimiter 288 | my $idx = 0; 289 | $idx++ while $idx < @data && $data[$idx] != 0x00080a04; 290 | 291 | my $first = $data[$idx+2] & 0xFFFF; 292 | #my $size = $data[$idx+2] >> 16; 293 | $idx += 4; 294 | 295 | my @params; 296 | while ($idx < @data && $data[$idx] == 0x000c1704) 297 | { 298 | # Get the ordinal, offset, size and pointer alignment for each param 299 | my $ord = $data[$idx+2] & 0xFFFF; 300 | my $offset = sprintf '0x%02x', $first + ($data[$idx+2] >> 16); 301 | my $psize = $data[$idx+3] >> 18; 302 | my $align = $data[$idx+3] & 0x400 ? 1 << ($data[$idx+3] & 0x3ff) : 0; 303 | unshift @params, "$ord:$offset:$psize:$align"; 304 | $idx += 4; 305 | } 306 | my @staticParams = @data[0 .. ($idx-1)]; 307 | 308 | my ($maxregCount, @exitOffsets, @ctaidOffsets, $ctaidzUsed, @reqntid, @maxntid, @stackSize); 309 | while ($idx < @data) 310 | { 311 | my $code = $data[$idx] & 0xffff; 312 | my $size = $data[$idx] >> 16; 313 | $idx++; 314 | 315 | # EIATTR_MAXREG_COUNT 316 | if ($code == 0x1b03) 317 | { 318 | $maxregCount = $size; 319 | } 320 | # EIATTR_S2RCTAID_INSTR_OFFSETS 321 | elsif ($code == 0x1d04) 322 | { 323 | while ($size > 0) 324 | { 325 | push @ctaidOffsets, $data[$idx++]; 326 | $size -= 4; 327 | } 328 | } 329 | # EIATTR_EXIT_INSTR_OFFSETS 330 | elsif ($code == 0x1c04) 331 | { 332 | while ($size > 0) 333 | { 334 | push @exitOffsets, $data[$idx++]; 335 | $size -= 4; 336 | } 337 | } 338 | # EIATTR_CTAIDZ_USED 339 | elsif ($code == 0x0401) 340 | { 341 | $ctaidzUsed = 1; 342 | } 343 | # EIATTR_REQNTID 344 | elsif ($code == 0x1004) 345 | { 346 | while ($size > 0) 347 | { 348 | push @reqntid, $data[$idx++]; 349 | $size -= 4; 350 | } 351 | } 352 | # EIATTR_MAX_THREADS 353 | elsif ($code == 0x0504) 354 | { 355 | while ($size > 0) 356 | { 357 | push @maxntid, $data[$idx++]; 358 | $size -= 4; 359 | } 360 | } 361 | # EIATTR_CRS_STACK_SIZE 362 | elsif ($code == 0x1e04) 363 | { 364 | while ($size > 0) 365 | { 366 | push @stackSize, $data[$idx++]; 367 | $size -= 4; 368 | } 369 | } 370 | else 371 | { 372 | printf STDERR "Unknown Code 0x%02x (size:%d)\n", $code, $size; 373 | } 374 | } 375 | $kernelSec->{Params} = \@params; 376 | $kernelSec->{ParamCnt} = scalar @params; 377 | 378 | $paramSec->{StaticParams} = \@staticParams; 379 | $paramSec->{MAXREG_COUNT} = $maxregCount; 380 | $paramSec->{ExitOffsets} = \@exitOffsets; 381 | $paramSec->{CTAIDOffsets} = \@ctaidOffsets; 382 | $paramSec->{CTAIDZUsed} = $ctaidzUsed; 383 | $paramSec->{REQNTID} = \@reqntid; 384 | $paramSec->{MAXNTID} = \@maxntid; 385 | $paramSec->{STACKSIZE} = \@stackSize; 386 | } 387 | # print Dumper($paramSec); 388 | # exit(); 389 | } 390 | # Note GLOBALs found in this cubin 391 | elsif (($symEnt->{info} & 0x10) == 0x10) 392 | { 393 | $cubin->{Symbols}{$symEnt->{Name}} = $symEnt; 394 | } 395 | } 396 | 397 | # print "phOffset: $elfHdr->{phOffset}\n"; 398 | # print "shOffset: $elfHdr->{shOffset}\n"; 399 | # foreach my $secHdr (@{$cubin->{secHdrs}}) 400 | # { 401 | # print "secHdr($secHdr->{Indx}): $secHdr->{offset}, $secHdr->{size}, $secHdr->{align} ($secHdr->{Name})\n"; 402 | # } 403 | # my $p = 0; 404 | # foreach my $prgHdr (@{$cubin->{prgHdrs}}) 405 | # { 406 | # print "prgHdr($p): type: $prgHdr->{type}, offset: $prgHdr->{offset}, fileSize: $prgHdr->{fileSize}, memSize: $prgHdr->{memSize}, align: $prgHdr->{align}\n"; 407 | # $p++; 408 | # } 409 | # exit(); 410 | 411 | #print map { sprintf "%016x\n", $_ } @{$cubin->{Kernels}{microbench}{KernelData}}; 412 | 413 | #print Dumper($cubin->{Kernels}{test}{KernelData}); 414 | #exit(); 415 | return $cubin; 416 | } 417 | sub class 418 | { 419 | return shift()->{Class}; 420 | } 421 | sub arch 422 | { 423 | return shift()->{Arch}; 424 | } 425 | sub address_size 426 | { 427 | return shift()->{AddressSize}; 428 | } 429 | sub listKernels 430 | { 431 | return shift()->{Kernels}; 432 | } 433 | sub listSymbols 434 | { 435 | return shift()->{Symbols}; 436 | } 437 | sub getKernel 438 | { 439 | my ($cubin, $kernel) = @_; 440 | return $cubin->{Kernels}{$kernel}; 441 | } 442 | 443 | sub modifyKernel 444 | { 445 | my ($cubin, %params) = @_; 446 | 447 | my $kernelSec = $params{Kernel}; 448 | my $newReg = $params{RegCnt}; 449 | my $newBar = $params{BarCnt}; 450 | my $exitOffsets = $params{ExitOffsets}; 451 | my $ctaidOffsets = $params{CTAIDOffsets}; 452 | my $ctaidzUsed = $params{CTAIDZUsed}; 453 | my $newData = $params{KernelData}; 454 | my $newSize = @$newData * 8; 455 | 456 | die "255 register max" if $newReg > 255; 457 | die "new kernel size must be multiple of 8 instructions (64 bytes)" if $newSize & 63; 458 | die "16 is max barrier count" if $newBar > 16; 459 | 460 | my $paramSec = $kernelSec->{ParamSec}; 461 | my $kernelName = $kernelSec->{SymbolEnt}{Name}; 462 | my $maxregCount = $paramSec->{MAXREG_COUNT}; 463 | my $stackSize = $paramSec->{STACKSIZE}; 464 | 465 | # update the kernel 466 | $kernelSec->{KernelData} = $newData; 467 | $kernelSec->{Data} = unpack "H*", pack "Q*", @$newData; 468 | 469 | if ($newReg != $kernelSec->{RegCnt}) 470 | { 471 | print "Modified $kernelName RegCnt: $kernelSec->{RegCnt} => $newReg\n"; 472 | $kernelSec->{RegCnt} = $newReg; 473 | $kernelSec->{info} &= ~0xff000000; 474 | $kernelSec->{info} |= $newReg << 24; 475 | } 476 | if ($newBar != $kernelSec->{BarCnt}) 477 | { 478 | print "Modified $kernelName BarCnt: $kernelSec->{BarCnt} => $newBar\n"; 479 | $kernelSec->{BarCnt} = $newBar; 480 | $kernelSec->{flags} &= ~0x01f00000; 481 | $kernelSec->{flags} |= $newBar << 20; 482 | } 483 | 484 | my @paramData = @{$paramSec->{StaticParams}}; 485 | 486 | if (defined $maxregCount) 487 | { 488 | push @paramData, ($maxregCount << 16) | 0x1b03; 489 | } 490 | 491 | my $newCTAIDs = join ',', map { sprintf '%04x', $_ } @$ctaidOffsets; 492 | my $oldCTAIDs = join ',', map { sprintf '%04x', $_ } @{$paramSec->{CTAIDOffsets}}; 493 | 494 | if ($newCTAIDs ne $oldCTAIDs) 495 | { 496 | print "Modified $kernelName CTAID Offsets: '$oldCTAIDs' => '$newCTAIDs'\n"; 497 | } 498 | if (@$ctaidOffsets) 499 | { 500 | push @paramData, (scalar(@$ctaidOffsets) << 18) | 0x1d04; 501 | push @paramData, @$ctaidOffsets; 502 | } 503 | 504 | my $newExits = join ',', map { sprintf '%04x', $_ } @$exitOffsets; 505 | my $oldExits = join ',', map { sprintf '%04x', $_ } @{$paramSec->{ExitOffsets}}; 506 | 507 | if ($newExits ne $oldExits) 508 | { 509 | print "Modified $kernelName Exit Offsets: '$oldExits' => '$newExits'\n"; 510 | } 511 | if (@$exitOffsets) 512 | { 513 | push @paramData, (scalar(@$exitOffsets) << 18) | 0x1c04; 514 | push @paramData, @$exitOffsets; 515 | } 516 | 517 | if ($ctaidzUsed != $paramSec->{CTAIDZUsed}) 518 | { 519 | print "Modified $kernelName CTAID.Z Used: '$paramSec->{CTAIDZUsed}' => '$ctaidzUsed'\n"; 520 | } 521 | if ($ctaidzUsed) 522 | { 523 | push @paramData, 0x0401; 524 | } 525 | 526 | if (@{$paramSec->{REQNTID}}) 527 | { 528 | push @paramData, (scalar(@{$paramSec->{REQNTID}}) << 18) | 0x1004; 529 | push @paramData, @{$paramSec->{REQNTID}}; 530 | } 531 | if (@{$paramSec->{MAXNTID}}) 532 | { 533 | push @paramData, (scalar(@{$paramSec->{MAXNTID}}) << 18) | 0x0504; 534 | push @paramData, @{$paramSec->{MAXNTID}}; 535 | } 536 | 537 | if (@$stackSize) 538 | { 539 | push @paramData, (scalar(@$stackSize) << 18) | 0x1e04; 540 | push @paramData, @$stackSize; 541 | } 542 | 543 | my $newParamSize = scalar(@paramData)*4; 544 | $paramSec->{Data} = unpack "H*", pack "L*", @paramData; 545 | if ($newParamSize != $paramSec->{size}) 546 | { 547 | print "Modified $kernelName ParamSecSize: $paramSec->{size} => $newParamSize\n"; 548 | $cubin->updateSize($paramSec, $newParamSize); 549 | } 550 | 551 | if ($newSize != $kernelSec->{size}) 552 | { 553 | print "Modified $kernelName KernelSize: $kernelSec->{size} => $newSize\n"; 554 | $cubin->updateSize($kernelSec, $newSize, 1); 555 | } 556 | } 557 | 558 | sub updateSize 559 | { 560 | my ($cubin, $sec, $newSize, $updatePrgSize) = @_; 561 | 562 | my $elfHdr = $cubin->{elfHdr}; 563 | my $class = $elfHdr->{fileClass}; 564 | 565 | # update section header 566 | my $delta = $newSize - $sec->{size}; 567 | $sec->{size} = $newSize; 568 | 569 | # update symtab section 570 | if ($sec->{SymbolEnt}) 571 | { 572 | $sec->{SymbolEnt}{size} = $newSize; 573 | my $symSection = $cubin->{'.symtab'}; 574 | $symSection->{Data} = ''; 575 | foreach my $symEnt (@{$symSection->{SymTab}}) 576 | { 577 | $symSection->{Data} .= unpack "H*", pack $symHdrT[$class], @{$symEnt}{@{$symHdrC[$class]}}; 578 | } 579 | } 580 | 581 | my $pos = $elfHdr->{ehSize}; 582 | my %sizeMap; 583 | 584 | # update section header offsets 585 | foreach my $secHdr (@{$cubin->{secHdrs}}) 586 | { 587 | # skip first header 588 | next if $secHdr->{align} == 0; 589 | 590 | # NOBITS data sections are size 0 591 | my $size = $secHdr->{type} == 8 ? 0 : $secHdr->{size}; 592 | 593 | # Add any needed padding between sections 594 | my $pad = $pos % $secHdr->{align}; 595 | if ($pad > 0) 596 | { 597 | $pos += $secHdr->{align} - $pad; 598 | } 599 | # map old offset to new 600 | $sizeMap{$secHdr->{offset}} = $pos; 601 | 602 | # update offset 603 | $secHdr->{offset} = $pos; 604 | 605 | # advance position by size 606 | $pos += $size; 607 | } 608 | 609 | # compute total section header size 610 | my $shSize = $elfHdr->{phOffset} - $elfHdr->{shOffset}; 611 | 612 | # map old offset to new 613 | $sizeMap{$elfHdr->{shOffset}} = $pos; 614 | $sizeMap{$elfHdr->{phOffset}} = $pos + $shSize; 615 | 616 | $elfHdr->{shOffset} = $pos; 617 | $elfHdr->{phOffset} = $pos + $shSize; 618 | 619 | # update program header offsets and sizes 620 | foreach my $prgHdr (@{$cubin->{prgHdrs}}) 621 | { 622 | # Not sure how best to adjust these so just assume they'll track other offsets. 623 | $prgHdr->{offset} = $sizeMap{$prgHdr->{offset}}; 624 | 625 | # If the kernel sizes changes, also update the associated ProgramHeader. 626 | # Note that this size is the kernel size plus any constant section sizes. 627 | if ($updatePrgSize && $prgHdr->{type} == 1 && 628 | $sec->{offset} >= $prgHdr->{offset} && 629 | $sec->{offset} < $prgHdr->{offset} + $prgHdr->{fileSize} + $delta) 630 | { 631 | $prgHdr->{fileSize} += $delta; 632 | $prgHdr->{memSize} += $delta; 633 | } 634 | } 635 | } 636 | 637 | # Write out the cubin after modifying it. 638 | sub write 639 | { 640 | my ($cubin, $file) = @_; 641 | 642 | open my $fh, ">$file" or die "Error: could not open $file for writing: $!"; 643 | binmode($fh); 644 | 645 | my $elfHdr = $cubin->{elfHdr}; 646 | my $class = $elfHdr->{fileClass}; 647 | 648 | # write elf header 649 | print $fh pack $elfHdrT[$class], @{$elfHdr}{@{$elfHdrC[$class]}}; 650 | my $pos = $elfHdr->{ehSize}; 651 | 652 | # write section data 653 | foreach my $secHdr (@{$cubin->{secHdrs}}) 654 | { 655 | # Skip NULL and NOBITS data sections 656 | next if $secHdr->{size} == 0 || $secHdr->{type} == 8; 657 | 658 | # Add any needed padding between sections 659 | my $pad = $pos % $secHdr->{align}; 660 | if ($pad > 0) 661 | { 662 | $pad = $secHdr->{align} - $pad; 663 | print $fh join '', "\0" x $pad; 664 | $pos += $pad; 665 | } 666 | 667 | print $fh pack 'H*', $secHdr->{Data}; 668 | $pos += $secHdr->{size}; 669 | } 670 | 671 | # write section headers 672 | foreach my $secHdr (@{$cubin->{secHdrs}}) 673 | { 674 | print $fh pack $secHdrT[$class], @{$secHdr}{@{$secHdrC[$class]}}; 675 | } 676 | 677 | #write program headers 678 | foreach my $prgHdr (@{$cubin->{prgHdrs}}) 679 | { 680 | print $fh pack $prgHdrT[$class], @{$prgHdr}{@{$prgHdrC[$class]}}; 681 | } 682 | close $fh; 683 | } 684 | 685 | __END__ 686 | 687 | -------------------------------------------------------------------------------- /maxas/maxas.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl 2 | use strict; 3 | use MaxAs::Cubin; 4 | use MaxAs::MaxAs; 5 | use Data::Dumper; 6 | use File::Spec; 7 | 8 | require 5.10.0; 9 | 10 | $Data::Dumper::Sortkeys = 1; 11 | 12 | my $mode = shift; 13 | 14 | # List cubin contents 15 | if ($mode =~ /^\-?\-l/i) 16 | { 17 | my $cubinFile = shift or usage(); 18 | 19 | my $cubin = MaxAs::Cubin->new($cubinFile); 20 | 21 | my $arch = $cubin->arch; 22 | my $class = $cubin->class; 23 | my $asize = $cubin->address_size; 24 | my $kernels = $cubin->listKernels; 25 | my $symbols = $cubin->listSymbols; 26 | 27 | printf "%s: arch:sm_%d machine:%dbit address_size:%dbit\n", $cubinFile, $arch, $class, $asize; 28 | 29 | foreach my $ker (sort keys %$kernels) 30 | { 31 | printf "Kernel: %s (Linkage: %s, Params: %d, Size: %d, Registers: %d, SharedMem: %d, Barriers: %d)\n", $ker, @{$kernels->{$ker}}{qw(Linkage ParamCnt size RegCnt SharedSize BarCnt)}; 32 | } 33 | foreach my $sym (sort keys %$symbols) 34 | { 35 | printf "Symbol: %s\n", $sym; 36 | } 37 | } 38 | # Test that the assembler can reproduce the op codes this cubin or sass contains 39 | elsif ($mode =~ /^\-?\-t/i) 40 | { 41 | my $reg = shift if $ARGV[0] =~ /^\-?\-r/i; 42 | my $all = shift if $ARGV[0] =~ /^\-?\-a/i; 43 | my $file = shift or usage(); 44 | my $fh; 45 | # sass file 46 | if (-T $file) 47 | { 48 | open $fh, $file or die "$file: $!"; 49 | } 50 | # cubin file 51 | else 52 | { 53 | my $cubin = MaxAs::Cubin->new($file); 54 | my $arch = $cubin->arch; 55 | 56 | open $fh, "cuobjdump -arch sm_$arch -sass $file |" or die "cuobjdump -arch sm_$arch -sass $file: $!"; 57 | my $first = <$fh>; 58 | if ($first =~ /cuobjdump fatal/) 59 | { 60 | print $first; 61 | exit(1); 62 | } 63 | } 64 | exit(MaxAs::MaxAs::Test($fh, $reg, $all) ? 1 : 0); 65 | } 66 | # Extract an asm file containing the desired kernel 67 | elsif ($mode =~ /^\-?\-e/i) 68 | { 69 | my $kernelName; 70 | if ($ARGV[0] =~ /^\-?\-k/i) 71 | { 72 | shift; 73 | $kernelName = shift or usage(); 74 | } 75 | my $cubinFile = shift or usage(); 76 | my $asmFile = shift; 77 | my $cubin = MaxAs::Cubin->new($cubinFile); 78 | my $arch = $cubin->arch; 79 | my $kernels = $cubin->listKernels; 80 | 81 | #default the kernel name if not specified. 82 | $kernelName ||= (sort keys %$kernels)[0]; 83 | 84 | my $kernel = $kernels->{$kernelName} or die "bad kernel: $kernelName"; 85 | 86 | open my $in, "cuobjdump -arch sm_$arch -sass -fun $kernelName $cubinFile |" or die "cuobjdump -arch sm_50 -sass -fun $kernelName $cubinFile: $!"; 87 | my $first = <$in>; 88 | if ($first =~ /cuobjdump fatal/) 89 | { 90 | print $first; 91 | exit(1); 92 | } 93 | my $out; 94 | if ($asmFile) 95 | { 96 | open $out, ">$asmFile" or die "$asmFile: $!"; 97 | } 98 | else 99 | { 100 | $out = \*STDOUT; 101 | } 102 | 103 | print $out "# Kernel: $kernelName\n# Arch: sm_$arch\n"; 104 | 105 | print $out "# $_: $kernel->{$_}\n" foreach (qw(InsCnt RegCnt SharedSize BarCnt)); 106 | 107 | print $out "# Params($kernel->{ParamCnt}):\n#\tord:addr:size:align\n"; 108 | 109 | print $out join('', map "#\t$_\n", @{$kernel->{Params}}) if $kernel->{Params}; 110 | 111 | print $out "#\n# Instructions:\n\n"; 112 | 113 | MaxAs::MaxAs::Extract($in, $out, $kernel->{Params}); 114 | 115 | close $out if $asmFile; 116 | close $in; 117 | } 118 | # Extract a kernel from a sass dump 119 | elsif ($mode =~ /^\-?\-s/i) 120 | { 121 | my $sassFile = shift or usage(); 122 | my $asmFile = shift; 123 | 124 | open my $in, $sassFile or die "$sassFile: $!"; 125 | 126 | my $out; 127 | if ($asmFile) 128 | { 129 | open $out, ">$asmFile" or die "$asmFile: $!"; 130 | } 131 | else 132 | { 133 | $out = \*STDOUT; 134 | } 135 | 136 | MaxAs::MaxAs::Extract($in, $out, []); 137 | 138 | close $out if $asmFile; 139 | close $in; 140 | } 141 | # Insert the kernel asm back into the cubin: 142 | elsif ($mode =~ /^\-?\-i/i) 143 | { 144 | my $nowarn; 145 | if ($ARGV[0] =~ /^\-?\-w/i) 146 | { 147 | $nowarn = shift; 148 | } 149 | my $kernelName; 150 | if ($ARGV[0] =~ /^\-?\-k/i) 151 | { 152 | shift; 153 | $kernelName = shift or usage(); 154 | } 155 | my $noReuse = shift if $ARGV[0] =~ /^\-?\-n/i; 156 | while ($ARGV[0] =~ /^\-?\-D(\w+)/) 157 | { 158 | shift; 159 | my $name = $1; 160 | my $value = shift; 161 | eval "package MaxAs::MaxAs::CODE; our \$$name = '$value';" 162 | } 163 | 164 | my $asmFile = shift or usage(); 165 | my $cubinFile = shift or usage(); 166 | my $newCubin = shift || $cubinFile; 167 | 168 | my $file; 169 | if (open my $fh, $asmFile) 170 | { 171 | local $/; 172 | $file = <$fh>; 173 | close $fh; 174 | } 175 | else { die "$asmFile: $!" } 176 | 177 | my ($vol,$dir) = File::Spec->splitpath($asmFile); 178 | my $include = [$vol, $dir]; 179 | 180 | # extract the kernel name from the file 181 | ($kernelName) = $file =~ /^# Kernel: (\w+)/ unless $kernelName; 182 | die "asm file missing kernel name or is badly formatted" unless $kernelName; 183 | 184 | my $kernel = MaxAs::MaxAs::Assemble($file, $include, !$noReuse, $nowarn); 185 | 186 | my $cubin = MaxAs::Cubin->new($cubinFile); 187 | $kernel->{Kernel} = $cubin->getKernel($kernelName) or die "cubin does not contain kernel: $kernelName"; 188 | 189 | $cubin->modifyKernel(%$kernel); 190 | 191 | $cubin->write($newCubin); 192 | 193 | printf "Kernel: $kernelName, Instructions: %d, Register Count: %d, Bank Conflicts: %d, Reuse: %.1f% (%d/%d)\n", 194 | @{$kernel}{qw(InsCnt RegCnt ConflictCnt ReusePct ReuseCnt ReuseTot)}; 195 | 196 | } 197 | # Preprocessing: 198 | elsif ($mode =~ /^\-?\-p/i) 199 | { 200 | while ($ARGV[0] =~ /^\-?\-D(\w+)/) 201 | { 202 | shift; 203 | my $name = $1; 204 | my $value = shift; 205 | eval "package MaxAs::MaxAs::CODE; our \$$name = '$value';"; 206 | } 207 | my $debug = shift if $ARGV[0] =~ /^\-?\-d/i; 208 | my $asmFile = shift or usage(); 209 | my $asmFile2 = shift; 210 | 211 | die "source and destination probably shouldn't be the same file\n" if $asmFile eq $asmFile2; 212 | 213 | open my $fh, $asmFile or die "$asmFile: $!"; 214 | local $/; 215 | my $file = <$fh>; 216 | close $fh; 217 | 218 | my ($vol,$dir) = File::Spec->splitpath($asmFile); 219 | my $include = [$vol, $dir]; 220 | 221 | if ($asmFile2) 222 | { 223 | open $fh, ">$asmFile2" or die "$asmFile2: $!"; 224 | } 225 | else 226 | { 227 | $fh = \*STDOUT; 228 | } 229 | print $fh MaxAs::MaxAs::Preprocess($file, $include, $debug); 230 | close $fh; 231 | } 232 | # get version information 233 | elsif ($mode =~ /^\-?\-v/i) 234 | { 235 | print "$MaxAs::MaxAs::VERSION\n"; 236 | } 237 | else 238 | { 239 | print "$mode\n"; 240 | usage(); 241 | } 242 | 243 | exit(0); 244 | 245 | 246 | 247 | sub usage 248 | { 249 | print < 255 | 256 | Test a cubin or sass file to to see if the assembler can reproduce all of the contained opcodes. 257 | Also useful for extending the missing grammar rules. Defaults to only showing failures without --all. 258 | With the --reg flag it will show register bank conflicts not hidden by reuse flags. 259 | 260 | maxas.pl --test|-t [--reg|-r] [--all|-a] 261 | 262 | Extract a single kernel into an asm file from a cubin. 263 | Works much like cuobjdump but outputs in a format that can be re-assembled back into the cubin. 264 | 265 | maxas.pl --extract|-e [--kernel|-k kernel_name] [asm_file] 266 | 267 | Preprocess the asm: expand CODE sections, perform scheduling. Mainly used for debugging purposes. 268 | Include the debug flag to print out detailed scheduler info. 269 | 270 | maxas.pl --pre|-p [--debug|-d] [new_asm_file] 271 | 272 | Insert the kernel asm back into the cubin. Overwrite existing or create new cubin. 273 | Optionally you can skip register reuse flag auto insertion. This allows you to observe 274 | performance without any reuse or you can use it to set the flags manually in your sass. 275 | 276 | maxas.pl --insert|-i [--noreuse|-n] [new_cubin_file] 277 | 278 | Display version information and exit: 279 | 280 | maxas.pl --version|-v 281 | 282 | EOF 283 | exit(1); 284 | } 285 | 286 | __END__ 287 | -------------------------------------------------------------------------------- /openai_gemm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import ctypes 4 | import appdirs 5 | import os.path 6 | import subprocess 7 | import numpy as np 8 | import pycuda.driver as drv 9 | from operator import mul 10 | from struct import unpack_from 11 | from pycuda.tools import context_dependent_memoize 12 | from scikits.cuda import cublas 13 | 14 | def matmul(A, B, C, alpha=1.0, beta=0.0, stream=None, bench=False): 15 | """ 16 | C = alpha * A . B + beta * C 17 | C = alpha * A.T . B + beta * C 18 | C = alpha * A . B.T + beta * C 19 | C = alpha * A.T . B.T + beta * C 20 | 21 | bench: return benchmark data for all available tiles + cublas 22 | """ 23 | 24 | # this could be relaxed, kernels are capable of mixed precision (with minor tweaks) 25 | # the s/h prefix would then go away and each type would be specified with kernel build option 26 | assert A.dtype.type == B.dtype.type == C.dtype.type 27 | 28 | if C.dtype.type is np.float32: 29 | prefix = "s" 30 | elif C.dtype.type is np.float16: 31 | prefix = "h" 32 | else: 33 | raise TypeError("Only floating point dot currently supported.") 34 | 35 | # (m,n) = (m,k) . (k,n) 36 | m = A.shape[0] 37 | n = B.shape[1] 38 | k = A.shape[1] 39 | assert m == C.shape[0] 40 | assert n == C.shape[1] 41 | assert k == B.shape[0] 42 | 43 | # Extract the operations and contiguous dimension sizes (cda, cdb, cdc). 44 | # Note that these can be the same as from the shape unless the non-contiguous dimension is sliced. 45 | # One dimension must be contiguous (DRAM efficiency demands this). 46 | # Note that the strides here do not include the datatype size as they would in numpy. 47 | # A transpose op (.T) on a GPUTensor reverses the shape and strides then flags the tensor as transposed (is_trans=True) - 48 | # The underlying data is unchanged. 49 | if A.is_trans: 50 | opA = 'T' 51 | cda = A.strides[1] 52 | assert A.strides[0] == 1 53 | else: 54 | opA = 'N' 55 | cda = A.strides[0] 56 | assert A.strides[1] == 1 57 | 58 | if B.is_trans: 59 | opB = 'T' 60 | cdb = B.strides[1] 61 | assert B.strides[0] == 1 62 | else: 63 | opB = 'N' 64 | cdb = B.strides[0] 65 | assert B.strides[1] == 1 66 | 67 | cdc = C.strides[0] 68 | assert C.strides[1] == 1 69 | 70 | op = opA + opB 71 | 72 | # get and autotune the kernel selection 73 | kernel, params, dynamic_shared = _get_gemm_kernel(prefix, op, cda, cdb, cdc, m, n, k) 74 | 75 | # bind dynamic params 76 | params[2:8] = (stream, C.gpudata, A.gpudata, B.gpudata, alpha, beta) 77 | 78 | # call the kernel 79 | kernel.prepared_async_call(*params, shared_size=dynamic_shared) 80 | 81 | # unbind dynamic params 82 | params[2:8] = (None,) * 6 83 | 84 | # return benchmark data if requested 85 | if bench: 86 | return _get_bench_data()[(prefix, op, cda, cdb, cdc, m, n, k)] 87 | 88 | return C 89 | 90 | 91 | 92 | #################################################################################################### 93 | 94 | 95 | # scikits.cuda doesn't expose cublasSgemmEx or cublasHgemm 96 | cublas._libcublas.cublasSgemmEx.restype = int 97 | cublas._libcublas.cublasSgemmEx.argtypes = [ 98 | cublas._types.handle, 99 | ctypes.c_int, ctypes.c_int, 100 | ctypes.c_int, ctypes.c_int, ctypes.c_int, 101 | ctypes.c_void_p, 102 | ctypes.c_void_p, ctypes.c_int, ctypes.c_int, 103 | ctypes.c_void_p, ctypes.c_int, ctypes.c_int, 104 | ctypes.c_void_p, 105 | ctypes.c_void_p, ctypes.c_int, ctypes.c_int ] 106 | 107 | def cublasSgemmEx(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc): 108 | status = cublas._libcublas.cublasSgemmEx(handle, 109 | cublas._CUBLAS_OP[transa], cublas._CUBLAS_OP[transb], 110 | m, n, k, 111 | ctypes.byref(ctypes.c_float(alpha)), 112 | int(A), 2, lda, 113 | int(B), 2, ldb, 114 | ctypes.byref(ctypes.c_float(beta)), 115 | int(C), 2, ldc) 116 | cublas.cublasCheckStatus(status) 117 | 118 | cublas._libcublas.cublasHgemm.restype = int 119 | cublas._libcublas.cublasHgemm.argtypes = [ 120 | cublas._types.handle, 121 | ctypes.c_int, ctypes.c_int, 122 | ctypes.c_int, ctypes.c_int, ctypes.c_int, 123 | ctypes.c_void_p, 124 | ctypes.c_void_p, ctypes.c_int, 125 | ctypes.c_void_p, ctypes.c_int, 126 | ctypes.c_void_p, 127 | ctypes.c_void_p, ctypes.c_int ] 128 | 129 | h_dtype = np.dtype(np.float16) 130 | 131 | def cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc): 132 | 133 | alpha = unpack_from('H', h_dtype.type(alpha))[0] 134 | beta = unpack_from('H', h_dtype.type(beta))[0] 135 | 136 | status = cublas._libcublas.cublasHgemm(handle, 137 | cublas._CUBLAS_OP[transa], cublas._CUBLAS_OP[transb], 138 | m, n, k, 139 | ctypes.byref(ctypes.c_uint16(alpha)), 140 | int(A), lda, 141 | int(B), ldb, 142 | ctypes.byref(ctypes.c_uint16(beta)), 143 | int(C), ldc) 144 | cublas.cublasCheckStatus(status) 145 | 146 | cublasXgemm = { 147 | "s" : cublas.cublasSgemm, 148 | "h" : cublasSgemmEx, 149 | "h2" : cublasHgemm, 150 | } 151 | 152 | 153 | @context_dependent_memoize 154 | def _get_sm_count(): 155 | attributes = drv.Context.get_device().get_attributes() 156 | return attributes[drv.device_attribute.MULTIPROCESSOR_COUNT] 157 | 158 | @context_dependent_memoize 159 | def _get_compute_capability(): 160 | 161 | attributes = drv.Context.get_device().get_attributes() 162 | major = attributes[drv.device_attribute.COMPUTE_CAPABILITY_MAJOR] 163 | minor = attributes[drv.device_attribute.COMPUTE_CAPABILITY_MINOR] 164 | return major, minor 165 | 166 | @context_dependent_memoize 167 | def _get_events(): 168 | return (drv.Event(), drv.Event()) 169 | 170 | @context_dependent_memoize 171 | def _get_cublas(): 172 | return cublas.cublasCreate() 173 | 174 | @context_dependent_memoize 175 | def _get_bench_data(): 176 | return dict() 177 | 178 | def _ceil_div(x, y): 179 | return -(-x // y) 180 | 181 | def _closest_divisor(val, div): 182 | divisors = sorted([(abs(i - div), i) for i in range(2, 8) if val % i == 0]) 183 | if len(divisors): 184 | return (divisors[0][1], val // divisors[0][1]) 185 | else: 186 | return (1, val) 187 | 188 | 189 | # Tile sizes: m, n, k, vA,vB,vC div, op (dynamic shared options) 190 | k128x128x8 = (128, 128, 8, 4, 4, 1, 2, 0, (0,)) 191 | k32x32x32 = ( 32, 32, 32, 4, 4, 1, 4, 0, (0, 2**14)) 192 | k32x64x32_NN = ( 32, 64, 32, 8, 4, 4, 4, 1, (0, 2**13)) 193 | k32x32x64_NT = ( 32, 32, 64, 8, 8, 4, 4, 1, (0,)) 194 | k16x64x64_NN = ( 16, 64, 64, 8, 4, 4, 4, 1, (0,)) 195 | k16x64x64_NT = ( 16, 64, 64, 8, 8, 4, 4, 1, (0,)) 196 | 197 | selections = { 198 | "s" : { 199 | "TN" : (k128x128x8, k32x32x32), 200 | "NN" : (k128x128x8, k32x32x32), 201 | "NT" : (k128x128x8, k32x32x32), 202 | "TT" : (k128x128x8, k32x32x32), 203 | }, 204 | "h" : { 205 | "TN" : (k128x128x8, k32x32x32), 206 | "NN" : (k128x128x8, k32x32x32, k32x64x32_NN, k16x64x64_NN), 207 | "NT" : (k128x128x8, k32x32x32, k32x32x64_NT, k16x64x64_NT), 208 | "TT" : (k128x128x8, k32x32x32), 209 | }, 210 | } 211 | 212 | # Autotune kernel selection 213 | @context_dependent_memoize 214 | def _get_gemm_kernel(prefix, op, cda, cdb, cdc, m, n, k): 215 | 216 | if op[0] == 'T': 217 | vec4A = (cda & 3) == 0 and (m & 3) == 0 218 | vec8A = (cda & 7) == 0 and (m & 7) == 0 219 | dimA = (k,cda) 220 | else: 221 | vec4A = (cda & 3) == 0 and (k & 3) == 0 222 | vec8A = (cda & 7) == 0 and (k & 7) == 0 223 | dimA = (m,cda) 224 | 225 | if op[1] == 'T': 226 | vec4B = (cdb & 3) == 0 and (k & 3) == 0 227 | vec8B = (cdb & 7) == 0 and (k & 7) == 0 228 | dimB = (n,cdb) 229 | else: 230 | vec4B = (cdb & 3) == 0 and (n & 3) == 0 231 | vec8B = (cdb & 7) == 0 and (n & 7) == 0 232 | dimB = (k,cdb) 233 | 234 | vec4C = (cdc & 3) == 0 and (n & 3) == 0 235 | dimC = (m,cdc) 236 | 237 | dtype = np.dtype(np.float32 if prefix == 's' else np.float16) 238 | 239 | A = drv.mem_alloc(mul(*dimA) * dtype.itemsize) 240 | B = drv.mem_alloc(mul(*dimB) * dtype.itemsize) 241 | C = drv.mem_alloc(mul(*dimC) * dtype.itemsize) 242 | 243 | # TODO: use curand 244 | dataA = np.random.uniform(-1.0, 1.0, dimA).astype(dtype) 245 | dataB = np.random.uniform(-1.0, 1.0, dimB).astype(dtype) 246 | drv.memcpy_htod(int(A), dataA) 247 | drv.memcpy_htod(int(B), dataB) 248 | 249 | # Using random data gets you more accurate autotune results 250 | # drv.memset_d8(int(A), 0, mul(*dimA) * dtype.itemsize) 251 | # drv.memset_d8(int(B), 0, mul(*dimB) * dtype.itemsize) 252 | 253 | timings = [] 254 | cache = [] 255 | 256 | # scale the repeat count to amount of work 257 | repeat = min(max(int(5e11 * 28 / (m*n*k * 2.0 * _get_sm_count()) ), 10), 5000) 258 | warmup = repeat 259 | #print repeat 260 | 261 | start, end = _get_events() 262 | flops = m * n * k * 2.0 263 | 264 | for tileM, tileN, tileK, vecA, vecB, vecC, div, base_op, dyn_shared in selections[prefix][op]: 265 | 266 | vecA = (vecA == 4 and vec4A) or (vecA == 8 and vec8A) 267 | vecB = (vecB == 4 and vec4B) or (vecB == 8 and vec8B) 268 | vecC = vecC == 1 or vec4C 269 | vec = vecA and vecB and vecC 270 | 271 | if base_op: 272 | # The op is part of the base kernel name 273 | base = "%sgemm_%dx%dx%d_%s" % (prefix, tileM, tileN, tileK, op) 274 | opts = ( "vec", ) if vec else () 275 | else: 276 | # The op is an option passed to a more generic kernel 277 | base = "%sgemm_%dx%dx%d" % (prefix, tileM, tileN, tileK) 278 | opts = ( op, "vec" ) if vec else (op,) 279 | 280 | kernel = get_kernel(base, opts) 281 | 282 | blk_A = _ceil_div(m, tileM) 283 | blk_B = _ceil_div(n, tileN) 284 | 285 | # TODO: perhaps autotune all possible small divisors 286 | blk_a, blk_A = _closest_divisor(blk_A, div) 287 | blk_b, blk_B = _closest_divisor(blk_B, div) 288 | if blk_a == 1: 289 | blk_a, blk_A = (blk_A, 1) 290 | 291 | for dynamic_shared in dyn_shared: 292 | 293 | params = [ 294 | (blk_a * blk_b, blk_B, blk_A), (kernel.threads, 1, 1), None, 295 | C, A, B, 1.0, 0.0, 296 | cda, cdb, cdc, m, n, k, blk_a, blk_b ] 297 | 298 | #print kernel.name, params, dynamic_shared 299 | 300 | # Warmup (once per config) 301 | for r in range(warmup): 302 | kernel.prepared_async_call(*params) 303 | warmup = 0 304 | 305 | # Benchmark 306 | start.record() 307 | for r in range(repeat): 308 | kernel.prepared_async_call(*params, shared_size=dynamic_shared) 309 | end.record() 310 | end.synchronize() 311 | msecs = end.time_since(start) / float(repeat) 312 | gflops = flops / (msecs * 1000000.0) 313 | 314 | params[3:8] = (None,) * 5 315 | 316 | timings.append((msecs, gflops, kernel, params, dynamic_shared)) 317 | cache.append((msecs, gflops, kernel.name, dynamic_shared)) 318 | 319 | major, minor = _get_compute_capability() 320 | if prefix == "h" and major == 6 and minor == 0: 321 | cublas_gemm = cublasXgemm["h2"] 322 | else: 323 | cublas_gemm = cublasXgemm[prefix] 324 | 325 | # record a cublas time for reference 326 | cublas_handle = _get_cublas() 327 | start.record() 328 | for r in range(repeat): 329 | # convert row order to col order 330 | cublas_gemm(cublas_handle, op[1], op[0], n, m, k, 1.0, B, cdb, A, cda, 0.0, C, cdc) 331 | end.record() 332 | end.synchronize() 333 | msecs = end.time_since(start) / float(repeat) 334 | gflops = flops / (msecs * 1000000.0) 335 | cache.append( (msecs, gflops, "cuBLAS", 0) ) 336 | 337 | # cache complete timing data for benchmark comparisons 338 | # this data could be cached to disk for quicker autotuning on future runs 339 | _get_bench_data()[(prefix, op, cda, cdb, cdc, m, n, k)] = cache 340 | 341 | # return the fastest kernel 342 | return tuple(sorted(timings)[0][2:5]) 343 | 344 | 345 | 346 | # Utility function to test all tiles for the given dimensions and dtype 347 | def matmul_test(ng, dtype, op, m, n, k, ones=False, out=False): 348 | 349 | prefix = "s" if dtype is np.float32 else "h" 350 | 351 | if op[0] == 'T': 352 | vec4A = (m & 3) == 0 353 | vec8A = (m & 7) == 0 354 | dimA = (k,m) 355 | cda = m 356 | else: 357 | vec4A = (k & 3) == 0 358 | vec8A = (k & 7) == 0 359 | dimA = (m,k) 360 | cda = k 361 | 362 | if op[1] == 'T': 363 | vec4B = (k & 3) == 0 364 | vec8B = (k & 7) == 0 365 | dimB = (n,k) 366 | cdb = k 367 | else: 368 | vec4B = (n & 3) == 0 369 | vec8B = (n & 7) == 0 370 | dimB = (k,n) 371 | cdb = n 372 | 373 | vec4C = (n & 3) == 0 374 | dimC = (m,n) 375 | cdc = n 376 | 377 | A1 = ng.empty(dimA, dtype=dtype) 378 | B1 = ng.empty(dimB, dtype=dtype) 379 | C1 = ng.empty(dimC, dtype=dtype) 380 | C2 = ng.empty(dimC, dtype=dtype) 381 | 382 | if ones: 383 | A1[:] = 1.0 384 | B1[:] = 1.0 385 | else: 386 | # fill with uniform randoms from -1 to 1 387 | A1[:] = 2 * (.5 - ng.rand()) 388 | B1[:] = 2 * (.5 - ng.rand()) 389 | 390 | # for reducing outputs 391 | partial1 = ng.empty((C1.shape[0],1), dtype=np.float32) 392 | partial2 = partial1[0:1,0:1] 393 | 394 | cublas_handle = _get_cublas() 395 | 396 | for tileM, tileN, tileK, vecA, vecB, vecC, div, base_op, dyn_shared in selections[prefix][op]: 397 | 398 | vecA = (vecA == 4 and vec4A) or (vecA == 8 and vec8A) 399 | vecB = (vecB == 4 and vec4B) or (vecB == 8 and vec8B) 400 | vecC = vecC == 1 or vec4C 401 | vec = vecA and vecB and vecC 402 | 403 | if base_op: 404 | # The op is part of the base kernel name 405 | base = "%sgemm_%dx%dx%d_%s" % (prefix, tileM, tileN, tileK, op) 406 | opts = ( "vec", ) if vec else () 407 | else: 408 | # The op is an option passed to a more generic kernel 409 | base = "%sgemm_%dx%dx%d" % (prefix, tileM, tileN, tileK) 410 | opts = ( op, "vec" ) if vec else (op,) 411 | 412 | kernel = get_kernel(base, opts) 413 | 414 | blk_A = _ceil_div(m, tileM) 415 | blk_B = _ceil_div(n, tileN) 416 | 417 | blk_a, blk_A = _closest_divisor(blk_A, div) 418 | blk_b, blk_B = _closest_divisor(blk_B, div) 419 | if blk_a == 1: 420 | blk_a, blk_A = (blk_A, 1) 421 | 422 | for alpha, beta in ( (1.0,0.0), (0.5,0.5) ): 423 | 424 | try: 425 | if ones: 426 | C1[:] = 1.0 427 | else: 428 | C1[:] = 2 * (.5 - ng.rand()) 429 | C2[:] = C1 430 | 431 | params = [ 432 | (blk_a * blk_b, blk_B, blk_A), (kernel.threads, 1, 1), None, 433 | C1.gpudata, A1.gpudata, B1.gpudata, alpha, beta, 434 | cda, cdb, cdc, m, n, k, blk_a, blk_b ] 435 | 436 | kernel.prepared_async_call(*params) 437 | 438 | # convert row order to col order 439 | cublasXgemm[prefix](cublas_handle, op[1], op[0], n, m, k, alpha, B1.gpudata, cdb, A1.gpudata, cda, beta, C2.gpudata, cdc) 440 | 441 | # Check for NaNs 442 | partial1[:] = ng.min(ng.finite(C1), axis=1) 443 | partial2[:] = ng.min(partial1, axis=0) 444 | if partial2.get()[0,0] == 0.0: 445 | print "Error: NaN kernel: %s mnk: (%d,%d,%d) ab: (%f,%f)" % (kernel.name, m,n,k, alpha,beta) 446 | exit() 447 | 448 | # Get Max Diff 449 | partial1[:] = ng.max(abs(C2 - C1), axis=1) 450 | partial2[:] = ng.max(partial1, axis=0) 451 | diff = partial2.get()[0,0] 452 | 453 | # Get Mean 454 | partial1[:] = ng.sum(abs(C2), axis=1) 455 | partial2[:] = ng.sum(partial1, axis=0) 456 | mean = partial2.get()[0,0] / C2.size 457 | 458 | # Scale diff by the mean 459 | pctErr = 100 * diff / mean 460 | 461 | #print "Error: %.3f %s" % (pctErr, kernel.name) 462 | 463 | maxerr = .005 if dtype is np.float32 else 0.7 464 | 465 | if pctErr > maxerr: 466 | print "Error: %.3f%% diff: %.5f mean %.5f kernel: %s mnk: (%d,%d,%d) ab: (%f,%f)" % (pctErr, diff, mean, kernel.name, m,n,k, alpha,beta) 467 | print params 468 | if out: 469 | C1 = C1.get() 470 | C2 = C2.get() 471 | D = abs(C2 - C1) 472 | np.savetxt("out_diff.txt", D, fmt='%3.1f') 473 | np.savetxt("out_correct.txt", C2, fmt='%5.1f') 474 | np.savetxt("out_error", C1, fmt='%5.1f') 475 | exit() 476 | 477 | except drv.Error as e: 478 | print "kernel: %s mnk: (%d,%d,%d) ab: (%f,%f)" % (kernel.name, m,n,k, alpha,beta) 479 | print e 480 | exit() 481 | 482 | ### below code adapted from Nervana Neon: kernel_specs.py 483 | 484 | def _get_cache_dir(subdir=None): 485 | 486 | cache_dir = appdirs.user_cache_dir("openai-gemm") 487 | 488 | if subdir: 489 | subdir = subdir if isinstance(subdir, list) else [subdir] 490 | cache_dir = os.path.join(cache_dir, *subdir) 491 | 492 | if not os.path.exists(cache_dir): 493 | os.makedirs(cache_dir) 494 | 495 | return cache_dir 496 | 497 | # helpful for kernel development 498 | debug = 0 499 | 500 | base_dir = os.path.dirname(__file__) 501 | maxas_dir = os.path.join(base_dir, "maxas") 502 | sass_dir = os.path.join(base_dir, "sass") 503 | 504 | kernels = { 505 | # Generic gemm tiles 506 | "sgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "s"} }, 507 | "hgemm_128x128x8": {"threads": 256, "sass": "xgemm_128x128x8", "params": "xgemm", "share": "(128*8 + 32)*4 + 4", "args": {"type": "h"} }, 508 | "sgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "s"} }, 509 | "hgemm_32x32x32": {"threads": 128, "sass": "xgemm_32x32x32", "params": "xgemm", "share": "(32*33)*4 + 4", "args": {"type": "h"} }, 510 | 511 | # Custom hgemm tiles designed for small minibatch RNNs 512 | "hgemm_32x64x32_NN": {"threads": 128, "sass": "hgemm_32x64x32_NN", "params": "xgemm", "share": "32*33*2 + 64*32*2 + 4" }, 513 | "hgemm_32x32x64_NT": {"threads": 128, "sass": "hgemm_32x32x64_NT", "params": "xgemm", "share": "32*65*4 + 4" }, 514 | "hgemm_16x64x64_NN": {"threads": 128, "sass": "hgemm_16x64x64_NN", "params": "xgemm", "share": "(16*64 + 32)*2 + 64*64*2 + 4" }, 515 | "hgemm_16x64x64_NT": {"threads": 128, "sass": "hgemm_16x64x64_NT", "params": "xgemm", "share": "(16*64 + 32)*2 + (64*64 + 32)*2 + 4" }, 516 | } 517 | 518 | _params = { 519 | "xgemm": [ 520 | "float* param_C", 521 | "float* param_A", 522 | "float* param_B", 523 | "float param_alpha", 524 | "float param_beta", 525 | "unsigned param_cda", 526 | "unsigned param_cdb", 527 | "unsigned param_cdc", 528 | "unsigned param_m", 529 | "unsigned param_n", 530 | "unsigned param_k", 531 | "unsigned param_blk_a", 532 | "unsigned param_blk_b", 533 | ], 534 | } 535 | 536 | _space_re = re.compile(r"\s+") 537 | 538 | _share_template = r""" 539 | .shared .align 4 .b32 share[{0}]; 540 | """ 541 | 542 | _kernel_template = r""" 543 | .version {6} 544 | .target {0} 545 | .address_size 64 546 | 547 | // args: {5} 548 | 549 | .visible .entry {1}( 550 | {2} 551 | ) 552 | .reqntid {3} 553 | {{ 554 | {4} 555 | ret; 556 | }} 557 | """ 558 | 559 | def get_ptx_file(kernel_spec, kernel_name, arch, ptx_ver): 560 | 561 | ptx_dir = _get_cache_dir([arch, 'ptx']) 562 | 563 | thread_spec = kernel_spec["threads"] 564 | args_spec = str(kernel_spec.get("args","")) 565 | param_spec = _params[kernel_spec["params"]] 566 | 567 | kernel_params = [] 568 | for p in param_spec: 569 | ptype, pname = _space_re.split(p) 570 | 571 | if ptype[-1] == '*': 572 | ptype = '.u64' 573 | elif ptype == 'float': 574 | ptype = '.f32' 575 | else: 576 | ptype = '.u32' 577 | 578 | kernel_params.append(" .param %s %s" % (ptype, pname)) 579 | 580 | kernel_params = ",\n".join(kernel_params) 581 | 582 | if "share" in kernel_spec: 583 | share = _share_template.format(eval(kernel_spec["share"])) 584 | else: 585 | share = "" 586 | 587 | kernel_text = _kernel_template.format(arch, kernel_name, kernel_params, thread_spec, share, args_spec, ptx_ver) 588 | kernel_ptx = os.path.join(ptx_dir, kernel_name + ".ptx") 589 | 590 | current_text = "" 591 | if os.path.exists(kernel_ptx): 592 | f = open(kernel_ptx, "r") 593 | current_text = f.read() 594 | f.close() 595 | # only write out the kernel if text has changed. 596 | if kernel_text != current_text: 597 | f = open(kernel_ptx, "w") 598 | f.write(kernel_text) 599 | f.close() 600 | 601 | return kernel_ptx 602 | 603 | 604 | include_re = re.compile(r'^') 605 | 606 | def extract_includes(name, includes=None): 607 | if not includes: 608 | includes = list() 609 | sass_file = os.path.join(sass_dir, name) 610 | includes.append((sass_file, os.path.getmtime(sass_file))) 611 | for line in open(sass_file, "r"): 612 | match = include_re.search(line) 613 | if match: 614 | extract_includes(match.group(1), includes) 615 | return includes 616 | 617 | def run_command(cmdlist): 618 | cmd = " ".join(cmdlist) 619 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 620 | out, err = proc.communicate() 621 | if proc.returncode: 622 | raise RuntimeError("Error(%d):\n%s\n%s" % (proc.returncode, cmd, err)) 623 | if debug: 624 | print cmd 625 | if out: print out 626 | if err: print err 627 | 628 | 629 | @context_dependent_memoize 630 | def get_kernel(base_name, options=None): 631 | 632 | major, minor = _get_compute_capability() 633 | if major < 5: 634 | raise RuntimeError("sass kernels require Maxwell or greater class hardware") 635 | 636 | arch = "sm_%d%d" % (major, minor) 637 | 638 | libprefix = "PERL5LIB=%s" % maxas_dir 639 | maxas_i = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -i -w"] 640 | maxas_p = [libprefix, os.path.join(maxas_dir, "maxas.pl") + " -p"] 641 | 642 | kernel_spec = kernels[base_name] 643 | kernel_name = base_name 644 | 645 | # static options 646 | if "args" in kernel_spec: 647 | for pair in kernel_spec["args"].items(): 648 | maxas_i.append("-D%s %s" % pair) 649 | maxas_p.append("-D%s %s" % pair) 650 | 651 | # dynamic options 652 | if options is not None: 653 | for opt in options: 654 | if type(opt) is tuple: 655 | maxas_i.append("-D%s %s" % opt) 656 | maxas_p.append("-D%s %s" % opt) 657 | kernel_name += "_%s%s" % opt 658 | else: 659 | maxas_i.append("-D%s 1" % opt) 660 | maxas_p.append("-D%s 1" % opt) 661 | kernel_name += "_%s" % opt 662 | 663 | maxas_i.insert(2, "-k " + kernel_name) 664 | 665 | sass_name = kernel_spec["sass"] + ".sass" 666 | cubin_name = kernel_name + ".cubin" 667 | cubin_dir = _get_cache_dir([arch, 'cubin']) 668 | 669 | ptx_version = "4.2" if major < 6 else "5.0" 670 | ptx_file = get_ptx_file(kernel_spec, kernel_name, arch, ptx_version) 671 | sass_file = os.path.join(sass_dir, sass_name) 672 | cubin_file = os.path.join(cubin_dir, cubin_name) 673 | 674 | if not os.path.exists(sass_file): 675 | raise RuntimeError("Missing: %s for kernel: %s" % (sass_name, kernel_name)) 676 | 677 | ptx_mtime = os.path.getmtime(ptx_file) 678 | cubin_mtime = os.path.getmtime(cubin_file) if os.path.exists(cubin_file) else 0 679 | 680 | build_cubin = False 681 | if ptx_mtime > cubin_mtime: 682 | build_cubin = True 683 | 684 | includes = extract_includes(sass_name) 685 | for include, include_mtime in includes: 686 | if include_mtime > cubin_mtime: 687 | build_cubin = True 688 | break 689 | 690 | if build_cubin: 691 | # build the cubin and run maxas in the same command 692 | # we don't want the chance of a generated cubin not processed by maxas (in case user hits ^C in between these steps) 693 | run_command([ "ptxas -v -arch", arch, "-o", cubin_file, ptx_file, ";" ] + maxas_i + [sass_file, cubin_file]) 694 | cubin_mtime = time.time() 695 | 696 | # output preprocessed and disassembled versions in debug mode 697 | if debug: 698 | pre_dir = _get_cache_dir([arch, 'pre']) 699 | dump_dir = _get_cache_dir([arch, 'dump']) 700 | 701 | pre_file = os.path.join(pre_dir, kernel_name + "_pre.sass") 702 | dump_file = os.path.join(dump_dir, kernel_name + "_dump.sass") 703 | pre_mtime = os.path.getmtime(pre_file) if os.path.exists(pre_file) else 0 704 | dump_mtime = os.path.getmtime(dump_file) if os.path.exists(dump_file) else 0 705 | 706 | for include, include_mtime in includes: 707 | if include_mtime > pre_mtime: 708 | run_command(maxas_p + [sass_file, pre_file]) 709 | break 710 | 711 | if cubin_mtime > dump_mtime: 712 | run_command(["nvdisasm -c", cubin_file, ">", dump_file]) 713 | 714 | # generate the function signature for pycuda 715 | params = _params[kernel_spec["params"]] 716 | sig = "" 717 | for p in params: 718 | ptype, pname = _space_re.split(p) 719 | if ptype[-1] == '*': 720 | sig += "Q" 721 | elif ptype == 'float': 722 | sig += "f" 723 | elif ptype == 'unsigned': 724 | sig += "I" 725 | else: 726 | sig += "i" 727 | 728 | module = drv.module_from_file(cubin_file) 729 | func = module.get_function(kernel_name) 730 | func.prepare(sig) 731 | func.threads = kernel_spec["threads"] 732 | func.name = kernel_name 733 | func.static_shared = eval(kernel_spec["share"]) 734 | 735 | return func -------------------------------------------------------------------------------- /sass/hgemm_32x64x32_NN.sass: -------------------------------------------------------------------------------- 1 | # Kernel: hgemm_32x64x32_NN 2 | 3 | [- 4 | our $vec; 5 | sub vector { return $vec; } 6 | -] 7 | 8 | 9 | addr_zero : 4x<32*33*2 + 64*32*2> 10 | szShareA : (32*33) 11 | szShareB : (64*32) 12 | 13 | param_C[0] : c[0x0][0x140] 14 | param_C[1] : c[0x0][0x144] 15 | param_A[0] : c[0x0][0x148] 16 | param_A[1] : c[0x0][0x14c] 17 | param_B[0] : c[0x0][0x150] 18 | param_B[1] : c[0x0][0x154] 19 | param_alpha : c[0x0][0x158] 20 | param_beta : c[0x0][0x15c] 21 | param_cda : c[0x0][0x160] 22 | param_cdb : c[0x0][0x164] 23 | param_cdc : c[0x0][0x168] 24 | param_m : c[0x0][0x16c] 25 | param_n : c[0x0][0x170] 26 | param_k : c[0x0][0x174] 27 | param_blk_a : c[0x0][0x178] 28 | param_blk_b : c[0x0][0x17c] 29 | 30 | 31 | 32 | 33 | 3, 2,11,10,19,18,27,26 : cx<0-7>y0 34 | 7, 6,15,14,23,22,31,30 : cx<0-7>y1 35 | 1, 0, 9, 8,17,16,25,24 : cx<0-7>y2 36 | 5, 4,13,12,21,20,29,28 : cx<0-7>y3 37 | 35,34,43,42,51,50,59,58 : cx<0-7>y4 38 | 39,38,47,46,55,54,63,62 : cx<0-7>y5 39 | 33,32,41,40,49,48,57,56 : cx<0-7>y6 40 | 37,36,45,44,53,52,61,60 : cx<0-7>y7 41 | 42 | 0-63 : czero<00-63> 43 | 64-79 : j0Ay<0-7>, j0Bx<0-7> 44 | 80-95 : j1Ay<0-7>, j1Bx<0-7> 45 | 46 | 64-95 ~ cda, cdb, cdb8, tidAX, tidAY, tidBX, tidBY, tidAY<1-3>, tidBY<8|16|24>, tid1, tid32, tb, shiftAX, partialK, partialB, ta, txa, txb, txb<1-3>, xmad_ta, xmad_tb 47 | 96-119 ~ idx_ab, idx_ab_f, idx_a, neg_blk_b, rcp_blk_b, idx_b 48 | 49 | 50 | 96-119 : load0A<0-7>, load0B<0-3>, load1B<0-3>, load2B<0-3>, load3B<0-3> 51 | 120-129 : track0A<0-1>, track0B<0-1>, track1B<0-1>, track2B<0-1>, track3B<0-1> 52 | 53 | 130-137 ~ swapBuf, readAs, readBs, writeAs, writeBs, k, cdb32 54 | 138-144 ~ tid, idx_A, idx_B, writeCs, preds 55 | 56 | 0-15 : part0C<0-3>, part1C<0-3>, part2C<0-3>, part3C<0-3> 57 | 64-95 : shuffle_x<0-7>y0, shuffle_x<0-7>y1, shuffle_x<0-7>y2, shuffle_x<0-7>y3 58 | 64-95 : shuffle_x<0-7>y4, shuffle_x<0-7>y5, shuffle_x<0-7>y6, shuffle_x<0-7>y7 59 | 96-99 : loadC<0-3> 60 | 100-103 : b<0-3> 61 | 104-107 : c<0-3> 62 | 108-109 : C<0-1> 63 | 110-137 ~ cdc, cx, cx<1-3>, cy, ci, xmad_c, cdc8, readCs, alpha, beta, tid15, tid16 64 | 65 | 66 | 67 | --:-:5:-:1 I2F.F32.S32 rcp_blk_b, param_blk_b; 68 | --:-:1:-:1 S2R tid, SR_TID.X; 69 | --:-:2:-:1 S2R idx_ab, SR_CTAID.X; 70 | --:-:3:-:1 S2R idx_A, SR_CTAID.Z; 71 | --:-:4:-:1 S2R idx_B, SR_CTAID.Y; 72 | 73 | 74 | 10:-:5:-:1 MUFU.RCP rcp_blk_b, rcp_blk_b; 75 | 76 | --:-:-:-:1 MOV k, param_k; 77 | --:-:-:-:1 MOV cda, param_cda; 78 | --:-:-:-:1 MOV cdb, param_cdb; 79 | --:-:-:-:1 SHL cdb8, cdb, 3; 80 | --:-:-:-:1 SHL cdb32, cdb, 6; 81 | 82 | // If k is not a multiple of 32 we want to grab the partial amount on the first fetch. 83 | // If it is a multiple of 32 then make a full 32 line fetch. 84 | --:-:-:-:1 LOP.AND.Z P0, partialK, k, 31; 85 | --:-:-:-:1 @P0 MOV partialK, 32; 86 | --:-:-:-:1 IADD k, k, -partialK; 87 | 88 | # idx_a = idx_ab // blk_b 89 | 02:-:2:-:1 I2F.F32.S32 idx_ab_f, idx_ab; 90 | 12:-:-:-:1 FMUL idx_a, idx_ab_f, rcp_blk_b; 91 | --:-:-:-:1 FFMA idx_a, idx_a, 5.9604644775390625e-08, idx_a; 92 | --:-:2:-:1 F2I.S32.F32.TRUNC idx_a, idx_a; 93 | # idx_b = idx_AB % blk_b 94 | --:-:-:-:1 IADD neg_blk_b, RZ, -param_blk_b; 95 | 02:-:-:-:1 XMAD.S16.U16 idx_b, neg_blk_b, idx_a, idx_ab; 96 | 97 | # idx_A = idx_A * blk_a + idx_a 98 | # idx_B = idx_B * blk_b + idx_b 99 | 06:-:-:-:1 XMAD.U16.U16 idx_A, idx_A, param_blk_a, idx_a; 100 | 08:-:-:-:1 XMAD.U16.U16 idx_B, idx_B, param_blk_b, idx_b; 101 | 102 | 103 | --:-:-:-:1 STS.128 [addr_zero], RZ; 104 | [+ join '', map sprintf("--:-:-:-:1 LDS.U.128 czero%02d, [addr_zero];\n", $_ * 4), 0..15; +] 105 | 106 | // tidAX = tid >> 2 107 | // tidAY = (tid & 3) << 3 108 | // shiftAX = (tid & 3) << 3 109 | 01:-:-:-:1 SHR.U32 tidAX, tid, 2; 110 | --:-:-:-:1 LOP.AND tidAY, tid, 3; 111 | --:-:-:-:1 SHL shiftAX, tidAY, 3; 112 | --:-:-:-:1 SHL tidAY, tidAY, 3; 113 | 114 | // tidBX = (tid & 15) << 2 115 | // tidBY = tid >> 4 116 | 01:-:-:-:1 LOP.AND tidBX, tid, 15; 117 | --:-:-:-:1 SHL tidBX, tidBX, 2; 118 | --:-:-:-:1 SHR.U32 tidBY, tid, 4; 119 | 120 | --:-:-:-:1 IADD tidBY8, tidBY, 8; 121 | --:-:-:-:1 IADD tidBY16, tidBY, 16; 122 | --:-:-:-:1 IADD tidBY24, tidBY, 24; 123 | 124 | // trackA += ((idx_A*32 + tidAX) * cda + tidAY) * 2 125 | --:-:-:-:1 ISCADD txa, idx_A, tidAX, 5; 126 | --:-:-:-:1 XMAD.LO ta, cda, txa, tidAY, xmad_ta; 127 | --:-:-:-:1 LEA track0A0.CC, ta, param_A[0], 1; 128 | --:-:-:-:1 LEA.HI.X track0A1, ta, param_A[1], RZ, 1; 129 | 130 | --:-:-:-:1 ISETP.LT.AND P2, PT, txa, param_m, PT; 131 | 132 | // trackB += (idx_B*64 + tidBX + cdb*tidBY) * 2 133 | --:-:-:-:1 ISCADD txb, idx_B, tidBX, 6; 134 | --:-:-:-:1 XMAD.LO2 tb, cdb, tidBY, txb; 135 | --:-:-:-:1 LEA track0B0.CC, tb, param_B[0], 1; 136 | --:-:-:-:1 LEA.HI.X track0B1, tb, param_B[1], RZ, 1; 137 | --:-:-:-:1 IADD tb, tb, cdb8; 138 | --:-:-:-:1 LEA track1B0.CC, tb, param_B[0], 1; 139 | --:-:-:-:1 LEA.HI.X track1B1, tb, param_B[1], RZ, 1; 140 | --:-:-:-:1 IADD tb, tb, cdb8; 141 | --:-:-:-:1 LEA track2B0.CC, tb, param_B[0], 1; 142 | --:-:-:-:1 LEA.HI.X track2B1, tb, param_B[1], RZ, 1; 143 | --:-:-:-:1 IADD tb, tb, cdb8; 144 | --:-:-:-:1 LEA track3B0.CC, tb, param_B[0], 1; 145 | --:-:-:-:1 LEA.HI.X track3B1, tb, param_B[1], RZ, 1; 146 | 147 | --:-:-:-:1 ISETP.LT.AND P3, PT, txb, param_n, PT; 148 | [+ 149 | return vector() ? '' : q{ 150 | --:-:-:-:1 IADD txb1, txb, 1; 151 | --:-:-:-:1 IADD txb2, txb, 2; 152 | --:-:-:-:1 IADD txb3, txb, 3; 153 | --:-:-:-:1 ISETP.LT.AND P4, PT, txb1, param_n, PT; 154 | --:-:-:-:1 ISETP.LT.AND P5, PT, txb2, param_n, PT; 155 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb3, param_n, PT; 156 | }; 157 | +] 158 | --:-:-:-:1 P2R preds, PR, RZ, 0x7c; 159 | 160 | // writeAs = (tidAY*32 + tidAX + shiftAX) * 4 161 | --:-:-:-:1 ISCADD writeAs, tidAY, tidAX, 5; 162 | --:-:-:-:1 IADD writeAs, writeAs, shiftAX; 163 | --:-:-:-:1 SHL writeAs, writeAs, 2; 164 | 165 | // writeBs = (tidBY*64 + tidBX) * 4 166 | --:-:-:-:1 ISCADD writeBs, tidBY, tidBX, 6; 167 | --:-:-:-:1 ISCADD writeBs, writeBs, 4x, 2; 168 | 169 | // readAs = (((tid & 16) >> 2) | (tid & 1)) << 4 170 | --:-:-:-:1 LOP.AND tid1, tid, 1; 171 | --:-:-:-:1 LOP.AND readAs, tid, 16; 172 | --:-:-:-:1 SHR.U32 readAs, readAs, 3; 173 | --:-:-:-:1 LOP.OR readAs, readAs, tid1; 174 | --:-:-:-:1 SHL readAs, readAs, 4; 175 | 176 | // readBs = (((tid >> 1) & 7) << 4 177 | --:-:-:-:1 BFE.U32 readBs, tid, 0x301; // 2 bits at position 1 178 | --:-:-:-:1 SHL readBs, readBs, 4; 179 | 180 | // Each tile has 32 threads so this is an index into the 4 tiles (at bit position 5) 181 | // tid32 = tid & -32 182 | --:-:-:-:1 LOP.AND tid32, tid, -32; 183 | 184 | // Write out the 4 groups of 32 rows 16 at a time 185 | // writeCs = (readAs + tid32/2*4) * 64 + readBs 186 | --:-:-:-:1 ISCADD writeCs, tid32, readAs, 1; 187 | --:-:-:-:1 ISCADD writeCs, writeCs, readBs, 6; 188 | 189 | // Each block of 32 threads works on 8 lines, 190 | // readAs is also shifted over by 8 for each group of 32 threads 191 | // readAs += tid32/4 * 32 * 4 + tid32/4 * 4 192 | // readBs += tid32/4 * 64 * 4 + 4x 193 | --:-:-:-:1 ISCADD readAs, tid32, readAs, 5; 194 | --:-:-:-:1 ISCADD readBs, tid32, readBs, 6; 195 | --:-:-:-:1 IADD readAs, tid32, readAs; 196 | --:-:-:-:1 IADD readBs, readBs, 4x; 197 | 198 | --:-:-:-:1 MOV32I swapBuf, 4x; 199 | 200 | [+ 201 | return vector() ? q{ 202 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidAY, partialK, P2; 203 | --:-:-:-:1 ISETP.LT.AND P4, PT, tidBY8, partialK, P3; 204 | --:-:-:-:1 ISETP.LT.AND P5, PT, tidBY16, partialK, P3; 205 | --:-:-:-:1 ISETP.LT.AND P6, PT, tidBY24, partialK, P3; 206 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidBY, partialK, P3; 207 | 208 | --:-:2:-:1 @P2 LDG.E.CI.128 load0A, [track0A]; 209 | --:-:3:-:1 @P3 LDG.E.CI.64 load0B, [track0B]; 210 | --:-:4:-:1 @P4 LDG.E.CI.64 load1B, [track1B]; 211 | --:-:5:-:1 @P5 LDG.E.CI.64 load2B, [track2B]; 212 | --:-:6:-:1 @P6 LDG.E.CI.64 load3B, [track3B]; 213 | 214 | 215 | --:-:-:-:1 @!P2 LDS.U.128 load0A, [addr_zero]; 216 | --:-:-:-:1 @!P3 LDS.U.64 load0B, [addr_zero]; 217 | --:-:-:-:1 @!P4 LDS.U.64 load1B, [addr_zero]; 218 | --:-:-:-:1 @!P5 LDS.U.64 load2B, [addr_zero]; 219 | --:-:1:-:1 @!P6 LDS.U.64 load3B, [addr_zero]; 220 | 221 | } : q{ 222 | 223 | --:-:-:-:1 IADD tidAY1, tidAY, 1; 224 | --:-:-:-:1 IADD tidAY2, tidAY, 2; 225 | --:-:-:-:1 IADD tidAY3, tidAY, 3; 226 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidAY, partialK, P2; 227 | --:-:-:-:1 ISETP.LT.AND P4, PT, tidAY1, partialK, P2; 228 | --:-:-:-:1 ISETP.LT.AND P5, PT, tidAY2, partialK, P2; 229 | --:-:-:-:1 ISETP.LT.AND P6, PT, tidAY3, partialK, P2; 230 | 231 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0A0, [track0A + 2x<0>]; 232 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0A1, [track0A + 2x<1>]; 233 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0A2, [track0A + 2x<2>]; 234 | --:-:2:-:1 @P6 LDG.E.CI.U16 load0A3, [track0A + 2x<3>]; 235 | 236 | --:-:-:-:1 @!P3 MOV load0A0, RZ; 237 | --:-:-:-:1 @!P4 MOV load0A1, RZ; 238 | --:-:-:-:1 @!P5 MOV load0A2, RZ; 239 | --:-:-:-:1 @!P6 MOV load0A3, RZ; 240 | 241 | --:-:-:-:1 IADD tidAY, tidAY, 4; 242 | --:-:-:-:1 IADD tidAY1, tidAY1, 4; 243 | --:-:-:-:1 IADD tidAY2, tidAY2, 4; 244 | --:-:-:-:1 IADD tidAY3, tidAY3, 4; 245 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidAY, partialK, P2; 246 | --:-:-:-:1 ISETP.LT.AND P4, PT, tidAY1, partialK, P2; 247 | --:-:-:-:1 ISETP.LT.AND P5, PT, tidAY2, partialK, P2; 248 | --:-:-:-:1 ISETP.LT.AND P6, PT, tidAY3, partialK, P2; 249 | 250 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0A4, [track0A + 2x<4>]; 251 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0A5, [track0A + 2x<5>]; 252 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0A6, [track0A + 2x<6>]; 253 | --:-:2:-:1 @P6 LDG.E.CI.U16 load0A7, [track0A + 2x<7>]; 254 | 255 | --:-:-:-:1 @!P3 MOV load0A4, RZ; 256 | --:-:-:-:1 @!P4 MOV load0A5, RZ; 257 | --:-:-:-:1 @!P5 MOV load0A6, RZ; 258 | --:-:-:-:1 @!P6 MOV load0A7, RZ; 259 | 260 | 261 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidBY, partialK, PT; 262 | --:-:-:-:1 @P0 R2P PR, preds, 0x78; 263 | --:-:-:-:1 @!P0 R2P PR, RZ, 0x78; 264 | 265 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0B0, [track0B + 2x<0>]; 266 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0B1, [track0B + 2x<1>]; 267 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0B2, [track0B + 2x<2>]; 268 | --:-:3:-:1 @P6 LDG.E.CI.U16 load0B3, [track0B + 2x<3>]; 269 | 270 | --:-:-:-:1 @!P3 MOV load0B0, RZ; 271 | --:-:-:-:1 @!P4 MOV load0B1, RZ; 272 | --:-:-:-:1 @!P5 MOV load0B2, RZ; 273 | --:-:-:-:1 @!P6 MOV load0B3, RZ; 274 | 275 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY8, partialK, PT; 276 | --:-:-:-:1 @P1 R2P PR, preds, 0x78; 277 | --:-:-:-:1 @!P1 R2P PR, RZ, 0x78; 278 | 279 | --:-:-:-:1 @P3 LDG.E.CI.U16 load1B0, [track1B + 2x<0>]; 280 | --:-:-:-:1 @P4 LDG.E.CI.U16 load1B1, [track1B + 2x<1>]; 281 | --:-:-:-:1 @P5 LDG.E.CI.U16 load1B2, [track1B + 2x<2>]; 282 | --:-:4:-:1 @P6 LDG.E.CI.U16 load1B3, [track1B + 2x<3>]; 283 | 284 | --:-:-:-:1 @!P3 MOV load1B0, RZ; 285 | --:-:-:-:1 @!P4 MOV load1B1, RZ; 286 | --:-:-:-:1 @!P5 MOV load1B2, RZ; 287 | --:-:-:-:1 @!P6 MOV load1B3, RZ; 288 | 289 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidBY16, partialK, PT; 290 | --:-:-:-:1 @P2 R2P PR, preds, 0x78; 291 | --:-:-:-:1 @!P2 R2P PR, RZ, 0x78; 292 | 293 | --:-:-:-:1 @P3 LDG.E.CI.U16 load2B0, [track2B + 2x<0>]; 294 | --:-:-:-:1 @P4 LDG.E.CI.U16 load2B1, [track2B + 2x<1>]; 295 | --:-:-:-:1 @P5 LDG.E.CI.U16 load2B2, [track2B + 2x<2>]; 296 | --:-:5:-:1 @P6 LDG.E.CI.U16 load2B3, [track2B + 2x<3>]; 297 | 298 | --:-:-:-:1 @!P3 MOV load2B0, RZ; 299 | --:-:-:-:1 @!P4 MOV load2B1, RZ; 300 | --:-:-:-:1 @!P5 MOV load2B2, RZ; 301 | --:-:-:-:1 @!P6 MOV load2B3, RZ; 302 | 303 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidBY24, partialK, PT; 304 | --:-:-:-:1 @P0 R2P PR, preds, 0x78; 305 | --:-:-:-:1 @!P0 R2P PR, RZ, 0x78; 306 | 307 | --:-:-:-:1 @P3 LDG.E.CI.U16 load3B0, [track3B + 2x<0>]; 308 | --:-:-:-:1 @P4 LDG.E.CI.U16 load3B1, [track3B + 2x<1>]; 309 | --:-:-:-:1 @P5 LDG.E.CI.U16 load3B2, [track3B + 2x<2>]; 310 | --:-:6:-:1 @P6 LDG.E.CI.U16 load3B3, [track3B + 2x<3>]; 311 | 312 | --:-:-:-:1 @!P3 MOV load3B0, RZ; 313 | --:-:-:-:1 @!P4 MOV load3B1, RZ; 314 | --:-:-:-:1 @!P5 MOV load3B2, RZ; 315 | --:-:-:-:1 @!P6 MOV load3B3, RZ; 316 | 317 | }; 318 | +] 319 | // partialB = partialK * cdb 320 | --:-:-:-:1 XMAD.LO2 partialB, cdb, partialK, RZ; 321 | 322 | --:-:-:-:1 ISETP.GE.AND P1, PT, k, 32, PT; 323 | --:-:-:-:1 IADD k, k, -32; 324 | --:-:-:-:1 @P1 R2P PR, preds, 0x7c; 325 | --:-:-:-:1 @!P1 R2P PR, RZ, 0x7c; 326 | 327 | 328 | [+ 329 | return vector() ? q{ 330 | 03:-:-:-:1 F2F.F32.F16 load0A7, load0A3.H1; 331 | --:-:-:-:1 F2F.F32.F16 load0A6, load0A3.H0; 332 | --:-:-:-:1 F2F.F32.F16 load0A5, load0A2.H1; 333 | --:-:1:-:1 F2F.F32.F16 load0A4, load0A2.H0; 334 | --:-:-:-:1 F2F.F32.F16 load0A3, load0A1.H1; 335 | --:-:-:-:1 F2F.F32.F16 load0A2, load0A1.H0; 336 | --:-:-:-:1 F2F.F32.F16 load0A1, load0A0.H1; 337 | --:-:2:-:1 F2F.F32.F16 load0A0, load0A0.H0; 338 | } : q{ 339 | 02:-:-:-:1 F2F.F32.F16 load0A7, load0A7; 340 | --:-:-:-:1 F2F.F32.F16 load0A6, load0A6; 341 | --:-:-:-:1 F2F.F32.F16 load0A5, load0A5; 342 | --:-:1:-:1 F2F.F32.F16 load0A4, load0A4; 343 | --:-:-:-:1 F2F.F32.F16 load0A3, load0A3; 344 | --:-:-:-:1 F2F.F32.F16 load0A2, load0A2; 345 | --:-:-:-:1 F2F.F32.F16 load0A1, load0A1; 346 | --:-:2:-:1 F2F.F32.F16 load0A0, load0A0; 347 | }; 348 | +] 349 | --:-:-:-:0 LEA track0A0.CC, partialK, track0A0, 1; 350 | 01:-:-:-:1 STS [writeAs + 4x<7*32>], load0A7; 351 | --:-:-:-:1 STS [writeAs + 4x<6*32>], load0A6; 352 | --:-:-:-:1 STS [writeAs + 4x<5*32>], load0A5; 353 | --:-:-:-:1 STS [writeAs + 4x<4*32>], load0A4; 354 | 02:-:-:-:1 STS [writeAs + 4x<3*32>], load0A3; 355 | --:-:-:-:1 STS [writeAs + 4x<2*32>], load0A2; 356 | --:-:-:-:1 STS [writeAs + 4x<1*32>], load0A1; 357 | --:-:-:-:1 STS [writeAs + 4x<0*32>], load0A0; 358 | --:-:-:-:0 IADD.X track0A1, track0A1, RZ; 359 | 360 | [+ 361 | return vector() ? q{ 362 | 04:-:-:-:1 F2F.F32.F16 load0B3, load0B1.H1; 363 | --:-:-:-:1 F2F.F32.F16 load0B2, load0B1.H0; 364 | --:-:-:-:1 F2F.F32.F16 load0B1, load0B0.H1; 365 | --:-:3:-:1 F2F.F32.F16 load0B0, load0B0.H0; 366 | 367 | 08:-:-:-:1 F2F.F32.F16 load1B3, load1B1.H1; 368 | --:-:-:-:1 F2F.F32.F16 load1B2, load1B1.H0; 369 | --:-:-:-:1 F2F.F32.F16 load1B1, load1B0.H1; 370 | --:-:4:-:1 F2F.F32.F16 load1B0, load1B0.H0; 371 | 372 | 10:-:-:-:1 F2F.F32.F16 load2B3, load2B1.H1; 373 | --:-:-:-:1 F2F.F32.F16 load2B2, load2B1.H0; 374 | --:-:-:-:1 F2F.F32.F16 load2B1, load2B0.H1; 375 | --:-:5:-:1 F2F.F32.F16 load2B0, load2B0.H0; 376 | 377 | 20:-:-:-:1 F2F.F32.F16 load3B3, load3B1.H1; 378 | --:-:-:-:1 F2F.F32.F16 load3B2, load3B1.H0; 379 | --:-:-:-:1 F2F.F32.F16 load3B1, load3B0.H1; 380 | --:-:6:-:1 F2F.F32.F16 load3B0, load3B0.H0; 381 | } : q{ 382 | 04:-:-:-:1 F2F.F32.F16 load0B0, load0B0; 383 | --:-:-:-:1 F2F.F32.F16 load0B1, load0B1; 384 | --:-:-:-:1 F2F.F32.F16 load0B2, load0B2; 385 | --:-:3:-:1 F2F.F32.F16 load0B3, load0B3; 386 | 387 | 08:-:-:-:1 F2F.F32.F16 load1B0, load1B0; 388 | --:-:-:-:1 F2F.F32.F16 load1B1, load1B1; 389 | --:-:-:-:1 F2F.F32.F16 load1B2, load1B2; 390 | --:-:4:-:1 F2F.F32.F16 load1B3, load1B3; 391 | 392 | 10:-:-:-:1 F2F.F32.F16 load2B0, load2B0; 393 | --:-:-:-:1 F2F.F32.F16 load2B1, load2B1; 394 | --:-:-:-:1 F2F.F32.F16 load2B2, load2B2; 395 | --:-:5:-:1 F2F.F32.F16 load2B3, load2B3; 396 | 397 | 20:-:-:-:1 F2F.F32.F16 load3B0, load3B0; 398 | --:-:-:-:1 F2F.F32.F16 load3B1, load3B1; 399 | --:-:-:-:1 F2F.F32.F16 load3B2, load3B2; 400 | --:-:6:-:1 F2F.F32.F16 load3B3, load3B3; 401 | }; 402 | +] 403 | 404 | --:-:-:-:0 LEA track0B0.CC, partialB, track0B0, 1; 405 | 04:-:-:-:6 STS.128 [writeBs + 4x<0*64>], load0B; 406 | --:-:-:-:1 IADD.X track0B1, track0B1, RZ; 407 | 408 | --:-:-:-:0 LEA track1B0.CC, partialB, track1B0, 1; 409 | 08:-:-:-:6 STS.128 [writeBs + 4x<8*64>], load1B; 410 | --:-:-:-:1 IADD.X track1B1, track1B1, RZ; 411 | 412 | --:-:-:-:0 LEA track2B0.CC, partialB, track2B0, 1; 413 | 10:-:-:-:6 STS.128 [writeBs + 4x<16*64>], load2B; 414 | --:-:-:-:1 IADD.X track2B1, track2B1, RZ; 415 | 416 | --:-:-:-:0 LEA track3B0.CC, partialB, track3B0, 1; 417 | 20:-:-:-:6 STS.128 [writeBs + 4x<24*64>], load3B; 418 | --:-:-:-:0 IADD.X track3B1, track3B1, RZ; 419 | 420 | --:-:-:-:5 BAR.SYNC 0; 421 | --:-:-:-:1 IADD writeBs, writeBs, swapBuf; 422 | --:-:-:-:1 IADD writeAs, writeAs, swapBuf; 423 | --:-:-:-:0 IADD swapBuf, RZ, -swapBuf; 424 | 425 | --:-:-:-:1 LDS.U.128 j0Ay0, [readAs + 4x<0*32 + 00>]; 426 | --:-:-:-:1 LDS.U.128 j0Bx0, [readBs + 4x<0*64 + 00>]; 427 | --:-:-:-:1 LDS.U.128 j0Ay4, [readAs + 4x<0*32 + 16>]; 428 | --:-:1:-:1 LDS.U.128 j0Bx4, [readBs + 4x<0*64 + 32>]; 429 | 430 | [+ 431 | return vector() ? q{ 432 | --:-:2:-:1 @P2 LDG.E.CI.128 load0A, [track0A]; 433 | --:-:3:-:1 @P3 LDG.E.CI.64 load0B, [track0B]; 434 | --:-:4:-:1 @P3 LDG.E.CI.64 load1B, [track1B]; 435 | --:-:5:-:1 @P3 LDG.E.CI.64 load2B, [track2B]; 436 | --:-:6:-:1 @P3 LDG.E.CI.64 load3B, [track3B]; 437 | } : q{ 438 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A0, [track0A + 2x<0>]; 439 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A1, [track0A + 2x<1>]; 440 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A2, [track0A + 2x<2>]; 441 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A3, [track0A + 2x<3>]; 442 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A4, [track0A + 2x<4>]; 443 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A5, [track0A + 2x<5>]; 444 | --:-:-:-:1 @P2 LDG.E.CI.U16 load0A6, [track0A + 2x<6>]; 445 | --:-:2:-:1 @P2 LDG.E.CI.U16 load0A7, [track0A + 2x<7>]; 446 | 447 | --:-:-:-:1 @P3 LDG.E.CI.U16 load0B0, [track0B + 2x<0>]; 448 | --:-:-:-:1 @P4 LDG.E.CI.U16 load0B1, [track0B + 2x<1>]; 449 | --:-:-:-:1 @P5 LDG.E.CI.U16 load0B2, [track0B + 2x<2>]; 450 | --:-:3:-:1 @P6 LDG.E.CI.U16 load0B3, [track0B + 2x<3>]; 451 | 452 | --:-:-:-:1 @P3 LDG.E.CI.U16 load1B0, [track1B + 2x<0>]; 453 | --:-:-:-:1 @P4 LDG.E.CI.U16 load1B1, [track1B + 2x<1>]; 454 | --:-:-:-:1 @P5 LDG.E.CI.U16 load1B2, [track1B + 2x<2>]; 455 | --:-:4:-:1 @P6 LDG.E.CI.U16 load1B3, [track1B + 2x<3>]; 456 | 457 | --:-:-:-:1 @P3 LDG.E.CI.U16 load2B0, [track2B + 2x<0>]; 458 | --:-:-:-:1 @P4 LDG.E.CI.U16 load2B1, [track2B + 2x<1>]; 459 | --:-:-:-:1 @P5 LDG.E.CI.U16 load2B2, [track2B + 2x<2>]; 460 | --:-:5:-:1 @P6 LDG.E.CI.U16 load2B3, [track2B + 2x<3>]; 461 | 462 | --:-:-:-:1 @P3 LDG.E.CI.U16 load3B0, [track3B + 2x<0>]; 463 | --:-:-:-:1 @P4 LDG.E.CI.U16 load3B1, [track3B + 2x<1>]; 464 | --:-:-:-:1 @P5 LDG.E.CI.U16 load3B2, [track3B + 2x<2>]; 465 | --:-:6:-:1 @P6 LDG.E.CI.U16 load3B3, [track3B + 2x<3>]; 466 | }; 467 | +] 468 | 469 | LOOP: 470 | 471 | [+ 472 | our %insert = 473 | ( 474 | j0c8 => "--:-:-:-:1 ISETP.GE.AND P0, PT, k, RZ, PT;\n", 475 | j0c10 => "--:-:-:-:1 ISETP.GE.AND P1, PT, k, 32, PT;\n" . 476 | "--:-:-:-:1 IADD k, k, -32;\n", 477 | 478 | j0c23 => "--:-:-:-:1 \@P1 R2P PR, preds, 0x7c;\n", 479 | j0c24 => "--:-:-:-:1 \@!P1 R2P PR, RZ, 0x7c;\n", 480 | 481 | j2c32 => "--:-:-:-:1 \@P2 IADD track0A0.CC, track0A0, 2x<32>;\n", 482 | j2c37 => "--:-:-:-:1 \@P2 IADD.X track0A1, track0A1, RZ;\n", 483 | j3c32 => "--:-:-:-:1 \@P3 IADD track0B0.CC, track0B0, cdb32;\n", 484 | j3c37 => "--:-:-:-:1 \@P3 IADD.X track0B1, track0B1, RZ;\n", 485 | j4c32 => "--:-:-:-:1 \@P3 IADD track1B0.CC, track1B0, cdb32;\n", 486 | j4c37 => "--:-:-:-:1 \@P3 IADD.X track1B1, track1B1, RZ;\n", 487 | j5c32 => "--:-:-:-:1 \@P3 IADD track2B0.CC, track2B0, cdb32;\n", 488 | j5c37 => "--:-:-:-:1 \@P3 IADD.X track2B1, track2B1, RZ;\n", 489 | j6c32 => "--:-:-:-:1 \@P3 IADD track3B0.CC, track3B0, cdb32;\n", 490 | j6c37 => "--:-:-:-:1 \@P3 IADD.X track3B1, track3B1, RZ;\n", 491 | 492 | j6c63 => "--:-:-:-:5 BAR.SYNC 0;\n" . 493 | "--:-:-:-:1 \@P0 IADD readAs, readAs, -swapBuf;\n" . 494 | "--:-:-:-:1 \@P0 IADD readBs, readBs, -swapBuf;\n" . 495 | "--:-:-:-:1 \@P0 IADD writeAs, writeAs, swapBuf;\n" . 496 | "--:-:-:-:1 \@P0 IADD writeBs, writeBs, swapBuf;\n" . 497 | "--:-:-:-:1 \@P0 IADD swapBuf, RZ, -swapBuf;\n", 498 | 499 | j2c16 => "02:-:-:-:1 \@P0 STS [writeAs + 4x<7*32>], load0A7;\n", 500 | j2c18 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<6*32>], load0A6;\n", 501 | j2c20 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<5*32>], load0A5;\n", 502 | j2c22 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<4*32>], load0A4;\n", 503 | j2c24 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<3*32>], load0A3;\n", 504 | j2c26 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<2*32>], load0A2;\n", 505 | j2c28 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<1*32>], load0A1;\n", 506 | j2c30 => "--:2:-:-:1 \@P0 STS [writeAs + 4x<0*32>], load0A0;\n", 507 | 508 | j3c16 => "04:3:-:-:1 \@P0 STS.128 [writeBs + 4x< 0*64>], load0B;\n", 509 | j4c16 => "08:4:-:-:1 \@P0 STS.128 [writeBs + 4x< 8*64>], load1B;\n", 510 | j5c16 => "10:5:-:-:1 \@P0 STS.128 [writeBs + 4x<16*64>], load2B;\n", 511 | j6c16 => "20:6:-:-:1 \@P0 STS.128 [writeBs + 4x<24*64>], load3B;\n", 512 | 513 | (vector() ? 514 | ( 515 | j1c35 => "02:-:-:-:1 \@P0 F2F.F32.F16 load0A7, load0A3.H1;\n", 516 | j1c39 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A6, load0A3.H0;\n", 517 | j1c43 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A5, load0A2.H1;\n", 518 | j1c47 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A4, load0A2.H0;\n", 519 | j1c51 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A3, load0A1.H1;\n", 520 | j1c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A2, load0A1.H0;\n", 521 | j1c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A1, load0A0.H1;\n", 522 | j1c63 => "--:-:2:-:1 \@P0 F2F.F32.F16 load0A0, load0A0.H0;\n", 523 | 524 | j2c51 => "04:-:-:-:1 \@P0 F2F.F32.F16 load0B3, load0B1.H1;\n", 525 | j2c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B2, load0B1.H0;\n", 526 | j2c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B1, load0B0.H1;\n", 527 | j2c63 => "--:-:3:-:1 \@P0 F2F.F32.F16 load0B0, load0B0.H0;\n", 528 | 529 | j3c51 => "08:-:-:-:1 \@P0 F2F.F32.F16 load1B3, load1B1.H1;\n", 530 | j3c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B2, load1B1.H0;\n", 531 | j3c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B1, load1B0.H1;\n", 532 | j3c63 => "--:-:4:-:1 \@P0 F2F.F32.F16 load1B0, load1B0.H0;\n", 533 | 534 | j4c51 => "10:-:-:-:1 \@P0 F2F.F32.F16 load2B3, load2B1.H1;\n", 535 | j4c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B2, load2B1.H0;\n", 536 | j4c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B1, load2B0.H1;\n", 537 | j4c63 => "--:-:5:-:1 \@P0 F2F.F32.F16 load2B0, load2B0.H0;\n", 538 | 539 | j5c51 => "20:-:-:-:1 \@P0 F2F.F32.F16 load3B3, load3B1.H1;\n", 540 | j5c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B2, load3B1.H0;\n", 541 | j5c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B1, load3B0.H1;\n", 542 | j5c63 => "--:-:6:-:1 \@P0 F2F.F32.F16 load3B0, load3B0.H0;\n", 543 | 544 | j2c61 => "02:-:2:-:1 \@P2 LDG.E.CI.128 load0A, [track0A];\n", 545 | j3c61 => "04:-:3:-:1 \@P3 LDG.E.CI.64 load0B, [track0B];\n", 546 | j4c61 => "08:-:4:-:1 \@P3 LDG.E.CI.64 load1B, [track1B];\n", 547 | j5c61 => "10:-:5:-:1 \@P3 LDG.E.CI.64 load2B, [track2B];\n", 548 | j6c61 => "20:-:6:-:1 \@P3 LDG.E.CI.64 load3B, [track3B];\n", 549 | ) : 550 | ( 551 | j1c35 => "02:-:-:-:1 \@P0 F2F.F32.F16 load0A0, load0A0;\n", 552 | j1c39 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A1, load0A1;\n", 553 | j1c43 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A2, load0A2;\n", 554 | j1c47 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A3, load0A3;\n", 555 | j1c51 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A4, load0A4;\n", 556 | j1c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A5, load0A5;\n", 557 | j1c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0A6, load0A6;\n", 558 | j1c63 => "--:2:-:-:1 \@P0 F2F.F32.F16 load0A7, load0A7;\n", 559 | 560 | j2c51 => "04:-:-:-:1 \@P0 F2F.F32.F16 load0B0, load0B0;\n", 561 | j2c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B1, load0B1;\n", 562 | j2c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load0B2, load0B2;\n", 563 | j2c63 => "--:-:3:-:1 \@P0 F2F.F32.F16 load0B3, load0B3;\n", 564 | 565 | j3c51 => "08:-:-:-:1 \@P0 F2F.F32.F16 load1B0, load1B0;\n", 566 | j3c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B1, load1B1;\n", 567 | j3c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load1B2, load1B2;\n", 568 | j3c63 => "--:-:4:-:1 \@P0 F2F.F32.F16 load1B3, load1B3;\n", 569 | 570 | j4c51 => "10:-:-:-:1 \@P0 F2F.F32.F16 load2B0, load2B0;\n", 571 | j4c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B1, load2B1;\n", 572 | j4c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load2B2, load2B2;\n", 573 | j4c63 => "--:-:5:-:1 \@P0 F2F.F32.F16 load2B3, load2B3;\n", 574 | 575 | j5c51 => "20:-:-:-:1 \@P0 F2F.F32.F16 load3B0, load3B0;\n", 576 | j5c55 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B1, load3B1;\n", 577 | j5c59 => "--:-:-:-:1 \@P0 F2F.F32.F16 load3B2, load3B2;\n", 578 | j5c63 => "--:-:6:-:1 \@P0 F2F.F32.F16 load3B3, load3B3;\n", 579 | 580 | j2c48 => "02:-:-:-:1 \@P2 LDG.E.CI.U16 load0A0, [track0A + 2x<0>];\n", 581 | j2c50 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A1, [track0A + 2x<1>];\n", 582 | j2c52 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A2, [track0A + 2x<2>];\n", 583 | j2c54 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A3, [track0A + 2x<3>];\n", 584 | j2c56 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A4, [track0A + 2x<4>];\n", 585 | j2c58 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A5, [track0A + 2x<5>];\n", 586 | j2c60 => "--:-:-:-:1 \@P2 LDG.E.CI.U16 load0A6, [track0A + 2x<6>];\n", 587 | j2c62 => "--:-:2:-:1 \@P2 LDG.E.CI.U16 load0A7, [track0A + 2x<7>];\n", 588 | 589 | j3c56 => "04:-:-:-:1 \@P3 LDG.E.CI.U16 load0B0, [track0B + 2x<0>];\n", 590 | j3c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load0B1, [track0B + 2x<1>];\n", 591 | j3c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load0B2, [track0B + 2x<2>];\n", 592 | j3c62 => "--:-:3:-:1 \@P6 LDG.E.CI.U16 load0B3, [track0B + 2x<3>];\n", 593 | 594 | j4c56 => "08:-:-:-:1 \@P3 LDG.E.CI.U16 load1B0, [track1B + 2x<0>];\n", 595 | j4c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load1B1, [track1B + 2x<1>];\n", 596 | j4c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load1B2, [track1B + 2x<2>];\n", 597 | j4c62 => "--:-:4:-:1 \@P6 LDG.E.CI.U16 load1B3, [track1B + 2x<3>];\n", 598 | 599 | j5c56 => "10:-:-:-:1 \@P3 LDG.E.CI.U16 load2B0, [track2B + 2x<0>];\n", 600 | j5c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load2B1, [track2B + 2x<1>];\n", 601 | j5c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load2B2, [track2B + 2x<2>];\n", 602 | j5c62 => "--:-:5:-:1 \@P6 LDG.E.CI.U16 load2B3, [track2B + 2x<3>];\n", 603 | 604 | j6c56 => "20:-:-:-:1 \@P3 LDG.E.CI.U16 load3B0, [track3B + 2x<0>];\n", 605 | j6c58 => "--:-:-:-:1 \@P4 LDG.E.CI.U16 load3B1, [track3B + 2x<1>];\n", 606 | j6c60 => "--:-:-:-:1 \@P5 LDG.E.CI.U16 load3B2, [track3B + 2x<2>];\n", 607 | j6c62 => "--:-:6:-:1 \@P6 LDG.E.CI.U16 load3B3, [track3B + 2x<3>];\n", 608 | ) 609 | ), 610 | j7c63 => "--:-:-:Y:5 \@P0 BRA.U LOOP;\n", 611 | ); 612 | my @cOrder; 613 | my @swirl = ([0,2],[1,2],[1,0],[0,0]); 614 | my @y = (0,1,4,5); 615 | foreach my $x (0,2,4,6) 616 | { 617 | foreach my $y (@y) 618 | { 619 | push @cOrder, [$x + $_->[0], $y + $_->[1]] foreach @swirl; 620 | } 621 | @y = reverse @y; 622 | } 623 | my $out = ''; 624 | foreach my $j (0 .. 7) 625 | { 626 | my $odd = $j & 1; 627 | my $nOdd = !$odd + 0; 628 | my $rsOffset = ($j + 1) % 8; 629 | my $rsPred = $j == 7 ? '@P0' : ' '; 630 | 631 | $insert{"j${j}c0"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy0, [readAs + 4x<%d*32 + 00>];\n", $rsPred, $nOdd, $rsOffset; 632 | $insert{"j${j}c2"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dBx0, [readBs + 4x<%d*64 + 00>];\n", $rsPred, $nOdd, $rsOffset; 633 | $insert{"j${j}c4"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy4, [readAs + 4x<%d*32 + 16>];\n", $rsPred, $nOdd, $rsOffset; 634 | $insert{"j${j}c6"} = sprintf "--:-:1:-:1 %s LDS.U.128 j%dBx4, [readBs + 4x<%d*64 + 32>];\n", $rsPred, $nOdd, $rsOffset; 635 | 636 | foreach my $c (0 .. 63) 637 | { 638 | my ($x,$y) = @{$cOrder[$c]}; 639 | 640 | my $ins = $insert{"j${j}c$c"} || ''; 641 | 642 | my $stall = $ins =~ /LDS|I2I|I2F|F2I|F2F|LDG|STS|BAR|BRA/ ? 0 : 1; 643 | 644 | my $yield = $c == 32 && $stall ? 'Y' : '-'; 645 | 646 | my $wait = $c == 0 ? '01' : '--'; 647 | 648 | my $ctrl = "$wait:-:-:$yield:$stall"; 649 | 650 | $out .= sprintf "%s FFMA cx%dy%d, j%dBx%d, j%dAy%d, cx%dy%d;\n%s", $ctrl, $x,$y, $odd,$x, $odd,$y, $x,$y, $ins; 651 | } 652 | } 653 | return $out; 654 | +] 655 | 656 | 657 | --:-:-:-:1 MOV alpha, param_alpha; 658 | --:-:-:-:1 MOV beta, param_beta; 659 | 660 | // readCs = ((tid & 15) * 4 + (tid / 16) * 64) * 4 661 | --:-:-:-:1 LOP.AND tid15, tid, 15; 662 | --:-:-:-:1 SHR.U32 tid16, tid, 4; 663 | --:-:-:-:1 SHL tid15, tid15, 2; 664 | --:-:-:-:1 ISCADD readCs, tid16, tid15, 6; 665 | --:-:-:-:1 SHL readCs, readCs, 2; 666 | 667 | // cx = idx_B*64 + tid15; 668 | --:-:-:-:1 ISCADD cx, idx_B, tid15, 6; 669 | --:-:-:-:1 IADD cx1, cx, 1; 670 | --:-:-:-:1 IADD cx2, cx, 2; 671 | --:-:-:-:1 IADD cx3, cx, 3; 672 | 673 | // cy = idx_A*32 + tid16 674 | --:-:-:-:1 ISCADD cy, idx_A, tid16, 5; 675 | 676 | // C += (cy*cdc + cx) * 2; 677 | --:-:-:-:1 MOV cdc, param_cdc; 678 | --:-:-:-:1 SHL cdc8, cdc, 4; 679 | 680 | --:-:-:-:1 XMAD.LO ci, cy, cdc, cx, xmad_c; 681 | --:-:-:-:1 LEA C0.CC, ci, param_C[0], 1; 682 | --:-:-:-:1 LEA.HI.X C1, ci, param_C[1], RZ, 1; 683 | 684 | // P0 = cx < n 685 | --:-:-:-:1 ISETP.LT.AND P0, PT, cx, param_n, PT; 686 | --:-:-:-:1 ISETP.LT.AND P1, PT, cx1, param_n, PT; 687 | --:-:-:-:1 ISETP.LT.AND P2, PT, cx2, param_n, PT; 688 | --:-:-:-:1 ISETP.LT.AND P3, PT, cx3, param_n, PT; 689 | --:-:-:-:1 P2R preds, PR, RZ, 0x0f; 690 | 691 | // P4 = cy < m 692 | --:-:-:-:1 ISETP.LT.AND P4, PT, cy, param_m, PT; 693 | 694 | // P5 = beta != 0 && P4 695 | --:-:-:-:1 ISETP.NE.AND P5, PT, beta, RZ, P4; 696 | 697 | // Init beta preds 698 | --:-:-:-:1 @P5 R2P PR, preds, 0x0f; 699 | --:-:-:-:1 @!P5 R2P PR, RZ, 0x0f; 700 | 701 | 702 | 703 | --:-:-:-:1 FMUL shuffle_x0y0, cx0y0, alpha; 704 | --:-:-:-:1 FMUL shuffle_x1y0, cx1y0, alpha; 705 | --:-:-:-:1 FMUL shuffle_x2y0, cx2y0, alpha; 706 | --:-:-:-:1 FMUL shuffle_x3y0, cx3y0, alpha; 707 | --:-:-:-:1 FMUL shuffle_x4y0, cx4y0, alpha; 708 | --:-:-:-:1 FMUL shuffle_x5y0, cx5y0, alpha; 709 | --:-:-:-:1 FMUL shuffle_x6y0, cx6y0, alpha; 710 | --:-:-:-:0 FMUL shuffle_x7y0, cx7y0, alpha; 711 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 00>], shuffle_x0y0; 712 | --:-:-:-:1 FMUL shuffle_x0y1, cx0y1, alpha; 713 | --:-:-:-:1 FMUL shuffle_x1y1, cx1y1, alpha; 714 | --:-:-:-:1 FMUL shuffle_x2y1, cx2y1, alpha; 715 | --:-:-:-:0 FMUL shuffle_x3y1, cx3y1, alpha; 716 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 32>], shuffle_x4y0; 717 | --:-:-:-:1 FMUL shuffle_x4y1, cx4y1, alpha; 718 | --:-:-:-:1 FMUL shuffle_x5y1, cx5y1, alpha; 719 | --:-:-:-:1 FMUL shuffle_x6y1, cx6y1, alpha; 720 | --:-:-:-:0 FMUL shuffle_x7y1, cx7y1, alpha; 721 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 00>], shuffle_x0y1; 722 | --:-:-:-:1 FMUL shuffle_x0y2, cx0y2, alpha; 723 | --:-:-:-:1 FMUL shuffle_x1y2, cx1y2, alpha; 724 | --:-:-:-:1 FMUL shuffle_x2y2, cx2y2, alpha; 725 | --:-:-:-:0 FMUL shuffle_x3y2, cx3y2, alpha; 726 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 32>], shuffle_x4y1; 727 | --:-:-:-:1 FMUL shuffle_x4y2, cx4y2, alpha; 728 | --:-:-:-:1 FMUL shuffle_x5y2, cx5y2, alpha; 729 | --:-:-:-:1 FMUL shuffle_x6y2, cx6y2, alpha; 730 | --:-:-:-:0 FMUL shuffle_x7y2, cx7y2, alpha; 731 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 00>], shuffle_x0y2; 732 | --:-:-:-:1 FMUL shuffle_x0y3, cx0y3, alpha; 733 | --:-:-:-:1 FMUL shuffle_x1y3, cx1y3, alpha; 734 | --:-:-:-:1 FMUL shuffle_x2y3, cx2y3, alpha; 735 | --:-:-:-:0 FMUL shuffle_x3y3, cx3y3, alpha; 736 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 32>], shuffle_x4y2; 737 | --:-:-:-:1 FMUL shuffle_x4y3, cx4y3, alpha; 738 | --:-:-:-:1 FMUL shuffle_x5y3, cx5y3, alpha; 739 | --:-:-:-:1 FMUL shuffle_x6y3, cx6y3, alpha; 740 | --:-:-:-:0 FMUL shuffle_x7y3, cx7y3, alpha; 741 | --:-:-:-:4 STS.128 [writeCs+4x<3*64 + 00>], shuffle_x0y3; 742 | --:-:-:-:1 STS.128 [writeCs+4x<3*64 + 32>], shuffle_x4y3; 743 | --:-:-:-:5 BAR.SYNC 0; 744 | 745 | --:-:-:-:5 CAL STORE_C; 746 | --:-:-:-:5 CAL STORE_C; 747 | 748 | --:-:-:-:1 FMUL shuffle_x0y4, cx0y4, alpha; 749 | --:-:-:-:1 FMUL shuffle_x1y4, cx1y4, alpha; 750 | --:-:-:-:1 FMUL shuffle_x2y4, cx2y4, alpha; 751 | --:-:-:-:1 FMUL shuffle_x3y4, cx3y4, alpha; 752 | --:-:-:-:1 FMUL shuffle_x4y4, cx4y4, alpha; 753 | --:-:-:-:1 FMUL shuffle_x5y4, cx5y4, alpha; 754 | --:-:-:-:0 FMUL shuffle_x6y4, cx6y4, alpha; 755 | --:-:-:-:5 BAR.SYNC 0; 756 | --:-:-:-:0 FMUL shuffle_x7y4, cx7y4, alpha; 757 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 00>], shuffle_x0y4; 758 | --:-:-:-:1 FMUL shuffle_x0y5, cx0y5, alpha; 759 | --:-:-:-:1 FMUL shuffle_x1y5, cx1y5, alpha; 760 | --:-:-:-:1 FMUL shuffle_x2y5, cx2y5, alpha; 761 | --:-:-:-:0 FMUL shuffle_x3y5, cx3y5, alpha; 762 | --:-:-:-:1 STS.128 [writeCs+4x<0*64 + 32>], shuffle_x4y4; 763 | --:-:-:-:1 FMUL shuffle_x4y5, cx4y5, alpha; 764 | --:-:-:-:1 FMUL shuffle_x5y5, cx5y5, alpha; 765 | --:-:-:-:1 FMUL shuffle_x6y5, cx6y5, alpha; 766 | --:-:-:-:0 FMUL shuffle_x7y5, cx7y5, alpha; 767 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 00>], shuffle_x0y5; 768 | --:-:-:-:1 FMUL shuffle_x0y6, cx0y6, alpha; 769 | --:-:-:-:1 FMUL shuffle_x1y6, cx1y6, alpha; 770 | --:-:-:-:1 FMUL shuffle_x2y6, cx2y6, alpha; 771 | --:-:-:-:0 FMUL shuffle_x3y6, cx3y6, alpha; 772 | --:-:-:-:1 STS.128 [writeCs+4x<1*64 + 32>], shuffle_x4y5; 773 | --:-:-:-:1 FMUL shuffle_x4y6, cx4y6, alpha; 774 | --:-:-:-:1 FMUL shuffle_x5y6, cx5y6, alpha; 775 | --:-:-:-:1 FMUL shuffle_x6y6, cx6y6, alpha; 776 | --:-:-:-:0 FMUL shuffle_x7y6, cx7y6, alpha; 777 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 00>], shuffle_x0y6; 778 | --:-:-:-:1 FMUL shuffle_x0y7, cx0y7, alpha; 779 | --:-:-:-:1 FMUL shuffle_x1y7, cx1y7, alpha; 780 | --:-:-:-:1 FMUL shuffle_x2y7, cx2y7, alpha; 781 | --:-:-:-:0 FMUL shuffle_x3y7, cx3y7, alpha; 782 | --:-:-:-:1 STS.128 [writeCs+4x<2*64 + 32>], shuffle_x4y6; 783 | --:-:-:-:1 FMUL shuffle_x4y7, cx4y7, alpha; 784 | --:-:-:-:1 FMUL shuffle_x5y7, cx5y7, alpha; 785 | --:-:-:-:1 FMUL shuffle_x6y7, cx6y7, alpha; 786 | --:-:-:-:0 FMUL shuffle_x7y7, cx7y7, alpha; 787 | --:-:-:-:4 STS.128 [writeCs+4x<3*64 + 00>], shuffle_x0y7; 788 | --:-:-:-:1 STS.128 [writeCs+4x<3*64 + 32>], shuffle_x4y7; 789 | --:-:-:-:5 BAR.SYNC 0; 790 | 791 | --:-:-:-:5 CAL STORE_C; 792 | --:-:-:-:5 CAL STORE_C; 793 | 794 | --:-:-:-:5 EXIT; 795 | 796 | STORE_C: 797 | 798 | [+ 799 | return vector() ? q{ 800 | --:-:1:-:1 @P0 LDG.E.64 loadC, [C]; 801 | } : q{ 802 | --:-:-:-:0 @!P0 MOV loadC0, RZ; 803 | --:-:-:-:1 @P0 LDG.E.CI.U16 loadC0, [C + 2x<0>]; 804 | --:-:-:-:0 @!P1 MOV loadC1, RZ; 805 | --:-:-:-:1 @P1 LDG.E.CI.U16 loadC1, [C + 2x<1>]; 806 | --:-:-:-:0 @!P2 MOV loadC2, RZ; 807 | --:-:-:-:1 @P2 LDG.E.CI.U16 loadC2, [C + 2x<2>]; 808 | --:-:-:-:0 @!P3 MOV loadC3, RZ; 809 | --:-:1:-:1 @P3 LDG.E.CI.U16 loadC3, [C + 2x<3>]; 810 | }; 811 | +] 812 | 813 | // Restore output preds 814 | --:-:-:-:1 @P4 R2P PR, preds, 0x0f; 815 | --:-:-:-:1 @!P4 R2P PR, RZ, 0x0f; 816 | 817 | --:-:-:-:1 LDS.U.128 part0C, [readCs + 4x< 0*64>]; 818 | --:-:2:-:1 LDS.U.128 part1C, [readCs + 4x<16*64>]; 819 | --:-:-:-:1 LDS.U.128 part2C, [readCs + 4x<32*64>]; 820 | --:-:3:-:1 LDS.U.128 part3C, [readCs + 4x<48*64>]; 821 | 822 | 823 | 02:-:-:-:1 @P0 FADD part0C0, part0C0, part1C0; 824 | --:-:-:-:1 @P1 FADD part0C1, part0C1, part1C1; 825 | --:-:-:-:1 @P2 FADD part0C2, part0C2, part1C2; 826 | --:-:-:-:1 @P3 FADD part0C3, part0C3, part1C3; 827 | 828 | 04:-:-:-:1 @P0 FADD part2C0, part2C0, part3C0; 829 | --:-:-:-:1 @P1 FADD part2C1, part2C1, part3C1; 830 | --:-:-:-:1 @P2 FADD part2C2, part2C2, part3C2; 831 | --:-:-:-:1 @P3 FADD part2C3, part2C3, part3C3; 832 | 833 | --:-:-:-:1 @P0 FADD c0, part0C0, part2C0; 834 | --:-:-:-:1 @P1 FADD c1, part0C1, part2C1; 835 | --:-:-:-:1 @P2 FADD c2, part0C2, part2C2; 836 | --:-:-:-:1 @P3 FADD c3, part0C3, part2C3; 837 | 838 | 839 | --:-:-:-:0 IADD cy, cy, 8; 840 | 841 | [+ 842 | return vector() ? q{ 843 | 01:-:1:-:1 @P5 F2F.F32.F16 b0, loadC0.H0; 844 | --:-:2:-:1 @P5 F2F.F32.F16 b1, loadC0.H1; 845 | --:-:3:-:1 @P5 F2F.F32.F16 b2, loadC1.H0; 846 | --:-:4:-:1 @P5 F2F.F32.F16 b3, loadC1.H1; 847 | } : q{ 848 | 01:-:1:-:1 @P5 F2F.F32.F16 b0, loadC0; 849 | --:-:2:-:1 @P5 F2F.F32.F16 b1, loadC1; 850 | --:-:3:-:1 @P5 F2F.F32.F16 b2, loadC2; 851 | --:-:4:-:1 @P5 F2F.F32.F16 b3, loadC3; 852 | }; 853 | +] 854 | 855 | 01:-:-:-:1 @P5 FFMA c0, b0, beta, c0; 856 | 02:-:-:-:1 @P5 FFMA c1, b1, beta, c1; 857 | 04:-:-:-:1 @P5 FFMA c2, b2, beta, c2; 858 | 08:-:-:-:1 @P5 FFMA c3, b3, beta, c3; 859 | 860 | --:-:-:-:0 ISETP.LT.AND P5, PT, cy, param_m, P5; 861 | 862 | --:-:1:-:1 @P0 F2F.F16.F32 c0, c0; 863 | --:-:2:-:1 @P1 F2F.F16.F32 c1, c1; 864 | 865 | --:-:-:-:0 ISETP.LT.AND P4, PT, cy, param_m, PT; 866 | 867 | --:-:3:-:1 @P2 F2F.F16.F32 c2, c2; 868 | 869 | --:-:-:-:0 LOP.XOR readCs, readCs, 4x<8*64>; 870 | 871 | --:-:4:-:1 @P3 F2F.F16.F32 c3, c3; 872 | 873 | [+ 874 | return vector() ? q{ 875 | 03:-:-:-:2 @P0 BFI c0, c1, 0x1010, c0; 876 | 0c:-:-:-:2 @P0 BFI c1, c3, 0x1010, c2; 877 | 878 | --:1:-:-:1 @P0 STG.E.CG.64 [C], c; 879 | } : q{ 880 | 01:-:-:-:1 @P0 STG.E.U16 [C + 2x<0>], c0; 881 | 02:-:-:-:1 @P1 STG.E.U16 [C + 2x<1>], c1; 882 | 04:-:-:-:1 @P2 STG.E.U16 [C + 2x<2>], c2; 883 | 08:1:-:-:1 @P3 STG.E.U16 [C + 2x<3>], c3; 884 | }; 885 | +] 886 | 887 | // Restore beta preds 888 | --:-:-:-:1 @P5 R2P PR, preds, 0x0f; 889 | --:-:-:-:1 @!P5 R2P PR, RZ, 0x0f; 890 | 891 | 01:-:-:-:6 IADD C0.CC, C0, cdc8; 892 | --:-:-:-:0 IADD.X C1, C1, RZ; 893 | 894 | --:-:-:-:5 RET; 895 | -------------------------------------------------------------------------------- /sass/xgemm_128x128x8.sass: -------------------------------------------------------------------------------- 1 | # Kernel: xgemm_128x128x8 2 | 3 | [- 4 | our ($type, $A16, $B16, $C16); 5 | sub A16 { return $type eq 'h' || $A16 } 6 | sub B16 { return $type eq 'h' || $B16 } 7 | sub C16 { return $type eq 'h' || $C16 } 8 | 9 | our $dtypeA = A16() ? 'U16' : '32'; 10 | our $dtypeB = B16() ? 'U16' : '32'; 11 | our $dtypeC = C16() ? 'U16' : '32'; 12 | 13 | our $dshiftA = A16() ? '1' : '2'; 14 | our $dshiftB = B16() ? '1' : '2'; 15 | our $dshiftC = C16() ? '1' : '2'; 16 | 17 | our $dsizeA = A16() ? '2' : '4'; 18 | our $dsizeB = B16() ? '2' : '4'; 19 | our $dsizeC = C16() ? '2' : '4'; 20 | 21 | our $vsizeA = A16() ? '64' : '128'; 22 | our $vsizeB = B16() ? '64' : '128'; 23 | our $vsizeC = C16() ? '64' : '128'; 24 | 25 | sub dtypeA { return $dtypeA; } 26 | sub dtypeB { return $dtypeB; } 27 | sub dtypeC { return $dtypeC; } 28 | 29 | sub dsizeA { return $dsizeA; } 30 | sub dsizeB { return $dsizeB; } 31 | sub dsizeC { return $dsizeC; } 32 | 33 | sub dshiftA { return $dshiftA; } 34 | sub dshiftB { return $dshiftB; } 35 | sub dshiftC { return $dshiftC; } 36 | 37 | our ($outerContigA, $outerContigB, $NN, $NT, $TN, $TT); 38 | if ($NN) { $outerContigA = 0; $outerContigB = 1 } 39 | elsif ($NT) { $outerContigA = 0; $outerContigB = 0 } 40 | elsif ($TN) { $outerContigA = 1; $outerContigB = 1 } 41 | elsif ($TT) { $outerContigA = 1; $outerContigB = 0 } 42 | 43 | sub outerContigA { return $outerContigA; } 44 | sub outerContigB { return $outerContigB; } 45 | 46 | our ($vecA, $vecB, $vec); 47 | sub vecA { return $vecA || $vec; } 48 | sub vecB { return $vecB || $vec; } 49 | -] 50 | 51 | 52 | [+ 53 | return outerContigA() && outerContigB() ? q{ 54 | addr_zero : 4x<128*8*4> 55 | szShareA : (128*8) 56 | szShareB : (128*8) 57 | } : q{ 58 | addr_zero : 4x<(128*8 + 32)*4> 59 | szShareA : (128*8 + 32) 60 | szShareB : (128*8 + 32) 61 | }; 62 | +] 63 | 64 | param_C[0] : c[0x0][0x140] 65 | param_C[1] : c[0x0][0x144] 66 | param_A[0] : c[0x0][0x148] 67 | param_A[1] : c[0x0][0x14c] 68 | param_B[0] : c[0x0][0x150] 69 | param_B[1] : c[0x0][0x154] 70 | param_alpha : c[0x0][0x158] 71 | param_beta : c[0x0][0x15c] 72 | param_cda : c[0x0][0x160] 73 | param_cdb : c[0x0][0x164] 74 | param_cdc : c[0x0][0x168] 75 | param_m : c[0x0][0x16c] 76 | param_n : c[0x0][0x170] 77 | param_k : c[0x0][0x174] 78 | param_blk_a : c[0x0][0x178] 79 | param_blk_b : c[0x0][0x17c] 80 | 81 | 82 | 83 | 84 | 3, 2,11,10,19,18,27,26 : cx<0-7>y0 85 | 7, 6,15,14,23,22,31,30 : cx<0-7>y1 86 | 1, 0, 9, 8,17,16,25,24 : cx<0-7>y2 87 | 5, 4,13,12,21,20,29,28 : cx<0-7>y3 88 | 35,34,43,42,51,50,59,58 : cx<0-7>y4 89 | 39,38,47,46,55,54,63,62 : cx<0-7>y5 90 | 33,32,41,40,49,48,57,56 : cx<0-7>y6 91 | 37,36,45,44,53,52,61,60 : cx<0-7>y7 92 | 93 | 0-63 : czero<00-63> 94 | 64-79 : j0Ay<0-7>, j0Bx<0-7> 95 | 80-95 : j1Ay<0-7>, j1Bx<0-7> 96 | 97 | 64-95 ~ idx_ab, cda, cdb, idx_ab_f, idx_a, neg_blk_b, rcp_blk_b, idx_b, tidAX, txa, txa<1-3>, ta, xmad_ta, tidBX, txb, txb<1-3>, tb, xmad_tb, tid1, tid32_2, tid96, tid128 98 | 99 | 96-99 : tidAY, tidAY<1-3> 100 | 104-107 : tidBY, tidBY<1-3> 101 | 100-103 ~ predsA, partialK, partialA, partialB 102 | 108-111 ~ predsB 103 | 104 | 96-99 : loadA<0-3> 105 | 100-103 : loadA<4-7> 106 | 104-107 : loadB<0-3> 107 | 108-111 : loadB<4-7> 108 | 109 | 112-115 : trackA<0-1>, trackB<0-1> 110 | 111 | 116-123 ~ k, swapBuf, readAs, readBs, writeAs, writeBs, cda8, cdb8 112 | 124-127 ~ tid, idx_A, idx_B, writeCs 113 | 114 | 64-71 : track00C<0-1>, track04C<0-1>, track08C<0-1>, track12C<0-1> 115 | 72-79 : c<0-7> 116 | 80-87 ~ c<00|04|08|12>_00, c<00|04|08|12>_32 117 | 88-95 ~ b<00|04|08|12>_00, b<00|04|08|12>_32 118 | 96-123 ~ readCs, cx<00|32>, cy<00|04|08|12>, cdc, cdc1, cdc4, cdc13, tc, xmad_tc, tid_31, tid_32, tid_96, tid_128, alpha, beta 119 | 120 | 121 | 122 | --:-:5:-:1 I2F.F32.S32 rcp_blk_b, param_blk_b; 123 | --:-:1:-:1 S2R tid, SR_TID.X; 124 | --:-:2:-:1 S2R idx_ab, SR_CTAID.X; 125 | --:-:3:-:1 S2R idx_A, SR_CTAID.Z; 126 | --:-:4:-:1 S2R idx_B, SR_CTAID.Y; 127 | 128 | 129 | 10:-:5:-:1 MUFU.RCP rcp_blk_b, rcp_blk_b; 130 | 131 | --:-:-:-:1 MOV k, param_k; 132 | --:-:-:-:1 MOV cda, param_cda; 133 | --:-:-:-:1 MOV cdb, param_cdb; 134 | 135 | // If k is not a multiple of 8 we want to grab the partial amount on the first fetch. 136 | // If it is a multiple of 8 then make a full 8 line fetch. 137 | --:-:-:-:1 LOP.AND.Z P0, partialK, k, 7; 138 | --:-:-:-:1 @P0 MOV partialK, 8; 139 | --:-:-:-:1 IADD k, k, -partialK; 140 | 141 | # idx_a = idx_ab // blk_b 142 | 02:-:2:-:1 I2F.F32.S32 idx_ab_f, idx_ab; 143 | 12:-:-:-:1 FMUL idx_a, idx_ab_f, rcp_blk_b; 144 | --:-:-:-:1 FFMA idx_a, idx_a, 5.9604644775390625e-08, idx_a; 145 | --:-:2:-:1 F2I.S32.F32.TRUNC idx_a, idx_a; 146 | # idx_b = idx_AB % blk_b 147 | --:-:-:-:1 IADD neg_blk_b, RZ, -param_blk_b; 148 | 02:-:-:-:1 XMAD.S16.U16 idx_b, neg_blk_b, idx_a, idx_ab; 149 | 150 | # idx_A = idx_A * blk_a + idx_a 151 | # idx_B = idx_B * blk_b + idx_b 152 | 06:-:-:-:1 XMAD.U16.U16 idx_A, idx_A, param_blk_a, idx_a; 153 | 08:-:-:-:1 XMAD.U16.U16 idx_B, idx_B, param_blk_b, idx_b; 154 | 155 | --:-:-:-:1 STS.128 [addr_zero], RZ; 156 | [+ join '', map sprintf("--:-:-:-:1 LDS.U.128 czero%02d, [addr_zero];\n", $_ * 4), 0..15; +] 157 | 158 | [+ 159 | our $dshiftA; 160 | my $predsA = vecA() ? q{ 161 | --:-:-:-:1 ISETP.LT.AND P5, PT, txa, param_m, PT; 162 | 163 | } : q{ 164 | --:-:-:-:1 IADD txa1, txa, 1; 165 | --:-:-:-:1 IADD txa2, txa, 2; 166 | --:-:-:-:1 IADD txa3, txa, 3; 167 | --:-:-:-:1 ISETP.LT.AND P2, PT, txa, param_m, PT; 168 | --:-:-:-:1 ISETP.LT.AND P3, PT, txa1, param_m, PT; 169 | --:-:-:-:1 ISETP.LT.AND P4, PT, txa2, param_m, PT; 170 | --:-:-:-:1 ISETP.LT.AND P5, PT, txa3, param_m, PT; 171 | --:-:-:-:1 P2R predsA, PR, RZ, 0x3c; 172 | }; 173 | return outerContigA() ? qq{ 174 | // tidAX = (tid & 31) << 2 175 | // tidAY = tid >> 5 176 | 01:-:-:-:1 LOP.AND tidAX, tid, 31; 177 | --:-:-:-:1 SHL tidAX, tidAX, 2; 178 | 01:-:-:-:1 SHR.U32 tidAY, tid, 5; 179 | 180 | 181 | // trackA += (idx_A*128 + tidAX + cda*tidAY) * dsize 182 | --:-:-:-:1 ISCADD txa, idx_A, tidAX, 7; 183 | --:-:-:-:1 XMAD.LO2 ta, cda, tidAY, txa; 184 | --:-:-:-:1 SHL cda8, cda, 1x<$dshiftA + 3>;$predsA 185 | 186 | // writeAs = (tidAY*128 + tidAX) * 4 187 | --:-:-:-:1 ISCADD writeAs, tidAY, tidAX, 7; 188 | --:-:-:-:1 SHL writeAs, writeAs, 2; 189 | 190 | // partialA = partialK * cda 191 | --:-:-:-:1 XMAD.LO2 partialA, cda, partialK, RZ; 192 | 193 | } : q{ 194 | // tidAX = tid >> 1 195 | // tidAY = (tid & 1) << 2 196 | 01:-:-:-:1 SHR.U32 tidAX, tid, 1; 197 | 01:-:-:-:1 LOP.AND tidAY, tid, 1; 198 | --:-:-:-:1 SHL tidAY, tidAY, 2; 199 | 200 | // trackA += ((idx_A*128 + tidAX) * cda + tidAY) * dsize 201 | --:-:-:-:1 ISCADD txa, idx_A, tidAX, 7; 202 | --:-:-:-:1 XMAD.LO ta, cda, txa, tidAY, xmad_ta; 203 | 204 | --:-:-:-:1 ISETP.LT.AND P5, PT, txa, param_m, PT; 205 | 206 | // The extra shiftAX here is to avoid bank conflicts on write 207 | // shiftAX = tidAY * 4 208 | // writeAs = (tidAY*128 + tidAX + shiftAX) * 4 209 | --:-:-:-:1 ISCADD writeAs, tidAY, tidAX, 7; 210 | --:-:-:-:1 ISCADD writeAs, tidAY, writeAs, 2; 211 | --:-:-:-:1 SHL writeAs, writeAs, 2; 212 | 213 | }; 214 | +] 215 | --:-:-:-:1 LEA trackA0.CC, ta, param_A[0], [+ dshiftA() +]; 216 | --:-:-:-:1 LEA.HI.X trackA1, ta, param_A[1], RZ, [+ dshiftA() +]; 217 | 218 | 219 | 220 | [+ 221 | our $dshiftB; 222 | my $predsB = vecB() ? q{ 223 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb, param_n, PT; 224 | 225 | } : q{ 226 | --:-:-:-:1 IADD txb1, txb, 1; 227 | --:-:-:-:1 IADD txb2, txb, 2; 228 | --:-:-:-:1 IADD txb3, txb, 3; 229 | --:-:-:-:1 ISETP.LT.AND P2, PT, txb, param_n, PT; 230 | --:-:-:-:1 ISETP.LT.AND P3, PT, txb1, param_n, PT; 231 | --:-:-:-:1 ISETP.LT.AND P4, PT, txb2, param_n, PT; 232 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb3, param_n, PT; 233 | --:-:-:-:1 P2R predsB, PR, RZ, 0x5c; 234 | }; 235 | return outerContigB() ? qq{ 236 | // tidBX = (tid & 31) << 2 237 | // tidBY = tid >> 5 238 | 01:-:-:-:1 LOP.AND tidBX, tid, 31; 239 | --:-:-:-:1 SHL tidBX, tidBX, 2; 240 | 01:-:-:-:1 SHR.U32 tidBY, tid, 5; 241 | 242 | 243 | // trackB += (idx_B*128 + tidBX + cdb*tidBY) * dsize 244 | --:-:-:-:1 ISCADD txb, idx_B, tidBX, 7; 245 | --:-:-:-:1 XMAD.LO2 tb, cdb, tidBY, txb; 246 | --:-:-:-:1 SHL cdb8, cdb, 1x<$dshiftB + 3>;$predsB 247 | 248 | 249 | // writeBs = (tidBY*128 + tidBX) * 4 250 | --:-:-:-:1 ISCADD writeBs, tidBY, tidBX, 7; 251 | --:-:-:-:1 ISCADD writeBs, writeBs, 4x, 2; 252 | 253 | // partialB = partialK * cdb 254 | --:-:-:-:1 XMAD.LO2 partialB, cdb, partialK, RZ; 255 | 256 | } : q{ 257 | // tidBX = tid >> 1 258 | // tidBY = (tid & 1) << 2 259 | 01:-:-:-:1 SHR.U32 tidBX, tid, 1; 260 | 01:-:-:-:1 LOP.AND tidBY, tid, 1; 261 | --:-:-:-:1 SHL tidBY, tidBY, 2; 262 | 263 | // trackB += ((idx_B*128 + tidBX) * cdb + tidBY) * dsize 264 | --:-:-:-:1 ISCADD txb, idx_B, tidBX, 7; 265 | --:-:-:-:1 XMAD.LO tb, cdb, txb, tidBY, xmad_tb; 266 | 267 | --:-:-:-:1 ISETP.LT.AND P6, PT, txb, param_n, PT; 268 | 269 | 270 | // The extra shiftBX here is to avoid bank conflicts on write 271 | // shiftBX = tidBY * 4 272 | // writeBs = (tidBY*128 + tidBX + shiftBX) * 4 273 | --:-:-:-:1 ISCADD writeBs, tidBY, tidBX, 7; 274 | --:-:-:-:1 ISCADD writeBs, tidBY, writeBs, 2; 275 | --:-:-:-:1 ISCADD writeBs, writeBs, 4x, 2; 276 | }; 277 | +] 278 | --:-:-:-:1 LEA trackB0.CC, tb, param_B[0], [+ dshiftB() +]; 279 | --:-:-:-:1 LEA.HI.X trackB1, tb, param_B[1], RZ, [+ dshiftB() +]; 280 | 281 | 282 | // readAs = ((tid & 16) >> 3) | (tid & 1) 283 | 01:-:-:-:1 LOP.AND tid1, tid, 1; 284 | 01:-:-:-:1 LOP.AND readAs, tid, 16; 285 | --:-:-:-:1 SHR.U32 readAs, readAs, 3; 286 | --:-:-:-:1 LOP.OR readAs, readAs, tid1; 287 | 288 | // readBs = (tid >> 1) & 7 289 | 01:-:-:-:1 BFE.U32 readBs, tid, 0x301; // 3 bits at position 1 290 | 291 | // writeCs = (readAs*64*8*4 + readAs*16*4 + readBs*4*4 + (tid & -32)*2*4 292 | 01:-:-:-:1 LOP.AND tid32_2, tid, -32; 293 | --:-:-:-:1 SHL tid32_2, tid32_2, 3; 294 | --:-:-:-:1 ISCADD writeCs, readAs, tid32_2, 11; 295 | --:-:-:-:1 ISCADD writeCs, readAs, writeCs, 6; 296 | --:-:-:-:1 ISCADD writeCs, readBs, writeCs, 4; 297 | 298 | // readAs = (readAs + ((tid & 96) >> 2)) * 16 299 | // readAs = readAs*16 + ((tid & 96)*4) 300 | 01:-:-:-:1 LOP.AND tid96, tid, 96; 301 | --:-:-:-:1 SHL tid96, tid96, 2; 302 | --:-:-:-:1 ISCADD readAs, readAs, tid96, 4; 303 | 304 | // readBs = (readBs + ((tid & 128) >> 3)) * 16 + 4x 305 | // readBs = readBs*16 + (tid & 128)*2 + 4x 306 | 01:-:-:-:1 LOP.AND tid128, tid, 128; 307 | --:-:-:-:1 SHL tid128, tid128, 1; 308 | --:-:-:-:1 ISCADD readBs, readBs, tid128, 4; 309 | --:-:-:-:1 IADD readBs, readBs, 4x; 310 | 311 | [+ 312 | return outerContigA() && outerContigB() ? '' : q{ 313 | --:-:-:-:1 MOV32I swapBuf, 4x; 314 | }; 315 | +] 316 | 317 | [+ 318 | our ($dsizeA, $vsizeA, $dtypeA); 319 | return 320 | outerContigA() ? 321 | vecA() ? qq{ 322 | 323 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, P5; 324 | --:-:2:-:1 \@P0 LDG.E.CI.$vsizeA loadA, [trackA]; 325 | --:-:2:-:1 \@!P0 LDS.U.$vsizeA loadA, [addr_zero]; 326 | } : qq{ 327 | 328 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, PT; 329 | --:-:-:-:1 \@P0 R2P PR, predsA, 0x3c; 330 | --:-:-:-:1 \@!P0 R2P PR, RZ, 0x3c; 331 | --:-:2:-:1 \@P2 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>]; 332 | --:-:2:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>]; 333 | --:-:2:-:1 \@P4 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>]; 334 | --:-:2:-:1 \@P5 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>]; 335 | --:-:-:-:1 \@!P2 MOV loadA0, RZ; 336 | --:-:-:-:1 \@!P3 MOV loadA1, RZ; 337 | --:-:-:-:1 \@!P4 MOV loadA2, RZ; 338 | --:-:-:-:1 \@!P5 MOV loadA3, RZ; } 339 | : 340 | # not outerContigA 341 | vecA() ? qq{ 342 | 343 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, P5; 344 | --:-:2:-:1 \@P0 LDG.E.CI.$vsizeA loadA, [trackA]; 345 | --:-:2:-:1 \@!P0 LDS.U.$vsizeA loadA, [addr_zero]; 346 | 347 | } : qq{ 348 | 349 | --:-:-:-:1 IADD tidAY1, tidAY, 1; 350 | --:-:-:-:1 IADD tidAY2, tidAY, 2; 351 | --:-:-:-:1 IADD tidAY3, tidAY, 3; 352 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidAY, partialK, P5; 353 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidAY1, partialK, P5; 354 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidAY2, partialK, P5; 355 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidAY3, partialK, P5; 356 | --:-:2:-:1 \@P0 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>]; 357 | --:-:2:-:1 \@P1 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>]; 358 | --:-:2:-:1 \@P2 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>]; 359 | --:-:2:-:1 \@P3 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>]; 360 | --:-:-:-:1 \@!P0 MOV loadA0, RZ; 361 | --:-:-:-:1 \@!P1 MOV loadA1, RZ; 362 | --:-:-:-:1 \@!P2 MOV loadA2, RZ; 363 | --:-:-:-:1 \@!P3 MOV loadA3, RZ; 364 | }; 365 | +] 366 | 367 | [+ 368 | our ($dsizeB, $vsizeB, $dtypeB); 369 | return 370 | outerContigB() ? 371 | vecB() ? qq{ 372 | 373 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY, partialK, P6; 374 | --:-:4:-:1 \@P1 LDG.E.CI.$vsizeB loadB, [trackB]; 375 | --:-:4:-:1 \@!P1 LDS.U.$vsizeB loadB, [addr_zero]; 376 | } : qq{ 377 | 378 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY, partialK, PT; 379 | --:-:-:-:1 \@P1 R2P PR, predsB, 0x5c; 380 | --:-:-:-:1 \@!P1 R2P PR, RZ, 0x5c; 381 | --:-:4:-:1 \@P2 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>]; 382 | --:-:4:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>]; 383 | --:-:4:-:1 \@P4 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>]; 384 | --:-:4:-:1 \@P6 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>]; 385 | --:-:-:-:1 \@!P2 MOV loadB0, RZ; 386 | --:-:-:-:1 \@!P3 MOV loadB1, RZ; 387 | --:-:-:-:1 \@!P4 MOV loadB2, RZ; 388 | --:-:-:-:1 \@!P6 MOV loadB3, RZ; 389 | } 390 | : 391 | # not outerContigB 392 | vecB() ? qq{ 393 | 394 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY, partialK, P6; 395 | --:-:4:-:1 \@P1 LDG.E.CI.$vsizeB loadB, [trackB]; 396 | --:-:4:-:1 \@!P1 LDS.U.$vsizeB loadB, [addr_zero]; 397 | 398 | } : qq{ 399 | 400 | --:-:-:-:1 IADD tidBY1, tidBY, 1; 401 | --:-:-:-:1 IADD tidBY2, tidBY, 2; 402 | --:-:-:-:1 IADD tidBY3, tidBY, 3; 403 | --:-:-:-:1 ISETP.LT.AND P0, PT, tidBY, partialK, P6; 404 | --:-:-:-:1 ISETP.LT.AND P1, PT, tidBY1, partialK, P6; 405 | --:-:-:-:1 ISETP.LT.AND P2, PT, tidBY2, partialK, P6; 406 | --:-:-:-:1 ISETP.LT.AND P3, PT, tidBY3, partialK, P6; 407 | --:-:4:-:1 \@P0 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>]; 408 | --:-:4:-:1 \@P1 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>]; 409 | --:-:4:-:1 \@P2 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>]; 410 | --:-:4:-:1 \@P3 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>]; 411 | --:-:-:-:1 \@!P0 MOV loadB0, RZ; 412 | --:-:-:-:1 \@!P1 MOV loadB1, RZ; 413 | --:-:-:-:1 \@!P2 MOV loadB2, RZ; 414 | --:-:-:-:1 \@!P3 MOV loadB3, RZ; 415 | }; 416 | +] 417 | 418 | [+ 419 | return outerContigA() && !vecA() ? q{ 420 | --:-:-:-:1 ISETP.GE.AND P0, PT, k, 8, PT; 421 | } : q{ 422 | --:-:-:-:1 ISETP.GE.AND P0, PT, k, 8, P5; 423 | }; 424 | +] 425 | [+ 426 | return outerContigB() && !vecB() ? q{ 427 | --:-:-:-:1 ISETP.GE.AND P1, PT, k, 8, PT; 428 | } : q{ 429 | --:-:-:-:1 ISETP.GE.AND P1, PT, k, 8, P6; 430 | }; 431 | +] 432 | 433 | --:-:-:-:0 IADD k, k, -8; 434 | [+ 435 | return A16() ? 436 | vecA() ? q{ 437 | 02:-:-:-:1 F2F.F32.F16 loadA3, loadA1.H1; 438 | --:-:-:-:1 F2F.F32.F16 loadA2, loadA1.H0; 439 | --:-:-:-:1 F2F.F32.F16 loadA1, loadA0.H1; 440 | --:-:2:-:1 F2F.F32.F16 loadA0, loadA0.H0; 441 | } : q{ 442 | 02:-:-:-:1 F2F.F32.F16 loadA3, loadA3; 443 | --:-:-:-:1 F2F.F32.F16 loadA2, loadA2; 444 | --:-:-:-:1 F2F.F32.F16 loadA1, loadA1; 445 | --:-:2:-:1 F2F.F32.F16 loadA0, loadA0; 446 | } 447 | : ''; 448 | +] 449 | [+ 450 | return B16() ? 451 | vecB() ? q{ 452 | 08:-:-:-:1 F2F.F32.F16 loadB3, loadB1.H1; 453 | --:-:-:-:1 F2F.F32.F16 loadB2, loadB1.H0; 454 | --:-:-:-:1 F2F.F32.F16 loadB1, loadB0.H1; 455 | --:-:4:-:1 F2F.F32.F16 loadB0, loadB0.H0; 456 | } : q{ 457 | 08:-:-:-:1 F2F.F32.F16 loadB3, loadB3; 458 | --:-:-:-:1 F2F.F32.F16 loadB2, loadB2; 459 | --:-:-:-:1 F2F.F32.F16 loadB1, loadB1; 460 | --:-:4:-:1 F2F.F32.F16 loadB0, loadB0; 461 | } 462 | : ''; 463 | +] 464 | [+ 465 | our $dshiftA; 466 | return outerContigA() ? qq{ 467 | 468 | 02:-:-:-:1 STS.128 [writeAs], loadA; 469 | 470 | --:-:-:-:6 LEA trackA0.CC, partialA, trackA0, $dshiftA; 471 | --:-:-:-:0 IADD.X trackA1, trackA1, RZ; 472 | 473 | } : qq{ 474 | 02:-:-:-:1 STS [writeAs + 4x<3*128>], loadA3; 475 | --:-:-:-:1 STS [writeAs + 4x<2*128>], loadA2; 476 | --:-:-:-:1 STS [writeAs + 4x<1*128>], loadA1; 477 | --:-:-:-:1 STS [writeAs + 4x<0*128>], loadA0; 478 | 479 | --:-:-:-:6 LEA trackA0.CC, partialK, trackA0, $dshiftA; 480 | --:-:-:-:0 IADD.X trackA1, trackA1, RZ; 481 | }; 482 | +] 483 | [+ 484 | our $dshiftB; 485 | return outerContigB() ? qq{ 486 | 487 | 08:-:-:-:1 STS.128 [writeBs], loadB; 488 | 489 | --:-:-:-:6 LEA trackB0.CC, partialB, trackB0, $dshiftB; 490 | --:-:-:-:0 IADD.X trackB1, trackB1, RZ; 491 | 492 | } : qq{ 493 | 08:-:-:-:1 STS [writeBs + 4x<3*128>], loadB3; 494 | --:-:-:-:1 STS [writeBs + 4x<2*128>], loadB2; 495 | --:-:-:-:1 STS [writeBs + 4x<1*128>], loadB1; 496 | --:-:-:-:1 STS [writeBs + 4x<0*128>], loadB0; 497 | 498 | --:-:-:-:6 LEA trackB0.CC, partialK, trackB0, $dshiftB; 499 | --:-:-:-:0 IADD.X trackB1, trackB1, RZ; 500 | }; 501 | +] 502 | 503 | --:-:-:-:5 BAR.SYNC 0; 504 | [+ 505 | return outerContigA() && outerContigB() ? q{ 506 | --:-:-:-:1 LOP.XOR writeBs, writeBs, 4x; 507 | --:-:-:-:0 LOP.XOR writeAs, writeAs, 4x; 508 | } : q{ 509 | --:-:-:-:1 IADD writeBs, writeBs, swapBuf; 510 | --:-:-:-:1 IADD writeAs, writeAs, swapBuf; 511 | --:-:-:-:0 IADD swapBuf, RZ, -swapBuf; 512 | }; 513 | +] 514 | --:-:-:-:1 LDS.U.128 j0Ay0, [readAs + 4x<0*128 + 00>]; 515 | --:-:-:-:1 LDS.U.128 j0Bx0, [readBs + 4x<0*128 + 00>]; 516 | --:-:-:-:1 LDS.U.128 j0Ay4, [readAs + 4x<0*128 + 16>]; 517 | --:-:1:-:1 LDS.U.128 j0Bx4, [readBs + 4x<0*128 + 32>]; 518 | [+ 519 | our ($dsizeA, $vsizeA, $dtypeA); 520 | return 521 | outerContigA() ? 522 | vecA() ? qq{ 523 | 524 | --:-:2:-:1 \@P0 LDG.E.CI.$vsizeA loadA, [trackA]; 525 | 526 | } : qq{ 527 | 528 | --:-:-:-:2 \@P0 R2P PR, predsA, 0x3c; 529 | --:-:-:Y:d \@!P0 R2P PR, RZ, 0x3c; 530 | --:-:-:-:1 \@P2 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>]; 531 | --:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>]; 532 | --:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>]; 533 | --:-:2:-:1 \@P5 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>]; 534 | } 535 | : 536 | vecA() ? qq{ 537 | 538 | --:-:3:-:1 \@P0 LDG.E.CI.$vsizeA loadA4, [trackA]; 539 | --:-:3:-:1 \@!P0 LDS.U.$vsizeA loadA4, [addr_zero]; 540 | 541 | } : qq{ 542 | 543 | --:-:-:-:1 \@P0 LDG.E.CI.$dtypeA loadA4, [trackA + ${dsizeA}x<0>]; 544 | --:-:-:-:1 \@P0 LDG.E.CI.$dtypeA loadA5, [trackA + ${dsizeA}x<1>]; 545 | --:-:-:-:1 \@P0 LDG.E.CI.$dtypeA loadA6, [trackA + ${dsizeA}x<2>]; 546 | --:-:3:-:1 \@P0 LDG.E.CI.$dtypeA loadA7, [trackA + ${dsizeA}x<3>]; 547 | --:-:3:-:1 \@!P0 LDS.U.128 loadA4, [addr_zero]; 548 | 549 | }; 550 | +] 551 | [+ 552 | our ($dsizeB, $vsizeB, $dtypeB); 553 | return 554 | outerContigB() ? 555 | vecB() ? qq{ 556 | 557 | --:-:4:-:1 \@P1 LDG.E.CI.$vsizeB loadB, [trackB]; 558 | 559 | } : qq{ 560 | 561 | --:-:-:-:2 \@P1 R2P PR, predsB, 0x5c; 562 | --:-:-:Y:d \@!P1 R2P PR, RZ, 0x5c; 563 | --:-:-:-:1 \@P2 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>]; 564 | --:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>]; 565 | --:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>]; 566 | --:-:4:-:1 \@P6 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>]; 567 | } 568 | : 569 | vecB() ? qq{ 570 | 571 | --:-:5:-:1 \@P1 LDG.E.CI.$vsizeB loadB4, [trackB]; 572 | --:-:5:-:1 \@!P1 LDS.U.$vsizeB loadB4, [addr_zero]; 573 | 574 | } : qq{ 575 | 576 | --:-:-:-:1 \@P1 LDG.E.CI.$dtypeB loadB4, [trackB + ${dsizeB}x<0>]; 577 | --:-:-:-:1 \@P1 LDG.E.CI.$dtypeB loadB5, [trackB + ${dsizeB}x<1>]; 578 | --:-:-:-:1 \@P1 LDG.E.CI.$dtypeB loadB6, [trackB + ${dsizeB}x<2>]; 579 | --:-:5:-:1 \@P1 LDG.E.CI.$dtypeB loadB7, [trackB + ${dsizeB}x<3>]; 580 | --:-:5:-:1 \@!P1 LDS.U.128 loadB4, [addr_zero]; 581 | 582 | }; 583 | +] 584 | --:-:-:-:2 PSETP.AND.AND P1, PT, !PT, !PT, !PT; 585 | 586 | LOOP: 587 | --:-:-:-:1 ISETP.GE.AND P0, PT, k, RZ, PT; 588 | [+ 589 | our ($dsizeA, $vsizeA, $dtypeA, $dsizeB, $vsizeB, $dtypeB); 590 | our %insert = 591 | ( 592 | ( outerContigA() && outerContigB() ? 593 | () : ( 594 | j0c10 => "--:-:-:-:1 PSETP.AND.AND P1, PT, !P1, PT, PT;\n", 595 | ) 596 | ), 597 | 598 | ( outerContigA() ? 599 | ( 600 | ( vecA() ? 601 | ( 602 | j2c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, P5;\n", 603 | ) : ( 604 | j2c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, PT;\n", 605 | j2c26 => "--:-:-:-:1 \@!P2 R2P PR, RZ, 0x3c;\n", 606 | j2c27 => "--:-:-:-:1 \@P2 R2P PR, predsA, 0x3c;\n", 607 | ) 608 | ), 609 | 610 | ( A16() ? 611 | ( vecA() ? 612 | ( 613 | j2c44 => "02:-:-:-:1 \@P0 F2F.F32.F16 loadA3, loadA1.H1;\n", 614 | j2c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA2, loadA1.H0;\n", 615 | j2c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA1, loadA0.H1;\n", 616 | j2c56 => "--:-:2:-:1 \@P0 F2F.F32.F16 loadA0, loadA0.H0;\n", 617 | ) : ( 618 | j2c44 => "02:-:-:-:1 \@P0 F2F.F32.F16 loadA3, loadA3;\n", 619 | j2c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA2, loadA2;\n", 620 | j2c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadA1, loadA1;\n", 621 | j2c56 => "--:-:2:-:1 \@P0 F2F.F32.F16 loadA0, loadA0;\n", 622 | ) 623 | ) : () 624 | ), 625 | 626 | j3c8 => "02:2:-:-:1 \@P0 STS.128 [writeAs], loadA;\n", 627 | 628 | j3c10 => "--:-:-:-:1 \@P2 IADD trackA0.CC, trackA0, cda8;\n", 629 | j3c15 => "--:-:-:-:1 \@P2 IADD.X trackA1, trackA1, RZ;\n", 630 | 631 | ( vecA() ? 632 | ( 633 | j3c56 => "02:-:2:-:1 \@P2 LDG.E.CI.$vsizeA loadA, [trackA];\n", 634 | ) : ( 635 | j3c56 => "02:-:-:-:1 \@P2 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x<0>];\n", 636 | j3c58 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x<1>];\n", 637 | j3c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x<2>];\n", 638 | j3c62 => "--:-:2:-:1 \@P5 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x<3>];\n", 639 | ) 640 | ), 641 | 642 | # Not outerContigA 643 | ) : ( 644 | j2c13 => "--:-:-:-:1 PSETP.AND.AND P2, PT, PT, P1, P5;\n", 645 | j2c26 => "--:-:-:-:1 ISETP.GE.AND P3, PT, k, 8, P2;\n", 646 | j2c27 => "--:-:-:-:1 ISETP.GE.AND P4, PT, k, 16, P2;\n", 647 | 648 | ( A16() ? 649 | ( 650 | ( vecA() ? 651 | ( 652 | j2c28 => "02:-:-:-:1 \@!P1 F2F.F32.F16 loadA3, loadA1.H1;\n", 653 | j2c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA2, loadA1.H0;\n", 654 | j2c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA1, loadA0.H1;\n", 655 | j2c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA0, loadA0.H0;\n", 656 | 657 | j2c44 => "04:-:-:-:1 \@P1 F2F.F32.F16 loadA3, loadA5.H1;\n", 658 | j2c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA2, loadA5.H0;\n", 659 | j2c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA1, loadA4.H1;\n", 660 | j2c56 => "--:-:2:-:1 \@P1 F2F.F32.F16 loadA0, loadA4.H0;\n", 661 | ) : ( 662 | j2c28 => "02:-:-:-:1 \@!P1 F2F.F32.F16 loadA3, loadA3;\n", 663 | j2c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA2, loadA2;\n", 664 | j2c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA1, loadA1;\n", 665 | j2c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadA0, loadA0;\n", 666 | 667 | j2c44 => "04:-:-:-:1 \@P1 F2F.F32.F16 loadA3, loadA7;\n", 668 | j2c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA2, loadA6;\n", 669 | j2c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadA1, loadA5;\n", 670 | j2c56 => "--:-:2:-:1 \@P1 F2F.F32.F16 loadA0, loadA4;\n", 671 | ), 672 | ), 673 | j3c8 => "02:-:-:-:1 \@P0 STS [writeAs + 4x<0*128>], loadA0;\n", 674 | j3c10 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<1*128>], loadA1;\n", 675 | j3c12 => "--:-:-:-:1 \@P0 STS [writeAs + 4x<2*128>], loadA2;\n", 676 | j3c14 => "--:2:-:-:1 \@P0 STS [writeAs + 4x<3*128>], loadA3;\n", 677 | 678 | ) : ( 679 | j2c54 => "02:-:-:-:1 \@!P1 STS [writeAs + 4x<0*128>], loadA0;\n", 680 | j2c56 => "--:-:-:-:1 \@!P1 STS [writeAs + 4x<1*128>], loadA1;\n", 681 | j2c58 => "--:-:-:-:1 \@!P1 STS [writeAs + 4x<2*128>], loadA2;\n", 682 | j2c60 => "--:2:-:-:1 \@!P1 STS [writeAs + 4x<3*128>], loadA3;\n", 683 | 684 | j3c8 => "04:-:-:-:1 \@P1 STS [writeAs + 4x<0*128>], loadA4;\n", 685 | j3c10 => "--:-:-:-:1 \@P1 STS [writeAs + 4x<1*128>], loadA5;\n", 686 | j3c12 => "--:-:-:-:1 \@P1 STS [writeAs + 4x<2*128>], loadA6;\n", 687 | j3c14 => "--:3:-:-:1 \@P1 STS [writeAs + 4x<3*128>], loadA7;\n", 688 | ) 689 | ), 690 | 691 | j3c15 => "--:-:-:-:1 \@P5 IADD trackA0.CC, trackA0, ${dsizeA}x<8>;\n", 692 | j3c20 => "--:-:-:-:1 \@P5 IADD.X trackA1, trackA1, RZ;\n", 693 | 694 | ( vecA() ? 695 | ( 696 | j3c60 => "02:-:2:-:1 \@P3 LDG.E.CI.$vsizeA loadA0, [trackA + ${dsizeA}x<0>];\n", 697 | j3c62 => "04:-:3:-:1 \@P4 LDG.E.CI.$vsizeA loadA4, [trackA + ${dsizeA}x<8>];\n", 698 | 699 | ) : ( 700 | j3c48 => "02:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA0, [trackA + ${dsizeA}x< 0>];\n", 701 | j3c50 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA1, [trackA + ${dsizeA}x< 1>];\n", 702 | j3c52 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeA loadA2, [trackA + ${dsizeA}x< 2>];\n", 703 | j3c54 => "--:-:2:-:1 \@P3 LDG.E.CI.$dtypeA loadA3, [trackA + ${dsizeA}x< 3>];\n", 704 | 705 | j3c56 => "04:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA4, [trackA + ${dsizeA}x< 8>];\n", 706 | j3c58 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA5, [trackA + ${dsizeA}x< 9>];\n", 707 | j3c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeA loadA6, [trackA + ${dsizeA}x<10>];\n", 708 | j3c62 => "--:-:3:-:1 \@P4 LDG.E.CI.$dtypeA loadA7, [trackA + ${dsizeA}x<11>];\n", 709 | ) 710 | ), 711 | ), 712 | ), 713 | 714 | ( outerContigB() ? 715 | ( 716 | ( vecB() ? 717 | ( 718 | j5c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, P6;\n", 719 | ) : ( 720 | j5c13 => "--:-:-:-:1 ISETP.GE.AND P2, PT, k, 8, PT;\n", 721 | j5c26 => "--:-:-:-:1 \@!P2 R2P PR, RZ, 0x5c;\n", 722 | j5c27 => "--:-:-:-:1 \@P2 R2P PR, predsB, 0x5c;\n", 723 | ) 724 | ), 725 | 726 | ( B16() ? 727 | ( vecB() ? 728 | ( 729 | j5c44 => "08:-:-:-:1 \@P0 F2F.F32.F16 loadB3, loadB1.H1;\n", 730 | j5c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB2, loadB1.H0;\n", 731 | j5c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB1, loadB0.H1;\n", 732 | j5c56 => "--:-:4:-:1 \@P0 F2F.F32.F16 loadB0, loadB0.H0;\n", 733 | ) : ( 734 | j5c44 => "08:-:-:-:1 \@P0 F2F.F32.F16 loadB3, loadB3;\n", 735 | j5c48 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB2, loadB2;\n", 736 | j5c52 => "--:-:-:-:1 \@P0 F2F.F32.F16 loadB1, loadB1;\n", 737 | j5c56 => "--:-:4:-:1 \@P0 F2F.F32.F16 loadB0, loadB0;\n", 738 | ) 739 | ) : () 740 | ), 741 | 742 | j6c8 => "08:4:-:-:1 \@P0 STS.128 [writeBs], loadB;\n", 743 | 744 | j6c10 => "--:-:-:-:1 \@P2 IADD trackB0.CC, trackB0, cdb8;\n", 745 | j6c15 => "--:-:-:-:1 \@P2 IADD.X trackB1, trackB1, RZ;\n", 746 | 747 | ( vecB() ? 748 | ( 749 | j6c56 => "08:-:4:-:1 \@P2 LDG.E.CI.$vsizeB loadB, [trackB];\n", 750 | ) : ( 751 | j6c56 => "08:-:-:-:1 \@P2 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x<0>];\n", 752 | j6c58 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x<1>];\n", 753 | j6c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x<2>];\n", 754 | j6c62 => "--:-:4:-:1 \@P6 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x<3>];\n", 755 | ) 756 | ), 757 | 758 | # Not outerContigB 759 | ) : ( 760 | j5c13 => "--:-:-:-:1 PSETP.AND.AND P2, PT, PT, P1, P6;\n", 761 | j5c26 => "--:-:-:-:1 ISETP.GE.AND P3, PT, k, 8, P2;\n", 762 | j5c27 => "--:-:-:-:1 ISETP.GE.AND P4, PT, k, 16, P2;\n", 763 | 764 | ( B16() ? 765 | ( 766 | ( vecB() ? 767 | ( 768 | j5c28 => "08:-:-:-:1 \@!P1 F2F.F32.F16 loadB3, loadB1.H1;\n", 769 | j5c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB2, loadB1.H0;\n", 770 | j5c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB1, loadB0.H1;\n", 771 | j5c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB0, loadB0.H0;\n", 772 | 773 | j5c44 => "10:-:-:-:1 \@P1 F2F.F32.F16 loadB3, loadB5.H1;\n", 774 | j5c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB2, loadB5.H0;\n", 775 | j5c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB1, loadB4.H1;\n", 776 | j5c56 => "--:-:4:-:1 \@P1 F2F.F32.F16 loadB0, loadB4.H0;\n", 777 | ) : ( 778 | j5c28 => "08:-:-:-:1 \@!P1 F2F.F32.F16 loadB3, loadB3;\n", 779 | j5c32 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB2, loadB2;\n", 780 | j5c36 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB1, loadB1;\n", 781 | j5c40 => "--:-:-:-:1 \@!P1 F2F.F32.F16 loadB0, loadB0;\n", 782 | 783 | j5c44 => "10:-:-:-:1 \@P1 F2F.F32.F16 loadB3, loadB7;\n", 784 | j5c48 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB2, loadB6;\n", 785 | j5c52 => "--:-:-:-:1 \@P1 F2F.F32.F16 loadB1, loadB5;\n", 786 | j5c56 => "--:-:4:-:1 \@P1 F2F.F32.F16 loadB0, loadB4;\n", 787 | ), 788 | ), 789 | j6c8 => "08:-:-:-:1 \@P0 STS [writeBs + 4x<0*128>], loadB0;\n", 790 | j6c10 => "--:-:-:-:1 \@P0 STS [writeBs + 4x<1*128>], loadB1;\n", 791 | j6c12 => "--:-:-:-:1 \@P0 STS [writeBs + 4x<2*128>], loadB2;\n", 792 | j6c14 => "--:4:-:-:1 \@P0 STS [writeBs + 4x<3*128>], loadB3;\n", 793 | 794 | ) : ( 795 | j5c54 => "08:-:-:-:1 \@!P1 STS [writeBs + 4x<0*128>], loadB0;\n", 796 | j5c56 => "--:-:-:-:1 \@!P1 STS [writeBs + 4x<1*128>], loadB1;\n", 797 | j5c58 => "--:-:-:-:1 \@!P1 STS [writeBs + 4x<2*128>], loadB2;\n", 798 | j5c60 => "--:4:-:-:1 \@!P1 STS [writeBs + 4x<3*128>], loadB3;\n", 799 | 800 | j6c8 => "10:-:-:-:1 \@P1 STS [writeBs + 4x<0*128>], loadB4;\n", 801 | j6c10 => "--:-:-:-:1 \@P1 STS [writeBs + 4x<1*128>], loadB5;\n", 802 | j6c12 => "--:-:-:-:1 \@P1 STS [writeBs + 4x<2*128>], loadB6;\n", 803 | j6c14 => "--:5:-:-:1 \@P1 STS [writeBs + 4x<3*128>], loadB7;\n", 804 | ) 805 | ), 806 | 807 | j6c15 => "--:-:-:-:1 \@P6 IADD trackB0.CC, trackB0, ${dsizeB}x<8>;\n", 808 | j6c20 => "--:-:-:-:1 \@P6 IADD.X trackB1, trackB1, RZ;\n", 809 | 810 | ( vecB() ? 811 | ( 812 | j6c60 => "08:-:4:-:1 \@P3 LDG.E.CI.$vsizeB loadB0, [trackB + ${dsizeB}x<0>];\n", 813 | j6c62 => "10:-:5:-:1 \@P4 LDG.E.CI.$vsizeB loadB4, [trackB + ${dsizeB}x<8>];\n", 814 | 815 | ) : ( 816 | j6c48 => "08:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB0, [trackB + ${dsizeB}x< 0>];\n", 817 | j6c50 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB1, [trackB + ${dsizeB}x< 1>];\n", 818 | j6c52 => "--:-:-:-:1 \@P3 LDG.E.CI.$dtypeB loadB2, [trackB + ${dsizeB}x< 2>];\n", 819 | j6c54 => "--:-:4:-:1 \@P3 LDG.E.CI.$dtypeB loadB3, [trackB + ${dsizeB}x< 3>];\n", 820 | 821 | j6c56 => "10:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB4, [trackB + ${dsizeB}x< 8>];\n", 822 | j6c58 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB5, [trackB + ${dsizeB}x< 9>];\n", 823 | j6c60 => "--:-:-:-:1 \@P4 LDG.E.CI.$dtypeB loadB6, [trackB + ${dsizeB}x<10>];\n", 824 | j6c62 => "--:-:5:-:1 \@P4 LDG.E.CI.$dtypeB loadB7, [trackB + ${dsizeB}x<11>];\n", 825 | ) 826 | ), 827 | ), 828 | ), 829 | 830 | ( outerContigA() && outerContigB() ? 831 | ( 832 | j6c63 => "--:-:-:-:5 BAR.SYNC 0;\n" . 833 | "--:-:-:-:1 \@P0 LOP.XOR readAs, readAs, 4x;\n" . 834 | "--:-:-:-:1 \@P0 LOP.XOR readBs, readBs, 4x;\n" . 835 | "--:-:-:-:1 \@P0 LOP.XOR writeAs, writeAs, 4x;\n" . 836 | "--:-:-:-:1 \@P0 LOP.XOR writeBs, writeBs, 4x;\n" . 837 | "--:-:-:-:1 IADD k, k, -8;\n", 838 | ) : ( 839 | j6c63 => "--:-:-:-:5 BAR.SYNC 0;\n" . 840 | "--:-:-:-:1 \@P0 IADD readAs, readAs, -swapBuf;\n" . 841 | "--:-:-:-:1 \@P0 IADD readBs, readBs, -swapBuf;\n" . 842 | "--:-:-:-:1 \@P0 IADD writeAs, writeAs, swapBuf;\n" . 843 | "--:-:-:-:1 \@P0 IADD writeBs, writeBs, swapBuf;\n" . 844 | "--:-:-:-:1 \@P0 IADD swapBuf, RZ, -swapBuf;\n" . 845 | "--:-:-:-:1 IADD k, k, -8;\n", 846 | ) 847 | ), 848 | 849 | j7c63 => "--:-:-:Y:5 \@P0 BRA.U LOOP;\n", 850 | ); 851 | 852 | my @cOrder; 853 | my @swirl = ([0,2],[1,2],[1,0],[0,0]); 854 | my @y = (0,1,4,5); 855 | foreach my $x (0,2,4,6) 856 | { 857 | foreach my $y (@y) 858 | { 859 | push @cOrder, [$x + $_->[0], $y + $_->[1]] foreach @swirl; 860 | } 861 | @y = reverse @y; 862 | } 863 | my $out = ''; 864 | foreach my $j (0 .. 7) 865 | { 866 | my $odd = $j & 1; 867 | my $nOdd = !$odd + 0; 868 | my $rsOffset = ($j + 1) % 8; 869 | my $rsPred = $j == 7 ? '@P0' : ' '; 870 | my $shiftA = outerContigA() || $rsOffset < 4 ? 0 : 16; 871 | my $shiftB = outerContigB() || $rsOffset < 4 ? 0 : 16; 872 | 873 | $insert{"j${j}c0"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy0, [readAs + 4x<%d*128 + 00 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftA; 874 | $insert{"j${j}c2"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dBx0, [readBs + 4x<%d*128 + 00 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftB; 875 | $insert{"j${j}c4"} = sprintf "--:-:-:-:1 %s LDS.U.128 j%dAy4, [readAs + 4x<%d*128 + 16 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftA; 876 | $insert{"j${j}c6"} = sprintf "--:-:1:-:1 %s LDS.U.128 j%dBx4, [readBs + 4x<%d*128 + 32 + %d>];\n", $rsPred, $nOdd, $rsOffset, $shiftB; 877 | 878 | foreach my $c (0 .. 63) 879 | { 880 | my ($x,$y) = @{$cOrder[$c]}; 881 | 882 | my $ins = $insert{"j${j}c$c"} || ''; 883 | 884 | my $stall = $ins =~ /LDS|I2I|I2F|F2I|F2F|LDG|STS|BAR|BRA/ ? 0 : 1; 885 | 886 | my $yield = $c == 32 && $stall ? 'Y' : '-'; 887 | 888 | my $wait = $c == 0 ? '01' : '--'; 889 | 890 | my $ctrl = "$wait:-:-:$yield:$stall"; 891 | 892 | $out .= sprintf "%s FFMA cx%dy%d, j%dBx%d, j%dAy%d, cx%dy%d;\n%s", $ctrl, $x,$y, $odd,$x, $odd,$y, $x,$y, $ins; 893 | } 894 | } 895 | return $out; 896 | +] 897 | 898 | 899 | --:-:-:-:1 MOV alpha, param_alpha; 900 | --:-:-:-:1 MOV beta, param_beta; 901 | 902 | // P4 = beta != 0 903 | --:-:-:-:1 ISETP.NE.AND P4, PT, RZ, param_beta, PT; 904 | 905 | --:-:-:-:1 SHR.U32 tid_32, tid, 5; 906 | --:-:-:-:1 LOP.AND tid_31, tid, 31; 907 | --:-:-:-:1 LOP.AND tid_96, tid, 96; 908 | --:-:-:-:1 LOP.AND tid_128, tid, 128; 909 | 910 | // readCs = (tid_32*64 + tid_31) * 4 911 | --:-:-:-:1 ISCADD readCs, tid_32, tid_31, 6; 912 | --:-:-:-:1 SHL readCs, readCs, 2; 913 | 914 | // cx = idx_B*128 + tid_128>>1 + tid_31; 915 | --:-:-:-:1 SHR.U32 cx00, tid_128, 1; 916 | --:-:-:-:1 LOP.OR cx00, tid_31, cx00; 917 | --:-:-:-:1 ISCADD cx00, idx_B, cx00, 7; 918 | --:-:-:-:1 IADD cx32, cx00, 32; 919 | 920 | --:-:-:-:1 ISETP.LT.AND P5, PT, cx00, param_n, P4; 921 | --:-:-:-:1 ISETP.LT.AND P6, PT, cx32, param_n, P4; 922 | 923 | // cy = idx_A*128 + tid_96 924 | --:-:-:-:1 ISCADD cy00, idx_A, tid_96, 7; 925 | --:-:-:-:1 IADD cy04, cy00, 4; 926 | --:-:-:-:1 IADD cy08, cy00, 8; 927 | --:-:-:-:1 IADD cy12, cy00, 12; 928 | 929 | --:-:-:-:1 MOV cdc, param_cdc; 930 | --:-:-:-:1 SHL cdc1, cdc, [+ dshiftC() + 0 +]; 931 | --:-:-:-:1 SHL cdc4, cdc, [+ dshiftC() + 2 +]; 932 | --:-:-:-:1 XMAD.LO2 cdc13, cdc, [+ dsizeC() * 13 +], RZ; 933 | 934 | // trackC += cy*cdc + cx; 935 | --:-:-:-:1 XMAD.LO tc, cy00, cdc, cx00, xmad_tc; 936 | 937 | --:-:-:-:1 LEA track00C0.CC, tc, param_C[0], [+ dshiftC() +]; 938 | --:-:-:-:1 LEA.HI.X track00C1, tc, param_C[1], RZ, [+ dshiftC() +]; 939 | --:-:-:-:1 IADD track04C0.CC, track00C0, cdc4; 940 | --:-:-:-:1 IADD.X track04C1, track00C1, RZ; 941 | --:-:-:-:1 IADD track08C0.CC, track04C0, cdc4; 942 | --:-:-:-:1 IADD.X track08C1, track04C1, RZ; 943 | --:-:-:-:1 IADD track12C0.CC, track08C0, cdc4; 944 | --:-:-:-:0 IADD.X track12C1, track08C1, RZ; 945 | 946 | --:-:-:-:1 FMUL c0, cx0y0, alpha; 947 | --:-:-:-:1 FMUL c1, cx1y0, alpha; 948 | --:-:-:-:1 FMUL c2, cx2y0, alpha; 949 | --:-:-:-:1 FMUL c3, cx3y0, alpha; 950 | --:-:-:-:1 FMUL c4, cx4y0, alpha; 951 | --:-:-:-:1 FMUL c5, cx5y0, alpha; 952 | --:-:-:-:1 FMUL c6, cx6y0, alpha; 953 | --:-:-:-:1 FMUL c7, cx7y0, alpha; 954 | 955 | 956 | 957 | --:-:-:-:5 CAL STORE_C; 958 | 959 | [+ 960 | my $out; 961 | foreach my $y (1..7) 962 | { 963 | my $inc = $y == 4 ? 13 : 1; 964 | 965 | $out .= qq{ 966 | 967 | --:-:-:-:1 ISETP.LT.AND P5, PT, cx00, param_n, P4; 968 | --:-:-:-:1 ISETP.LT.AND P6, PT, cx32, param_n, P4; 969 | 970 | 01:-:-:-:1 IADD track00C0.CC, track00C0, cdc$inc; 971 | --:-:-:-:1 IADD.X track00C1, track00C1, RZ; 972 | 02:-:-:-:1 IADD track04C0.CC, track04C0, cdc$inc; 973 | --:-:-:-:1 IADD.X track04C1, track04C1, RZ; 974 | 04:-:-:-:1 IADD track08C0.CC, track08C0, cdc$inc; 975 | --:-:-:-:1 IADD.X track08C1, track08C1, RZ; 976 | 08:-:-:-:1 IADD track12C0.CC, track12C0, cdc$inc; 977 | --:-:-:-:0 IADD.X track12C1, track12C1, RZ; 978 | 979 | --:-:-:-:1 IADD cy00, cy00, $inc; 980 | --:-:-:-:1 IADD cy04, cy04, $inc; 981 | --:-:-:-:1 IADD cy08, cy08, $inc; 982 | --:-:-:-:1 IADD cy12, cy12, $inc; 983 | 984 | --:-:-:-:1 FMUL c0, cx0y$y, alpha; 985 | --:-:-:-:1 FMUL c1, cx1y$y, alpha; 986 | --:-:-:-:1 FMUL c2, cx2y$y, alpha; 987 | --:-:-:-:1 FMUL c3, cx3y$y, alpha; 988 | --:-:-:-:1 FMUL c4, cx4y$y, alpha; 989 | --:-:-:-:1 FMUL c5, cx5y$y, alpha; 990 | --:-:-:-:1 FMUL c6, cx6y$y, alpha; 991 | --:-:-:-:1 FMUL c7, cx7y$y, alpha; 992 | 993 | 994 | --:-:-:-:5 CAL STORE_C; 995 | }; 996 | } 997 | return $out; 998 | +] 999 | 1000 | --:-:-:-:5 EXIT; 1001 | 1002 | 1003 | STORE_C: 1004 | 1005 | 1006 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy00, param_m, P5; 1007 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy00, param_m, P6; 1008 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy04, param_m, P5; 1009 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy04, param_m, P6; 1010 | 1011 | --:-:-:-:1 @!P0 MOV b00_00, RZ; 1012 | --:-:-:-:1 @!P1 MOV b00_32, RZ; 1013 | --:-:-:-:1 @!P2 MOV b04_00, RZ; 1014 | --:-:-:-:1 @!P3 MOV b04_32, RZ; 1015 | --:-:-:-:1 @P0 LDG.E.CI.[+ dtypeC() +] b00_00, [track00C + 1x<$dsizeC * 00>]; 1016 | --:-:-:-:1 @P1 LDG.E.CI.[+ dtypeC() +] b00_32, [track00C + 1x<$dsizeC * 32>]; 1017 | --:-:-:-:1 @P2 LDG.E.CI.[+ dtypeC() +] b04_00, [track04C + 1x<$dsizeC * 00>]; 1018 | --:-:5:-:1 @P3 LDG.E.CI.[+ dtypeC() +] b04_32, [track04C + 1x<$dsizeC * 32>]; 1019 | 1020 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy08, param_m, P5; 1021 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy08, param_m, P6; 1022 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy12, param_m, P5; 1023 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy12, param_m, P6; 1024 | 1025 | --:-:-:-:1 ISETP.LT.AND P5, PT, cx00, param_n, PT; 1026 | --:-:-:-:1 ISETP.LT.AND P6, PT, cx32, param_n, PT; 1027 | 1028 | --:-:-:-:1 @!P0 MOV b08_00, RZ; 1029 | --:-:-:-:1 @!P1 MOV b08_32, RZ; 1030 | --:-:-:-:1 @!P2 MOV b12_00, RZ; 1031 | --:-:-:-:1 @!P3 MOV b12_32, RZ; 1032 | --:-:-:-:1 @P0 LDG.E.CI.[+ dtypeC() +] b08_00, [track08C + 1x<$dsizeC * 00>]; 1033 | --:-:-:-:1 @P1 LDG.E.CI.[+ dtypeC() +] b08_32, [track08C + 1x<$dsizeC * 32>]; 1034 | --:-:-:-:1 @P2 LDG.E.CI.[+ dtypeC() +] b12_00, [track12C + 1x<$dsizeC * 00>]; 1035 | --:-:6:-:1 @P3 LDG.E.CI.[+ dtypeC() +] b12_32, [track12C + 1x<$dsizeC * 32>]; 1036 | 1037 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy00, param_m, P5; 1038 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy00, param_m, P6; 1039 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy04, param_m, P5; 1040 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy04, param_m, P6; 1041 | 1042 | 1043 | --:-:-:-:1 STS.128 [writeCs + 4x<00>], c0; 1044 | --:-:-:-:1 STS.128 [writeCs + 4x<32>], c4; 1045 | --:-:-:-:1 LDS c00_00, [readCs + 4x<0*8*64 + 00 + 0*16>]; 1046 | --:-:1:-:1 LDS c00_32, [readCs + 4x<0*8*64 + 32 + 0*16>]; 1047 | --:-:-:-:1 LDS c04_00, [readCs + 4x<1*8*64 + 00 + 1*16>]; 1048 | --:-:2:-:1 LDS c04_32, [readCs + 4x<1*8*64 + 32 + 1*16>]; 1049 | --:-:-:-:1 LDS c08_00, [readCs + 4x<2*8*64 + 00 + 2*16>]; 1050 | --:-:3:-:1 LDS c08_32, [readCs + 4x<2*8*64 + 32 + 2*16>]; 1051 | --:-:-:-:1 LDS c12_00, [readCs + 4x<3*8*64 + 00 + 3*16>]; 1052 | --:-:4:-:1 LDS c12_32, [readCs + 4x<3*8*64 + 32 + 3*16>]; 1053 | 1054 | [+ 1055 | C16() ? q{ 1056 | 10:-:-:-:1 F2F.F32.F16 b00_00, b00_00; 1057 | --:-:-:-:1 F2F.F32.F16 b00_32, b00_32; 1058 | --:-:-:-:1 F2F.F32.F16 b04_00, b04_00; 1059 | --:-:5:-:1 F2F.F32.F16 b04_32, b04_32; 1060 | 20:-:-:-:1 F2F.F32.F16 b08_00, b08_00; 1061 | --:-:-:-:1 F2F.F32.F16 b08_32, b08_32; 1062 | --:-:-:-:1 F2F.F32.F16 b12_00, b12_00; 1063 | --:-:6:-:1 F2F.F32.F16 b12_32, b12_32; 1064 | } : ''; 1065 | +] 1066 | 11:-:-:-:1 FFMA c00_00, b00_00, beta, c00_00; 1067 | --:-:-:-:1 FFMA c00_32, b00_32, beta, c00_32; 1068 | 02:-:-:-:1 FFMA c04_00, b04_00, beta, c04_00; 1069 | --:-:-:-:1 FFMA c04_32, b04_32, beta, c04_32; 1070 | 24:-:-:-:1 FFMA c08_00, b08_00, beta, c08_00; 1071 | --:-:-:-:1 FFMA c08_32, b08_32, beta, c08_32; 1072 | 08:-:-:-:1 FFMA c12_00, b12_00, beta, c12_00; 1073 | --:-:-:-:1 FFMA c12_32, b12_32, beta, c12_32; 1074 | [+ 1075 | C16() ? q{ 1076 | --:-:-:-:1 F2F.F16.F32 c00_00, c00_00; 1077 | --:-:1:-:1 F2F.F16.F32 c00_32, c00_32; 1078 | --:-:-:-:1 F2F.F16.F32 c04_00, c04_00; 1079 | --:-:2:-:1 F2F.F16.F32 c04_32, c04_32; 1080 | --:-:-:-:1 F2F.F16.F32 c08_00, c08_00; 1081 | --:-:3:-:1 F2F.F16.F32 c08_32, c08_32; 1082 | --:-:-:-:1 F2F.F16.F32 c12_00, c12_00; 1083 | --:-:4:-:1 F2F.F16.F32 c12_32, c12_32; 1084 | } : ''; 1085 | +] 1086 | 1087 | 1088 | 01:-:-:-:1 @P0 STG.E.CG.[+ dtypeC() +] [track00C + 1x<$dsizeC * 00>], c00_00; 1089 | --:1:-:-:1 @P1 STG.E.CG.[+ dtypeC() +] [track00C + 1x<$dsizeC * 32>], c00_32; 1090 | 02:-:-:-:1 @P2 STG.E.CG.[+ dtypeC() +] [track04C + 1x<$dsizeC * 00>], c04_00; 1091 | --:2:-:-:1 @P3 STG.E.CG.[+ dtypeC() +] [track04C + 1x<$dsizeC * 32>], c04_32; 1092 | 1093 | --:-:-:-:1 ISETP.LT.AND P0, PT, cy08, param_m, P5; 1094 | --:-:-:-:1 ISETP.LT.AND P1, PT, cy08, param_m, P6; 1095 | --:-:-:-:1 ISETP.LT.AND P2, PT, cy12, param_m, P5; 1096 | --:-:-:-:1 ISETP.LT.AND P3, PT, cy12, param_m, P6; 1097 | 1098 | 04:-:-:-:1 @P0 STG.E.CG.[+ dtypeC() +] [track08C + 1x<$dsizeC * 00>], c08_00; 1099 | --:3:-:-:1 @P1 STG.E.CG.[+ dtypeC() +] [track08C + 1x<$dsizeC * 32>], c08_32; 1100 | 08:-:-:-:1 @P2 STG.E.CG.[+ dtypeC() +] [track12C + 1x<$dsizeC * 00>], c12_00; 1101 | --:4:-:-:1 @P3 STG.E.CG.[+ dtypeC() +] [track12C + 1x<$dsizeC * 32>], c12_32; 1102 | 1103 | 1104 | --:-:-:-:5 RET; 1105 | -------------------------------------------------------------------------------- /src/c_interface.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | typedef struct CUfunc_st *CUfunction; 11 | typedef struct CUmod_st *CUmodule; 12 | typedef struct CUstream_st *CUstream; 13 | typedef int CUdevice; 14 | 15 | //include the cuda function prototypes here so that we don't need to know 16 | //anything about cuda to compile this file - makes building with bazel 17 | //a million times easier 18 | extern "C" { 19 | 20 | int 21 | #ifdef _WIN32 22 | __stdcall 23 | #endif 24 | cuModuleLoadData(CUmodule *, const void *); 25 | 26 | int 27 | #ifdef _WIN32 28 | __stdcall 29 | #endif 30 | cuModuleGetFunction(CUfunction *, CUmodule, const char *); 31 | 32 | int 33 | #ifdef _WIN32 34 | __stdcall 35 | #endif 36 | cuDeviceGetAttribute(int *, int, CUdevice); 37 | 38 | int 39 | #ifdef _WIN32 40 | __stdcall 41 | #endif 42 | cuLaunchKernel(CUfunction, 43 | unsigned int, 44 | unsigned int, 45 | unsigned int, 46 | unsigned int, 47 | unsigned int, 48 | unsigned int, 49 | unsigned int, 50 | CUstream, 51 | void **, 52 | void **); 53 | 54 | int 55 | #ifdef _WIN32 56 | __stdcall 57 | #endif 58 | cuCtxGetDevice(CUdevice *); 59 | 60 | }; 61 | 62 | #include "include/c_interface.h" 63 | #include "include/kernel_headers.h" 64 | 65 | namespace { 66 | #include "include/static_kernel_information.h" 67 | 68 | std::mutex load_kernel_mutex_; 69 | 70 | std::unordered_map kernels_; 71 | 72 | bool loadKernelsHelper(const std::unordered_map& kernels) { 73 | for (auto kernel : kernels) { 74 | if (kernels_.count(kernel.first) > 0) 75 | continue; 76 | 77 | CUmodule module; 78 | 79 | int res = cuModuleLoadData(&module, kernel.second); 80 | if (res != 0) { 81 | std::cerr << "Failed to load " << kernel.first << " " << 82 | res << std::endl; 83 | return false; 84 | } 85 | 86 | CUfunction function; 87 | 88 | std::string kernel_name = kernel.first.substr(0, kernel.first.size() - 6); 89 | 90 | res = cuModuleGetFunction(&function, module, kernel_name.c_str()); 91 | if (res != 0) { 92 | std::cerr << "Failed to extract " << kernel_name << " " << 93 | res << std::endl; 94 | return false; 95 | } 96 | 97 | kernels_.insert(std::make_pair(kernel.first, function)); 98 | } 99 | 100 | return true; 101 | } 102 | 103 | bool loadKernels(int major) { 104 | std::lock_guard lock(load_kernel_mutex_); 105 | 106 | if (major == 5) 107 | return loadKernelsHelper(kernels_50); 108 | else if (major == 6) 109 | return loadKernelsHelper(kernels_60); 110 | else { 111 | std::cerr << "Arch must be 5 or 6" << std::endl; 112 | return false; 113 | } 114 | } 115 | 116 | std::tuple getDeviceProperties(CUdevice& device) { 117 | int major, minor; 118 | int res = cuDeviceGetAttribute(&major, 75, device); 119 | if (res != 0) 120 | return std::make_tuple(res, -1, -1); 121 | 122 | res = cuDeviceGetAttribute(&minor, 76, device); 123 | if (res != 0) 124 | return std::make_tuple(res, -1, -1); 125 | 126 | return std::make_tuple(0, major, minor); 127 | } 128 | 129 | std::pair closest_divisor(int val, int div) { 130 | if (div == 2) { 131 | if ((val & 1) == 0) { return std::make_pair(2, val >> 1); } 132 | else { return std::make_pair(1, val); } 133 | } 134 | else if (div == 4) { 135 | if ((val & 3) == 0) { return std::make_pair(4, val >> 2); } 136 | else if ((val % 3) == 0) { return std::make_pair(3, val / 3); } 137 | else if ((val % 5) == 0) { return std::make_pair(5, val / 5); } 138 | else if ((val & 1) == 0) { return std::make_pair(2, val >> 1); } 139 | else if ((val % 7) == 0) { return std::make_pair(7, val / 7); } 140 | else { return std::make_pair(1, val); } 141 | } 142 | else { 143 | return std::make_pair(1, val); 144 | } 145 | } 146 | 147 | std::string get_op_string(bool a_t, bool b_t) { 148 | if (!a_t && !b_t) return "NN"; 149 | else if (a_t && !b_t) return "TN"; 150 | else if (!a_t && b_t) return "NT"; 151 | else return "TT"; 152 | } 153 | 154 | bool gemm(std::string precision, void *A, void *B, void *C, 155 | bool a_t, bool b_t, 156 | int m, int n, int k, 157 | int lda, int ldb, int ldc, 158 | float alpha, float beta, 159 | CUstream stream, unsigned int grid, unsigned int shared) { 160 | std::string kernel_op = get_op_string(a_t, b_t); 161 | 162 | if (grid >= selections[precision][kernel_op].size()) 163 | return false; 164 | 165 | kernel_properties kp = selections[precision][kernel_op][grid]; 166 | 167 | if (shared >= kp.shared_sizes.size()) 168 | return false; 169 | 170 | bool vec4A, vec8A; 171 | bool vec4B, vec8B; 172 | if (a_t) { 173 | vec4A = (lda & 3) == 0 && (m & 3) == 0; //multiple of 4 174 | vec8A = (lda & 7) == 0 && (m & 7) == 0; //multiple of 8 175 | } 176 | else { 177 | vec4A = (lda & 3) == 0 && (k & 3) == 0; //multiple of 4 178 | vec8A = (lda & 7) == 0 && (k & 7) == 0; //multiple of 8 179 | } 180 | 181 | if (b_t) { 182 | vec4B = (ldb & 3) == 0 && (k & 3) == 0; //multiple of 4 183 | vec8B = (ldb & 7) == 0 && (k & 7) == 0; //multiple of 8 184 | } 185 | else { 186 | vec4B = (ldb & 3) == 0 && (n & 3) == 0; //multiple of 4 187 | vec8B = (ldb & 7) == 0 && (n & 7) == 0; //multiple of 8 188 | } 189 | 190 | bool vec4C = (ldc & 3) == 0 && (n & 3) == 0; 191 | 192 | bool vecA = (kp.vA == 4 && vec4A) || (kp.vA == 8 && vec8A); 193 | bool vecB = (kp.vB == 4 && vec4B) || (kp.vB == 8 && vec8B); 194 | bool vecC = kp.vC == 1 || vec4C; 195 | 196 | bool vec = vecA && vecB && vecC; 197 | 198 | CUdevice device; 199 | int res = cuCtxGetDevice(&device); 200 | if (res != 0) 201 | return false; 202 | 203 | bool success; 204 | int major, minor; 205 | 206 | std::tie(success, major, minor) = getDeviceProperties(device); 207 | 208 | if (success != 0) 209 | return false; 210 | 211 | std::string kernel_string; 212 | kernel_string.reserve(64); 213 | 214 | kernel_string += precision + "gemm_" + kp.tile_string + 215 | "_" + kernel_op; 216 | if (vec) 217 | kernel_string += "_vec"; 218 | 219 | if (major == 5) 220 | kernel_string += "_sm_50"; 221 | else if (major == 6) 222 | kernel_string += "_sm_60"; 223 | else 224 | return false; 225 | 226 | auto kernel = kernels_.find(kernel_string); 227 | if (kernel == kernels_.end()) { 228 | loadKernels(major); 229 | kernel = kernels_.find(kernel_string); 230 | } 231 | 232 | int blk_A = (m + kp.tile_m - 1) / kp.tile_m; 233 | int blk_B = (n + kp.tile_n - 1) / kp.tile_n; 234 | 235 | int blk_a, blk_b; 236 | std::tie(blk_a, blk_A) = closest_divisor(blk_A, kp.div); 237 | std::tie(blk_b, blk_B) = closest_divisor(blk_B, kp.div); 238 | 239 | if (blk_a == 1) 240 | std::tie(blk_a, blk_A) = std::make_pair(blk_A, 1); 241 | 242 | void *args[13] = {&C, &A, &B, &alpha, &beta, &lda, &ldb, &ldc, 243 | &m, &n, &k, &blk_a, &blk_b}; 244 | 245 | res = cuLaunchKernel(kernel->second, blk_a * blk_b, blk_B, blk_A, 246 | kp.threads, 1, 1, 247 | kp.shared_sizes[shared], stream, args, NULL); 248 | 249 | if (res != 0) { 250 | std::cerr << "Failed to execute " << kernel_string << " " << 251 | res << std::endl; 252 | return false; 253 | } 254 | 255 | return true; 256 | } 257 | 258 | }; 259 | 260 | bool get_grid_limits(char precision, bool a_t, bool b_t, unsigned int *grid) 261 | { 262 | std::string prec_string(1, precision); 263 | 264 | *grid = selections[prec_string][get_op_string(a_t, b_t)].size(); 265 | return true; 266 | } 267 | 268 | bool get_shared_limits(char precision, bool a_t, bool b_t, unsigned int grid, unsigned int *shared) { 269 | std::string prec_string(1, precision); 270 | 271 | if (grid >= selections[prec_string][get_op_string(a_t, b_t)].size()) 272 | return false; 273 | 274 | *shared = selections[prec_string][get_op_string(a_t, b_t)][grid].shared_sizes.size(); 275 | 276 | return true; 277 | } 278 | 279 | bool openai_sgemm(float *A, float *B, float *C, 280 | bool a_t, bool b_t, 281 | int m, int n, int k, 282 | int lda, int ldb, int ldc, 283 | float alpha, float beta, 284 | CUstream stream, unsigned int grid, unsigned int shared) { 285 | return gemm("s", 286 | static_cast(A), 287 | static_cast(B), 288 | static_cast(C), 289 | a_t, b_t, m, n, k, lda, ldb, ldc, 290 | alpha, beta, stream, grid, shared); 291 | } 292 | 293 | bool openai_hgemm(uint16_t *A, uint16_t *B, uint16_t *C, 294 | bool a_t, bool b_t, 295 | int m, int n, int k, 296 | int lda, int ldb, int ldc, 297 | float alpha, float beta, 298 | CUstream stream, unsigned int grid, unsigned int shared) { 299 | return gemm("h", 300 | static_cast(A), 301 | static_cast(B), 302 | static_cast(C), 303 | a_t, b_t, m, n, k, lda, ldb, ldc, 304 | alpha, beta, stream, grid, shared); 305 | } 306 | -------------------------------------------------------------------------------- /src/test.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | 6 | #include "include/c_interface.h" 7 | 8 | /* Simple program to call all possible kernel variants to make sure the are 9 | callable and produce sane results. Not meant to check the correctness of 10 | the underlying routines */ 11 | 12 | int main(void) { 13 | cudaFree(0); 14 | 15 | std::vector> ops = { {false, false}, {false, true}, 16 | {true, false}, {true, true} }; 17 | { 18 | float *A, *B, *C; 19 | const int size = 1024; 20 | 21 | cudaMalloc(&A, size * size * sizeof(float)); 22 | cudaMalloc(&B, size * size * sizeof(float)); 23 | cudaMalloc(&C, size * size * sizeof(float)); 24 | float *C_host = (float *)malloc(size * size * sizeof(float)); 25 | 26 | thrust::fill_n(thrust::device, A, size * size, 1.f); 27 | thrust::fill_n(thrust::device, B, size * size, 1.f); 28 | 29 | for (auto op : ops) { 30 | unsigned int grid; 31 | get_grid_limits('s', op.first, op.second, &grid); 32 | for (int g = 0; g < grid; ++g) { 33 | unsigned int shared; 34 | get_shared_limits('s', op.first, op.second, g, &shared); 35 | for (int s = 0; s < shared; ++s) { 36 | bool res = openai_sgemm(A, B, C, op.first, op.second, size, size, size, 37 | size, size, size, 1.0, 0.0, NULL, g, s); 38 | assert(res); 39 | cudaMemcpy(C_host, C, size * size * sizeof(float), cudaMemcpyDeviceToHost); 40 | 41 | assert(C_host[0] == 1024); 42 | } 43 | } 44 | } 45 | 46 | cudaFree(A); 47 | cudaFree(B); 48 | cudaFree(C); 49 | free(C_host); 50 | } 51 | 52 | { 53 | uint16_t *A, *B, *C; 54 | const int size = 1024; 55 | 56 | cudaMalloc(&A, size * size * sizeof(uint16_t)); 57 | cudaMalloc(&B, size * size * sizeof(uint16_t)); 58 | cudaMalloc(&C, size * size * sizeof(uint16_t)); 59 | uint16_t *C_host = (uint16_t *)malloc(size * size * sizeof(uint16_t)); 60 | 61 | thrust::fill_n(thrust::device, A, size * size, 0x3c00); 62 | thrust::fill_n(thrust::device, B, size * size, 0x3c00); 63 | 64 | for (auto op : ops) { 65 | unsigned int grid; 66 | get_grid_limits('h', op.first, op.second, &grid); 67 | for (int g = 0; g < grid; ++g) { 68 | unsigned int shared; 69 | get_shared_limits('h', op.first, op.second, g, &shared); 70 | for (int s = 0; s < shared; ++s) { 71 | bool res = openai_hgemm(A, B, C, op.first, op.second, size, size, size, 72 | size, size, size, 1.0, 0.0, NULL, g, s); 73 | assert(res); 74 | cudaMemcpy(C_host, C, size * size * sizeof(uint16_t), cudaMemcpyDeviceToHost); 75 | 76 | assert(C_host[0] == 25600); 77 | } 78 | } 79 | } 80 | 81 | cudaFree(A); 82 | cudaFree(B); 83 | cudaFree(C); 84 | free(C_host); 85 | } 86 | 87 | return 0; 88 | } 89 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import pycuda.driver as drv 5 | from neon.backends.nervanagpu import NervanaGPU 6 | 7 | from openai_gemm import matmul_test 8 | 9 | ng = NervanaGPU() 10 | print drv.Context.get_current().get_device().name() 11 | 12 | ones = 0 13 | out = 0 14 | 15 | # for i in range(1000): # np.float32, np.float16 16 | 17 | # matmul_test(ng, np.float32, "TN", 4096*4, 4096*4, 33, ones=ones, out=out) # update 18 | 19 | # if i % 100 == 0: print i 20 | 21 | # exit() 22 | 23 | small_1 = (1,2,3,4,5,6,7,8,9,16,32,64,65,72,120,127,128,192) 24 | medium_1 = (32,64,128,192,778,785,786,787,794) 25 | big_1 = (32,64,128,1532,1535,1536,1537,1540,3073,4095) 26 | 27 | small_2 = (8,16,32,64,72,96,120,128,192) 28 | medium_2 = (32,64,128,192,256,768-4,768-8,768,768+16,768+32) 29 | big_2 = (32,64,128,1536-12,1536-24,1536,1536+28,1536+32,3072,4096) 30 | 31 | for dtype in (np.float32, np.float16, ): # np.float32, np.float16 32 | print dtype 33 | 34 | for size in (small_1, small_2, medium_1, medium_2, big_1, big_2,): # small_1, small_2, medium_1, medium_2, big_1, big_2 35 | print size 36 | 37 | for K in size: 38 | print "K:", K 39 | 40 | for C in (size): 41 | print "C:", C 42 | 43 | for N in (size): 44 | 45 | matmul_test(ng, dtype, "NN", N, K, C, ones=ones, out=out) # fprop 46 | matmul_test(ng, dtype, "NT", N, C, K, ones=ones, out=out) # bprop 47 | matmul_test(ng, dtype, "TN", C, K, N, ones=ones, out=out) # update 48 | matmul_test(ng, dtype, "TT", K, N, C, ones=ones, out=out) # ------ 49 | --------------------------------------------------------------------------------