├── .dockerignore ├── Dockerfile ├── CODE_OF_CONDUCT.md ├── .gitignore ├── LICENSE ├── SUPPORT.md ├── SECURITY.md ├── src ├── 1d │ ├── 1d_utils.h │ ├── gpu_2r.cu │ ├── main.cu │ └── gpu_1r.cu ├── 2d │ ├── 2d_utils.h │ ├── main.cu │ └── gpu.cu ├── 3d │ ├── 3d_utils.h │ └── gpu_star.cu └── cudnn │ ├── conv_box2d49p.cu │ ├── conv_box2d9p.cu │ ├── conv_1d3p.cu │ ├── conv_1d5p.cu │ └── conv_box3d27p.cu ├── CMakeLists.txt └── README.md /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | .vscode -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 2 | 3 | LABEL maintainer="y_t_chen@outlook.com" 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y build-essential cmake && \ 7 | rm -rf /var/lib/apt/lists/* 8 | 9 | RUN ln -s /usr/bin/cmake /usr/local/bin/cmake 10 | 11 | COPY . /convStencil 12 | WORKDIR /convStencil -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.i 2 | *.ii 3 | *.gpu 4 | *.ptx 5 | *.cubin 6 | *.fatbin 7 | .vscode/ 8 | 9 | build/* 10 | 11 | # Prerequisites 12 | *.d 13 | 14 | # Compiled Object files 15 | *.slo 16 | *.lo 17 | *.o 18 | *.obj 19 | 20 | # Precompiled Headers 21 | *.gch 22 | *.pch 23 | 24 | # Compiled Dynamic libraries 25 | *.so 26 | *.dylib 27 | *.dll 28 | 29 | # Fortran module files 30 | *.mod 31 | *.smod 32 | 33 | # Compiled Static libraries 34 | *.lai 35 | *.la 36 | *.a 37 | *.lib 38 | 39 | # Executables 40 | *.exe 41 | *.out 42 | *.app -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/1d/1d_utils.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // #pragma once 4 | 5 | // #include 6 | // #include 7 | // #include 8 | // #include 9 | // #include 10 | // #include 11 | 12 | // #include 13 | // #include 14 | 15 | 16 | 17 | // memory alignment 18 | #define ALIGN_TO(A, B) (((A + B - 1) / B) * B) 19 | 20 | // device memory pitch alignment 21 | static const size_t device_alignment = 32; 22 | 23 | // void gpu_box_2d1r_step3(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 24 | 25 | // __global__ void gpu_box_2d1r_step3_kernel (const double * __restrict__ in, double * __restrict__ out); 26 | 27 | // #include 28 | // #include 29 | 30 | #define DATA_TYPE double 31 | 32 | #define TENSOR_CORE_M 8 33 | 34 | #pragma once 35 | #define CUDAKERNELCHECK(expr) \ 36 | do \ 37 | { \ 38 | expr; \ 39 | \ 40 | cudaError_t __err = cudaGetLastError(); \ 41 | if (__err != cudaSuccess) \ 42 | { \ 43 | printf("Line %d: '%s' failed: %s\n", __LINE__, #expr, cudaGetErrorString(__err)); \ 44 | abort(); \ 45 | } \ 46 | } while (0) 47 | 48 | 49 | #include 50 | 51 | #define CUDA_CHECK(call) \ 52 | do \ 53 | { \ 54 | const cudaError_t error_code = call; \ 55 | if (error_code != cudaSuccess) \ 56 | { \ 57 | printf("CUDA Error:\n"); \ 58 | printf(" File: %s\n", __FILE__); \ 59 | printf(" Line: %d\n", __LINE__); \ 60 | printf(" Error code: %d\n", error_code); \ 61 | printf(" Error text: %s\n", \ 62 | cudaGetErrorString(error_code)); \ 63 | exit(1); \ 64 | } \ 65 | } while (0) 66 | 67 | // #pragma once 68 | 69 | enum Shape 70 | { 71 | star_1d1r, 72 | star_1d2r, 73 | }; 74 | 75 | void gpu_1d1r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_n); 76 | 77 | void gpu_1d1r_breakdown4(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_n); 78 | 79 | void gpu_1d1r_breakdown3(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_n); 80 | 81 | void gpu_1d1r_breakdown2(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_n); 82 | 83 | void gpu_1d1r_breakdown1(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_n); 84 | 85 | void gpu_star_1d2r_step2(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_n); 86 | -------------------------------------------------------------------------------- /src/2d/2d_utils.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // #pragma once 4 | 5 | // #include 6 | // #include 7 | // #include 8 | // #include 9 | // #include 10 | // #include 11 | 12 | // #include 13 | // #include 14 | 15 | 16 | 17 | // memory alignment 18 | #define ALIGN_TO(A, B) (((A + B - 1) / B) * B) 19 | 20 | // device memory pitch alignment 21 | static const size_t device_alignment = 32; 22 | 23 | // void gpu_box_2d1r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 24 | 25 | // __global__ void gpu_box_2d1r_step3_kernel (const double * __restrict__ in, double * __restrict__ out); 26 | 27 | // #include 28 | // #include 29 | 30 | #define DATA_TYPE double 31 | 32 | #define TENSOR_CORE_M 8 33 | 34 | #pragma once 35 | #define CUDAKERNELCHECK(expr) \ 36 | do \ 37 | { \ 38 | expr; \ 39 | \ 40 | cudaError_t __err = cudaGetLastError(); \ 41 | if (__err != cudaSuccess) \ 42 | { \ 43 | printf("Line %d: '%s' failed: %s\n", __LINE__, #expr, cudaGetErrorString(__err)); \ 44 | abort(); \ 45 | } \ 46 | } while (0) 47 | 48 | 49 | #include 50 | 51 | #define CUDA_CHECK(call) \ 52 | do \ 53 | { \ 54 | const cudaError_t error_code = call; \ 55 | if (error_code != cudaSuccess) \ 56 | { \ 57 | printf("CUDA Error:\n"); \ 58 | printf(" File: %s\n", __FILE__); \ 59 | printf(" Line: %d\n", __LINE__); \ 60 | printf(" Error code: %d\n", error_code); \ 61 | printf(" Error text: %s\n", \ 62 | cudaGetErrorString(error_code)); \ 63 | exit(1); \ 64 | } \ 65 | } while (0) 66 | 67 | // #pragma once 68 | 69 | enum Shape 70 | { 71 | star_2d1r, 72 | box_2d1r, 73 | star_2d3r, 74 | box_2d3r, 75 | }; 76 | 77 | void gpu_box_2d1r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 78 | 79 | void gpu_box_2d3r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 80 | 81 | void gpu_box_2d1r_breakdown4(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 82 | 83 | void gpu_box_2d1r_breakdown3(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 84 | 85 | void gpu_box_2d1r_breakdown2(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 86 | 87 | void gpu_box_2d1r_breakdown1(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); -------------------------------------------------------------------------------- /src/3d/3d_utils.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | // #pragma once 4 | 5 | // #include 6 | // #include 7 | // #include 8 | // #include 9 | // #include 10 | // #include 11 | 12 | // #include 13 | // #include 14 | 15 | 16 | 17 | // memory alignment 18 | #define ALIGN_TO(A, B) (((A + B - 1) / B) * B) 19 | 20 | // device memory pitch alignment 21 | static const size_t device_alignment = 32; 22 | 23 | // void gpu_box_2d1r_step3(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int time, const int input_m, const int input_n); 24 | 25 | // __global__ void gpu_box_2d1r_step3_kernel (const double * __restrict__ in, double * __restrict__ out); 26 | 27 | // #include 28 | // #include 29 | 30 | #define DATA_TYPE double 31 | 32 | #define TENSOR_CORE_M 8 33 | 34 | #pragma once 35 | #define CUDAKERNELCHECK(expr) \ 36 | do \ 37 | { \ 38 | expr; \ 39 | \ 40 | cudaError_t __err = cudaGetLastError(); \ 41 | if (__err != cudaSuccess) \ 42 | { \ 43 | printf("Line %d: '%s' failed: %s\n", __LINE__, #expr, cudaGetErrorString(__err)); \ 44 | abort(); \ 45 | } \ 46 | } while (0) 47 | 48 | 49 | #include 50 | 51 | #define CUDA_CHECK(call) \ 52 | do \ 53 | { \ 54 | const cudaError_t error_code = call; \ 55 | if (error_code != cudaSuccess) \ 56 | { \ 57 | printf("CUDA Error:\n"); \ 58 | printf(" File: %s\n", __FILE__); \ 59 | printf(" Line: %d\n", __LINE__); \ 60 | printf(" Error code: %d\n", error_code); \ 61 | printf(" Error text: %s\n", \ 62 | cudaGetErrorString(error_code)); \ 63 | exit(1); \ 64 | } \ 65 | } while (0) 66 | 67 | // #pragma once 68 | 69 | enum Shape 70 | { 71 | box_3d1r, 72 | star_3d1r, 73 | }; 74 | 75 | void gpu_box_3d1r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n); 76 | 77 | void gpu_star_3d1r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n); 78 | 79 | void gpu_box_3d1r_breakdown4(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n); 80 | 81 | void gpu_box_3d1r_breakdown3(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n); 82 | 83 | void gpu_box_3d1r_breakdown2(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n); 84 | 85 | void gpu_box_3d1r_breakdown1(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n); -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc) 3 | project(spmm_coo VERSION 0.01 LANGUAGES CXX CUDA) 4 | set(CMAKE_VERBOSE_MAKEFILE ON) 5 | 6 | set(CMAKE_CXX_STANDARD 14) 7 | set(CMAKE_CUDA_STANDARD 14) 8 | enable_language(CXX CUDA) 9 | 10 | # Define the include DIRs 11 | include_directories(${CUDA_INCLUDE_DIRS}) 12 | include_directories(/usr/local/cuda/include) 13 | 14 | # Define the link libraries 15 | link_directories(${CUDA_LIBRARY_DIRS}) 16 | link_directories(/usr/local/cuda/lib64) 17 | add_executable(convstencil_1d src/1d/main.cu src/1d/gpu_1r.cu src/1d/gpu_2r.cu src/1d/1d_utils.h) 18 | # target_link_libraries( gpu_box2d1r ${CUDA_cusparse_LIBRARY} ${CUDA_cublas_LIBRARY} ) 19 | set_target_properties( 20 | convstencil_1d 21 | PROPERTIES 22 | CUDA_SEPARABLE_COMPILATION ON 23 | CUDA_ARCHITECTURES "80") 24 | target_compile_options(convstencil_1d PRIVATE $<$:-O3 -lineinfo --use_fast_math --gpu-architecture=sm_80>) 25 | 26 | 27 | add_executable(cudnn_1d3p src/cudnn/conv_1d3p.cu) 28 | target_link_libraries( cudnn_1d3p ${CUDA_cudnn_LIBRARY}) 29 | set_target_properties( 30 | cudnn_1d3p 31 | PROPERTIES 32 | CUDA_SEPARABLE_COMPILATION ON 33 | CUDA_ARCHITECTURES "80") 34 | target_compile_options(cudnn_1d3p PRIVATE $<$:-O3 -lineinfo -lcudnn --use_fast_math --gpu-architecture=sm_80>) 35 | target_link_libraries(cudnn_1d3p ${CUDA_LIBRARIES} cudnn) 36 | 37 | add_executable(cudnn_1d5p src/cudnn/conv_1d5p.cu) 38 | target_link_libraries( cudnn_1d5p ${CUDA_cudnn_LIBRARY}) 39 | set_target_properties( 40 | cudnn_1d5p 41 | PROPERTIES 42 | CUDA_SEPARABLE_COMPILATION ON 43 | CUDA_ARCHITECTURES "80") 44 | target_compile_options(cudnn_1d5p PRIVATE $<$:-O3 -lineinfo -lcudnn --use_fast_math --gpu-architecture=sm_80>) 45 | target_link_libraries(cudnn_1d5p ${CUDA_LIBRARIES} cudnn) 46 | 47 | add_executable(cudnn_box2d49p src/cudnn/conv_box2d49p.cu) 48 | target_link_libraries( cudnn_box2d49p ${CUDA_cudnn_LIBRARY}) 49 | set_target_properties( 50 | cudnn_box2d49p 51 | PROPERTIES 52 | CUDA_SEPARABLE_COMPILATION ON 53 | CUDA_ARCHITECTURES "80") 54 | target_compile_options(cudnn_box2d49p PRIVATE $<$:-O3 -lineinfo -lcudnn --use_fast_math --gpu-architecture=sm_80>) 55 | target_link_libraries(cudnn_box2d49p ${CUDA_LIBRARIES} cudnn) 56 | 57 | add_executable(cudnn_box2d9p src/cudnn/conv_box2d9p.cu) 58 | target_link_libraries( cudnn_box2d9p ${CUDA_cudnn_LIBRARY}) 59 | set_target_properties( 60 | cudnn_box2d9p 61 | PROPERTIES 62 | CUDA_SEPARABLE_COMPILATION ON 63 | CUDA_ARCHITECTURES "80") 64 | target_compile_options(cudnn_box2d9p PRIVATE $<$:-O3 -lineinfo -lcudnn --use_fast_math --gpu-architecture=sm_80>) 65 | target_link_libraries(cudnn_box2d9p ${CUDA_LIBRARIES} cudnn) 66 | 67 | add_executable(cudnn_box3d27p src/cudnn/conv_box3d27p.cu) 68 | target_link_libraries( cudnn_box2d9p ${CUDA_cudnn_LIBRARY}) 69 | set_target_properties( 70 | cudnn_box3d27p 71 | PROPERTIES 72 | CUDA_SEPARABLE_COMPILATION ON 73 | CUDA_ARCHITECTURES "80") 74 | target_compile_options(cudnn_box3d27p PRIVATE $<$:-O3 -lineinfo -lcudnn --use_fast_math --gpu-architecture=sm_80>) 75 | target_link_libraries(cudnn_box3d27p ${CUDA_LIBRARIES} cudnn) 76 | 77 | add_executable(convstencil_2d src/2d/main.cu src/2d/gpu.cu src/2d/2d_utils.h) 78 | # target_link_libraries( gpu_box2d1r ${CUDA_cusparse_LIBRARY} ${CUDA_cublas_LIBRARY} ) 79 | set_target_properties( 80 | convstencil_2d 81 | PROPERTIES 82 | CUDA_SEPARABLE_COMPILATION ON 83 | CUDA_ARCHITECTURES "80") 84 | target_compile_options(convstencil_2d PRIVATE $<$:-O3 -lineinfo --use_fast_math --gpu-architecture=sm_80>) 85 | 86 | add_executable(convstencil_3d src/3d/main.cu src/3d/gpu_box.cu src/3d/gpu_star.cu src/3d/3d_utils.h) 87 | # target_link_libraries( gpu_box2d1r ${CUDA_cusparse_LIBRARY} ${CUDA_cublas_LIBRARY} ) 88 | set_target_properties( 89 | convstencil_3d 90 | PROPERTIES 91 | CUDA_SEPARABLE_COMPILATION ON 92 | CUDA_ARCHITECTURES "80") 93 | target_compile_options(convstencil_3d PRIVATE $<$:-O3 -lineinfo --use_fast_math --gpu-architecture=sm_80>) 94 | 95 | set(CMAKE_CUDA_ARCHITECTURES 80) 96 | # add_subdirectory(breakdown) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvStencil 2 | 3 | > ConvStencil: Transform Stencil Computation to Matrix Multiplication on Tensor Cores 4 | 5 | ## Abstract 6 | 7 | This artifact contains the source code of ConvStencil, a novel stencil computing system to transform stencil computation to matrix multiplication on Tensor Cores efficiently. 8 | 9 | ## Prerequisites 10 | 11 | - Hardware 12 | - x86-64 CPU 13 | - a single NVIDIA A100 GPU 14 | - Software (attached in the docker image) 15 | - CUDA - 12.2 (Tested). Lower versions down to CUDA 11.0 are also supported, but it may affect the performance. 16 | - GCC - above 9.4.0. You may also try to use icx or clang. 17 | - cuDNN - above 8.0 18 | 19 | ## Getting Code 20 | The code can be downloaded using git: 21 | ``` 22 | git clone https://github.com/microsoft/ConvStencil.git 23 | ``` 24 | 25 | ## Compile 26 | 27 | Use the following commands: 28 | ``` 29 | mkdir -p build 30 | cd build 31 | cmake .. 32 | make all -j24 33 | ``` 34 | 35 | ## Usage 36 | 37 | You can run `convstencil` in the following input format. 38 | ``` 39 | convstencil_program shape input_size time_interation_size options 40 | ``` 41 | - `convstencil_program` can be chosen from `convstencil_1d`, `convstencil_2d`, and `convstencil_3d` for different dimensions. 42 | - `shape` can be chosen by the different dimension: 43 | - `1d1r` and `1d2r` for 1D 44 | - `star2d1r`, `box2d1r`, `star2d3r` and `box2d3r` for 2D 45 | - `star3d1r` and `box3d1r` for 3D 46 | - `input_size` depends on the number of dimensions; the number of inputs required is equal to the number of dimensions. 47 | - `time_interation_size` is the iteration time. 48 | - `options`: 49 | - `--help` prints the help information. 50 | - `--custom` inputs the custom stencil kernel weights. 51 | 52 | ## Contact 53 | 54 | If you have any questions, please send an email to the author at kunli@microsoft.com. 55 | 56 | ## Reference 57 | 58 | 59 | Yuetao Chen, Kun Li, Yuhao Wang, Donglin Bai, Lei Wang, Lingxiao Ma, Liang Yuan, Yunquan Zhang, Ting Cao, Mao Yang. [ConvStencil: Transform Stencil Computation to Matrix Multiplication on Tensor Cores](https://doi.org/10.1145/3627535.3638476). In *ACM SIGPLAN Symposium on Principles and Practice of Parallel Programming (PPoPP)*, pp. 333–347, 2024. 60 | 61 | If you use our code, please cite our paper: 62 | ``` 63 | @inproceedings{10.1145/3627535.3638476, 64 | author = {Chen, Yuetao and Li, Kun and Wang, Yuhao and Bai, Donglin and Wang, Lei and Ma, Lingxiao and Yuan, Liang and Zhang, Yunquan and Cao, Ting and Yang, Mao}, 65 | title = {ConvStencil: Transform Stencil Computation to Matrix Multiplication on Tensor Cores}, 66 | year = {2024}, 67 | isbn = {9798400704352}, 68 | publisher = {Association for Computing Machinery}, 69 | address = {New York, NY, USA}, 70 | url = {https://doi.org/10.1145/3627535.3638476}, 71 | doi = {10.1145/3627535.3638476}, 72 | booktitle = {Proceedings of the 29th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming}, 73 | pages = {333–347}, 74 | series = {PPoPP '24} 75 | } 76 | ``` 77 | 78 | ## Contributing 79 | 80 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 81 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 82 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 83 | 84 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 85 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 86 | provided by the bot. You will only need to do this once across all repos using our CLA. 87 | 88 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 89 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 90 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 91 | 92 | ## Trademarks 93 | 94 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 95 | trademarks or logos is subject to and must follow 96 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 97 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 98 | Any use of third-party trademarks or logos are subject to those third-party's policies. 99 | -------------------------------------------------------------------------------- /src/1d/gpu_2r.cu: -------------------------------------------------------------------------------- 1 | //!支持任意大小 2 | // corner case 64 2 3 | #include 4 | #include 5 | #include "1d_utils.h" 6 | #include 7 | 8 | using namespace nvcuda; 9 | 10 | #define BLOCK_SIZE_COL 1024//tune 11 | #define HALO 4 12 | #define D_BLOCK_SIZE_COL (BLOCK_SIZE_COL + HALO * 2) //2*HALO 8 13 | #define PAD 14 | #define SM_SIZE_ROW (D_BLOCK_SIZE_COL / 8) 15 | #define UNIT_LENGTH 7 16 | #define TENSOR_CORE_M 8 17 | #define WARP_PER_BLOCK 8 18 | #define MMA_NUM 2 19 | #define IDX(x, y, ldm) ((x) * (ldm) + (y)) 20 | 21 | extern __constant__ double param_matrix_d[2 * 8 * TENSOR_CORE_M]; 22 | 23 | __global__ void gpu_star_1d2r_step2_kernel(const double *__restrict__ in, double *__restrict__ out) 24 | { 25 | __shared__ double sharedmem[SM_SIZE_ROW * 8]; 26 | 27 | int begin = blockIdx.x * BLOCK_SIZE_COL; 28 | int laneid = threadIdx.x % 32; 29 | 30 | int tid = threadIdx.x; 31 | int totalThreads = blockDim.x; 32 | 33 | for (int i = tid; i < D_BLOCK_SIZE_COL; i += totalThreads) 34 | { 35 | // if ( i < D_BLOCK_SIZE_COL) 36 | // if ( i < D_BLOCK_SIZE_COL - 2 * HALO) 37 | // sharedmem[lookup_table[i]] = in[begin + i]; 38 | sharedmem[i] = in[begin + i]; 39 | // sharedmem[i / 8 * 8 + i % 8] = in[begin + i]; 40 | // if (i >= 2 * HALO) 41 | // sharedmem[1][(i - 8) / 8 * 8 + (i - 8) % 8] = in[begin + i]; 42 | } 43 | 44 | __syncthreads(); 45 | 46 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 47 | #pragma unroll 48 | for (int i = 0; i < MMA_NUM; i++) 49 | { 50 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 51 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 2 * 4 * 8 + i * 32, 8); 52 | } 53 | 54 | nvcuda::wmma::fragment acc_frag; 55 | 56 | nvcuda::wmma::fragment in_frag; 57 | int warp_id=threadIdx.x/32; 58 | // 行数/warp数 59 | #pragma unroll 60 | for (int row = 2*8*warp_id; row < (warp_id+1)*8*2; row += TENSOR_CORE_M) //得是8的倍数 8*4 61 | { 62 | nvcuda::wmma::fill_fragment(acc_frag, 0.0); 63 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 64 | { 65 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem + IDX(row, compute_idx * 4, 8), 8); 66 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 67 | } 68 | 69 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 70 | { 71 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem + IDX(row+1, compute_idx * 4, 8), 8); 72 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 73 | } 74 | 75 | nvcuda::wmma::store_matrix_sync(out + begin + row / 8 * 64 + HALO , acc_frag, TENSOR_CORE_M, nvcuda::wmma::mem_row_major); 76 | } 77 | } 78 | 79 | /** 80 | * @param in input array pointer 81 | * @param out output array pointer 82 | * @param params parameter array pointer 83 | * 84 | */ 85 | void gpu_star_1d2r_step2(const double *__restrict__ in, double *__restrict__ out, const double *__restrict__ params, const int times, const int input_n) 86 | { 87 | double param_matrix_h[2][8 * 8] = {}; 88 | 89 | // Initialize parameter matrix 90 | 91 | for (int row = 0; row < 8; row++) // kernel size 7 92 | for (int col = 0; col <= row; ++col) 93 | param_matrix_h[0][row * 8 + col] = params[row - col]; 94 | 95 | for (int row = 0; row < 8; row++) 96 | for (int col = row; col < 8; ++col) 97 | param_matrix_h[1][row * 8 + col] = params[row + 8 - col]; 98 | 99 | // for(int i=0;i<8;++i){ 100 | // for(int j=0;j<8;++j) 101 | // printf("%8.3f ",param_matrix_h[0][i*8+j]); 102 | // printf("\n"); 103 | // } 104 | 105 | // printf("\n"); 106 | // for(int i=0;i<8;++i){ 107 | // for(int j=0;j<8;++j) 108 | // printf("%8.3f ",param_matrix_h[1][i*8+j]); 109 | // printf("\n"); 110 | // } 111 | 112 | 113 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 8 * sizeof(double))); 114 | 115 | const int cols = input_n + 2 * HALO ; 116 | const size_t array_size = cols * sizeof(double); 117 | double *array_d[2]; 118 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 119 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 120 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 121 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 122 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 123 | 124 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 125 | dim3 grid_config(BLOCK_N); 126 | dim3 block_config(32 * WARP_PER_BLOCK); 127 | 128 | // int lookup_table_h[D_BLOCK_SIZE_COL]; 129 | // for(int j=0;j>>(array_d[i % 2] , array_d[(i + 1) % 2]))); 146 | } 147 | CUDA_CHECK(cudaDeviceSynchronize()); 148 | 149 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 150 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 151 | 152 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 153 | // std::cout << secs << std::endl; 154 | printf("GStencil/s = %f\n", ((double)input_n * times * 2) / secs / 1e9); 155 | 156 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2] , array_size - sizeof(double), cudaMemcpyDeviceToHost)); 157 | 158 | return; 159 | } 160 | -------------------------------------------------------------------------------- /src/cudnn/conv_box2d49p.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDNN(expression) \ 7 | { \ 8 | cudnnStatus_t status = (expression); \ 9 | if (status != CUDNN_STATUS_SUCCESS) { \ 10 | std::cerr << "Error on line " << __LINE__ << ": " \ 11 | << cudnnGetErrorString(status) << std::endl;\ 12 | std::exit(EXIT_FAILURE); \ 13 | } \ 14 | } 15 | 16 | int main() { 17 | cudnnHandle_t cudnn; 18 | CHECK_CUDNN(cudnnCreate(&cudnn)); 19 | 20 | int H = 10000; 21 | int W = 10000; 22 | int T = 100; 23 | double *input_data_h; 24 | input_data_h = (double*)malloc(1 * 1 * H * W * sizeof(double)); 25 | 26 | for (int i = 0; i < H * W; i++) { 27 | input_data_h[i] = 1.0f; 28 | } 29 | 30 | double *data[2]; 31 | double *input_data; 32 | cudaMalloc(&input_data, 1 * 1 * H * W * sizeof(double)); 33 | cudaMemcpy(input_data, input_data_h, 1 * 1 * H * W * sizeof(double), cudaMemcpyHostToDevice); 34 | data[0] = input_data; 35 | 36 | cudnnTensorDescriptor_t input_descriptor; 37 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); 38 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, 39 | /*format=*/CUDNN_TENSOR_NHWC, 40 | /*dataType=*/CUDNN_DATA_DOUBLE, 41 | /*batch_size=*/1, 42 | /*channels=*/1, 43 | /*image_height=*/H, 44 | /*image_width=*/W)); 45 | 46 | double *filter_data_h; 47 | filter_data_h = (double*)malloc(1 * 1 * 7 * 7 * sizeof(double)); 48 | 49 | for (int i = 0; i < 7 * 7; i++) { 50 | filter_data_h[i] = 0.1111f; 51 | } 52 | 53 | double *filter_data; 54 | cudaMalloc(&filter_data, 1 * 1 * 7 * 7 * sizeof(double)); 55 | cudaMemcpy(filter_data, filter_data_h, 1 * 1 * 7 * 7 * sizeof(double), cudaMemcpyHostToDevice); 56 | 57 | cudnnFilterDescriptor_t filter_descriptor; 58 | CHECK_CUDNN(cudnnCreateFilterDescriptor(&filter_descriptor)); 59 | CHECK_CUDNN(cudnnSetFilter4dDescriptor(filter_descriptor, 60 | /*dataType=*/CUDNN_DATA_DOUBLE, 61 | /*format=*/CUDNN_TENSOR_NCHW, 62 | /*out_channels=*/1, 63 | /*in_channels=*/1, 64 | /*kernel_height=*/7, 65 | /*kernel_width=*/7)); 66 | 67 | cudnnConvolutionDescriptor_t convolution_descriptor; 68 | CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); 69 | CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, 70 | /*pad_height=*/3, 71 | /*pad_width=*/3, 72 | /*vertical_stride=*/1, 73 | /*horizontal_stride=*/1, 74 | /*dilation_height=*/1, 75 | /*dilation_width=*/1, 76 | /*mode=*/CUDNN_CROSS_CORRELATION, 77 | /*computeType=*/CUDNN_DATA_DOUBLE)); 78 | CHECK_CUDNN(cudnnSetConvolutionMathType(convolution_descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); 79 | 80 | // 计算输出数据尺寸 81 | int batch_size{0}, channels{0}, height{0}, width{0}; 82 | CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, 83 | input_descriptor, 84 | filter_descriptor, 85 | &batch_size, 86 | &channels, 87 | &height, 88 | &width)); 89 | 90 | double *output_data_h; 91 | output_data_h = (double*)malloc(batch_size * channels * height * width * sizeof(double)); 92 | 93 | double *output_data; 94 | cudaMalloc(&output_data, batch_size * channels * height * width * sizeof(double)); 95 | data[1] = output_data; 96 | 97 | cudnnTensorDescriptor_t output_descriptor; 98 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); 99 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, 100 | /*format=*/CUDNN_TENSOR_NHWC, 101 | /*dataType=*/CUDNN_DATA_DOUBLE, 102 | /*batch_size=*/batch_size, 103 | /*channels=*/channels, 104 | /*image_height=*/height, 105 | /*image_width=*/width)); 106 | 107 | double alpha = 1.0f, beta = 0.0f; 108 | cudnnConvolutionFwdAlgo_t convolution_algorithm = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; 109 | // CHECK_CUDNN( 110 | // cudnnFindConvolutionForwardAlgorithm(cudnn, 111 | // input_descriptor, 112 | // filter_descriptor, 113 | // convolution_descriptor, 114 | // output_descriptor, 115 | // CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 116 | // /*memoryLimitInBytes=*/0, 117 | // &convolution_algorithm)); 118 | 119 | size_t workspace_bytes{0}; 120 | CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn, 121 | input_descriptor, 122 | filter_descriptor, 123 | convolution_descriptor, 124 | output_descriptor, 125 | convolution_algorithm, 126 | &workspace_bytes)); 127 | 128 | void* d_workspace{nullptr}; 129 | cudaMalloc(&d_workspace, workspace_bytes); 130 | 131 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 132 | 133 | for (int t = 0; t < T; t++) { 134 | CHECK_CUDNN(cudnnConvolutionForward(cudnn, 135 | &alpha, 136 | input_descriptor, 137 | data[t % 2], 138 | filter_descriptor, 139 | filter_data, 140 | convolution_descriptor, 141 | convolution_algorithm, 142 | d_workspace, 143 | workspace_bytes, 144 | &beta, 145 | output_descriptor, 146 | data[(t + 1) % 2])); 147 | } 148 | cudaDeviceSynchronize() ; 149 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 150 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 151 | 152 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 153 | printf("GStencil/s = %f\n", ((double)H * W * T) / secs / 1e9); 154 | 155 | cudaMemcpy(output_data_h, output_data, batch_size * channels * height * width * sizeof(double), cudaMemcpyDeviceToHost); 156 | 157 | 158 | 159 | 160 | cudnnDestroyTensorDescriptor(input_descriptor); 161 | cudnnDestroyTensorDescriptor(output_descriptor); 162 | cudnnDestroyFilterDescriptor(filter_descriptor); 163 | cudnnDestroyConvolutionDescriptor(convolution_descriptor); 164 | cudnnDestroy(cudnn); 165 | 166 | cudaFree(input_data); 167 | cudaFree(filter_data); 168 | cudaFree(output_data); 169 | cudaFree(d_workspace); 170 | 171 | return 0; 172 | } 173 | -------------------------------------------------------------------------------- /src/cudnn/conv_box2d9p.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDNN(expression) \ 7 | { \ 8 | cudnnStatus_t status = (expression); \ 9 | if (status != CUDNN_STATUS_SUCCESS) { \ 10 | std::cerr << "Error on line " << __LINE__ << ": " \ 11 | << cudnnGetErrorString(status) << std::endl;\ 12 | std::exit(EXIT_FAILURE); \ 13 | } \ 14 | } 15 | 16 | int main() { 17 | cudnnHandle_t cudnn; 18 | CHECK_CUDNN(cudnnCreate(&cudnn)); 19 | 20 | int H = 10000; 21 | int W = 10000; 22 | int T = 10000; 23 | double *input_data_h; 24 | input_data_h = (double*)malloc(1 * 1 * H * W * sizeof(double)); 25 | 26 | for (int i = 0; i < H * W; i++) { 27 | input_data_h[i] = 1.0f; 28 | } 29 | 30 | double *data[2]; 31 | double *input_data; 32 | cudaMalloc(&input_data, 1 * 1 * H * W * sizeof(double)); 33 | cudaMemcpy(input_data, input_data_h, 1 * 1 * H * W * sizeof(double), cudaMemcpyHostToDevice); 34 | data[0] = input_data; 35 | 36 | cudnnTensorDescriptor_t input_descriptor; 37 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); 38 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, 39 | /*format=*/CUDNN_TENSOR_NHWC, 40 | /*dataType=*/CUDNN_DATA_DOUBLE, 41 | /*batch_size=*/1, 42 | /*channels=*/1, 43 | /*image_height=*/H, 44 | /*image_width=*/W)); 45 | 46 | double *filter_data_h; 47 | filter_data_h = (double*)malloc(1 * 1 * 3 * 3 * sizeof(double)); 48 | 49 | for (int i = 0; i < 3 * 3; i++) { 50 | filter_data_h[i] = 0.1111f; 51 | } 52 | 53 | double *filter_data; 54 | cudaMalloc(&filter_data, 1 * 1 * 3 * 3 * sizeof(double)); 55 | cudaMemcpy(filter_data, filter_data_h, 1 * 1 * 3 * 3 * sizeof(double), cudaMemcpyHostToDevice); 56 | 57 | cudnnFilterDescriptor_t filter_descriptor; 58 | CHECK_CUDNN(cudnnCreateFilterDescriptor(&filter_descriptor)); 59 | CHECK_CUDNN(cudnnSetFilter4dDescriptor(filter_descriptor, 60 | /*dataType=*/CUDNN_DATA_DOUBLE, 61 | /*format=*/CUDNN_TENSOR_NCHW, 62 | /*out_channels=*/1, 63 | /*in_channels=*/1, 64 | /*kernel_height=*/3, 65 | /*kernel_width=*/3)); 66 | 67 | cudnnConvolutionDescriptor_t convolution_descriptor; 68 | CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); 69 | CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, 70 | /*pad_height=*/1, 71 | /*pad_width=*/1, 72 | /*vertical_stride=*/1, 73 | /*horizontal_stride=*/1, 74 | /*dilation_height=*/1, 75 | /*dilation_width=*/1, 76 | /*mode=*/CUDNN_CROSS_CORRELATION, 77 | /*computeType=*/CUDNN_DATA_DOUBLE)); 78 | CHECK_CUDNN(cudnnSetConvolutionMathType(convolution_descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); 79 | 80 | // 计算输出数据尺寸 81 | int batch_size{0}, channels{0}, height{0}, width{0}; 82 | CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, 83 | input_descriptor, 84 | filter_descriptor, 85 | &batch_size, 86 | &channels, 87 | &height, 88 | &width)); 89 | 90 | double *output_data_h; 91 | output_data_h = (double*)malloc(batch_size * channels * height * width * sizeof(double)); 92 | 93 | double *output_data; 94 | cudaMalloc(&output_data, batch_size * channels * height * width * sizeof(double)); 95 | data[1] = output_data; 96 | 97 | cudnnTensorDescriptor_t output_descriptor; 98 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); 99 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, 100 | /*format=*/CUDNN_TENSOR_NHWC, 101 | /*dataType=*/CUDNN_DATA_DOUBLE, 102 | /*batch_size=*/batch_size, 103 | /*channels=*/channels, 104 | /*image_height=*/height, 105 | /*image_width=*/width)); 106 | 107 | double alpha = 1.0f, beta = 0.0f; 108 | cudnnConvolutionFwdAlgo_t convolution_algorithm = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; 109 | // CHECK_CUDNN( 110 | // cudnnFindConvolutionForwardAlgorithm(cudnn, 111 | // input_descriptor, 112 | // filter_descriptor, 113 | // convolution_descriptor, 114 | // output_descriptor, 115 | // CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 116 | // /*memoryLimitInBytes=*/0, 117 | // &convolution_algorithm)); 118 | 119 | size_t workspace_bytes{0}; 120 | CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn, 121 | input_descriptor, 122 | filter_descriptor, 123 | convolution_descriptor, 124 | output_descriptor, 125 | convolution_algorithm, 126 | &workspace_bytes)); 127 | 128 | void* d_workspace{nullptr}; 129 | cudaMalloc(&d_workspace, workspace_bytes); 130 | 131 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 132 | 133 | for (int t = 0; t < T; t++) { 134 | CHECK_CUDNN(cudnnConvolutionForward(cudnn, 135 | &alpha, 136 | input_descriptor, 137 | data[t % 2], 138 | filter_descriptor, 139 | filter_data, 140 | convolution_descriptor, 141 | convolution_algorithm, 142 | d_workspace, 143 | workspace_bytes, 144 | &beta, 145 | output_descriptor, 146 | data[(t + 1) % 2])); 147 | } 148 | cudaDeviceSynchronize() ; 149 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 150 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 151 | 152 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 153 | printf("GStencil/s = %f\n", ((double)H * W * T) / secs / 1e9); 154 | 155 | cudaMemcpy(output_data_h, output_data, batch_size * channels * height * width * sizeof(double), cudaMemcpyDeviceToHost); 156 | 157 | 158 | 159 | 160 | cudnnDestroyTensorDescriptor(input_descriptor); 161 | cudnnDestroyTensorDescriptor(output_descriptor); 162 | cudnnDestroyFilterDescriptor(filter_descriptor); 163 | cudnnDestroyConvolutionDescriptor(convolution_descriptor); 164 | cudnnDestroy(cudnn); 165 | 166 | cudaFree(input_data); 167 | cudaFree(filter_data); 168 | cudaFree(output_data); 169 | cudaFree(d_workspace); 170 | 171 | return 0; 172 | } 173 | -------------------------------------------------------------------------------- /src/1d/main.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "1d_utils.h" 4 | #include 5 | 6 | const char *ShapeStr[4] = { 7 | "1d1r", 8 | "1d2r" 9 | }; 10 | 11 | #define FILL_RANDOM 12 | // #define FILL_INDEX 13 | 14 | // #define CHECK_ERROR 15 | const double tolerance = 1e-7; 16 | __constant__ double param_matrix_d[2 * 8 * TENSOR_CORE_M]; 17 | // #define WRITE_OUTPUT 18 | 19 | int HALO; 20 | 21 | void save_to_txt(double *arr, int cols, const char *filename) 22 | { 23 | FILE *file = fopen(filename, "w"); 24 | if (file == NULL) 25 | { 26 | printf("Error opening file!\n"); 27 | return; 28 | } 29 | 30 | for (int i = 0; i < cols ; i++) 31 | { 32 | fprintf(file, "%d %.0f\n", i , arr[i]); 33 | } 34 | 35 | fclose(file); 36 | } 37 | 38 | void naive_1d(double *in, double *out, double *param, int N, int halo)//halo=r 39 | { 40 | #pragma unroll 41 | for (int j = halo; j < N + halo; j++) 42 | { 43 | out[j] = 0; 44 | for (int k = -halo; k <= halo; ++k) 45 | out[j] += param[k + halo] * in[j + k]; 46 | } 47 | } 48 | 49 | void printHelp() 50 | { 51 | const char *helpMessage = 52 | "Program name: convstencil_1d\n" 53 | "Usage: convstencil_2 shape input_size time_iteration_size [Options]\n" 54 | "Shape: 1d1r or 1d2r\n" 55 | "Options:\n" 56 | " --help Display this help message and exit\n" 57 | " --custom If you want to use costum parameters, please use this option and input your parameters like 0.2 0.2 0.2 0.2 0.2 if the shape is star2d1r\n"; 58 | printf("%s\n", helpMessage); 59 | } 60 | 61 | int main(int argc, char *argv[]) 62 | { 63 | if (argc < 4) 64 | { 65 | printHelp(); 66 | return 1; 67 | } 68 | 69 | // configurable settings 70 | std::string arg1 = argv[1]; 71 | 72 | Shape compute_shape; 73 | if(arg1 == "1d1r"){ 74 | compute_shape=star_1d1r; 75 | } 76 | else if(arg1 == "1d2r"){ 77 | compute_shape=star_1d2r; 78 | } else { 79 | printHelp(); 80 | return 1; 81 | } 82 | 83 | int n = 0; 84 | int times = 0; 85 | 86 | try 87 | { 88 | n = std::stoi(argv[2]); 89 | times = std::stoi(argv[3]); 90 | } 91 | catch (const std::invalid_argument &e) 92 | { 93 | std::cerr << "Invalid argument: cannot convert the parameter(s) to integer.\n"; 94 | return 1; 95 | } 96 | catch (const std::out_of_range &e) 97 | { 98 | std::cerr << "Argument out of range: the parameter(s) is(are) too large.\n"; 99 | return 1; 100 | } 101 | 102 | double param_1d1r[7] = {}; 103 | double param_1d2r[9] = {}; 104 | 105 | bool breakdown = false; 106 | 107 | for (int i = 0; i < 7; i++) 108 | { 109 | param_1d1r[i] = i + 1; 110 | } 111 | 112 | 113 | for (int i = 0; i < 9; i++) 114 | { 115 | param_1d2r[i] = i + 1; 116 | } 117 | 118 | if (argc == 5 && std::string(argv[4]) == "--custom") { 119 | int num_param = 0; 120 | if (arg1 == "1d1r") { 121 | num_param = 3; 122 | } else if (arg1 == "1d2r") { 123 | num_param = 5; 124 | } 125 | printf("Please enter %d parameters:\n", num_param); 126 | double values[num_param]; 127 | for (int i = 0; i < num_param; i++) 128 | { 129 | int readNum = scanf("%lf", &values[i]); 130 | if (readNum != 1) 131 | return 1; 132 | } 133 | if (arg1 == "1d1r") { 134 | param_1d1r[0] = values[0] * values[0] * values[0]; 135 | param_1d1r[1] = 3 * values[0] * values[0] * values[1]; 136 | param_1d1r[2] = 3 * values[0] * values[0] * values[2] + 3 * values[0] * values[1] * values[1]; 137 | param_1d1r[3] = 6 * values[0] * values[1] * values[2] + values[1] * values[1] * values[1]; 138 | param_1d1r[4] = 3 * values[0] * values[2] * values[2] + 3 * values[1] * values[1] * values[2]; 139 | param_1d1r[5] = 3 * values[1] * values[2] * values[2]; 140 | param_1d1r[6] = values[2] * values[2] * values[2]; 141 | } else if (arg1 == "1d2r") { 142 | param_1d2r[0] = values[0] * values[0]; 143 | param_1d2r[1] = 2 * values[0] * values[1]; 144 | param_1d2r[2] = 2 * values[0] * values[2] + values[1] * values[1]; 145 | param_1d2r[3] = 2 * values[0] * values[3] + 2 * values[1] * values[2]; 146 | param_1d2r[4] = 2 * values[0] * values[4] + 2 * values[1] * values[3] + values[2] * values[2]; 147 | param_1d2r[5] = 2 * values[1] * values[4] + 2 * values[2] * values[3]; 148 | param_1d2r[6] = 2 * values[2] * values[4] + values[3] * values[3]; 149 | param_1d2r[7] = 2 * values[3] * values[4]; 150 | param_1d2r[8] = values[4] * values[4]; 151 | } 152 | } 153 | 154 | if (argc == 5 && std::string(argv[4]) == "--breakdown") { 155 | breakdown = true; 156 | } 157 | 158 | double *param; 159 | 160 | switch (compute_shape) 161 | { 162 | case star_1d1r: 163 | param = param_1d1r; 164 | HALO = 3; 165 | break; 166 | case star_1d2r: 167 | param=param_1d2r; 168 | HALO= 4; 169 | break; 170 | } 171 | 172 | // print brief info 173 | 174 | printf("INFO: shape = %s, n = %d, times = %d\n", ShapeStr[compute_shape], n, times); 175 | 176 | int cols = n + 2 * HALO; //+1 177 | 178 | size_t input_size = (unsigned long)cols * sizeof(double); 179 | 180 | // allocate space 181 | 182 | double *input = (double *)malloc(input_size + sizeof(double)); // alignment for tensor core 183 | double *output = (double *)malloc(input_size + sizeof(double)); 184 | 185 | // fill input matrix 186 | 187 | #if defined(FILL_RANDOM) 188 | #pragma unroll 189 | for (int i = 0; i < cols + 1; i++) 190 | { 191 | input[i] = (double)(rand() % 10000); 192 | } 193 | #elif defined(FILL_INDEX) 194 | if(compute_shape==star_1d1r_step3){ 195 | for (int i = 0; i < cols + 1; i++) 196 | { 197 | if (i < HALO + 1 || i > cols - HALO)//+1为了对齐 198 | input[i] = 0; 199 | else 200 | { 201 | input[i] = i + 1 - (HALO+1); 202 | // printf("%d %lf\n",i,input[i]); 203 | } 204 | } 205 | } 206 | else{ 207 | for (int i = 0; i < cols ; i++) 208 | { 209 | if (i < HALO || i > cols - HALO -1) 210 | input[i] = 0; 211 | else 212 | { 213 | input[i] = i + 1 - HALO; 214 | } 215 | // printf("%d %lf\n",i,input[i]); 216 | } 217 | 218 | } 219 | #endif 220 | 221 | switch (compute_shape) 222 | { 223 | case star_1d1r: 224 | if (breakdown) { 225 | gpu_1d1r_breakdown1(input, output, param, times, n); 226 | gpu_1d1r_breakdown2(input, output, param, times, n); 227 | gpu_1d1r_breakdown3(input, output, param, times, n); 228 | gpu_1d1r_breakdown4(input, output, param, times, n); 229 | } 230 | gpu_1d1r(input, output, param, times, n); 231 | break; 232 | case star_1d2r: 233 | gpu_star_1d2r_step2(input, output, param, times, n); 234 | break; 235 | } 236 | 237 | // check result correctness 238 | 239 | #if defined(CHECK_ERROR) 240 | printf("\nChecking ... \n"); 241 | double *naive[2]; 242 | naive[0] = (double *)malloc(input_size); 243 | naive[1] = (double *)malloc(input_size); 244 | 245 | for (int i = 0; i < cols; i++) 246 | { 247 | if(compute_shape==star_1d2r_step2)naive[0][i] = input[i]; 248 | else naive[0][i]=input[i+1]; 249 | naive[1][i] = 0; 250 | // printf("%lf ",naive[0][i]); 251 | } 252 | 253 | int t = 0; 254 | 255 | for (; t < times; t++) 256 | { 257 | naive_1d(naive[t % 2], naive[(t + 1) % 2], param, n, HALO); 258 | } 259 | 260 | 261 | printf("Comparing naive and output\n"); 262 | for (int col = HALO; col < cols - HALO; col++) 263 | { 264 | if (std::fabs(naive[t % 2][col] - output[col]) > 1e-7) 265 | { 266 | printf("col = %d, naive = %lf, output = %lf\n", col, naive[t % 2][col], output[col]); 267 | } 268 | } 269 | #endif 270 | 271 | // write to file 272 | 273 | #ifdef WRITE_OUTPUT 274 | printf("Writing output_gpu.txt\n"); 275 | save_to_txt(output, cols, "output_gpu.txt"); 276 | #if defined(CHECK_ERROR) 277 | save_to_txt(naive[t % 2], cols, "output_naive.txt"); 278 | #endif 279 | #endif 280 | 281 | // free space 282 | free(output); 283 | free(input); 284 | 285 | return 0; 286 | } -------------------------------------------------------------------------------- /src/cudnn/conv_1d3p.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDNN(expression) \ 7 | { \ 8 | cudnnStatus_t status = (expression); \ 9 | if (status != CUDNN_STATUS_SUCCESS) { \ 10 | std::cerr << "Error on line " << __LINE__ << ": " \ 11 | << cudnnGetErrorString(status) << std::endl;\ 12 | std::exit(EXIT_FAILURE); \ 13 | } \ 14 | } 15 | 16 | int main() { 17 | cudnnHandle_t cudnn; 18 | CHECK_CUDNN(cudnnCreate(&cudnn)); 19 | 20 | int H = 1; 21 | int W = 10240000; 22 | int T = 100000; 23 | double *input_data_h; 24 | input_data_h = (double*)malloc(1 * 1 * H * W * sizeof(double)); 25 | 26 | for (int i = 0; i < H * W; i++) { 27 | input_data_h[i] = 1.0f; 28 | } 29 | 30 | double *data[2]; 31 | double *input_data; 32 | cudaMalloc(&input_data, 1 * 1 * H * W * sizeof(double)); 33 | cudaMemcpy(input_data, input_data_h, 1 * 1 * H * W * sizeof(double), cudaMemcpyHostToDevice); 34 | data[0] = input_data; 35 | 36 | cudnnTensorDescriptor_t input_descriptor; 37 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); 38 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, 39 | /*format=*/CUDNN_TENSOR_NHWC, 40 | /*dataType=*/CUDNN_DATA_DOUBLE, 41 | /*batch_size=*/1, 42 | /*channels=*/1, 43 | /*image_height=*/H, 44 | /*image_width=*/W)); 45 | 46 | double *filter_data_h; 47 | int filter_H = 1; 48 | int filter_W = 3; 49 | filter_data_h = (double*)malloc(1 * 1 * filter_H * filter_W * sizeof(double)); 50 | 51 | for (int i = 0; i < filter_H * filter_W; i++) { 52 | filter_data_h[i] = 0.1111f; 53 | } 54 | 55 | double *filter_data; 56 | cudaMalloc(&filter_data, 1 * 1 * filter_H * filter_W * sizeof(double)); 57 | cudaMemcpy(filter_data, filter_data_h, 1 * 1 * filter_H * filter_W * sizeof(double), cudaMemcpyHostToDevice); 58 | 59 | cudnnFilterDescriptor_t filter_descriptor; 60 | CHECK_CUDNN(cudnnCreateFilterDescriptor(&filter_descriptor)); 61 | CHECK_CUDNN(cudnnSetFilter4dDescriptor(filter_descriptor, 62 | /*dataType=*/CUDNN_DATA_DOUBLE, 63 | /*format=*/CUDNN_TENSOR_NCHW, 64 | /*out_channels=*/1, 65 | /*in_channels=*/1, 66 | /*kernel_height=*/filter_H, 67 | /*kernel_width=*/filter_W)); 68 | 69 | cudnnConvolutionDescriptor_t convolution_descriptor; 70 | CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); 71 | CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, 72 | /*pad_height=*/0, 73 | /*pad_width=*/1, 74 | /*vertical_stride=*/1, 75 | /*horizontal_stride=*/1, 76 | /*dilation_height=*/1, 77 | /*dilation_width=*/1, 78 | /*mode=*/CUDNN_CROSS_CORRELATION, 79 | /*computeType=*/CUDNN_DATA_DOUBLE)); 80 | CHECK_CUDNN(cudnnSetConvolutionMathType(convolution_descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); 81 | 82 | int batch_size{0}, channels{0}, height{0}, width{0}; 83 | CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, 84 | input_descriptor, 85 | filter_descriptor, 86 | &batch_size, 87 | &channels, 88 | &height, 89 | &width)); 90 | 91 | double *output_data_h; 92 | output_data_h = (double*)malloc(batch_size * channels * height * width * sizeof(double)); 93 | 94 | double *output_data; 95 | cudaMalloc(&output_data, batch_size * channels * height * width * sizeof(double)); 96 | data[1] = output_data; 97 | 98 | cudnnTensorDescriptor_t output_descriptor; 99 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); 100 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, 101 | /*format=*/CUDNN_TENSOR_NHWC, 102 | /*dataType=*/CUDNN_DATA_DOUBLE, 103 | /*batch_size=*/batch_size, 104 | /*channels=*/channels, 105 | /*image_height=*/height, 106 | /*image_width=*/width)); 107 | 108 | double alpha = 1.0f, beta = 0.0f; 109 | cudnnConvolutionFwdAlgo_t convolution_algorithm = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; 110 | // CHECK_CUDNN( 111 | // cudnnFindConvolutionForwardAlgorithm(cudnn, 112 | // input_descriptor, 113 | // filter_descriptor, 114 | // convolution_descriptor, 115 | // output_descriptor, 116 | // CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 117 | // /*memoryLimitInBytes=*/0, 118 | // &convolution_algorithm)); 119 | 120 | size_t workspace_bytes{0}; 121 | CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn, 122 | input_descriptor, 123 | filter_descriptor, 124 | convolution_descriptor, 125 | output_descriptor, 126 | convolution_algorithm, 127 | &workspace_bytes)); 128 | 129 | void* d_workspace{nullptr}; 130 | cudaMalloc(&d_workspace, workspace_bytes); 131 | 132 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 133 | 134 | for (int t = 0; t < T; t++) { 135 | CHECK_CUDNN(cudnnConvolutionForward(cudnn, 136 | &alpha, 137 | input_descriptor, 138 | data[t % 2], 139 | filter_descriptor, 140 | filter_data, 141 | convolution_descriptor, 142 | convolution_algorithm, 143 | d_workspace, 144 | workspace_bytes, 145 | &beta, 146 | output_descriptor, 147 | data[(t + 1) % 2])); 148 | } 149 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 150 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 151 | 152 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 153 | printf("GStencil/s = %f\n", ((double)H * W * T) / secs / 1e9); 154 | 155 | cudaMemcpy(output_data_h, output_data, batch_size * channels * height * width * sizeof(double), cudaMemcpyDeviceToHost); 156 | 157 | cudnnDestroyTensorDescriptor(input_descriptor); 158 | cudnnDestroyTensorDescriptor(output_descriptor); 159 | cudnnDestroyFilterDescriptor(filter_descriptor); 160 | cudnnDestroyConvolutionDescriptor(convolution_descriptor); 161 | cudnnDestroy(cudnn); 162 | 163 | cudaFree(input_data); 164 | cudaFree(filter_data); 165 | cudaFree(output_data); 166 | cudaFree(d_workspace); 167 | 168 | return 0; 169 | } 170 | -------------------------------------------------------------------------------- /src/cudnn/conv_1d5p.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDNN(expression) \ 7 | { \ 8 | cudnnStatus_t status = (expression); \ 9 | if (status != CUDNN_STATUS_SUCCESS) { \ 10 | std::cerr << "Error on line " << __LINE__ << ": " \ 11 | << cudnnGetErrorString(status) << std::endl;\ 12 | std::exit(EXIT_FAILURE); \ 13 | } \ 14 | } 15 | 16 | int main() { 17 | cudnnHandle_t cudnn; 18 | CHECK_CUDNN(cudnnCreate(&cudnn)); 19 | 20 | int H = 1; 21 | int W = 10240000; 22 | int T = 100000; 23 | double *input_data_h; 24 | input_data_h = (double*)malloc(1 * 1 * H * W * sizeof(double)); 25 | 26 | for (int i = 0; i < H * W; i++) { 27 | input_data_h[i] = 1.0f; 28 | } 29 | 30 | double *data[2]; 31 | double *input_data; 32 | cudaMalloc(&input_data, 1 * 1 * H * W * sizeof(double)); 33 | cudaMemcpy(input_data, input_data_h, 1 * 1 * H * W * sizeof(double), cudaMemcpyHostToDevice); 34 | data[0] = input_data; 35 | 36 | cudnnTensorDescriptor_t input_descriptor; 37 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); 38 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(input_descriptor, 39 | /*format=*/CUDNN_TENSOR_NHWC, 40 | /*dataType=*/CUDNN_DATA_DOUBLE, 41 | /*batch_size=*/1, 42 | /*channels=*/1, 43 | /*image_height=*/H, 44 | /*image_width=*/W)); 45 | 46 | double *filter_data_h; 47 | int filter_H = 1; 48 | int filter_W = 5; 49 | filter_data_h = (double*)malloc(1 * 1 * filter_H * filter_W * sizeof(double)); 50 | 51 | for (int i = 0; i < filter_H * filter_W; i++) { 52 | filter_data_h[i] = 0.1111f; 53 | } 54 | 55 | double *filter_data; 56 | cudaMalloc(&filter_data, 1 * 1 * filter_H * filter_W * sizeof(double)); 57 | cudaMemcpy(filter_data, filter_data_h, 1 * 1 * filter_H * filter_W * sizeof(double), cudaMemcpyHostToDevice); 58 | 59 | cudnnFilterDescriptor_t filter_descriptor; 60 | CHECK_CUDNN(cudnnCreateFilterDescriptor(&filter_descriptor)); 61 | CHECK_CUDNN(cudnnSetFilter4dDescriptor(filter_descriptor, 62 | /*dataType=*/CUDNN_DATA_DOUBLE, 63 | /*format=*/CUDNN_TENSOR_NCHW, 64 | /*out_channels=*/1, 65 | /*in_channels=*/1, 66 | /*kernel_height=*/filter_H, 67 | /*kernel_width=*/filter_W)); 68 | 69 | cudnnConvolutionDescriptor_t convolution_descriptor; 70 | CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); 71 | CHECK_CUDNN(cudnnSetConvolution2dDescriptor(convolution_descriptor, 72 | /*pad_height=*/0, 73 | /*pad_width=*/1, 74 | /*vertical_stride=*/1, 75 | /*horizontal_stride=*/1, 76 | /*dilation_height=*/1, 77 | /*dilation_width=*/1, 78 | /*mode=*/CUDNN_CROSS_CORRELATION, 79 | /*computeType=*/CUDNN_DATA_DOUBLE)); 80 | CHECK_CUDNN(cudnnSetConvolutionMathType(convolution_descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); 81 | 82 | int batch_size{0}, channels{0}, height{0}, width{0}; 83 | CHECK_CUDNN(cudnnGetConvolution2dForwardOutputDim(convolution_descriptor, 84 | input_descriptor, 85 | filter_descriptor, 86 | &batch_size, 87 | &channels, 88 | &height, 89 | &width)); 90 | 91 | double *output_data_h; 92 | output_data_h = (double*)malloc(batch_size * channels * height * width * sizeof(double)); 93 | 94 | double *output_data; 95 | cudaMalloc(&output_data, batch_size * channels * height * width * sizeof(double)); 96 | data[1] = output_data; 97 | 98 | cudnnTensorDescriptor_t output_descriptor; 99 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&output_descriptor)); 100 | CHECK_CUDNN(cudnnSetTensor4dDescriptor(output_descriptor, 101 | /*format=*/CUDNN_TENSOR_NHWC, 102 | /*dataType=*/CUDNN_DATA_DOUBLE, 103 | /*batch_size=*/batch_size, 104 | /*channels=*/channels, 105 | /*image_height=*/height, 106 | /*image_width=*/width)); 107 | 108 | double alpha = 1.0f, beta = 0.0f; 109 | cudnnConvolutionFwdAlgo_t convolution_algorithm = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; 110 | // CHECK_CUDNN( 111 | // cudnnFindConvolutionForwardAlgorithm(cudnn, 112 | // input_descriptor, 113 | // filter_descriptor, 114 | // convolution_descriptor, 115 | // output_descriptor, 116 | // CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 117 | // /*memoryLimitInBytes=*/0, 118 | // &convolution_algorithm)); 119 | 120 | size_t workspace_bytes{0}; 121 | CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn, 122 | input_descriptor, 123 | filter_descriptor, 124 | convolution_descriptor, 125 | output_descriptor, 126 | convolution_algorithm, 127 | &workspace_bytes)); 128 | 129 | void* d_workspace{nullptr}; 130 | cudaMalloc(&d_workspace, workspace_bytes); 131 | 132 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 133 | 134 | for (int t = 0; t < T; t++) { 135 | CHECK_CUDNN(cudnnConvolutionForward(cudnn, 136 | &alpha, 137 | input_descriptor, 138 | data[t % 2], 139 | filter_descriptor, 140 | filter_data, 141 | convolution_descriptor, 142 | convolution_algorithm, 143 | d_workspace, 144 | workspace_bytes, 145 | &beta, 146 | output_descriptor, 147 | data[(t + 1) % 2])); 148 | } 149 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 150 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 151 | 152 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 153 | printf("GStencil/s = %f\n", ((double)H * W * T) / secs / 1e9); 154 | 155 | cudaMemcpy(output_data_h, output_data, batch_size * channels * height * width * sizeof(double), cudaMemcpyDeviceToHost); 156 | 157 | cudnnDestroyTensorDescriptor(input_descriptor); 158 | cudnnDestroyTensorDescriptor(output_descriptor); 159 | cudnnDestroyFilterDescriptor(filter_descriptor); 160 | cudnnDestroyConvolutionDescriptor(convolution_descriptor); 161 | cudnnDestroy(cudnn); 162 | 163 | cudaFree(input_data); 164 | cudaFree(filter_data); 165 | cudaFree(output_data); 166 | cudaFree(d_workspace); 167 | 168 | return 0; 169 | } 170 | -------------------------------------------------------------------------------- /src/cudnn/conv_box3d27p.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define CUDA_CHECK(call) \ 7 | do \ 8 | { \ 9 | const cudaError_t error_code = call; \ 10 | if (error_code != cudaSuccess) \ 11 | { \ 12 | printf("CUDA Error:\n"); \ 13 | printf(" File: %s\n", __FILE__); \ 14 | printf(" Line: %d\n", __LINE__); \ 15 | printf(" Error code: %d\n", error_code); \ 16 | printf(" Error text: %s\n", \ 17 | cudaGetErrorString(error_code)); \ 18 | exit(1); \ 19 | } \ 20 | } while (0) 21 | 22 | #define CHECK_CUDNN(expression) \ 23 | { \ 24 | cudnnStatus_t status = (expression); \ 25 | if (status != CUDNN_STATUS_SUCCESS) { \ 26 | std::cerr << "Error on line " << __LINE__ << ": " \ 27 | << cudnnGetErrorString(status) << std::endl;\ 28 | std::exit(EXIT_FAILURE); \ 29 | } \ 30 | } 31 | 32 | int main() { 33 | cudnnHandle_t cudnn; 34 | CHECK_CUDNN(cudnnCreate(&cudnn)); 35 | 36 | // 输入数据(N=1, C=3, H=8, W=8) 37 | int H = 512; 38 | int W = 512; 39 | int L = 512; 40 | int T = 512; 41 | double *input_data_h; 42 | input_data_h = (double*)malloc(1 * 1 * H * W * L * sizeof(double)); 43 | 44 | for (int i = 0; i < H * W * L; i++) { 45 | input_data_h[i] = 1.0f; 46 | } 47 | 48 | double *data[2]; 49 | double *input_data; 50 | CUDA_CHECK(cudaMalloc(&input_data, 1 * 1 * H * W * L * sizeof(double))); 51 | CUDA_CHECK(cudaMemcpy(input_data, input_data_h, 1 * 1 * H * W * L * sizeof(double), cudaMemcpyHostToDevice)); 52 | data[0] = input_data; 53 | 54 | cudnnTensorDescriptor_t input_descriptor; 55 | CHECK_CUDNN(cudnnCreateTensorDescriptor(&input_descriptor)); 56 | int dims[5] = {1, 1, H, W, L}; 57 | int strides[5] = {1*H*W*L, H*W*L, W*L, L, 1}; 58 | CHECK_CUDNN(cudnnSetTensorNdDescriptor(input_descriptor, 59 | // /*format=*/CUDNN_TENSOR_NHWC, 60 | /*dataType=*/CUDNN_DATA_DOUBLE, 61 | 5, 62 | dims, 63 | strides)); 64 | 65 | // 卷积滤波器(K=2, C=3, H=5, W=5) 66 | double *filter_data_h; 67 | int kernel_size = 3; 68 | filter_data_h = (double*)malloc(1 * 1 * kernel_size * kernel_size * kernel_size * sizeof(double)); 69 | 70 | for (int i = 0; i < kernel_size * kernel_size * kernel_size; i++) { 71 | filter_data_h[i] = (double)1/27; 72 | } 73 | 74 | double *filter_data; 75 | CUDA_CHECK(cudaMalloc(&filter_data, 1 * 1 * kernel_size * kernel_size * kernel_size * sizeof(double))); 76 | CUDA_CHECK(cudaMemcpy(filter_data, filter_data_h, 1 * 1 * kernel_size * kernel_size * kernel_size * sizeof(double), cudaMemcpyHostToDevice)); 77 | 78 | cudnnFilterDescriptor_t filter_descriptor; 79 | CHECK_CUDNN(cudnnCreateFilterDescriptor(&filter_descriptor)); 80 | int kernelDims[5] = {1, 1, kernel_size, kernel_size, kernel_size}; 81 | CHECK_CUDNN(cudnnSetFilterNdDescriptor(filter_descriptor, 82 | /*dataType=*/CUDNN_DATA_DOUBLE, 83 | /*format=*/CUDNN_TENSOR_NCHW, 84 | 5, 85 | kernelDims)); 86 | 87 | // 卷积描述符 88 | cudnnConvolutionDescriptor_t convolution_descriptor; 89 | CHECK_CUDNN(cudnnCreateConvolutionDescriptor(&convolution_descriptor)); 90 | int pad = 1; 91 | int padA[3] = {pad, pad, pad}; 92 | int filterStrideA[3] = {1, 1, 1}; 93 | int dilationA[3] = {1, 1, 1}; 94 | CHECK_CUDNN(cudnnSetConvolutionNdDescriptor(convolution_descriptor, 95 | 3, 96 | padA, 97 | filterStrideA, 98 | dilationA, 99 | CUDNN_CROSS_CORRELATION, 100 | CUDNN_DATA_DOUBLE)); 101 | CHECK_CUDNN(cudnnSetConvolutionMathType(convolution_descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); 102 | 103 | // 计算输出数据尺寸 104 | // int batch_size{0}, channels{0}, height{0}, width{0}; 105 | int outputDims[5]; 106 | CHECK_CUDNN(cudnnGetConvolutionNdForwardOutputDim(convolution_descriptor, 107 | input_descriptor, 108 | filter_descriptor, 109 | 5, 110 | outputDims)); 111 | 112 | // 输出数据 113 | // double *output_data_h; 114 | // output_data_h = (double*)malloc(batch_size * channels * height * width * sizeof(double)); 115 | 116 | // double *output_data; 117 | // cudaMalloc(&output_data, batch_size * channels * height * width * sizeof(double)); 118 | // data[1] = output_data; 119 | 120 | int outputStrides[5] = {outputDims[1]*outputDims[2]*outputDims[3]*outputDims[4], outputDims[2]*outputDims[3]*outputDims[4], outputDims[3]*outputDims[4], outputDims[4], 1}; 121 | 122 | cudnnTensorDescriptor_t output_descriptor; 123 | cudnnCreateTensorDescriptor(&output_descriptor); 124 | cudnnSetTensorNdDescriptor(output_descriptor, 125 | CUDNN_DATA_DOUBLE, 126 | 5, 127 | outputDims, 128 | outputStrides); 129 | 130 | double *output_data; 131 | CUDA_CHECK(cudaMalloc(&output_data, outputDims[0]*outputDims[1]*outputDims[2]*outputDims[3]*outputDims[4] * sizeof(double))); 132 | data[1] = output_data; 133 | 134 | double *output_data_h; 135 | output_data_h = (double*)malloc(outputDims[0]*outputDims[1]*outputDims[2]*outputDims[3]*outputDims[4] * sizeof(double)); 136 | // 执行卷积前向传播 137 | double alpha = 1.0f, beta = 0.0f; 138 | cudnnConvolutionFwdAlgo_t convolution_algorithm = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; 139 | // CHECK_CUDNN( 140 | // cudnnFindConvolutionForwardAlgorithm(cudnn, 141 | // input_descriptor, 142 | // filter_descriptor, 143 | // convolution_descriptor, 144 | // output_descriptor, 145 | // CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 146 | // /*memoryLimitInBytes=*/0, 147 | // &convolution_algorithm)); 148 | 149 | size_t workspace_bytes{0}; 150 | CHECK_CUDNN(cudnnGetConvolutionForwardWorkspaceSize(cudnn, 151 | input_descriptor, 152 | filter_descriptor, 153 | convolution_descriptor, 154 | output_descriptor, 155 | convolution_algorithm, 156 | &workspace_bytes)); 157 | 158 | void* d_workspace{nullptr}; 159 | CUDA_CHECK(cudaMalloc(&d_workspace, workspace_bytes)); 160 | 161 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 162 | 163 | for (int t = 0; t < T; t++) { 164 | CHECK_CUDNN(cudnnConvolutionForward(cudnn, 165 | &alpha, 166 | input_descriptor, 167 | data[t % 2], 168 | filter_descriptor, 169 | filter_data, 170 | convolution_descriptor, 171 | convolution_algorithm, 172 | d_workspace, 173 | workspace_bytes, 174 | &beta, 175 | output_descriptor, 176 | data[(t + 1) % 2])); 177 | } 178 | CUDA_CHECK(cudaDeviceSynchronize()); 179 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 180 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 181 | 182 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 183 | std::cout << secs << std::endl; 184 | printf("GStencil/s = %f\n", ((double)H * W * L * T) / secs / 1e9); 185 | 186 | CUDA_CHECK(cudaMemcpy(output_data_h, output_data, outputDims[0]*outputDims[1]*outputDims[2]*outputDims[3]*outputDims[4] * sizeof(double), cudaMemcpyDeviceToHost)); 187 | // for (int i = 500 * 500 * 499; i < 500 * 500 * 500; i++) { 188 | // std::cout << output_data_h[i] << " "; 189 | // } 190 | // std::cout << std::endl; 191 | // std::cout << height << " " << width << std::endl; 192 | 193 | 194 | // 释放所有资源 195 | cudnnDestroyTensorDescriptor(input_descriptor); 196 | cudnnDestroyTensorDescriptor(output_descriptor); 197 | cudnnDestroyFilterDescriptor(filter_descriptor); 198 | cudnnDestroyConvolutionDescriptor(convolution_descriptor); 199 | cudnnDestroy(cudnn); 200 | 201 | cudaFree(input_data); 202 | cudaFree(filter_data); 203 | cudaFree(output_data); 204 | cudaFree(d_workspace); 205 | 206 | return 0; 207 | } 208 | -------------------------------------------------------------------------------- /src/3d/gpu_star.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "3d_utils.h" 4 | #include 5 | 6 | using namespace nvcuda; 7 | 8 | #define BLOCK_SIZE_ROW 8 9 | #define BLOCK_SIZE_COL 64 10 | #define HALO 3 11 | #define UNIT_LENGTH 7 12 | #define D_BLOCK_SIZE_COL (BLOCK_SIZE_COL + HALO * 2) 13 | #define D_BLOCK_SIZE_ROW (BLOCK_SIZE_ROW + HALO * 2) 14 | #define PAD 2 15 | #define SM_SIZE_COL (UNIT_LENGTH * D_BLOCK_SIZE_ROW + PAD) 16 | #define SM_SIZE_ROW (D_BLOCK_SIZE_COL / (UNIT_LENGTH + 1)) 17 | #define SM_DIFF (SM_SIZE_ROW * SM_SIZE_COL - D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL) 18 | #define WARP_PER_BLOCK 8 19 | #define COL_PER_WARP (BLOCK_SIZE_ROW / WARP_PER_BLOCK * UNIT_LENGTH) 20 | #define TENSOR_CORE_M 8 21 | #define MMA_NUM 13 22 | #define IDX2D(x, y, ldm) ((x) * (ldm) + (y)) 23 | #define IDX3D(x, y, z, rows, cols) ((x) * (rows) * (cols) + (y) * (cols) + (z)) 24 | 25 | __constant__ double param_star_matrix_d[2 * 52 * TENSOR_CORE_M]; 26 | __constant__ double param_one_d[1]; 27 | __constant__ double param_five_d[5]; 28 | __constant__ double param_thirteen_d[13]; 29 | 30 | void copy_temp(double * __restrict__ temp_para, const double * __restrict__ params) { 31 | temp_para[IDX2D(0, 3, UNIT_LENGTH)] = params[19 + 0]; 32 | temp_para[IDX2D(1, 2, UNIT_LENGTH)] = params[19 + 1]; 33 | temp_para[IDX2D(1, 3, UNIT_LENGTH)] = params[19 + 2]; 34 | temp_para[IDX2D(1, 4, UNIT_LENGTH)] = params[19 + 3]; 35 | temp_para[IDX2D(2, 1, UNIT_LENGTH)] = params[19 + 4]; 36 | temp_para[IDX2D(2, 2, UNIT_LENGTH)] = params[19 + 5]; 37 | temp_para[IDX2D(2, 3, UNIT_LENGTH)] = params[19 + 6]; 38 | temp_para[IDX2D(2, 4, UNIT_LENGTH)] = params[19 + 7]; 39 | temp_para[IDX2D(2, 5, UNIT_LENGTH)] = params[19 + 8]; 40 | temp_para[IDX2D(3, 0, UNIT_LENGTH)] = params[19 + 9]; 41 | temp_para[IDX2D(3, 1, UNIT_LENGTH)] = params[19 + 10]; 42 | temp_para[IDX2D(3, 2, UNIT_LENGTH)] = params[19 + 11]; 43 | temp_para[IDX2D(3, 3, UNIT_LENGTH)] = params[19 + 12]; 44 | temp_para[IDX2D(3, 4, UNIT_LENGTH)] = params[19 + 13]; 45 | temp_para[IDX2D(3, 5, UNIT_LENGTH)] = params[19 + 14]; 46 | temp_para[IDX2D(3, 6, UNIT_LENGTH)] = params[19 + 15]; 47 | temp_para[IDX2D(4, 1, UNIT_LENGTH)] = params[19 + 16]; 48 | temp_para[IDX2D(4, 2, UNIT_LENGTH)] = params[19 + 17]; 49 | temp_para[IDX2D(4, 3, UNIT_LENGTH)] = params[19 + 18]; 50 | temp_para[IDX2D(4, 4, UNIT_LENGTH)] = params[19 + 19]; 51 | temp_para[IDX2D(4, 5, UNIT_LENGTH)] = params[19 + 20]; 52 | temp_para[IDX2D(5, 2, UNIT_LENGTH)] = params[19 + 21]; 53 | temp_para[IDX2D(5, 3, UNIT_LENGTH)] = params[19 + 22]; 54 | temp_para[IDX2D(5, 4, UNIT_LENGTH)] = params[19 + 23]; 55 | temp_para[IDX2D(6, 3, UNIT_LENGTH)] = params[19 + 24]; 56 | } 57 | 58 | __forceinline__ __device__ void load_original_data(double * __restrict__ data, const double * __restrict__ in, const int h, const int rows, const int cols) { 59 | int begin = IDX3D(h, blockIdx.x * BLOCK_SIZE_ROW, blockIdx.y * BLOCK_SIZE_COL, rows, cols); 60 | int tid = threadIdx.x; 61 | int total_threads = blockDim.x; 62 | for (int i = tid; i < D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL; i += total_threads) { 63 | int row = i / D_BLOCK_SIZE_COL; 64 | int col = i % D_BLOCK_SIZE_COL; 65 | data[i] = in[begin + IDX2D(row, col, cols)]; 66 | } 67 | __syncthreads(); 68 | } 69 | 70 | __forceinline__ __device__ void load_shared_data(double * __restrict__ data, const double * __restrict__ in, const int h, const int rows, const int cols, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 71 | int tid = threadIdx.x; 72 | int total_threads = blockDim.x; 73 | for (int i = tid; i < D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL; i += total_threads) { 74 | data[IDX2D(0, lookup_table1[i], SM_SIZE_ROW * SM_SIZE_COL)] = in[i]; 75 | data[IDX2D(1, lookup_table2[i], SM_SIZE_ROW * SM_SIZE_COL)] = in[i]; 76 | } 77 | __syncthreads(); 78 | } 79 | 80 | __forceinline__ __device__ void load_trans_data(double * __restrict__ data, const double * __restrict__ in, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 81 | int tid = threadIdx.x; 82 | int total_threads = blockDim.x; 83 | for (int i = tid; i < D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL; i += total_threads) { 84 | data[IDX2D(0, lookup_table1[i], SM_SIZE_ROW * SM_SIZE_COL)] = in[i]; 85 | data[IDX2D(1, lookup_table2[i], SM_SIZE_ROW * SM_SIZE_COL)] = in[i]; 86 | } 87 | __syncthreads(); 88 | } 89 | 90 | __forceinline__ __device__ void compute_one_point(double * __restrict__ data, double * __restrict__ out) { 91 | int tid = threadIdx.x; 92 | int total_threads = blockDim.x; 93 | for (int i = tid; i < BLOCK_SIZE_ROW * BLOCK_SIZE_COL; i += total_threads) { 94 | int row = i / BLOCK_SIZE_COL; 95 | int col = i % BLOCK_SIZE_COL; 96 | out[i] = param_one_d[0] * data[IDX2D(row + HALO, col + HALO, D_BLOCK_SIZE_COL)]; 97 | } 98 | } 99 | 100 | __forceinline__ __device__ void compute_five_point(double * __restrict__ data, double * __restrict__ out) { 101 | int tid = threadIdx.x; 102 | int total_threads = blockDim.x; 103 | for (int i = tid; i < BLOCK_SIZE_ROW * BLOCK_SIZE_COL; i += total_threads) { 104 | int row = i / BLOCK_SIZE_COL; 105 | int col = i % BLOCK_SIZE_COL; 106 | out[i] = 107 | param_five_d[0] * data[IDX2D(HALO + row - 1, HALO + col, D_BLOCK_SIZE_COL)] + 108 | param_five_d[1] * data[IDX2D(HALO + row, HALO + col - 1, D_BLOCK_SIZE_COL)] + 109 | param_five_d[2] * data[IDX2D(HALO + row, HALO + col, D_BLOCK_SIZE_COL)] + 110 | param_five_d[3] * data[IDX2D(HALO + row, HALO + col + 1, D_BLOCK_SIZE_COL)] + 111 | param_five_d[4] * data[IDX2D(HALO + row + 1, HALO + col, D_BLOCK_SIZE_COL)]; 112 | } 113 | } 114 | 115 | __forceinline__ __device__ void compute_thirteen_point(double * __restrict__ data, double * __restrict__ out) { 116 | int tid = threadIdx.x; 117 | int total_threads = blockDim.x; 118 | for (int i = tid; i < BLOCK_SIZE_ROW * BLOCK_SIZE_COL; i += total_threads) { 119 | int row = i / BLOCK_SIZE_COL; 120 | int col = i % BLOCK_SIZE_COL; 121 | out[i] = 122 | param_thirteen_d[0] * data[IDX2D(HALO + row - 2, HALO + col, D_BLOCK_SIZE_COL)] + 123 | param_thirteen_d[1] * data[IDX2D(HALO + row - 1, HALO + col - 1, D_BLOCK_SIZE_COL)] + 124 | param_thirteen_d[2] * data[IDX2D(HALO + row - 1, HALO + col, D_BLOCK_SIZE_COL)] + 125 | param_thirteen_d[3] * data[IDX2D(HALO + row - 1, HALO + col + 1, D_BLOCK_SIZE_COL)] + 126 | param_thirteen_d[4] * data[IDX2D(HALO + row, HALO + col - 2, D_BLOCK_SIZE_COL)] + 127 | param_thirteen_d[5] * data[IDX2D(HALO + row, HALO + col - 1, D_BLOCK_SIZE_COL)] + 128 | param_thirteen_d[6] * data[IDX2D(HALO + row, HALO + col, D_BLOCK_SIZE_COL)] + 129 | param_thirteen_d[7] * data[IDX2D(HALO + row, HALO + col + 1, D_BLOCK_SIZE_COL)] + 130 | param_thirteen_d[8] * data[IDX2D(HALO + row, HALO + col + 2, D_BLOCK_SIZE_COL)] + 131 | param_thirteen_d[9] * data[IDX2D(HALO + row + 1, HALO + col - 1, D_BLOCK_SIZE_COL)] + 132 | param_thirteen_d[10] * data[IDX2D(HALO + row + 1, HALO + col, D_BLOCK_SIZE_COL)] + 133 | param_thirteen_d[11] * data[IDX2D(HALO + row + 1, HALO + col + 1, D_BLOCK_SIZE_COL)] + 134 | param_thirteen_d[12] * data[IDX2D(HALO + row + 2, HALO + col, D_BLOCK_SIZE_COL)]; 135 | } 136 | } 137 | 138 | __forceinline__ __device__ void compute_tensorcore(double * __restrict__ data, double * __restrict__ out, const int ldm, const int warp_id) { 139 | wmma::fragment param_frag[2][MMA_NUM]; 140 | #pragma unroll 141 | for (int i = 0; i < MMA_NUM; i++) { 142 | wmma::load_matrix_sync(param_frag[0][i], param_star_matrix_d + i * 32, 8); 143 | wmma::load_matrix_sync(param_frag[1][i], param_star_matrix_d + 52 * 8 + i * 32, 8); 144 | } 145 | 146 | wmma::fragment acc_frag; 147 | 148 | wmma::fragment in_frag; 149 | 150 | for (int col = warp_id * COL_PER_WARP; col < warp_id * COL_PER_WARP + COL_PER_WARP; col += UNIT_LENGTH) { 151 | wmma::fill_fragment(acc_frag, 0.0); 152 | #pragma unroll 153 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 154 | wmma::load_matrix_sync(in_frag, data + IDX2D(0, col + compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL); 155 | wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 156 | } 157 | #pragma unroll 158 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 159 | wmma::load_matrix_sync(in_frag, data + SM_SIZE_ROW * SM_SIZE_COL + IDX2D(0, col + compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL); 160 | wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 161 | } 162 | wmma::store_matrix_sync(out + IDX2D(col / UNIT_LENGTH, 0, BLOCK_SIZE_COL), acc_frag, TENSOR_CORE_M, wmma::mem_row_major); 163 | } 164 | __syncthreads(); 165 | } 166 | 167 | __forceinline__ __device__ void add(double * __restrict__ data1, double * __restrict__ data2, double * __restrict__ data3, double * __restrict__ data4, double * __restrict__ data5, double * __restrict__ data6, double * __restrict__ data7, double * __restrict__ out, const int cols) { 168 | int tid = threadIdx.x; 169 | int total_threads = blockDim.x; 170 | for (int i = tid; i < BLOCK_SIZE_ROW * BLOCK_SIZE_COL; i += total_threads) { 171 | int row = i / BLOCK_SIZE_COL; 172 | int col = i % BLOCK_SIZE_COL; 173 | out[IDX2D(row, col, cols)] = data1[i] + data2[i] + data3[i] + data4[i] + data5[i] + data6[i] + data7[i]; 174 | } 175 | } 176 | 177 | 178 | __global__ void gpu_star_3d1r_step3_kernel (const double * __restrict__ in, double * __restrict__ out, const int heights, const int rows, const int cols, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 179 | // __shared__ double data[D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL]; 180 | // __shared__ double trans[2][SM_SIZE_ROW * SM_SIZE_COL]; 181 | // __shared__ double intermediate[19][BLOCK_SIZE_ROW * BLOCK_SIZE_COL]; 182 | extern __shared__ double data[]; 183 | double * trans = &data[D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL]; 184 | double * intermediate = &data[D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL + 2 * SM_SIZE_ROW * SM_SIZE_COL]; 185 | 186 | int begin = IDX2D(blockIdx.x * BLOCK_SIZE_ROW, blockIdx.y * BLOCK_SIZE_COL, cols); 187 | int warp_id = threadIdx.x / 32; 188 | // int tid = threadIdx.x; 189 | // int total_threads = blockDim.x; 190 | 191 | load_original_data(data, in, 0, rows, cols); 192 | compute_one_point(data, intermediate); 193 | load_original_data(data, in, 1, rows, cols); 194 | compute_one_point(data, intermediate + BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 195 | compute_five_point(data, intermediate + 7 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 196 | load_original_data(data, in, 2, rows, cols); 197 | compute_one_point(data, intermediate + 2 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 198 | compute_five_point(data, intermediate + 8 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 199 | compute_thirteen_point(data, intermediate + 12 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 200 | load_original_data(data, in, 3, rows, cols); 201 | compute_one_point(data, intermediate + 3 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 202 | compute_five_point(data, intermediate + 9 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 203 | compute_thirteen_point(data, intermediate + 13 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 204 | load_trans_data(trans, data, lookup_table1, lookup_table2); 205 | compute_tensorcore(trans, intermediate + 16 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, SM_SIZE_COL, warp_id); 206 | load_original_data(data, in, 4, rows, cols); 207 | compute_one_point(data, intermediate + 4 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 208 | compute_five_point(data, intermediate + 10 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 209 | compute_thirteen_point(data, intermediate + 14 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 210 | load_trans_data(trans, data, lookup_table1, lookup_table2); 211 | compute_tensorcore(trans, intermediate + 17 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, SM_SIZE_COL, warp_id); 212 | load_original_data(data, in, 5, rows, cols); 213 | compute_one_point(data, intermediate + 5 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 214 | compute_five_point(data, intermediate + 11 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 215 | compute_thirteen_point(data, intermediate + 15 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 216 | load_trans_data(trans, data, lookup_table1, lookup_table2); 217 | compute_tensorcore(trans, intermediate + 18 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, SM_SIZE_COL, warp_id); 218 | for (int h = 6; h < heights + 6; h++) { 219 | load_original_data(data, in, h, rows, cols); 220 | compute_one_point(data, intermediate + (h % 7) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 221 | add( 222 | intermediate + ((h - 6) % 7) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 223 | intermediate + ((h - 6) % 5 + 7) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 224 | intermediate + ((h - 6) % 4 + 12) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 225 | intermediate + ((h - 6) % 3 + 16) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 226 | intermediate + ((h - 4) % 4 + 12) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 227 | intermediate + ((h - 2) % 5 + 7) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 228 | intermediate + (h % 7) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, 229 | out + (h - 3) * rows * cols + begin + IDX2D(HALO, HALO, cols), 230 | cols); 231 | compute_five_point(data, intermediate + ((h - 6) % 5 + 7) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 232 | compute_thirteen_point(data, intermediate + ((h - 6) % 4 + 12) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL); 233 | load_trans_data(trans, data, lookup_table1, lookup_table2); 234 | compute_tensorcore(trans, intermediate + ((h - 6) % 3 + 16) * BLOCK_SIZE_ROW * BLOCK_SIZE_COL, SM_SIZE_COL, warp_id); 235 | } 236 | } 237 | 238 | void gpu_star_3d1r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_h, const int input_m, const int input_n) { 239 | double param_matrix_h[2][52 * 8] = {0.0}; 240 | 241 | // Initialize parameter matrix 242 | CUDA_CHECK(cudaMemcpyToSymbol(param_one_d, params, sizeof(double))); 243 | CUDA_CHECK(cudaMemcpyToSymbol(param_five_d, params + 1, 5 * sizeof(double))); 244 | CUDA_CHECK(cudaMemcpyToSymbol(param_thirteen_d, params + 6, 13 * sizeof(double))); 245 | 246 | double temp_para[49] = {0.0}; 247 | copy_temp(temp_para, params); 248 | 249 | for (int col = 0; col < TENSOR_CORE_M; col++) { 250 | for(int i = 0; i < UNIT_LENGTH; i++) { 251 | for(int j = 0; j < UNIT_LENGTH; j++) { 252 | if (j >= col) { 253 | param_matrix_h[0][(i * UNIT_LENGTH + j) * 8 + col] = temp_para[i * UNIT_LENGTH + j - col]; 254 | } 255 | } 256 | } 257 | } 258 | for (int col = 0; col < TENSOR_CORE_M; col++) { 259 | for(int i = 0; i < UNIT_LENGTH; i++) { 260 | for(int j = 0; j < UNIT_LENGTH; j++) { 261 | if (j < col) { 262 | param_matrix_h[1][(i * UNIT_LENGTH + j) * 8 + col] = temp_para[i * UNIT_LENGTH + j - col + 7]; 263 | } 264 | } 265 | } 266 | } 267 | 268 | CUDA_CHECK(cudaMemcpyToSymbol(param_star_matrix_d, param_matrix_h, 2 * 8 * 52 * sizeof(double))); 269 | 270 | const int heights = input_h + 2 * HALO; 271 | const int rows = input_m + 2 * HALO; 272 | const int cols = input_n + 2 * HALO; 273 | const size_t array_size = heights * rows * cols * sizeof(double); 274 | double *array_d[2]; 275 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 276 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 277 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 278 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 279 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 280 | 281 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 282 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 283 | dim3 grid_config(BLOCK_M, BLOCK_N); 284 | // dim3 grid_config(1, 1); 285 | dim3 block_config(32 * WARP_PER_BLOCK); 286 | int sm_size = (D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL + 2 * SM_SIZE_ROW * SM_SIZE_COL + 19 * BLOCK_SIZE_ROW * BLOCK_SIZE_COL) * sizeof(double); 287 | CUDA_CHECK(cudaFuncSetAttribute(gpu_star_3d1r_step3_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, sm_size)); 288 | 289 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 290 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 291 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 292 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 293 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 294 | lookup_table1_h[i][j] = IDX2D(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 295 | } else { 296 | lookup_table1_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 297 | } 298 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 299 | lookup_table2_h[i][j] = IDX2D((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 300 | } else { 301 | lookup_table2_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 302 | } 303 | } 304 | } 305 | 306 | int * lookup_table1_d; 307 | int * lookup_table2_d; 308 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 309 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 310 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 311 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 312 | 313 | int i = 0; 314 | 315 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 316 | 317 | for(; i < times; i++) { 318 | CUDAKERNELCHECK((gpu_star_3d1r_step3_kernel<<>>(array_d[i % 2], array_d[(i + 1) % 2], input_h, rows, cols, lookup_table1_d, lookup_table2_d))); 319 | } 320 | CUDA_CHECK(cudaDeviceSynchronize()); 321 | 322 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 323 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 324 | 325 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 326 | std::cout << secs << std::endl; 327 | printf("GStencil/s = %f\n", ((double)input_m * input_n * input_h * times * 3) / secs / 1e9); 328 | 329 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 330 | 331 | return; 332 | } -------------------------------------------------------------------------------- /src/2d/main.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // #include "type.h" 4 | // #include "../utils.h" 5 | #include "2d_utils.h" 6 | // #include "mix/mix.h" 7 | // #include "cpu/cpu.h" 8 | // #include "gpu/gpu.h" 9 | // #include "heat/heat.h" 10 | 11 | const char *ShapeStr[5] = { 12 | "star_2d1r", 13 | "box_2d1r", 14 | "star_2d3r", 15 | "box_2d3r", 16 | }; 17 | 18 | // Fill the matrix with random numbers or indices 19 | #define FILL_RANDOM 20 | // #define FILL_INDEX 21 | 22 | // Check the correctness of the result or not 23 | // #define CHECK_ERROR 24 | const double tolerance = 1e-7; 25 | 26 | #define IDX(x, y, ldm) ((x) * (ldm) + (y)) 27 | #define ABS(x, y) (((x) > (y)) ? ((x) - (y)) : ((y) - (x))) 28 | 29 | // Write the output to file or not 30 | // #define WRITE_OUTPUT 31 | 32 | /* Global variable */ 33 | int NY; 34 | int XSLOPE, YSLOPE; 35 | 36 | void save_to_txt(double *arr, int rows, int cols, const char *filename) 37 | { 38 | FILE *file = fopen(filename, "w"); 39 | if (file == NULL) 40 | { 41 | printf("Error opening file!\n"); 42 | return; 43 | } 44 | 45 | for (int i = 0; i < rows; i++) 46 | { 47 | for (int j = 0; j < cols; j++) 48 | { 49 | fprintf(file, "%.0f\t", arr[IDX(i, j, cols)]); 50 | } 51 | fprintf(file, "\n"); 52 | } 53 | 54 | fclose(file); 55 | } 56 | 57 | void naive_box2d1r(double *in, double *out, double *param, const int input_m, const int input_n) 58 | { 59 | for (int row = 3; row < input_m - 3; row++) 60 | { 61 | for (int col = 4; col < input_n - 4; col++) 62 | { 63 | out[IDX(row, col, input_n)] = 64 | param[0] * in[IDX(row - 3, col - 3, input_n)] + 65 | param[1] * in[IDX(row - 3, col - 2, input_n)] + 66 | param[2] * in[IDX(row - 3, col - 1, input_n)] + 67 | param[3] * in[IDX(row - 3, col, input_n)] + 68 | param[4] * in[IDX(row - 3, col + 1, input_n)] + 69 | param[5] * in[IDX(row - 3, col + 2, input_n)] + 70 | param[6] * in[IDX(row - 3, col + 3, input_n)] + 71 | param[7] * in[IDX(row - 2, col - 3, input_n)] + 72 | param[8] * in[IDX(row - 2, col - 2, input_n)] + 73 | param[9] * in[IDX(row - 2, col - 1, input_n)] + 74 | param[10] * in[IDX(row - 2, col, input_n)] + 75 | param[11] * in[IDX(row - 2, col + 1, input_n)] + 76 | param[12] * in[IDX(row - 2, col + 2, input_n)] + 77 | param[13] * in[IDX(row - 2, col + 3, input_n)] + 78 | param[14] * in[IDX(row - 1, col - 3, input_n)] + 79 | param[15] * in[IDX(row - 1, col - 2, input_n)] + 80 | param[16] * in[IDX(row - 1, col - 1, input_n)] + 81 | param[17] * in[IDX(row - 1, col, input_n)] + 82 | param[18] * in[IDX(row - 1, col + 1, input_n)] + 83 | param[19] * in[IDX(row - 1, col + 2, input_n)] + 84 | param[20] * in[IDX(row - 1, col + 3, input_n)] + 85 | param[21] * in[IDX(row, col - 3, input_n)] + 86 | param[22] * in[IDX(row, col - 2, input_n)] + 87 | param[23] * in[IDX(row, col - 1, input_n)] + 88 | param[24] * in[IDX(row, col, input_n)] + 89 | param[25] * in[IDX(row, col + 1, input_n)] + 90 | param[26] * in[IDX(row, col + 2, input_n)] + 91 | param[27] * in[IDX(row, col + 3, input_n)] + 92 | param[28] * in[IDX(row + 1, col - 3, input_n)] + 93 | param[29] * in[IDX(row + 1, col - 2, input_n)] + 94 | param[30] * in[IDX(row + 1, col - 1, input_n)] + 95 | param[31] * in[IDX(row + 1, col, input_n)] + 96 | param[32] * in[IDX(row + 1, col + 1, input_n)] + 97 | param[33] * in[IDX(row + 1, col + 2, input_n)] + 98 | param[34] * in[IDX(row + 1, col + 3, input_n)] + 99 | param[35] * in[IDX(row + 2, col - 3, input_n)] + 100 | param[36] * in[IDX(row + 2, col - 2, input_n)] + 101 | param[37] * in[IDX(row + 2, col - 1, input_n)] + 102 | param[38] * in[IDX(row + 2, col, input_n)] + 103 | param[39] * in[IDX(row + 2, col + 1, input_n)] + 104 | param[40] * in[IDX(row + 2, col + 2, input_n)] + 105 | param[41] * in[IDX(row + 2, col + 3, input_n)] + 106 | param[42] * in[IDX(row + 3, col - 3, input_n)] + 107 | param[43] * in[IDX(row + 3, col - 2, input_n)] + 108 | param[44] * in[IDX(row + 3, col - 1, input_n)] + 109 | param[45] * in[IDX(row + 3, col, input_n)] + 110 | param[46] * in[IDX(row + 3, col + 1, input_n)] + 111 | param[47] * in[IDX(row + 3, col + 2, input_n)] + 112 | param[48] * in[IDX(row + 3, col + 3, input_n)]; 113 | } 114 | } 115 | } 116 | 117 | void printHelp() 118 | { 119 | const char *helpMessage = 120 | "Program name: convstencil_2d\n" 121 | "Usage: convstencil_2d shape input_size_of_first_dimension input_size_of_second_dimension time_iteration_size [Options]\n" 122 | "Shape: box2d1r or star2d1r or box2d3r or star2d3r\n" 123 | "Options:\n" 124 | " --help Display this help message and exit\n" 125 | " --custom If you want to use costum parameters, please use this option and input your parameters like 0.2 0.2 0.2 0.2 0.2 if the shape is star2d1r\n"; 126 | printf("%s\n", helpMessage); 127 | } 128 | 129 | int main(int argc, char *argv[]) 130 | { 131 | if (argc < 5) 132 | { 133 | printHelp(); 134 | return 1; 135 | } 136 | 137 | // configurable settings 138 | Shape compute_shape; 139 | std::string arg1 = argv[1]; 140 | if (arg1 == "box2d1r") 141 | { 142 | compute_shape = box_2d1r; 143 | } 144 | else if (arg1 == "star2d1r") 145 | { 146 | compute_shape = star_2d1r; 147 | } 148 | else if (arg1 == "star2d3r") 149 | { 150 | compute_shape = star_2d3r; 151 | } 152 | else if (arg1 == "box2d3r") 153 | { 154 | compute_shape = box_2d3r; 155 | } 156 | else 157 | { 158 | printHelp(); 159 | return 1; 160 | } 161 | 162 | int m = 0; 163 | int n = 0; 164 | int times = 0; 165 | 166 | try 167 | { 168 | m = std::stoi(argv[2]); 169 | n = std::stoi(argv[3]); 170 | times = std::stoi(argv[4]); 171 | } 172 | catch (const std::invalid_argument &e) 173 | { 174 | std::cerr << "Invalid argument: cannot convert the parameter(s) to integer.\n"; 175 | return 1; 176 | } 177 | catch (const std::out_of_range &e) 178 | { 179 | std::cerr << "Argument out of range: the parameter(s) is(are) too large.\n"; 180 | return 1; 181 | } 182 | 183 | double param_1r[9] = {0.0}; 184 | bool breakdown = false; 185 | if (argc == 6 && std::string(argv[5]) == "--custom") 186 | { 187 | int num_param = 9; 188 | if (arg1 == "box2d1r") 189 | { 190 | num_param = 9; 191 | } 192 | else if (arg1 == "star2d1r") 193 | { 194 | num_param = 5; 195 | } 196 | printf("Please enter %d parameters:\n", num_param); 197 | double values[num_param]; 198 | for (int i = 0; i < num_param; i++) 199 | { 200 | int readNum = scanf("%lf", &values[i]); 201 | if (readNum != 1) 202 | return 1; 203 | } 204 | if (num_param == 9) 205 | { 206 | for (int i = 0; i < 9; i++) 207 | { 208 | param_1r[i] = values[i]; 209 | } 210 | } 211 | else 212 | { 213 | param_1r[1] = values[0]; 214 | param_1r[3] = values[1]; 215 | param_1r[4] = values[2]; 216 | param_1r[5] = values[3]; 217 | param_1r[7] = values[4]; 218 | } 219 | } 220 | 221 | if (argc == 6 && std::string(argv[5]) == "--breakdown") { 222 | breakdown = true; 223 | } 224 | 225 | double param_box_2d1r[49] = {0.0}; 226 | double param_star_2d1r[49] = {0.0}; 227 | 228 | for (int i = 0; i < 49; i++) 229 | { 230 | param_box_2d1r[i] = 0.021; 231 | } 232 | 233 | param_box_2d1r[16] = (3 * param_1r[0] * param_1r[0] * param_1r[8] + 6 * param_1r[0] * param_1r[1] * param_1r[7] + 6 * param_1r[0] * param_1r[2] * param_1r[6] + 6 * param_1r[0] * param_1r[3] * param_1r[5] + 3 * param_1r[0] * param_1r[4] * param_1r[4] + 3 * param_1r[1] * param_1r[1] * param_1r[6] + 6 * param_1r[1] * param_1r[3] * param_1r[4] + 3 * param_1r[2] * param_1r[3] * param_1r[3]); 234 | param_box_2d1r[15] = (3 * param_1r[0] * param_1r[0] * param_1r[7] + 6 * param_1r[0] * param_1r[1] * param_1r[6] + 6 * param_1r[0] * param_1r[3] * param_1r[4] + 3 * param_1r[1] * param_1r[3] * param_1r[3]); 235 | param_box_2d1r[14] = (3 * param_1r[0] * param_1r[0] * param_1r[6] + 3 * param_1r[0] * param_1r[3] * param_1r[3]); 236 | param_box_2d1r[17] = (6 * param_1r[0] * param_1r[1] * param_1r[8] + 6 * param_1r[0] * param_1r[2] * param_1r[7] + 6 * param_1r[0] * param_1r[4] * param_1r[5] + 3 * param_1r[1] * param_1r[1] * param_1r[7] + 6 * param_1r[1] * param_1r[2] * param_1r[6] + 6 * param_1r[1] * param_1r[3] * param_1r[5] + 3 * param_1r[1] * param_1r[4] * param_1r[4] + 6 * param_1r[2] * param_1r[3] * param_1r[4]); 237 | param_box_2d1r[18] = (6 * param_1r[0] * param_1r[2] * param_1r[8] + 3 * param_1r[0] * param_1r[0] * param_1r[5] + 3 * param_1r[1] * param_1r[1] * param_1r[8] + 6 * param_1r[1] * param_1r[2] * param_1r[7] + 6 * param_1r[1] * param_1r[4] * param_1r[5] + 3 * param_1r[2] * param_1r[2] * param_1r[6] + 6 * param_1r[2] * param_1r[3] * param_1r[5] + 3 * param_1r[2] * param_1r[4] * param_1r[4]); 238 | param_box_2d1r[19] = (6 * param_1r[1] * param_1r[2] * param_1r[8] + 3 * param_1r[1] * param_1r[1] * param_1r[5] + 3 * param_1r[2] * param_1r[2] * param_1r[7] + 6 * param_1r[2] * param_1r[4] * param_1r[5]); 239 | param_box_2d1r[20] = (3 * param_1r[2] * param_1r[2] * param_1r[8] + 3 * param_1r[2] * param_1r[5] * param_1r[5]); 240 | param_box_2d1r[9] = (3 * param_1r[0] * param_1r[0] * param_1r[5] + 6 * param_1r[0] * param_1r[1] * param_1r[4] + 6 * param_1r[0] * param_1r[2] * param_1r[3] + 3 * param_1r[1] * param_1r[1] * param_1r[3]); 241 | param_box_2d1r[8] = (3 * param_1r[0] * param_1r[0] * param_1r[4] + 6 * param_1r[0] * param_1r[1] * param_1r[3]); 242 | param_box_2d1r[7] = 3 * param_1r[0] * param_1r[0] * param_1r[3]; 243 | param_box_2d1r[10] = (6 * param_1r[0] * param_1r[1] * param_1r[5] + 6 * param_1r[0] * param_1r[2] * param_1r[4] + 3 * param_1r[1] * param_1r[1] * param_1r[4] + 6 * param_1r[1] * param_1r[2] * param_1r[3]); 244 | param_box_2d1r[11] = (6 * param_1r[0] * param_1r[2] * param_1r[5] + 3 * param_1r[1] * param_1r[1] * param_1r[5] + 6 * param_1r[1] * param_1r[2] * param_1r[4] + 3 * param_1r[2] * param_1r[2] * param_1r[3]); 245 | param_box_2d1r[12] = (6 * param_1r[1] * param_1r[2] * param_1r[5] + 3 * param_1r[2] * param_1r[2] * param_1r[4]); 246 | param_box_2d1r[13] = 3 * param_1r[2] * param_1r[2] * param_1r[5]; 247 | param_box_2d1r[2] = (3 * param_1r[0] * param_1r[0] * param_1r[2] + 3 * param_1r[0] * param_1r[1] * param_1r[1]); 248 | param_box_2d1r[1] = 3 * param_1r[0] * param_1r[0] * param_1r[1]; 249 | param_box_2d1r[0] = param_1r[0] * param_1r[0] * param_1r[0]; 250 | param_box_2d1r[3] = (6 * param_1r[0] * param_1r[1] * param_1r[2] + param_1r[1] * param_1r[1] * param_1r[1]); 251 | param_box_2d1r[4] = (3 * param_1r[0] * param_1r[2] * param_1r[2] + 3 * param_1r[1] * param_1r[1] * param_1r[2]); 252 | param_box_2d1r[5] = 3 * param_1r[1] * param_1r[2] * param_1r[2]; 253 | param_box_2d1r[6] = param_1r[2] * param_1r[2] * param_1r[2]; 254 | param_box_2d1r[23] = (6 * param_1r[0] * param_1r[3] * param_1r[8] + 6 * param_1r[0] * param_1r[4] * param_1r[7] + 6 * param_1r[0] * param_1r[5] * param_1r[6] + 6 * param_1r[1] * param_1r[3] * param_1r[7] + 6 * param_1r[1] * param_1r[4] * param_1r[6] + 6 * param_1r[2] * param_1r[3] * param_1r[6] + 3 * param_1r[3] * param_1r[3] * param_1r[5] + 3 * param_1r[3] * param_1r[4] * param_1r[4]); 255 | param_box_2d1r[22] = (6 * param_1r[0] * param_1r[3] * param_1r[7] + 6 * param_1r[0] * param_1r[4] * param_1r[6] + 6 * param_1r[1] * param_1r[3] * param_1r[6] + 3 * param_1r[3] * param_1r[3] * param_1r[4]); 256 | param_box_2d1r[21] = (6 * param_1r[0] * param_1r[3] * param_1r[6] + param_1r[3] * param_1r[3] * param_1r[3]); 257 | param_box_2d1r[24] = (6 * param_1r[0] * param_1r[4] * param_1r[8] + 6 * param_1r[0] * param_1r[5] * param_1r[7] + 6 * param_1r[1] * param_1r[3] * param_1r[8] + 6 * param_1r[1] * param_1r[4] * param_1r[7] + 6 * param_1r[1] * param_1r[5] * param_1r[6] + 6 * param_1r[2] * param_1r[3] * param_1r[7] + 6 * param_1r[2] * param_1r[4] * param_1r[6] + 6 * param_1r[3] * param_1r[4] * param_1r[5] + pow(param_1r[4], 3)); 258 | param_box_2d1r[25] = (6 * param_1r[0] * param_1r[5] * param_1r[8] + 6 * param_1r[1] * param_1r[4] * param_1r[8] + 6 * param_1r[1] * param_1r[5] * param_1r[7] + 6 * param_1r[2] * param_1r[3] * param_1r[8] + 6 * param_1r[2] * param_1r[4] * param_1r[7] + 6 * param_1r[2] * param_1r[5] * param_1r[6] + 3 * param_1r[3] * param_1r[5] * param_1r[5] + 3 * param_1r[4] * param_1r[4] * param_1r[5]); 259 | param_box_2d1r[26] = (6 * param_1r[1] * param_1r[5] * param_1r[8] + 6 * param_1r[2] * param_1r[4] * param_1r[8] + 6 * param_1r[2] * param_1r[5] * param_1r[7] + 3 * param_1r[4] * param_1r[5] * param_1r[5]); 260 | param_box_2d1r[27] = (6 * param_1r[2] * param_1r[5] * param_1r[8] + param_1r[5] * param_1r[5] * param_1r[5]); 261 | param_box_2d1r[30] = (6 * param_1r[0] * param_1r[6] * param_1r[8] + 3 * param_1r[0] * param_1r[7] * param_1r[7] + 6 * param_1r[1] * param_1r[6] * param_1r[7] + 3 * param_1r[2] * param_1r[6] * param_1r[6] + 3 * param_1r[3] * param_1r[3] * param_1r[8] + 6 * param_1r[3] * param_1r[4] * param_1r[7] + 6 * param_1r[3] * param_1r[5] * param_1r[6] + 3 * param_1r[4] * param_1r[6] * param_1r[6]); 262 | param_box_2d1r[29] = (6 * param_1r[0] * param_1r[6] * param_1r[7] + 3 * param_1r[1] * param_1r[6] * param_1r[6] + 3 * param_1r[3] * param_1r[3] * param_1r[7] + 6 * param_1r[3] * param_1r[4] * param_1r[6]); 263 | param_box_2d1r[28] = (3 * param_1r[0] * param_1r[6] * param_1r[6] + 3 * param_1r[3] * param_1r[3] * param_1r[6]); 264 | param_box_2d1r[31] = (6 * param_1r[0] * param_1r[7] * param_1r[8] + 6 * param_1r[1] * param_1r[6] * param_1r[8] + 3 * param_1r[1] * param_1r[7] * param_1r[7] + 6 * param_1r[2] * param_1r[6] * param_1r[7] + 6 * param_1r[3] * param_1r[4] * param_1r[8] + 6 * param_1r[3] * param_1r[5] * param_1r[7] + 3 * param_1r[4] * param_1r[7] * param_1r[7] + 6 * param_1r[4] * param_1r[5] * param_1r[6]); 265 | param_box_2d1r[32] = (3 * param_1r[0] * param_1r[8] * param_1r[8] + 6 * param_1r[1] * param_1r[7] * param_1r[8] + 6 * param_1r[2] * param_1r[6] * param_1r[8] + 3 * param_1r[2] * param_1r[7] * param_1r[7] + 6 * param_1r[3] * param_1r[5] * param_1r[8] + 3 * param_1r[4] * param_1r[8] * param_1r[8] + 6 * param_1r[4] * param_1r[5] * param_1r[7] + 3 * param_1r[5] * param_1r[6] * param_1r[6]); 266 | param_box_2d1r[33] = (3 * param_1r[1] * param_1r[8] * param_1r[8] + 6 * param_1r[2] * param_1r[7] * param_1r[8] + 6 * param_1r[4] * param_1r[5] * param_1r[8] + 3 * param_1r[5] * param_1r[7] * param_1r[7]); 267 | param_box_2d1r[34] = (3 * param_1r[2] * param_1r[8] * param_1r[8] + 3 * param_1r[5] * param_1r[8] * param_1r[8]); 268 | param_box_2d1r[37] = (6 * param_1r[3] * param_1r[6] * param_1r[8] + 3 * param_1r[3] * param_1r[7] * param_1r[7] + 6 * param_1r[4] * param_1r[6] * param_1r[7] + 3 * param_1r[5] * param_1r[6] * param_1r[6]); 269 | param_box_2d1r[36] = (6 * param_1r[3] * param_1r[6] * param_1r[7] + 3 * param_1r[4] * param_1r[6] * param_1r[6]); 270 | param_box_2d1r[35] = 3 * param_1r[3] * param_1r[3] * param_1r[6]; 271 | param_box_2d1r[38] = (6 * param_1r[3] * param_1r[7] * param_1r[8] + 6 * param_1r[4] * param_1r[6] * param_1r[8] + 3 * param_1r[4] * param_1r[7] * param_1r[7] + 6 * param_1r[5] * param_1r[6] * param_1r[7]); 272 | param_box_2d1r[39] = (3 * param_1r[3] * param_1r[8] * param_1r[8] + 6 * param_1r[4] * param_1r[7] * param_1r[8] + 6 * param_1r[5] * param_1r[6] * param_1r[8] + 3 * param_1r[5] * param_1r[7] * param_1r[7]); 273 | param_box_2d1r[40] = (3 * param_1r[4] * param_1r[8] * param_1r[8] + 6 * param_1r[5] * param_1r[7] * param_1r[8]); 274 | param_box_2d1r[41] = 3 * param_1r[5] * param_1r[5] * param_1r[8]; 275 | param_box_2d1r[44] = (3 * param_1r[6] * param_1r[6] * param_1r[8] + 3 * param_1r[6] * param_1r[7] * param_1r[7]); 276 | param_box_2d1r[43] = 3 * param_1r[6] * param_1r[6] * param_1r[7]; 277 | param_box_2d1r[42] = param_1r[6] * param_1r[6] * param_1r[6]; 278 | param_box_2d1r[45] = (6 * param_1r[6] * param_1r[7] * param_1r[8] + param_1r[7] * param_1r[7] * param_1r[7]); 279 | param_box_2d1r[46] = (3 * param_1r[6] * param_1r[8] * param_1r[8] + 3 * param_1r[7] * param_1r[7] * param_1r[8]); 280 | param_box_2d1r[47] = 3 * param_1r[7] * param_1r[8] * param_1r[8]; 281 | param_box_2d1r[48] = param_1r[8] * param_1r[8] * param_1r[8]; 282 | 283 | double *param; 284 | int halo; 285 | switch (compute_shape) 286 | { 287 | case box_2d1r: 288 | param = param_box_2d1r; 289 | halo = 3; 290 | break; 291 | case star_2d1r: 292 | param = param_star_2d1r; 293 | halo = 3; 294 | break; 295 | case star_2d3r: 296 | param = param_star_2d1r; 297 | halo = 3; 298 | break; 299 | case box_2d3r: 300 | param = param_box_2d1r; 301 | halo = 3; 302 | break; 303 | } 304 | 305 | // print brief info 306 | printf("INFO: shape = %s, m = %d, n = %d, times = %d\n", ShapeStr[compute_shape], m, n, times); 307 | 308 | int rows = m + 2 * halo; 309 | int cols = n + 2 * halo + 2; 310 | NY = n; 311 | size_t matrix_size = (unsigned long)rows * cols * sizeof(double); 312 | 313 | // allocate space 314 | 315 | double *matrix = (double *)malloc(matrix_size); 316 | double *output = (double *)malloc(matrix_size); 317 | 318 | // fill input matrix 319 | 320 | #if defined(FILL_RANDOM) 321 | #pragma unroll 322 | for (int i = 0; i < rows * cols; i++) 323 | { 324 | matrix[i] = (double)(rand() % 100); 325 | } 326 | #elif defined(FILL_INDEX) 327 | for (int i = 0; i < rows; i++) 328 | { 329 | for (int j = 1; j < cols - 1; j++) 330 | { 331 | matrix[i * cols + j] = (double)(i * (cols - 2) + j); 332 | } 333 | } 334 | #else 335 | for (int i = 0; i < rows; i++) 336 | { 337 | for (int j = 0; j < cols - 1; j++) 338 | { 339 | matrix[i * cols + j] = 1.0; 340 | } 341 | } 342 | // std::fill_n(matrix, rows * cols, 1.0); 343 | #endif 344 | 345 | switch (compute_shape) 346 | { 347 | case box_2d1r: 348 | case star_2d1r: 349 | if (breakdown) 350 | { 351 | gpu_box_2d1r_breakdown1(matrix, output, param, times, m, n); 352 | gpu_box_2d1r_breakdown2(matrix, output, param, times, m, n); 353 | gpu_box_2d1r_breakdown3(matrix, output, param, times, m, n); 354 | gpu_box_2d1r_breakdown4(matrix, output, param, times, m, n); 355 | gpu_box_2d1r(matrix, output, param, times, m, n); 356 | 357 | } 358 | else 359 | { 360 | gpu_box_2d1r(matrix, output, 361 | param, times, 362 | m, n); 363 | } 364 | break; 365 | case star_2d3r: 366 | case box_2d3r: 367 | gpu_box_2d3r(matrix, output, 368 | param, times, 369 | m, n); 370 | break; 371 | } 372 | 373 | // check result correctness 374 | 375 | #if defined(CHECK_ERROR) 376 | printf("\nChecking ... \n"); 377 | double *naive[2]; 378 | naive[0] = (double *)malloc(matrix_size); 379 | naive[1] = (double *)malloc(matrix_size); 380 | 381 | for (int i = 0; i < rows * cols; i++) 382 | { 383 | naive[0][i] = matrix[i]; 384 | naive[1][i] = 0; 385 | } 386 | 387 | int t = 0; 388 | if (compute_shape == box_2d1r_step3) 389 | { 390 | for (; t < times; t++) 391 | { 392 | naive_box2d1r(naive[t % 2], naive[(t + 1) % 2], param, rows, cols); 393 | } 394 | } 395 | printf("Comparing naive and output\n"); 396 | for (int row = 0; row < rows; row++) 397 | { 398 | for (int col = 0; col < cols; col++) 399 | { 400 | if (ABS(naive[t % 2][IDX(row, col, cols)], output[IDX(row, col, cols)]) > 1e-7) 401 | { 402 | printf("row = %d, col = %d, naive = %lf, output = %lf\n", row, col, naive[t % 2][IDX(row, col, cols)], output[IDX(row, col, cols)]); 403 | } 404 | } 405 | } 406 | #endif 407 | 408 | // write to file 409 | 410 | #ifdef WRITE_OUTPUT 411 | #ifdef RUN_GPU 412 | printf("Writing output_gpu.txt\n"); 413 | save_to_txt(output, rows, cols, "output_gpu.txt"); 414 | save_to_txt(naive[t % 2], rows, cols, "output_naive.txt"); 415 | #endif 416 | #endif 417 | 418 | // free space 419 | free(output); 420 | free(matrix); 421 | 422 | return 0; 423 | } -------------------------------------------------------------------------------- /src/1d/gpu_1r.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "1d_utils.h" 4 | #include 5 | 6 | using namespace nvcuda; 7 | 8 | #define BLOCK_SIZE_COL 1024//tune 9 | #define HALO 3 10 | #define D_BLOCK_SIZE_COL (BLOCK_SIZE_COL + HALO * 2) 11 | #define PAD 0 12 | #define SM_SIZE_ROW (D_BLOCK_SIZE_COL / 8) 13 | #define UNIT_LENGTH 7 14 | #define TENSOR_CORE_M 8 15 | #define WARP_PER_BLOCK 8//tune 16 | #define MMA_NUM 2 17 | #define IDX(x, y, ldm) ((x) * (ldm) + (y)) 18 | #define SM_SIZE_ROW (D_BLOCK_SIZE_COL / 8) 19 | #define SM_SIZE_COL (7+PAD) 20 | #define SM_SIZE_COL2 (7+PAD2) 21 | #define PAD2 2 22 | 23 | 24 | extern __constant__ double param_matrix_d[2 * 8 * TENSOR_CORE_M]; 25 | // extern __constant__ double param_matrix_d[2 * UNIT_LENGTH * UNIT_LENGTH]; 26 | __global__ void gpu_1d1r_kernel(const double *__restrict__ in, double *__restrict__ out, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) 27 | { 28 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL]; 29 | 30 | int begin = blockIdx.x * BLOCK_SIZE_COL+1; 31 | int laneid = threadIdx.x % 32; 32 | 33 | int tid = threadIdx.x; 34 | int totalThreads = blockDim.x; 35 | 36 | for (int i = tid; i < D_BLOCK_SIZE_COL; i += totalThreads) { 37 | sharedmem[0][lookup_table1[i]] = in[begin + i]; 38 | sharedmem[1][lookup_table2[i]] = in[begin + i]; 39 | } 40 | 41 | __syncthreads(); 42 | 43 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 44 | #pragma unroll 45 | for (int i = 0; i < MMA_NUM; i++) 46 | { 47 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 48 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 2 * 4 * 8 + i * 32, 8); 49 | } 50 | 51 | nvcuda::wmma::fragment acc_frag; 52 | 53 | nvcuda::wmma::fragment in_frag; 54 | int warp_id=threadIdx.x/32; 55 | #pragma unroll 56 | for (int row = 2*8*warp_id; row < (warp_id+1)*8*2; row += TENSOR_CORE_M) 57 | { 58 | #pragma unroll 59 | 60 | nvcuda::wmma::fill_fragment(acc_frag, 0.0); 61 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 62 | { 63 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem[0] + IDX(row, compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL); 64 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 65 | } 66 | 67 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 68 | { 69 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem[1] + IDX(row, compute_idx * 4, SM_SIZE_COL),SM_SIZE_COL); 70 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 71 | } 72 | 73 | nvcuda::wmma::store_matrix_sync(out + begin + row / 8 * 64 + HALO , acc_frag, TENSOR_CORE_M, nvcuda::wmma::mem_row_major); //+1为了对齐 74 | } 75 | } 76 | 77 | __global__ void breakdown4_kernel(const double *__restrict__ in, double *__restrict__ out, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) 78 | { 79 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL2]; 80 | 81 | int begin = blockIdx.x * BLOCK_SIZE_COL+1; 82 | int laneid = threadIdx.x % 32; 83 | 84 | int tid = threadIdx.x; 85 | int totalThreads = blockDim.x; 86 | 87 | for (int i = tid; i < D_BLOCK_SIZE_COL; i += totalThreads) { 88 | if (lookup_table1[i] != -1) { 89 | sharedmem[0][lookup_table1[i]] = in[begin + i]; 90 | } 91 | if (lookup_table2[i] != -1) { 92 | sharedmem[1][lookup_table2[i]] = in[begin + i]; 93 | } 94 | } 95 | 96 | __syncthreads(); 97 | 98 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 99 | #pragma unroll 100 | for (int i = 0; i < MMA_NUM; i++) 101 | { 102 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 103 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 2 * 4 * 8 + i * 32, 8); 104 | } 105 | 106 | nvcuda::wmma::fragment acc_frag; 107 | 108 | nvcuda::wmma::fragment in_frag; 109 | int warp_id=threadIdx.x/32; 110 | 111 | #pragma unroll 112 | for (int row = 2*8*warp_id; row < (warp_id+1)*8*2; row += TENSOR_CORE_M) 113 | { 114 | #pragma unroll 115 | nvcuda::wmma::fill_fragment(acc_frag, 0.0); 116 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 117 | { 118 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem[0] + IDX(row, compute_idx * 4, SM_SIZE_COL2), SM_SIZE_COL2); 119 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 120 | } 121 | 122 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 123 | { 124 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem[1] + IDX(row, compute_idx * 4, SM_SIZE_COL2),SM_SIZE_COL2); 125 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 126 | } 127 | 128 | nvcuda::wmma::store_matrix_sync(out + begin + row / 8 * 64 + HALO , acc_frag, TENSOR_CORE_M, nvcuda::wmma::mem_row_major); //+1为了对齐 129 | } 130 | } 131 | 132 | __global__ void breakdown3_kernel(const double *__restrict__ in, double *__restrict__ out, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) 133 | { 134 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL]; 135 | 136 | int begin = blockIdx.x * BLOCK_SIZE_COL+1; 137 | int laneid = threadIdx.x % 32; 138 | 139 | int tid = threadIdx.x; 140 | int totalThreads = blockDim.x; 141 | 142 | for (int i = tid; i < D_BLOCK_SIZE_COL; i += totalThreads) { 143 | if (lookup_table1[i] != -1) { 144 | sharedmem[0][lookup_table1[i]] = in[begin + i]; 145 | } 146 | if (lookup_table2[i] != -1) { 147 | sharedmem[1][lookup_table2[i]] = in[begin + i]; 148 | } 149 | } 150 | 151 | __syncthreads(); 152 | 153 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 154 | #pragma unroll 155 | for (int i = 0; i < MMA_NUM; i++) 156 | { 157 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 158 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 2 * 4 * 8 + i * 32, 8); 159 | } 160 | 161 | nvcuda::wmma::fragment acc_frag; 162 | 163 | nvcuda::wmma::fragment in_frag; 164 | int warp_id=threadIdx.x/32; 165 | #pragma unroll 166 | for (int row = 2*8*warp_id; row < (warp_id+1)*8*2; row += TENSOR_CORE_M) 167 | { 168 | #pragma unroll 169 | nvcuda::wmma::fill_fragment(acc_frag, 0.0); 170 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 171 | { 172 | 173 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem[0] + IDX(row, compute_idx * 4, 7), 7); 174 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 175 | } 176 | 177 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) 178 | { 179 | nvcuda::wmma::load_matrix_sync(in_frag, sharedmem[1] + IDX(row, compute_idx * 4, 7), 7); 180 | nvcuda::wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 181 | } 182 | 183 | nvcuda::wmma::store_matrix_sync(out + begin + row / 8 * 64 + HALO , acc_frag, TENSOR_CORE_M, nvcuda::wmma::mem_row_major); //+1为了对齐 184 | } 185 | } 186 | 187 | __global__ void breakdown1_kernel(const double *__restrict__ in, double *__restrict__ out, double* __restrict__ la,double* __restrict__ lb) 188 | { 189 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL]; 190 | 191 | int begin = blockIdx.x * BLOCK_SIZE_COL+1; 192 | int gbegin=(blockIdx.x)*SM_SIZE_COL*SM_SIZE_ROW; 193 | int x=threadIdx.x; 194 | 195 | 196 | for(int col=x;col 2 * HALO) { 201 | lb[gbegin+IDX((col - UNIT_LENGTH) / (UNIT_LENGTH + 1), (col - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL)]=in[begin+col]; 202 | } 203 | } 204 | 205 | 206 | __syncthreads(); 207 | 208 | for(int row=x;row 2 * HALO) { 253 | lookup_table2_h[j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 254 | } else { 255 | lookup_table2_h[j] = SM_SIZE_ROW*SM_SIZE_COL - 1; 256 | } 257 | } 258 | 259 | int * lookup_table1_d; 260 | int * lookup_table2_d; 261 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_COL * sizeof(int))); 262 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_COL * sizeof(int))); 263 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 264 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 265 | 266 | const int cols = input_n + 2 * HALO + 1; // 1 for address alighment 267 | const size_t array_size = cols * sizeof(double); 268 | double *array_d[2]; 269 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 270 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 271 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 272 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 273 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 274 | 275 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 276 | dim3 grid_config(BLOCK_N); 277 | dim3 block_config(32 * WARP_PER_BLOCK); 278 | 279 | CUDA_CHECK(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); 280 | 281 | // timing 282 | int i = 0; 283 | 284 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 285 | 286 | for (; i < times; i++) 287 | { 288 | CUDAKERNELCHECK((gpu_1d1r_kernel<<>>(array_d[i % 2] , array_d[(i + 1) % 2], lookup_table1_d, lookup_table2_d))); // 为了对齐空了4个 289 | } 290 | CUDA_CHECK(cudaDeviceSynchronize()); 291 | 292 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 293 | std::cout << "ConvStencil(1D): " << std::endl; 294 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 295 | 296 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 297 | printf("GStencil/s = %f\n", ((double)input_n * times * 3) / secs / 1e9); 298 | 299 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2] + 1, array_size - sizeof(double), cudaMemcpyDeviceToHost)); 300 | 301 | return; 302 | } 303 | 304 | 305 | void gpu_1d1r_breakdown4(const double *__restrict__ in, double *__restrict__ out, const double *__restrict__ params, const int times, const int input_n) 306 | { 307 | double param_matrix_h[2][8 * 8] = {}; 308 | 309 | // Initialize parameter matrix 310 | 311 | for (int row = 0; row < 7; row++) // kernel size 7 312 | for (int col = 0; col <= row; ++col) 313 | param_matrix_h[0][row * 8 + col] = params[row - col]; 314 | 315 | for (int row = 0; row < 7; row++) 316 | for (int col = row + 1; col < 8; ++col) 317 | param_matrix_h[1][row * 8 + col] = params[row + 7 - col]; 318 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 8 * sizeof(double))); 319 | 320 | int lookup_table1_h[D_BLOCK_SIZE_COL]; 321 | int lookup_table2_h[D_BLOCK_SIZE_COL]; 322 | 323 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 324 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 325 | lookup_table1_h[j] = IDX(j / (UNIT_LENGTH + 1), j % (UNIT_LENGTH + 1), SM_SIZE_COL2);//9 326 | } else { 327 | lookup_table1_h[j] = - 1; 328 | } 329 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 330 | lookup_table2_h[j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL2); 331 | } else { 332 | lookup_table2_h[j] = - 1; 333 | } 334 | } 335 | 336 | int * lookup_table1_d; 337 | int * lookup_table2_d; 338 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_COL * sizeof(int))); 339 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_COL * sizeof(int))); 340 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 341 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 342 | 343 | const int cols = input_n + 2 * HALO + 1; // 1 for address alighment 344 | const size_t array_size = cols * sizeof(double); 345 | double *array_d[2]; 346 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 347 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 348 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 349 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 350 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 351 | 352 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 353 | dim3 grid_config(BLOCK_N); 354 | dim3 block_config(32 * WARP_PER_BLOCK); 355 | 356 | int smem_size=SM_SIZE_ROW * SM_SIZE_COL2 * 2 * sizeof(double); 357 | // CUDA_CHECK(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); 358 | 359 | // timing 360 | int i = 0; 361 | 362 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 363 | 364 | for (; i < times; i++) 365 | { 366 | CUDAKERNELCHECK((breakdown4_kernel<<>>(array_d[i % 2] , array_d[(i + 1) % 2], lookup_table1_d, lookup_table2_d))); 367 | } 368 | CUDA_CHECK(cudaDeviceSynchronize()); 369 | 370 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 371 | std::cout << "Experiment - Breakdown(1D) 4: " << std::endl; 372 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 373 | 374 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 375 | printf("GStencil/s = %f\n\n", ((double)input_n * times * 3) / secs / 1e9); 376 | 377 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2] + 1, array_size - sizeof(double), cudaMemcpyDeviceToHost)); 378 | 379 | return; 380 | } 381 | 382 | 383 | void gpu_1d1r_breakdown3(const double *__restrict__ in, double *__restrict__ out, const double *__restrict__ params, const int times, const int input_n) 384 | { 385 | double param_matrix_h[2][8 * 8] = {}; 386 | 387 | // Initialize parameter matrix 388 | 389 | for (int row = 0; row < 7; row++) // kernel size 7 390 | for (int col = 0; col <= row; ++col) 391 | param_matrix_h[0][row * 8 + col] = params[row - col]; 392 | 393 | for (int row = 0; row < 7; row++) 394 | for (int col = row + 1; col < 8; ++col) 395 | param_matrix_h[1][row * 8 + col] = params[row + 7 - col]; 396 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 8 * sizeof(double))); 397 | 398 | int lookup_table1_h[D_BLOCK_SIZE_COL]; 399 | int lookup_table2_h[D_BLOCK_SIZE_COL]; 400 | 401 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 402 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 403 | lookup_table1_h[j] = IDX(j / (UNIT_LENGTH + 1), j % (UNIT_LENGTH + 1), SM_SIZE_COL);//9 404 | } else { 405 | lookup_table1_h[j] = - 1;//去掉*7 差得更少 得填到padding区域 406 | } 407 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 408 | lookup_table2_h[j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 409 | } else { 410 | lookup_table2_h[j] = - 1; 411 | } 412 | } 413 | 414 | int * lookup_table1_d; 415 | int * lookup_table2_d; 416 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_COL * sizeof(int))); 417 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_COL * sizeof(int))); 418 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 419 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 420 | 421 | const int cols = input_n + 2 * HALO + 1; // 1 for address alighment 422 | const size_t array_size = cols * sizeof(double); 423 | double *array_d[2]; 424 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 425 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 426 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 427 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 428 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 429 | 430 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 431 | dim3 grid_config(BLOCK_N); 432 | dim3 block_config(32 * WARP_PER_BLOCK); 433 | 434 | CUDA_CHECK(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); 435 | 436 | // timing 437 | int i = 0; 438 | 439 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 440 | 441 | for (; i < times; i++) 442 | { 443 | CUDAKERNELCHECK((breakdown3_kernel<<>>(array_d[i % 2] , array_d[(i + 1) % 2], lookup_table1_d, lookup_table2_d))); // 为了对齐空了4个 444 | } 445 | CUDA_CHECK(cudaDeviceSynchronize()); 446 | 447 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 448 | std::cout << "Experiment - Breakdown(1D) 3: " << std::endl; 449 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 450 | 451 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 452 | printf("GStencil/s = %f\n\n", ((double)input_n * times * 3) / secs / 1e9); 453 | 454 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2] + 1, array_size - sizeof(double), cudaMemcpyDeviceToHost)); 455 | 456 | return; 457 | } 458 | 459 | __global__ void breakdown2_kernel(const double *__restrict__ in, double *__restrict__ out, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) 460 | { 461 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL]; 462 | 463 | int begin = blockIdx.x * BLOCK_SIZE_COL+1; 464 | int x=threadIdx.x; 465 | 466 | 467 | for(int col=x;col 2 * HALO) { 472 | sharedmem[1][IDX((col - UNIT_LENGTH) / (UNIT_LENGTH + 1), (col - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL)]= in[begin + col]; 473 | } 474 | } 475 | 476 | 477 | __syncthreads(); 478 | 479 | for(int row=x;row 2 * HALO) { 517 | lookup_table2_h[j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 518 | } else { 519 | lookup_table2_h[j] = SM_SIZE_ROW*SM_SIZE_COL - 1; 520 | } 521 | } 522 | 523 | int * lookup_table1_d; 524 | int * lookup_table2_d; 525 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_COL * sizeof(int))); 526 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_COL * sizeof(int))); 527 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 528 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 529 | 530 | const int cols = input_n + 2 * HALO + 1; // 1 for address alighment 531 | const size_t array_size = cols * sizeof(double); 532 | double *array_d[2]; 533 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 534 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 535 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 536 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 537 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 538 | 539 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 540 | dim3 grid_config(BLOCK_N); 541 | dim3 block_config(32 * WARP_PER_BLOCK); 542 | 543 | CUDA_CHECK(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); 544 | 545 | // timing 546 | int i = 0; 547 | 548 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 549 | 550 | for (; i < times; i++) 551 | { 552 | CUDAKERNELCHECK((breakdown2_kernel<<>>(array_d[i % 2] , array_d[(i + 1) % 2], lookup_table1_d, lookup_table2_d))); // 为了对齐空了4个 553 | } 554 | CUDA_CHECK(cudaDeviceSynchronize()); 555 | 556 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 557 | std::cout << "Experiment - Breakdown(1D) 2: " << std::endl; 558 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 559 | 560 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 561 | printf("GStencil/s = %f\n\n", ((double)input_n * times * 3) / secs / 1e9); 562 | 563 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2] + 1, array_size - sizeof(double), cudaMemcpyDeviceToHost)); 564 | 565 | return; 566 | } 567 | 568 | void gpu_1d1r_breakdown1(const double *__restrict__ in, double *__restrict__ out, const double *__restrict__ params, const int times, const int input_n) 569 | { 570 | double param_matrix_h[2][7 * 7] = {}; 571 | 572 | // Initialize parameter matrix 573 | 574 | // 最后一列空 575 | for (int row = 0; row < UNIT_LENGTH; row++) // kernel size 7 576 | for (int col = 0; col < UNIT_LENGTH; ++col) 577 | param_matrix_h[0][row * UNIT_LENGTH + col] = params[row - col]; 578 | 579 | // 第一列空 580 | for (int row = 0; row < UNIT_LENGTH; row++) 581 | for (int col = row; col < UNIT_LENGTH; ++col) 582 | param_matrix_h[1][row * UNIT_LENGTH + col] = params[row + 6 - col];//+7 583 | 584 | // 常量内存搬运 585 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 7 * 7 * sizeof(double))); 586 | 587 | int lookup_table1_h[D_BLOCK_SIZE_COL]; 588 | int lookup_table2_h[D_BLOCK_SIZE_COL]; 589 | 590 | //刨去了若干列 591 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 592 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 593 | lookup_table1_h[j] = IDX(j / (UNIT_LENGTH + 1), j % (UNIT_LENGTH + 1), SM_SIZE_COL);//9 594 | } else { 595 | lookup_table1_h[j] = SM_SIZE_ROW *SM_SIZE_COL - 1; 596 | } 597 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 598 | lookup_table2_h[j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 599 | } else { 600 | lookup_table2_h[j] = SM_SIZE_ROW*SM_SIZE_COL - 1; 601 | } 602 | } 603 | 604 | int * lookup_table1_d; 605 | int * lookup_table2_d; 606 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_COL * sizeof(int))); 607 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_COL * sizeof(int))); 608 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 609 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 610 | 611 | const int cols = input_n + 2 * HALO + 1; // 1 for address alighment 612 | const size_t array_size = cols * sizeof(double); 613 | double *array_d[2]; 614 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 615 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 616 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 617 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 618 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 619 | 620 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 621 | dim3 grid_config(BLOCK_N); 622 | dim3 block_config(32 * WARP_PER_BLOCK); 623 | 624 | double* stencil2row[2]; 625 | CUDA_CHECK(cudaMalloc(&stencil2row[0],BLOCK_N*SM_SIZE_COL*SM_SIZE_ROW*sizeof(double)));//*0.75? 626 | CUDA_CHECK(cudaMalloc(&stencil2row[1], BLOCK_N*SM_SIZE_COL*SM_SIZE_ROW*sizeof(double))); 627 | 628 | CUDA_CHECK(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeDefault)); 629 | 630 | // timing 631 | int i = 0; 632 | 633 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 634 | 635 | for (; i < times; i++) 636 | { 637 | CUDAKERNELCHECK((breakdown1_kernel<<>>(array_d[i % 2] , array_d[(i + 1) % 2],stencil2row[0],stencil2row[1]))); // 为了对齐空了4个 638 | } 639 | CUDA_CHECK(cudaDeviceSynchronize()); 640 | 641 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 642 | std::cout << "Experiment - Breakdown(1D) 1: " << std::endl; 643 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 644 | 645 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 646 | printf("GStencil/s = %f\n\n", ((double)input_n * times * 3) / secs / 1e9); 647 | 648 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2] + 1, array_size - sizeof(double), cudaMemcpyDeviceToHost)); 649 | 650 | return; 651 | } -------------------------------------------------------------------------------- /src/2d/gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | // #include 3 | // #include 4 | // #include "../utils.h" 5 | #include 6 | #include "2d_utils.h" 7 | #include 8 | 9 | using namespace nvcuda; 10 | 11 | #define BLOCK_SIZE_ROW 32 12 | #define BLOCK_SIZE_COL 64 13 | #define HALO 3 14 | #define D_BLOCK_SIZE_COL (BLOCK_SIZE_COL + HALO * 2) 15 | #define D_BLOCK_SIZE_ROW (BLOCK_SIZE_ROW + HALO * 2) 16 | #define PAD 2 17 | #define SM_SIZE_COL (7 * D_BLOCK_SIZE_ROW + PAD) 18 | #define SM_SIZE_ROW (D_BLOCK_SIZE_COL / 8) 19 | #define UNIT_LENGTH 7 20 | #define TENSOR_CORE_M 8 21 | #define IDX(x, y, ldm) ((x) * (ldm) + (y)) 22 | #define WARP_PER_BLOCK 8 23 | // #define ACCS_PER_WARP (BLOCK_SIZE_COL * BLOCK_SIZE_ROW / 64 / WARP_PER_BLOCK) 24 | #define MMA_NUM 13 25 | #define ceild(n,d) (((n)-1)/(d) + 1) 26 | 27 | __constant__ double param_matrix_d[2 * 52 * TENSOR_CORE_M]; 28 | 29 | 30 | __global__ void kernel2d (const double * __restrict__ in, double * __restrict__ out, const int ldm, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 31 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL]; 32 | int begin = IDX(blockIdx.x * BLOCK_SIZE_ROW, blockIdx.y * BLOCK_SIZE_COL + 1, ldm); 33 | int tid = threadIdx.x; 34 | int totalThreads = blockDim.x; 35 | #pragma unroll 36 | for (int i = tid; i < D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL; i += totalThreads) { 37 | int row = i / D_BLOCK_SIZE_COL; 38 | int col = i % D_BLOCK_SIZE_COL; 39 | sharedmem[0][lookup_table1[i]] = in[begin + IDX(row, col, ldm)]; 40 | sharedmem[1][lookup_table2[i]] = in[begin + IDX(row, col, ldm)]; 41 | } 42 | __syncthreads(); 43 | 44 | 45 | int warp_id = threadIdx.x / 32; 46 | 47 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 48 | #pragma unroll 49 | for (int i = 0; i < MMA_NUM; i++) { 50 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 51 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 52 * 8 + i * 32, 8); 52 | } 53 | 54 | wmma::fragment acc_frag; 55 | wmma::fragment in_frag; 56 | for (int col = warp_id * 28; col < warp_id * 28 + 28; col += UNIT_LENGTH) { 57 | wmma::fill_fragment(acc_frag, 0.0); 58 | #pragma unroll 59 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 60 | wmma::load_matrix_sync(in_frag, sharedmem[0] + IDX(0, col + compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL); 61 | wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 62 | } 63 | #pragma unroll 64 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 65 | wmma::load_matrix_sync(in_frag, sharedmem[1] + IDX(0, col + compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL); 66 | wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 67 | } 68 | wmma::store_matrix_sync(out + begin + IDX(HALO + col / 7, HALO, ldm), acc_frag, TENSOR_CORE_M, wmma::mem_row_major); 69 | } 70 | } 71 | 72 | __global__ void breakdown4_kernel (const double * __restrict__ in, double * __restrict__ out, const int ldm, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 73 | __shared__ double sharedmem[2][SM_SIZE_ROW * SM_SIZE_COL]; 74 | int begin = IDX(blockIdx.x * BLOCK_SIZE_ROW, blockIdx.y * BLOCK_SIZE_COL + 1, ldm); 75 | int tid = threadIdx.x; 76 | int totalThreads = blockDim.x; 77 | #pragma unroll 78 | for (int i = tid; i < D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL; i += totalThreads) { 79 | int row = i / D_BLOCK_SIZE_COL; 80 | int col = i % D_BLOCK_SIZE_COL; 81 | 82 | if (lookup_table1[i] != -1) { 83 | sharedmem[0][lookup_table1[i]] = in[begin + IDX(row, col, ldm)]; 84 | } 85 | if (lookup_table2[i] != -1) { 86 | sharedmem[1][lookup_table2[i]] = in[begin + IDX(row, col, ldm)]; 87 | } 88 | } 89 | __syncthreads(); 90 | 91 | int warp_id = threadIdx.x / 32; 92 | 93 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 94 | #pragma unroll 95 | for (int i = 0; i < MMA_NUM; i++) { 96 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 97 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 52 * 8 + i * 32, 8); 98 | } 99 | 100 | wmma::fragment acc_frag; 101 | wmma::fragment in_frag; 102 | for (int col = warp_id * 4*7; col < warp_id *4*7 + 4*7; col += UNIT_LENGTH) { 103 | wmma::fill_fragment(acc_frag, 0.0); 104 | #pragma unroll 105 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 106 | wmma::load_matrix_sync(in_frag, sharedmem[0] + IDX(0, col + compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL);//1+ 107 | wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 108 | } 109 | #pragma unroll 110 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 111 | wmma::load_matrix_sync(in_frag, sharedmem[1] + IDX(0, col + compute_idx * 4, SM_SIZE_COL), SM_SIZE_COL);//1+ 112 | wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 113 | } 114 | wmma::store_matrix_sync(out + begin + IDX(HALO + col / 7, HALO, ldm), acc_frag, TENSOR_CORE_M, wmma::mem_row_major); 115 | } 116 | } 117 | 118 | __global__ void breakdown3_kernel(const double * __restrict__ in, double * __restrict__ out, const int ldm, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 119 | __shared__ double sharedmem[2][SM_SIZE_ROW * (SM_SIZE_COL - PAD)]; 120 | int begin = IDX(blockIdx.x * BLOCK_SIZE_ROW, blockIdx.y * BLOCK_SIZE_COL + 1, ldm); 121 | int tid = threadIdx.x; 122 | int totalThreads = blockDim.x; 123 | #pragma unroll 124 | for (int i = tid; i < D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL; i += totalThreads) { 125 | int row = i / D_BLOCK_SIZE_COL; 126 | int col = i % D_BLOCK_SIZE_COL; 127 | 128 | if (lookup_table1[i] != -1) { 129 | sharedmem[0][lookup_table1[i]] = in[begin + IDX(row, col, ldm)]; 130 | } 131 | if (lookup_table2[i] != -1) { 132 | sharedmem[1][lookup_table2[i]] = in[begin + IDX(row, col, ldm)]; 133 | } 134 | } 135 | __syncthreads(); 136 | 137 | int warp_id = threadIdx.x / 32; 138 | 139 | nvcuda::wmma::fragment param_frag[2][MMA_NUM]; 140 | #pragma unroll 141 | for (int i = 0; i < MMA_NUM; i++) { 142 | nvcuda::wmma::load_matrix_sync(param_frag[0][i], param_matrix_d + i * 32, 8); 143 | nvcuda::wmma::load_matrix_sync(param_frag[1][i], param_matrix_d + 52 * 8 + i * 32, 8); 144 | } 145 | 146 | wmma::fragment acc_frag; 147 | wmma::fragment in_frag; 148 | for (int col = warp_id * 4*7; col < warp_id *4*7 + 4*7; col += UNIT_LENGTH) { 149 | wmma::fill_fragment(acc_frag, 0.0); 150 | #pragma unroll 151 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 152 | // if(threadIdx.x%32==0)printf("%d\n",IDX(0, col + compute_idx * 4, (SM_SIZE_COL - PAD))); 153 | wmma::load_matrix_sync(in_frag, sharedmem[0] + IDX(0, col + compute_idx * 4, (SM_SIZE_COL - PAD)), (SM_SIZE_COL - PAD));//1+ 154 | wmma::mma_sync(acc_frag, in_frag, param_frag[0][compute_idx], acc_frag); 155 | } 156 | #pragma unroll 157 | for (int compute_idx = 0; compute_idx < MMA_NUM; compute_idx++) { 158 | wmma::load_matrix_sync(in_frag, sharedmem[1] + IDX(0, col + compute_idx * 4, (SM_SIZE_COL - PAD)), (SM_SIZE_COL - PAD));//1+ 159 | wmma::mma_sync(acc_frag, in_frag, param_frag[1][compute_idx], acc_frag); 160 | } 161 | wmma::store_matrix_sync(out + begin + IDX(HALO + col / 7, HALO, ldm), acc_frag, TENSOR_CORE_M, wmma::mem_row_major); 162 | } 163 | } 164 | 165 | __global__ void breakdown2_kernel (const double * __restrict__ in, double * __restrict__ out, const int ldm, const int * __restrict__ lookup_table1, const int * __restrict__ lookup_table2) { 166 | __shared__ double sharedmem[2][SM_SIZE_ROW * (SM_SIZE_COL - PAD)]; 167 | 168 | int begin = IDX(blockIdx.x * BLOCK_SIZE_ROW, blockIdx.y * BLOCK_SIZE_COL + 1, ldm);//分块 169 | 170 | 171 | int x=threadIdx.x;//0~7 172 | int y=threadIdx.y;//0~31 173 | 174 | int tid=threadIdx.x+threadIdx.y*blockDim.x; 175 | for(int i=tid;i 2 * HALO) { 182 | sharedmem[1][IDX((col - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * row + (col - UNIT_LENGTH) % (UNIT_LENGTH + 1), (SM_SIZE_COL - PAD))]= in[begin + IDX(row, col, ldm)]; 183 | } 184 | } 185 | __syncthreads(); 186 | 187 | for(int row=x;row 2 * HALO) { 219 | lb[gbegin+IDX((col - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * row + (col - UNIT_LENGTH) % (UNIT_LENGTH + 1), (SM_SIZE_COL - PAD))]=in[begin + IDX(row, col, ldm)]; 220 | } 221 | } 222 | 223 | __syncthreads(); 224 | for(int row=x;row= col) { 253 | param_matrix_h[0][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col]; 254 | } 255 | } 256 | } 257 | } 258 | for (int col = 0; col < TENSOR_CORE_M; col++) { 259 | for(int i = 0; i < UNIT_LENGTH; i++) { 260 | for(int j = 0; j < UNIT_LENGTH; j++) { 261 | if (j < col) { 262 | param_matrix_h[1][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col + 7]; 263 | } 264 | } 265 | } 266 | } 267 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 52 * sizeof(double))); 268 | 269 | const int rows = input_m + 2 * HALO; 270 | const int cols = input_n + 2 * HALO + 2; 271 | const size_t array_size = rows * cols * sizeof(double); 272 | double * array_d[2]; 273 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 274 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 275 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 276 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 277 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 278 | 279 | 280 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 281 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 282 | dim3 grid_config(BLOCK_M, BLOCK_N); 283 | // dim3 grid_config(1, 1); 284 | dim3 block_config(32 * WARP_PER_BLOCK); 285 | 286 | // Lookup table 287 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 288 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 289 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 290 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 291 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 292 | lookup_table1_h[i][j] = IDX(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 293 | } else { 294 | lookup_table1_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 295 | } 296 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 297 | lookup_table2_h[i][j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 298 | } else { 299 | lookup_table2_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 300 | } 301 | } 302 | } 303 | 304 | 305 | int * lookup_table1_d; 306 | int * lookup_table2_d; 307 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 308 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 309 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 310 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 311 | 312 | 313 | // timing 314 | int i = 0; 315 | 316 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 317 | 318 | for(; i < times; i++) { 319 | CUDAKERNELCHECK((kernel2d<<>>(array_d[i % 2], array_d[(i + 1) % 2], cols, lookup_table1_d, lookup_table2_d))); 320 | } 321 | CUDA_CHECK(cudaDeviceSynchronize()); 322 | 323 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 324 | std::cout << "ConvStencil(2D): " << std::endl; 325 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 326 | 327 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 328 | 329 | printf("GStencil/s = %f\n", ((double)input_m * input_n * times * 3) / secs / 1e9); 330 | 331 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 332 | 333 | return; 334 | } 335 | 336 | void gpu_box_2d3r(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_m, const int input_n) { 337 | double param_matrix_h[2][52 * 8] = {0.0}; 338 | 339 | // Initialize parameter matrix 340 | for (int col = 0; col < TENSOR_CORE_M; col++) { 341 | for(int i = 0; i < UNIT_LENGTH; i++) { 342 | for(int j = 0; j < UNIT_LENGTH; j++) { 343 | if (j >= col) { 344 | param_matrix_h[0][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col]; 345 | } 346 | } 347 | } 348 | } 349 | for (int col = 0; col < TENSOR_CORE_M; col++) { 350 | for(int i = 0; i < UNIT_LENGTH; i++) { 351 | for(int j = 0; j < UNIT_LENGTH; j++) { 352 | if (j < col) { 353 | param_matrix_h[1][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col + 7]; 354 | } 355 | } 356 | } 357 | } 358 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 52 * sizeof(double))); 359 | 360 | const int rows = input_m + 2 * HALO; 361 | const int cols = input_n + 2 * HALO + 2; 362 | const size_t array_size = rows * cols * sizeof(double); 363 | double * array_d[2]; 364 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 365 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 366 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 367 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 368 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 369 | 370 | 371 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 372 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 373 | dim3 grid_config(BLOCK_M, BLOCK_N); 374 | // dim3 grid_config(1, 1); 375 | dim3 block_config(32 * WARP_PER_BLOCK); 376 | 377 | // Lookup table 378 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 379 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 380 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 381 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 382 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 383 | lookup_table1_h[i][j] = IDX(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 384 | } else { 385 | lookup_table1_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 386 | } 387 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 388 | lookup_table2_h[i][j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 389 | } else { 390 | lookup_table2_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 391 | } 392 | } 393 | } 394 | 395 | 396 | int * lookup_table1_d; 397 | int * lookup_table2_d; 398 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 399 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 400 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 401 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 402 | 403 | 404 | // timing 405 | int i = 0; 406 | 407 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 408 | 409 | for(; i < times; i++) { 410 | CUDAKERNELCHECK((kernel2d<<>>(array_d[i % 2], array_d[(i + 1) % 2], cols, lookup_table1_d, lookup_table2_d))); 411 | } 412 | CUDA_CHECK(cudaDeviceSynchronize()); 413 | 414 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 415 | std::cout << "ConvStencil(2D): " << std::endl; 416 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 417 | 418 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 419 | 420 | printf("GStencil/s = %f\n", ((double)input_m * input_n * times) / secs / 1e9); 421 | 422 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 423 | 424 | return; 425 | } 426 | 427 | 428 | /** 429 | * @param in input array pointer 430 | * @param out output array pointer 431 | * @param params parameter array pointer (length 49) 432 | * 433 | */ 434 | void gpu_box_2d1r_breakdown4(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_m, const int input_n) { 435 | double param_matrix_h[2][52 * 8] = {0.0}; 436 | 437 | // Initialize parameter matrix 438 | for (int col = 0; col < TENSOR_CORE_M; col++) { 439 | for(int i = 0; i < UNIT_LENGTH; i++) { 440 | for(int j = 0; j < UNIT_LENGTH; j++) { 441 | if (j >= col) { 442 | param_matrix_h[0][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col]; 443 | } 444 | } 445 | } 446 | } 447 | for (int col = 0; col < TENSOR_CORE_M; col++) { 448 | for(int i = 0; i < UNIT_LENGTH; i++) { 449 | for(int j = 0; j < UNIT_LENGTH; j++) { 450 | if (j < col) { 451 | param_matrix_h[1][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col + 7]; 452 | } 453 | } 454 | } 455 | } 456 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 52 * sizeof(double))); 457 | 458 | const int rows = input_m + 2 * HALO; 459 | const int cols = input_n + 2 * HALO + 2; 460 | const size_t array_size = rows * cols * sizeof(double); 461 | double * array_d[2]; 462 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 463 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 464 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 465 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 466 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 467 | 468 | 469 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 470 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 471 | dim3 grid_config(BLOCK_M, BLOCK_N); 472 | // dim3 grid_config(1, 1); 473 | dim3 block_config(32 * WARP_PER_BLOCK); 474 | 475 | // Lookup table 476 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 477 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 478 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 479 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 480 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 481 | lookup_table1_h[i][j] = IDX(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 482 | } else { 483 | lookup_table1_h[i][j] = - 1; 484 | } 485 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 486 | lookup_table2_h[i][j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 487 | } else { 488 | lookup_table2_h[i][j] = - 1; 489 | } 490 | } 491 | } 492 | 493 | 494 | int * lookup_table1_d; 495 | int * lookup_table2_d; 496 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 497 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 498 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 499 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 500 | 501 | 502 | // timing 503 | int i = 0; 504 | 505 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 506 | 507 | for(; i < times; i++) { 508 | CUDAKERNELCHECK((breakdown4_kernel<<>>(array_d[i % 2], array_d[(i + 1) % 2], cols, lookup_table1_d, lookup_table2_d))); 509 | } 510 | CUDA_CHECK(cudaDeviceSynchronize()); 511 | 512 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 513 | std::cout << "Experiment - Breakdown(2D) 4: " << std::endl; 514 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 515 | 516 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 517 | 518 | printf("GStencil/s = %f\n\n", ((double)input_m * input_n * times * 3) / secs / 1e9); 519 | 520 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 521 | 522 | return; 523 | } 524 | 525 | /** 526 | * @param in input array pointer 527 | * @param out output array pointer 528 | * @param params parameter array pointer (length 49) 529 | * 530 | */ 531 | void gpu_box_2d1r_breakdown3(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_m, const int input_n) { 532 | double param_matrix_h[2][52 * 8] = {0.0}; 533 | 534 | // Initialize parameter matrix 535 | for (int col = 0; col < TENSOR_CORE_M; col++) { 536 | for(int i = 0; i < UNIT_LENGTH; i++) { 537 | for(int j = 0; j < UNIT_LENGTH; j++) { 538 | if (j >= col) { 539 | param_matrix_h[0][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col]; 540 | } 541 | } 542 | } 543 | } 544 | for (int col = 0; col < TENSOR_CORE_M; col++) { 545 | for(int i = 0; i < UNIT_LENGTH; i++) { 546 | for(int j = 0; j < UNIT_LENGTH; j++) { 547 | if (j < col) { 548 | param_matrix_h[1][(i * UNIT_LENGTH + j) * 8 + col] = params[i * UNIT_LENGTH + j - col + 7]; 549 | } 550 | } 551 | } 552 | } 553 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 8 * 52 * sizeof(double))); 554 | 555 | const int rows = input_m + 2 * HALO; 556 | const int cols = input_n + 2 * HALO +2; 557 | const size_t array_size = rows * cols * sizeof(double); 558 | double * array_d[2]; 559 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 560 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 561 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 562 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 563 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 564 | 565 | 566 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 567 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 568 | dim3 grid_config(BLOCK_M, BLOCK_N); 569 | // dim3 grid_config(1, 1); 570 | dim3 block_config(32 * WARP_PER_BLOCK); 571 | 572 | // Lookup table 573 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 574 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 575 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 576 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 577 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 578 | lookup_table1_h[i][j] = IDX(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 579 | } else { 580 | lookup_table1_h[i][j] = - 1; 581 | } 582 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 583 | lookup_table2_h[i][j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 584 | } else { 585 | lookup_table2_h[i][j] = - 1; 586 | } 587 | } 588 | } 589 | 590 | int * lookup_table1_d; 591 | int * lookup_table2_d; 592 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 593 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 594 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 595 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 596 | 597 | // timing 598 | int i = 0; 599 | 600 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 601 | 602 | for(; i < times; i++) { 603 | CUDAKERNELCHECK((breakdown3_kernel<<>>(array_d[i % 2], array_d[(i + 1) % 2], cols, lookup_table1_d, lookup_table2_d))); 604 | } 605 | CUDA_CHECK(cudaDeviceSynchronize()); 606 | 607 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 608 | std::cout << "Experiment - Breakdown(2D) 3: " << std::endl; 609 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 610 | 611 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 612 | 613 | printf("GStencil/s = %f\n\n", ((double)input_m * input_n * times * 3) / secs / 1e9); 614 | 615 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 616 | 617 | return; 618 | } 619 | 620 | void gpu_box_2d1r_breakdown2(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_m, const int input_n) { 621 | double param_matrix_h[2][49 * 7] = {0.0}; 622 | 623 | // Initialize parameter matrix 624 | for (int col = 0; col < UNIT_LENGTH ; col++) { 625 | for(int i = 0; i < UNIT_LENGTH; i++) { 626 | for(int j = 0; j < UNIT_LENGTH; j++) { 627 | if (j >= col) { 628 | param_matrix_h[0][(i * UNIT_LENGTH + j) * UNIT_LENGTH + col] = params[i * UNIT_LENGTH + j - col];//(i*UNIT_LENGTH+j,col) 629 | } 630 | } 631 | } 632 | } 633 | for (int col = 0; col < UNIT_LENGTH ; col++) { 634 | for(int i = 0; i < UNIT_LENGTH; i++) { 635 | for(int j = 0; j < UNIT_LENGTH; j++) { 636 | if (j <= col) { 637 | param_matrix_h[1][(i * UNIT_LENGTH + j) * UNIT_LENGTH + col] = params[i * UNIT_LENGTH + j - col + 6]; 638 | } 639 | } 640 | } 641 | } 642 | 643 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 7 * 49 * sizeof(double))); 644 | 645 | const int rows = input_m + 2 * HALO; 646 | const int cols = input_n + 2 * HALO+1 ; 647 | const size_t array_size = rows * cols * sizeof(double); 648 | double * array_d[2]; 649 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 650 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 651 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 652 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 653 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 654 | 655 | 656 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 657 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 658 | dim3 grid_config(BLOCK_M, BLOCK_N); 659 | dim3 block_config(8,32); 660 | 661 | // Lookup table 662 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 663 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 664 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 665 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 666 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 667 | lookup_table1_h[i][j] = IDX(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 668 | } else { 669 | lookup_table1_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 670 | } 671 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 672 | lookup_table2_h[i][j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 673 | } else { 674 | lookup_table2_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 675 | } 676 | } 677 | } 678 | 679 | int * lookup_table1_d; 680 | int * lookup_table2_d; 681 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 682 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 683 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 684 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 685 | 686 | // timing 687 | int i = 0; 688 | 689 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 690 | 691 | for(; i < times; i++) { 692 | CUDAKERNELCHECK((breakdown2_kernel<<>>(array_d[i % 2], array_d[(i + 1) % 2], cols, lookup_table1_d, lookup_table2_d))); 693 | } 694 | CUDA_CHECK(cudaDeviceSynchronize()); 695 | 696 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 697 | std::cout << "Experiment - Breakdown(2D) 2: " << std::endl; 698 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 699 | 700 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 701 | 702 | printf("GStencil/s = %f\n\n", ((double)input_m * input_n * times * 3) / secs / 1e9); 703 | 704 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 705 | 706 | return; 707 | } 708 | 709 | void gpu_box_2d1r_breakdown1(const double * __restrict__ in, double * __restrict__ out, const double * __restrict__ params, const int times, const int input_m, const int input_n) { 710 | double param_matrix_h[2][49 * 7] = {0.0}; 711 | 712 | // Initialize parameter matrix 713 | for (int col = 0; col < UNIT_LENGTH ; col++) { 714 | for(int i = 0; i < UNIT_LENGTH; i++) { 715 | for(int j = 0; j < UNIT_LENGTH; j++) { 716 | if (j >= col) { 717 | param_matrix_h[0][(i * UNIT_LENGTH + j) * UNIT_LENGTH + col] = params[i * UNIT_LENGTH + j - col];//(i*UNIT_LENGTH+j,col) 718 | } 719 | } 720 | } 721 | } 722 | for (int col = 0; col < UNIT_LENGTH ; col++) { 723 | for(int i = 0; i < UNIT_LENGTH; i++) { 724 | for(int j = 0; j < UNIT_LENGTH; j++) { 725 | if (j <= col) { 726 | param_matrix_h[1][(i * UNIT_LENGTH + j) * UNIT_LENGTH + col] = params[i * UNIT_LENGTH + j - col + 6]; 727 | } 728 | } 729 | } 730 | } 731 | 732 | CUDA_CHECK(cudaMemcpyToSymbol(param_matrix_d, param_matrix_h, 2 * 7 * 49 * sizeof(double))); 733 | 734 | const int rows = input_m + 2 * HALO; 735 | // const int cols = input_n + 2 * HALO + 2; 736 | const int cols = input_n + 2 * HALO+1 ; 737 | const size_t array_size = rows * cols * sizeof(double); 738 | double * array_d[2]; 739 | CUDA_CHECK(cudaMalloc(&array_d[0], array_size)); 740 | CUDA_CHECK(cudaMalloc(&array_d[1], array_size)); 741 | CUDA_CHECK(cudaMemset(array_d[0], 0, array_size)); 742 | CUDA_CHECK(cudaMemcpy(array_d[0], in, array_size, cudaMemcpyHostToDevice)); 743 | CUDA_CHECK(cudaMemset(array_d[1], 0, array_size)); 744 | 745 | 746 | 747 | const int BLOCK_M = (input_m + BLOCK_SIZE_ROW - 1) / BLOCK_SIZE_ROW; 748 | const int BLOCK_N = (input_n + BLOCK_SIZE_COL - 1) / BLOCK_SIZE_COL; 749 | dim3 grid_config(BLOCK_M, BLOCK_N); 750 | dim3 block_config(8,32); 751 | 752 | double* stencil2row[2]; 753 | CUDA_CHECK(cudaMalloc(&stencil2row[0], BLOCK_M*BLOCK_N*SM_SIZE_COL*SM_SIZE_ROW*sizeof(double)));//*0.75? 754 | CUDA_CHECK(cudaMalloc(&stencil2row[1], BLOCK_M*BLOCK_N*SM_SIZE_COL*SM_SIZE_ROW*sizeof(double))); 755 | 756 | // Lookup table 757 | int lookup_table1_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 758 | int lookup_table2_h[D_BLOCK_SIZE_ROW][D_BLOCK_SIZE_COL]; 759 | for (int i = 0; i < D_BLOCK_SIZE_ROW; i++) { 760 | for (int j = 0; j < D_BLOCK_SIZE_COL; j++) { 761 | if ((j + 1) % 8 != 0 && j < D_BLOCK_SIZE_COL - 2 * HALO - 1) { 762 | lookup_table1_h[i][j] = IDX(j / (UNIT_LENGTH + 1), UNIT_LENGTH * i + j % (UNIT_LENGTH + 1), SM_SIZE_COL); 763 | } else { 764 | lookup_table1_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 765 | } 766 | if ((j + 2) % 8 != 0 && j > 2 * HALO) { 767 | lookup_table2_h[i][j] = IDX((j - UNIT_LENGTH) / (UNIT_LENGTH + 1), UNIT_LENGTH * i + (j - UNIT_LENGTH) % (UNIT_LENGTH + 1), SM_SIZE_COL); 768 | } else { 769 | lookup_table2_h[i][j] = SM_SIZE_ROW * SM_SIZE_COL - 1; 770 | } 771 | } 772 | } 773 | int * lookup_table1_d; 774 | int * lookup_table2_d; 775 | CUDA_CHECK(cudaMalloc(&lookup_table1_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 776 | CUDA_CHECK(cudaMalloc(&lookup_table2_d, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int))); 777 | CUDA_CHECK(cudaMemcpy(lookup_table1_d, lookup_table1_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 778 | CUDA_CHECK(cudaMemcpy(lookup_table2_d, lookup_table2_h, D_BLOCK_SIZE_ROW * D_BLOCK_SIZE_COL * sizeof(int), cudaMemcpyHostToDevice)); 779 | 780 | // timing 781 | int i = 0; 782 | 783 | std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); 784 | 785 | for(; i < times; i++) { 786 | CUDAKERNELCHECK((breakdown1_kernel<<>>(array_d[i % 2], array_d[(i + 1) % 2], cols, lookup_table1_d, lookup_table2_d,stencil2row[0],stencil2row[1]))); 787 | } 788 | CUDA_CHECK(cudaDeviceSynchronize()); 789 | 790 | std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 791 | std::cout << "Experiment - Breakdown(2D) 1: " << std::endl; 792 | std::cout << "Time = " << std::chrono::duration_cast(end - begin).count() << "[ms]" << std::endl; 793 | 794 | double secs = std::chrono::duration_cast(end - begin).count() / 1e6; 795 | 796 | printf("GStencil/s = %f\n\n", ((double)input_m * input_n * times * 3) / secs / 1e9); 797 | 798 | CUDA_CHECK(cudaMemcpy(out, array_d[i % 2], array_size, cudaMemcpyDeviceToHost)); 799 | 800 | return; 801 | } 802 | --------------------------------------------------------------------------------