├── .gitignore ├── .vscode └── c_cpp_properties.json ├── README.md ├── include └── utils.h ├── interpolation.cpp ├── interpolation_kernel.cu ├── setup.py ├── test.py └── test_rji.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "/home/ubuntu/anaconda3/envs/cppcuda/include/python3.8", 8 | "/home/ubuntu/anaconda3/envs/cppcuda/lib/python3.8/site-packages/torch/include", 9 | "/home/ubuntu/anaconda3/envs/cppcuda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include" 10 | ], 11 | "defines": [], 12 | "compilerPath": "/usr/bin/clang", 13 | "cStandard": "c17", 14 | "cppStandard": "c++14", 15 | "intelliSenseMode": "linux-clang-x64" 16 | } 17 | ], 18 | "version": 4 19 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CUDA编程学习:自定义Pytorch+cpp/cuda extension 2 | 3 | - [CUDA编程学习:自定义Pytorch+cpp/cuda extension](#cuda编程学习自定义pytorchcppcuda-extension) 4 | - [学习背景](#学习背景) 5 | - [适用对象与场景](#适用对象与场景) 6 | - [Pytorch和CUDA的关系](#pytorch和cuda的关系) 7 | - [Python调用C++函数(桥梁)](#python调用c函数桥梁) 8 | - [Building with `setuptools`](#building-with-setuptools) 9 | - [JIT Compiling Extensions](#jit-compiling-extensions) 10 | - [CUDA加速的原理](#cuda加速的原理) 11 | - [三线性插值问题定义](#三线性插值问题定义) 12 | - [C++调用CUDA函数](#c调用cuda函数) 13 | - [三线性插值CUDA实现](#三线性插值cuda实现) 14 | - [scalar\_t类型](#scalar_t类型) 15 | - [accessors](#accessors) 16 | - [模板函数](#模板函数) 17 | - [foward验证与比较](#foward验证与比较) 18 | - [CUDA反向传播](#cuda反向传播) 19 | - [定义CUDA函数](#定义cuda函数) 20 | - [核函数实现微分计算](#核函数实现微分计算) 21 | - [PYBIND11绑定函数](#pybind11绑定函数) 22 | - [torch.autograd.Function封装](#torchautogradfunction封装) 23 | - [backward验证与比较](#backward验证与比较) 24 | - [参考](#参考) 25 | 26 | 27 | ## 学习背景 28 | 29 | 虽然说PyTorch提供了丰富的与神经网络、张量代数、数据处理等相关的操作,但是有时候你可能需要**更定制化的操作**,比如使用论文中的新型激活函数,或者实现作为研究一部分开发的操作。在PyTorch中,最简单的集成自定义操作的方式是在Python中编写,通过扩展Function和Module来实现,这使我们可以充分利用自动微分`autograd`的功能,**然而有时候代码在模型中被频繁调用或者调用代价比较大,我们就可能需要在C++中进行实现。另一个可能的原因可能需要依赖于其他的C++库,为了解决这种情况,PyTorch提供了一种非常简单的方式来编写自定义C++扩展。** 30 | 31 | 简单介绍一下Pytorch C++的API部分,主要有以下五部分 32 | 33 | 1. **ATen:** 作为基础张量和数学操作库,所有其他接口都构建在其上。 34 | 2. **Autograd:** 通过自动微分增强了ATen,记录张量上的操作以形成自动微分图。 35 | 3. **C++ Frontend:** 提供了用于神经网络和机器学习模型的高级纯C++建模接口。 36 | 4. **TorchScript:** 是一个可以由TorchScript编译器理解、编译和序列化的PyTorch模型表示。 37 | 5. **C++ Extensions:** 用于扩展Python API的自定义C++和CUDA例程。 38 | 39 | 这些块组合形成了一个C++库,可用于张量计算和具有高效的GPU加速以及快速CPU性能的动态神经网络。 40 | 41 | > 在这部分中,ATen是一个基础张量库,几乎所有PyTorch的Python和C++接口都构建在其上。Autograd是C++ API的一部分,用于为ATen张量类添加自动微分功能。我们编写C++的扩展的时候,我们实际上是基于ATen进行操作和书写的。 42 | 43 | 44 | 45 | ## 适用对象与场景 46 | 47 | 实际上pytorch+cuda是为了加速pytorch的计算,如果pytorch的计算已经可以满足了,就可以跳过这一部分,因为本身pytorch也已经蕴含了很多的函数 48 | 49 | - **非平行运算 non parallel computation**:在这样的场景下,比如现在一个batch里面,都是平行运算,所以这时候可以直接用pytorch进行实现,但是在NeRF的体渲染volume rendering中,我们就可以知道,每一条射线可能采样的点都是不一样的,如果我们去用for循环就可能需要花比较多的时间,这时候就需要cuda的存在。 50 | 51 | - **大量的串列计算 lots of sequential computation**:比如神经网络的卷积层的计算的,比如在forward中,经常会出现以下这样的情况 52 | 53 | ```python 54 | x = f1(x) 55 | x = f2(x) 56 | ... 57 | x = fn(x) 58 | ``` 59 | 60 | 如果在层数比较小的时候,这样是可以得到不错的结果的,但是层数比较大的时候,不断的内存访问其实会减慢速度,这时候就需要CUDA来进行加速,比如我们可以将所有的f变成一个函数F,融合了所有的函数后,我们就可以进行一次运算得到最后的结果,在层数大的时候能得到很大的提升。 61 | 62 | 在这一部分的学习中,主要还是在第一个场景,非平行运算,特别是NeRF的体渲染部分,这一部分的学习和加速还是非常重要的,值得学习。 63 | 64 | 65 | 66 | ## Pytorch和CUDA的关系 67 | 68 | 一般来说,是pytorch -> C++ > cuda,也就是pytorch调用C++,然后C++再调用cuda,所以C++作为的是一个桥梁,所以比较重要的cuda,而不是C++,利用cuda进行平行的计算。 69 | 70 | ![Pytorch和CUDA的关系](https://img-blog.csdnimg.cn/direct/64f3116d776247ae975479b252554c0a.png) 71 | 72 | 73 | 74 | 75 | 76 | ## Python调用C++函数(桥梁) 77 | 78 | 首先声明一下,我的文件夹格式如下: 79 | 80 | ```bash 81 | pytorch-cppcuda-tutorial/ 82 | test.py 83 | setup.py 84 | interpolation.cpp 85 | interpolation_kernel.cu 86 | include/ 87 | utils.h 88 | ``` 89 | 90 | **有一个问题可能我们会疑惑很久,就是python是怎么调用C++和CUDA的,这里面根据课程简单来讲一下,以三线性插值为例子** 91 | 92 | 首先,我们定义一个简单的函数。这个函数接受两个参数,分别是特征和点,然后直接返回特征。在这里,我们将看到一个核心的东西,即 `PYBIND11_MODULE`。这是 Python 调用 C++ 函数的关键部分。这个函数会在Python执行`import`语句时被调用,其接受两个参数,第一个参数为模块名称,这里我们直接将`trilinear_interpolation`填入,稍候可以在Python中使用`import cppcuda_tutorial`导入该模块;第二个参数`m`是创建Python关联代码的主接口,其类型为`py::module_`。`module_::def()`用于生成能够将`trilinear_interpolation`函数暴露给Python的代码,其第一个参数为**字符串**,将会成为Python中调用的函数名;第二个参数是**C++函数**的引用;第三个参数是**说明字符串**,在Python中可以使用`help(trilinear_interpolation)`查看。比如下面的例子中,C++ 中的函数 `trilinear_interpolation` 对应 Python 中的 `trilinear_interpolation`。 93 | 94 | > ****是一站式头文件,包含写入C++扩展所需的所有PyTorch操作,包括: 95 | > 96 | > - ATen库是用于张量计算的主要API, 97 | > - pybind11,是为C++代码创建Python绑定的方式 98 | > - 管理ATen和pybind11之间交互细节的头文件 99 | > 100 | > PyTorch的张量和变量接口是从ATen库自动生成的,因此几乎可以将Python实现1:1转换为C++。所有计算的主要数据类型将是torch::Tensor。 101 | 102 | ```C++ 103 | #include 104 | 105 | torch::Tensor trilinear_interpolation( 106 | torch::Tensor feats, 107 | torch::Tensor point 108 | ){ 109 | return feats; 110 | } 111 | 112 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 113 | m.def("trilinear_interpolation", &trilinear_interpolation); 114 | } 115 | ``` 116 | 117 | **注意:TORCH_EXTENSION_NAME,torch扩展构建会将其定义为在setup.py脚本中为扩展指定的名称。比如这里为“TORCH_EXTENSION_NAME“,两者之间的不匹配可能会导致严重且难以跟踪的问题。** 118 | 119 | 120 | 121 | C++扩展一般有两种方式 122 | 123 | - 通过`setuptools`“提前”构建 124 | - 通过`torch.utils.cpp_extension.load()`“实时”构建 125 | 126 | ### Building with `setuptools` 127 | 128 | 接下来先试用`setuptools`进行构建,编写一个 `setup.py` 文件,主要用于定义和说明一些重要的信息。其中关键的参数包括: 129 | 130 | - `name`:Python 调用的包的名称。 131 | - `ext_modules` 的 `sources`:需要编译的 C++ 源文件,如果有多个 C++ 文件,需要列举所有。 132 | - `cmdclass`:用BuildExtension执行许多必需的配置步骤和检查,并在混合C++/CUDA扩展的情况下处理混合编译。 133 | 134 | ```C++ 135 | from setuptools import setup 136 | from torch.utils.cpp_extension import CppExtension, BuildExtension 137 | 138 | 139 | setup( 140 | name='cppcuda_tutorial', 141 | version='1.0', 142 | author='xxx', 143 | author_email='xxx@gmail.com', 144 | description='cppcuda example', 145 | long_description='cppcuda example', 146 | ext_modules=[ 147 | CppExtension( 148 | name='cppcuda_tutorial', 149 | sources=['interpolation.cpp']) 150 | ], 151 | cmdclass={ 152 | 'build_ext': BuildExtension 153 | } 154 | ) 155 | ``` 156 | 157 | ### JIT Compiling Extensions 158 | 159 | 除了上述的`setuptools`的方法,接下来介绍即时编译(JIT)机制构建C++扩展。JIT编译机制通过调用PyTorch API中的一个简单函数`torch.utils.cpp_extension.load()`,为你提供了一种即时编译和加载扩展的方式。 160 | 161 | ```python 162 | from torch.utils.cpp_extension import load 163 | 164 | cppcuda_tutorial = load(name="cppcuda_tutorial", 165 | # extra_include_paths=include_dirs, 166 | sources=['interpolation.cpp'],) 167 | ``` 168 | 169 | 在这里,实际提供的是域setuptools相同的信息。在后台,这将执行以下操作: 170 | 171 | 1. 创建一个临时目录`/tmp/torch_extensions/cppcuda_tutorial`, 172 | 2. 向该临时目录发出Ninja构建文件, 173 | 3. 将你的源文件编译成一个共享库, 174 | 4. 将这个共享库导入为Python模块。 175 | 176 | 实际上,如果将`verbose=True`传递给`cpp_extension.load()`,你将得到有关该过程的信息: 177 | 178 | ```bash 179 | Using /path/.cache/torch_extensions/py310_cu113 as PyTorch extensions root... 180 | Detected CUDA files, patching ldflags 181 | Emitting ninja build file /path/.cache/torch_extensions/py310_cu113/cppcuda_tutorial/build.ninja... 182 | Building extension module cppcuda_tutorial... 183 | Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) 184 | [1/2] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=cppcuda_tutorial -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/path/workdirs/pytorch-cppcuda-tutorial/include -isystem /path/anaconda3/envs/cppcuda/lib/python3.10/site-packages/torch/include -isystem /path/anaconda3/envs/cppcuda/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /path/anaconda3/envs/cppcuda/lib/python3.10/site-packages/torch/include/TH -isystem /path/anaconda3/envs/cppcuda/lib/python3.10/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /path/anaconda3/envs/cppcuda/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++14 -c /path/workdirs/pytorch-cppcuda-tutorial/interpolation_kernel.cu -o interpolation_kernel.cuda.o 185 | [2/2] c++ interpolation.o interpolation_kernel.cuda.o -shared -L/path/anaconda3/envs/cppcuda/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda_cu -ltorch_cuda_cpp -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o cppcuda_tutorial.so 186 | ``` 187 | 188 | 完成这一步后,如果使用`setuptools`进行构建,我们可以使用 `pip` 进行安装。如果在当前文件夹下,直接运行 `pip install .` 即可完成安装或者我们也可以使用`python set.py install`,安装成功后应该会显示以下结果: 189 | 190 | ```bash 191 | Processing path/pytorch-cppcuda-tutorial 192 | Preparing metadata (setup.py) ... done 193 | Building wheels for collected packages: cppcuda-tutorial 194 | Building wheel for cppcuda-tutorial (setup.py) ... done 195 | Created wheel for cppcuda-tutorial: filename=cppcuda_tutorial-1.0-cp310-cp310-linux_x86_64.whl size=74123 sha256=3029b98bd3b49bed65f42640e60932c38379f15db48a5187fe40610b525307c9 196 | Stored in directory: /path/.cache/pip/wheels/65/53/4a/5e2c10d11e3a657b9efae376ccce3277e5535d691dd4659883 197 | Successfully built cppcuda-tutorial 198 | Installing collected packages: cppcuda-tutorial 199 | Successfully installed cppcuda-tutorial-1.0 200 | ``` 201 | 202 | 完成以上步骤后,我们可以编写一个 `test.py` 文件来测试代码的正确性。只要能够成功运行,就代表一切正常。 203 | 204 | ```python 205 | import torch 206 | import cppcuda_tutorial # 位置需要在import torch后面 207 | 208 | 209 | feats = torch.ones(2) 210 | point = torch.zeros(2) 211 | 212 | # 调用函数 213 | out = cppcuda_tutorial.trilinear_interpolation(feats, point) 214 | 215 | print(out) 216 | ``` 217 | 218 | 这里面要注意的就是,首席爱你要导入torch,解析动态链接器必须看到的一些符号 219 | 220 | 221 | 222 | ## CUDA加速的原理 223 | 224 | 首先介绍一个CUDA程序实现的流程 225 | 226 | 1. 把数据从CPU内存拷贝到GPU内存 227 | 2. 调用核函数对存储在GPU内存中的数据进行操作 228 | 3. 将数据从GPU内存传送回CPU内存 229 | 230 | ![CUDA编程入门极简教程- 知乎](https://pic2.zhimg.com/v2-2959e07a36a8dc8f59280f53b43eb9d1_b.jpg) 231 | 232 | 233 | 234 | 如下图所示,在利用CUDA加速的时候,图的左边是CPU,右边是GPU,我们需要把数据从CPU传到GPU中。在GPU中,就会生成对应的Grid来进行计算,每个Grid里面又有很多的block,从block中看又有很多的线程thread进行运算。我们也可以想象一下,如果计算一个矩阵的加法,我们可以让每个thread做对应的元素的相加,这样就可以大大加快计算速度,达到并行的效果。所以这之中内核(kernel)是CUDA编程中一个重要的部分,其代码在GPU上运行,比如矩阵乘法,我们就可以写一个加法的核函数,然后串行执行核函数,这样我们就快速能完成CUDA代码的编写,而不用在创建和管理大量的GPU线程时拘泥于细节。 235 | 236 | ![Thread Mapping](https://nyu-cds.github.io/python-gpu/fig/02-threadmapping.png) 237 | 238 | 239 | 240 | 所以我们可以发现,实际上CUDA的计算是`Grid`——》`Block`——》`Thread`,然后用多个`Thread`进行计算,这里面可能会疑惑,为什么不是直接`Grid`——》`Thread`,实际上是因为硬件的限制是`Block`上限是$(2^{31}-1)*2^{16}*2^{16}$,`Thread`的上限是1024,所以这样的组合设计能够利用好更多的`Thread`,这也是为什么GPU速度那么快的原因。 241 | 242 | 243 | 244 | ## 三线性插值问题定义 245 | 246 | 有关于线性插值和三线性插值的介绍,可以从这部分资料进行了解,[https://zhuanlan.zhihu.com/p/77496615](https://zhuanlan.zhihu.com/p/77496615),[https://blog.csdn.net/webzhuce/article/details/86585489](https://blog.csdn.net/webzhuce/article/details/86585489),这样我们就知道三线性插值的概念,和大概的思路,这样我们就可以进行一个CUDA的实现了。 247 | 248 | ![三线性插值](https://img-blog.csdnimg.cn/20190121221016700.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlYnpodWNl,size_16,color_FFFFFF,t_70) 249 | 250 | 从三线性插值的概念我们可以知道,我们需要传入两个参数 251 | 252 | - feats(N, 8, F):N个立方体,每个立方体有8个点,每个点有F个特征 253 | - Points(N,3):N个点,每个点的坐标分别是xyz,一共有三个维度 254 | 255 | 我们也可以知道输出的参数为`feat_interp(N, F)`,也就是插值后的结果 256 | 257 | 从上述定义我们就可以知道,我们有两个部分可以进行平行运算,分别是N和F,因为它们是独立的,不会相互影响计算。 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | ## C++调用CUDA函数 266 | 267 | 首先我们先看看,怎么使用CUDA去进行编程,首先CUDA的代码是`cu`结尾的,我们通过编写CUDA来进行一个计算加速。我们还是先按之前的方法,看看如何使用CUDA进行编程先,这里面有几个注意的点: 268 | 269 | - 需要编写CUDA函数 270 | - 需要在头文件`.h`中去定义需要使用的函数,包括一些常用的测试函数。 271 | - 修改CUDA的`setup.py` 272 | 273 | 接下来一步一步来,首先写一个`interpolation_kernel.cu`函数,也就是一个CUDA函数,后面我们可以用C++调用CUDA,这里面还是直接返回feats 274 | 275 | ```c++ 276 | #include 277 | 278 | torch::Tensor trilinear_fw_cu( 279 | torch::Tensor feats, 280 | torch::Tensor points 281 | ){ 282 | return feats; 283 | } 284 | ``` 285 | 286 | 接下来,我们就需要在头文件`utils.h `中去定义我们在文件中需要使用的函数,类似于原本C++的一个声明和定义函数 287 | 288 | ```c++ 289 | #include 290 | 291 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 292 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 293 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 294 | 295 | // 声明和定义函数 296 | torch::Tensor trilinear_fw_cu( 297 | torch::Tensor feats, 298 | torch::Tensor points 299 | ); 300 | ``` 301 | 302 | 这样编写以后,我们的cpp的代码就可以调用CUDA对函数来进行调用,但是由于我们使用CUDA的函数,所以这里面我们还要用到`CHECK_INPUT`函数来判断是否在GPU上,也就是一个检测,并且内存是否连续,因为后续要进行一个并行的计算。 303 | 304 | ```C++ 305 | #include "utils.h" 306 | 307 | 308 | torch::Tensor trilinear_interpolation( 309 | torch::Tensor feats, 310 | torch::Tensor points 311 | ){ 312 | CHECK_INPUT(feats); 313 | CHECK_INPUT(points); 314 | 315 | return trilinear_fw_cu(feats, points); 316 | } 317 | 318 | 319 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 320 | m.def("trilinear_interpolation", &trilinear_interpolation); 321 | } 322 | ``` 323 | 324 | 完成上述编写之后,我们最后就只剩下`setup.py`函数需要修改,其实需要修改的东西非常有限,只需要将上述的`CPPExtension`改为`CUDAExtension`,也就是改成CUDA的编译,这里面还有比较好的方法就是,之前我们source需要自己写,但是当我们有很多个文件的时候,我们就可以自动检索文件夹下的cpp和cu文件,进行build即可得到最后的结果。 325 | 326 | ```python 327 | import glob 328 | import os.path as osp 329 | from setuptools import setup 330 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 331 | 332 | 333 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 334 | include_dirs = [osp.join(ROOT_DIR, "include")] # 得到include文件夹下所有的头文件.h 335 | 336 | sources = glob.glob('*.cpp')+glob.glob('*.cu') # 得到当前文件夹下所有cu文件和cpp文件 337 | 338 | 339 | setup( 340 | name='cppcuda_tutorial', 341 | version='1.0', 342 | author='xxx', 343 | author_email='xxx@gmail.com', 344 | description='cppcuda_tutorial', 345 | long_description='cppcuda_tutorial', 346 | ext_modules=[ 347 | CUDAExtension( 348 | name='cppcuda_tutorial', 349 | sources=sources, 350 | include_dirs=include_dirs, 351 | extra_compile_args={'cxx': ['-O2'], 352 | 'nvcc': ['-O2']} 353 | ) 354 | ], 355 | cmdclass={ 356 | 'build_ext': BuildExtension 357 | } 358 | ) 359 | ``` 360 | 361 | 安装过后,我们就可以测试是否使用CUDA进行计算,唯一不同的是,由于我们是使用CUDA进行计算,所以我们要把数据转到CUDA中即可 362 | 363 | ```python 364 | import torch 365 | import cppcuda_tutorial 366 | 367 | 368 | if __name__ == '__main__': 369 | 370 | feats = torch.ones(2, device='cuda') 371 | points = torch.zeros(2, device='cuda') 372 | 373 | out = cppcuda_tutorial.trilinear_interpolation(feats, points) 374 | 375 | print(out) 376 | ``` 377 | 378 | > 在这里,可能第一次学习会觉得比较麻烦,但是实际上有一些函数,比如CHECK的函数和setup.py的函数,只要写了一次以后,之后都是可以参考复用的,不用重复写 379 | 380 | 381 | 382 | ## 三线性插值CUDA实现 383 | 384 | 接下来就是主要的三线性插值的CUDA实现了,在前面的CUDA加速中有说到,实际上我们是希望在每一个thread都执行一个单元的计算,在三线性插值中,我们可以知道,我们两个部分需要并行,分别是`N`和`F`两个部分,也就是立方体的个数和特征的个数。 385 | 386 | 我们先看看需要进行编写的函数,然后一步一步的来解释和探索,以下是更新后的`forward`函数 387 | 388 | ```C++ 389 | torch::Tensor trilinear_fw_cu( 390 | torch::Tensor feats, 391 | torch::Tensor points 392 | ){ 393 | const int N = feats.size(0), F = feats.size(2); 394 | // 等价于 feat_interp = torch.zeros(N, F, dtype = torch.float32, device = "cuda:0") 395 | torch::Tensor feat_interp = torch::zeros({N, F}, feats.options()); 396 | 397 | const dim3 threads(16, 16); // 128,256,512 398 | const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y); 399 | 400 | AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_fw_cu", 401 | ([&] { 402 | trilinear_fw_kernel<<>>( 403 | feats.packed_accessor(), 404 | points.packed_accessor(), 405 | feat_interp.packed_accessor() 406 | ); 407 | })); 408 | 409 | return feat_interp; 410 | } 411 | ``` 412 | 413 | 第5行我们得到了对应的维度,分别是N和F,这也是我们最后需要返回值`feat_interp`的维度。 414 | 415 | 第7行我们初始化了变量`feat_interp`,这里面是初始化为zero,在里面还有一个参数是`feats.options()`,在CUDA编程中,`feats.options()`表示获取`feats`张量的选项(options)。选项包括张量的数据类型、设备(设备指定为CUDA或CPU)以及其他相关的配置信息。通过使用`feats.options()`,可以确保新创建的`feat_interp`张量与`feats`张量具有相同的选项,以便在相同的设备上进行操作,并保持一致性。 416 | 417 | 简单来说,就是保持一致的设备等等,这样就方便后续在同一个设备进行计算,和pytorch需要放在cpu和cuda上是一样的,除此之外,还有一些另外的写法,比如是创建一个整型的,可以写成如下,一样的意思。 418 | 419 | ```c++ 420 | torch::zeros({N,F},torch::dtype(torch::kInt32).device(feats,device)); 421 | ``` 422 | 423 | 第9行和第10行就是定义上述提过的`threads`和`blocks`了,dim3是NVIDIA的CUDA编程中一种自定义的整型向量类型,基于用于指定维度的uint3,dim3类型最终设置的是一个三维向量,三维参数分别为x,y,z。在并行中,通常只支持三个并行,比如这里的N和F刚刚好就是两个并行,这里设置为16x16的线程,一般可以是128,256,512,不一定使用越多越好,这里面只是给了一个例子。 424 | 425 | 第10行中是定义了`blocks`的计算,`blocks`的个数实际上是计算的得到的,如下图所示,如果N=10,F=20,我们最后的输出就是10x20,这样我们就会用一个16x16的block去覆盖这整个矩阵,我们会发现大概需要2个矩阵,所以我们的`blocks`就是(2,1),从下图也可以看出来,所以上述公式就是计算`block`的个数,这样我们就可以用每一个`thread`去计算,这样就能大大加快速度。 426 | 427 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/3dfa0f2ed94b4d9d9297a90e55a59662.png) 428 | 429 | 第12~19行就是CUDA的核心函数,这里面就是一个启动核函数的部分,后面会提到核函数的编写,这里也是一个框架的部分,`AT_DISPATCH_FLOATING_TYPES` 是处理核函数的启动(使用 `<<<...>>>` 表示),它一般有三个参数 430 | 431 | - 一个类型 feats.type() 432 | - 一个名称 "trilinear_fw_cu",用于错误消息 433 | - 一个 lambda 函数,是一个模版函数template,类型别名为 `scalar_t` 434 | 435 | 在这里面可以看出处理的是float类型的数据,如果想对所有类型进行操作而不仅仅是浮点类型(Float 和 Double),可以使用 `AT_DISPATCH_ALL_TYPES`。 436 | 437 | ### scalar_t类型 438 | 439 | 在函数之中,我们可以看到我们有三个input,其中两个是三线性插值的input,一个是output,为什么呢,是因为其实这个函数是没有返回值的,所以说实际上是在函数里面计算后复制在output之中,最后进行返回。 440 | 441 | 我们来仔细了解了一下具体函数的编写是什么意思,首先是`scalar_t`其实是一种类型,他可以表示任何类型,包括整型,浮点型等等,如果我们确定数据是float,我们也可以直接将`scalar_t`写为`float`,那么他可能就只能处理浮点型的数据了,从下面也可以看到`scalar_t`的一个简单的实现。 442 | 443 | ```python 444 | switch (tensor.type().scalarType()) { 445 | case torch::ScalarType::Double: 446 | return function(tensor.data()); 447 | case torch::ScalarType::Float: 448 | return function(tensor.data()); 449 | ... 450 | } 451 | ``` 452 | 453 | --- 454 | 455 | ### accessors 456 | 457 | 在CUDA计算的时候,还有一个问题,在 CUDA 核函数内部,虽然我们能正确处理数据,直接使用高级类型`scalar_t`不可知的张量将非常低效,因为这是以易用性和可读性为代价的,特别是对于高维数据。 458 | 459 | 比如说,在数据中,我们这里有(N, F)个数据,那我们有没有快速的方法去读取到`feat_interp[i][j]`的数据呢,特别是有些一般是三个维度的,比如(bs,row,index)这样的,并且有时候我们还需要知道stride才能快速索引到位置,比如`gates.data()[n*3*state_size + row*state_size + column]` 460 | 461 | 在这里面,我们可能就需要用到一个`ATen`提供的`accessors`,他可以动态检查确保张量具有指定的类型和维度数量,器提供了一个 API,用于高效地访问张量元素,而无需转换为单个指针,就可以高效访问 cpu 张量上的数据,cuda我们就可以用`packed_accessor`。 462 | 463 | ```c++ 464 | torch::Tensor foo = torch::rand({12, 12}); 465 | 466 | // 确定 foo 是二维的并且包含浮点数。 467 | auto foo_a = foo.accessor(); 468 | float trace = 0; 469 | 470 | for(int i = 0; i < foo_a.size(0); i++) { 471 | // 使用访问器 foo_a 来获取张量数据。 472 | trace += foo_a[i][i]; 473 | } 474 | ``` 475 | 476 | 所以我们在核函数内部看到了`packed_accessor`,这一部分就是做这样一件事情,**不过值得注意的是,只有对torch的向量我们需要这样的操作,如果是bool等,我们是不需要处理的。** 477 | 478 | --- 479 | 480 | ### 模板函数 481 | 482 | 上述有提到,实际上我们的`trilinear_fw_kernel`是一个模板函数,我们接下来看一下具体的实现,我们利用`scalar_t`对其进行实例化,我们在这里再解释一下这个模板函数的参数部分: 483 | 484 | - 首先是 `scalar_t`,它是一个模板参数,代表张量的数据类型。在这个上下文中,通常会使用 `float` 或 `double` 作为 `scalar_t`,具体取决于张量的数据类型。 485 | 486 | - 接下来是 `3`,它表示张量的维度数量。在这个例子中,我们的feats的维度是3,所以维度数量为 3。 487 | 488 | - 然后是 `torch::RestrictPtrTraits`,它是一个模板参数,用于指定指针的限定符。`__restrict__` 关键字在 CUDA 中用于指示指针是唯一的,并且没有别名。这有助于编译器进行优化,提高代码的性能。 489 | 490 | - 最后是 `PackedTensorAccessor`,它是一个访问器(accessor)的变体,用于存储大小和步幅信息,可以使得在访问器对象传递给 CUDA 核函数时,内存传输的数据量也更小。 491 | 492 | 接下来我们仔细分析里面代码的细节,首先介绍一下这个`__global__`,他实际意义如下,如下图所示 493 | 494 | - `__global__`表示CPU上定义,GPU上执行,是CUDA的关键字 495 | 496 | - `__device__` GPU定 义,GPU执行 497 | 498 | - `__host__` CPU定义,CPU执行 499 | 500 | ![CUDA编程之快速入门- 最难不过二叉树- 博客园](https://img2018.cnblogs.com/blog/1093303/201809/1093303-20180919123125957-1702896390.png) 501 | 502 | 接下里分析函数的主体部分,主要做两件事情: 503 | 504 | 1. 为每个`threads`进行编号 505 | 2. 去除不必要的`threads` 506 | 507 | 在使用`threads`计算的时候,实际上每一个`threads`都有一个对应的编号,计算方式如第7,8行所示,实际上就是block的x*block的个数+block的y就能得到最后的结果 508 | 509 | 除了编号之外,还有去除不必要的`threads`,因为有一部分是没有覆盖到的,比如如下图的黄色部分就是不必要的`threads`,所以在第10行进行判断,如果超过范围,直接return不计算 510 | 511 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/6daed7e700bb4fb6927b4a8a89b8a1ea.png) 512 | 513 | 最后就是上述说明的三线性插值的做法了,先进行一个标准化,然后代入公式进行计算,最后将值写入feat_interp中,就完成了整个模板函数的编写,大功告成!!! 514 | 515 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20190121221044883.png) 516 | 517 | ```C++ 518 | template 519 | __global__ void trilinear_fw_kernel( 520 | const torch::PackedTensorAccessor feats, 521 | const torch::PackedTensorAccessor points, 522 | torch::PackedTensorAccessor feat_interp 523 | ){ 524 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 525 | const int f = blockIdx.y * blockDim.y + threadIdx.y; 526 | 527 | if (n>=feats.size(0) || f>=feats.size(2)) return; 528 | 529 | // point -1~1 530 | const scalar_t u = (points[n][0]+1)/2; 531 | const scalar_t v = (points[n][1]+1)/2; 532 | const scalar_t w = (points[n][2]+1)/2; 533 | 534 | const scalar_t a = (1-v)*(1-w); 535 | const scalar_t b = (1-v)*w; 536 | const scalar_t c = v*(1-w); 537 | const scalar_t d = 1-a-b-c; 538 | feat_interp[n][f] = (1-u)*(a*feats[n][0][f] + 539 | b*feats[n][1][f] + 540 | c*feats[n][2][f] + 541 | d*feats[n][3][f]) + 542 | u*(a*feats[n][4][f] + 543 | b*feats[n][5][f] + 544 | c*feats[n][6][f] + 545 | d*feats[n][7][f]); 546 | } 547 | ``` 548 | 549 | ### foward验证与比较 550 | 551 | 经过`python setup.py install`以后(每次修改后都要重新运行`setup.py`),我们就可以进行运行了,在这里面为了验证结果的正确性和与python进行比较,用python实现三线性插值的算法,比较两者的结果和时间效率,`test.py`如下: 552 | 553 | ```python 554 | import torch 555 | import cppcuda_tutorial 556 | import time 557 | 558 | 559 | def trilinear_interpolation_py(feats, points): 560 | """ 561 | Inputs: 562 | feats: (N, 8, F) 563 | points: (N, 3) local coordinates in [-1, 1] 564 | 565 | Outputs: 566 | feats_interp: (N, F) 567 | """ 568 | u = (points[:, 0:1]+1)/2 569 | v = (points[:, 1:2]+1)/2 570 | w = (points[:, 2:3]+1)/2 571 | a = (1-v)*(1-w) 572 | b = (1-v)*w 573 | c = v*(1-w) 574 | d = 1-a-b-c 575 | 576 | feats_interp = (1-u)*(a*feats[:, 0] + 577 | b*feats[:, 1] + 578 | c*feats[:, 2] + 579 | d*feats[:, 3]) + \ 580 | u*(a*feats[:, 4] + 581 | b*feats[:, 5] + 582 | c*feats[:, 6] + 583 | d*feats[:, 7]) 584 | 585 | return feats_interp 586 | 587 | 588 | if __name__ == '__main__': 589 | N = 65536; F = 256 590 | feats = torch.rand(N, 8, F, device='cuda').requires_grad_() 591 | points = torch.rand(N, 3, device='cuda')*2-1 592 | 593 | t = time.time() 594 | out_cuda = cppcuda_tutorial.trilinear_interpolation(feats, points) 595 | torch.cuda.synchronize() 596 | print(' cuda time', time.time()-t, 's') 597 | 598 | t = time.time() 599 | out_py = trilinear_interpolation_py(feats, points) 600 | torch.cuda.synchronize() 601 | print('pytorch time', time.time()-t, 's') 602 | 603 | print(torch.allclose(out_py, out_cuda)) # 判断两者的差异 604 | ``` 605 | 606 | 经过运行和计算后,我们会得到以下结果,我们可以看到,CUDA是明显比Pytorch更快的,并且两者的计算结果也是一样的,如果需要更好的计算两者的速度的话,可能需要进行循环运行计算取平均更有可信度。 607 | 608 | ```bash 609 | cuda time 0.02436351776123047 s 610 | pytorch time 0.04364943504333496 s 611 | True 612 | ``` 613 | 614 | ## CUDA反向传播 615 | 616 | 在上述实验中,当我们尝试添加自动求导的梯度计算时,使用`requires_grad_`,我们会发现通过CUDA返回的值实际上不会自动进行梯度计算(autograd)。然而,如果我们在Python中进行计算,它会自动进行梯度计算。 617 | 618 | 然而,在实际应用中,神经网络经常需要计算损失函数,并使用梯度下降等优化算法来不断优化参数。但是,在CUDA编程中,C++扩展API并没有提供自动求导(autograd)的方法。因此,我们必须自己实现反向传播的代码,计算每个输入的导数,并将其封装在`torch.autograd.Function`中。 619 | 620 | 在CUDA编程中,实现反向传播的代码通常包括以下步骤: 621 | 622 | 1. 在C++扩展中,创建一个新类,继承自`torch::autograd::Function`,用于定义前向传播和反向传播操作。 623 | 2. 在新类中,重写`forward()`方法,定义前向传播的操作。这些操作将使用CUDA执行计算,并返回结果,其实就是上述的cuda的部分。 624 | 3. 在新类中,重写`backward()`方法,定义反向传播的操作。这些操作将计算输入张量的梯度,并传递给上一层。 625 | 4. 在CUDA和C++中,编写对应的`forward`和`backward`函数,计算前向传播和微分。 626 | 5. 在Python代码中,使用这个自定义函数执行前向传播,并通过调用`backward()`方法执行反向传播。 627 | 6. 在反向传播过程中,梯度将通过CUDA计算,并在每个层之间传递,从而计算出每个输入的导数。 628 | 629 | 通过这种方式,我们可以在CUDA编程中手动实现反向传播,并获得每个输入的梯度,以便进行优化算法的参数更新。尽管需要手动编写反向传播代码,但这使我们能够在CUDA扩展中自定义梯度计算,并与PyTorch的自动求导机制无缝配合。 630 | 631 | 我们还是把三线性插值作为我们的一个例子,编写对应的反向传播,根据问题设定,我们可以知道我们的Points是固定的,所以我们需要对我们的Feats进行求微分。我们以简单的双线性插值来学习一下,怎么求微分,这里面其实涉及高数的知识,当然我们也可以把函数交给一些数学网站帮我们求得结果。**从下图我们可以看到,`f`是双线性插值的结果,我们可以得到对应的四个导数,我们会发现实际上,他们的微分是对应的系数,推导在三线性插值也是一样的,所以他们对应的微分也就是对应的前缀。** 632 | 633 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/c35a354c6fcb49678872728cf04e6306.png) 634 | 635 | 636 | 637 | 在计算反向传播前,我们往往会有一个对应的一个损失`Loss`,然后再进行求微分,这里面其实就用到了高数里面的链式法则,使用链式法则我们就可以得到`L`对每一个`feat`的微分,明白了双线性插值的计算,我们也可以推导在三线性插值中。 638 | 639 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/777b69824dcf4ca0ba257613067fa9be.png) 640 | 641 | ### 定义CUDA函数 642 | 643 | 明白了理论的计算,我们就可以进行对应的实现,首先我们可以编写对应的反向传播的CUDA函数,这一部分实际上和前向传播是一样的,首先我们还是定义反向传播函数,在这里和前向传播函数的不同就是对名字进行了修改,除此之外加入了两个参数,分别是`dL_dfeats`参数和`dL_dfeats`。简单解释一下这些参数,在问题的设定中,我们的`feats`的维度是的(N, 8, F),所以我们的微分`dL_dfeats`的维度是和`feats`是一样的,然后再加入反向传播的核函数中即可;`dL_dfeat_interp`则是根据函数已知的,所以不用计算,直接传参数。 644 | 645 | ```C++ 646 | torch::Tensor trilinear_bw_cu( 647 | const torch::Tensor dL_dfeat_interp, 648 | const torch::Tensor feats, 649 | const torch::Tensor points 650 | ){ 651 | const int N = feats.size(0), F = feats.size(2); 652 | 653 | torch::Tensor dL_dfeats = torch::empty({N, 8, F}, feats.options()); 654 | 655 | const dim3 threads(16, 16); 656 | const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y); 657 | 658 | AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_bw_cu", 659 | ([&] { 660 | trilinear_bw_kernel<<>>( 661 | dL_dfeat_interp.packed_accessor(), 662 | feats.packed_accessor(), 663 | points.packed_accessor(), 664 | dL_dfeats.packed_accessor() 665 | ); 666 | })); 667 | 668 | return dL_dfeats; 669 | } 670 | ``` 671 | 672 | ### 核函数实现微分计算 673 | 674 | 接下来就是主要的核函数的实现,在这一部分我们就需要实现微分的计算,在前面已经介绍了双线性插值的微分的计算,推导在三线性插值是一样的,我们可以根据前向传播的代码,这一部分只需要保留前面的系数✖️对应位置的`dL_dfeat_interp`就可以得到最后的微分值,这一部分跟上述的推导是一模一样的。 675 | 676 | ```C++ 677 | template 678 | __global__ void trilinear_bw_kernel( 679 | const torch::PackedTensorAccessor dL_dfeat_interp, 680 | const torch::PackedTensorAccessor feats, 681 | const torch::PackedTensorAccessor points, 682 | torch::PackedTensorAccessor dL_dfeats 683 | ){ 684 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 685 | const int f = blockIdx.y * blockDim.y + threadIdx.y; 686 | 687 | if (n>=feats.size(0) || f>=feats.size(2)) return; 688 | 689 | // point -1~1 690 | const scalar_t u = (points[n][0]+1)/2; 691 | const scalar_t v = (points[n][1]+1)/2; 692 | const scalar_t w = (points[n][2]+1)/2; 693 | 694 | const scalar_t a = (1-v)*(1-w); 695 | const scalar_t b = (1-v)*w; 696 | const scalar_t c = v*(1-w); 697 | const scalar_t d = 1-a-b-c; 698 | 699 | dL_dfeats[n][0][f] = (1-u)*a*dL_dfeat_interp[n][f]; 700 | dL_dfeats[n][1][f] = (1-u)*b*dL_dfeat_interp[n][f]; 701 | dL_dfeats[n][2][f] = (1-u)*c*dL_dfeat_interp[n][f]; 702 | dL_dfeats[n][3][f] = (1-u)*d*dL_dfeat_interp[n][f]; 703 | dL_dfeats[n][4][f] = u*a*dL_dfeat_interp[n][f]; 704 | dL_dfeats[n][5][f] = u*b*dL_dfeat_interp[n][f]; 705 | dL_dfeats[n][6][f] = u*c*dL_dfeat_interp[n][f]; 706 | dL_dfeats[n][7][f] = u*d*dL_dfeat_interp[n][f]; 707 | } 708 | ``` 709 | 710 | ### PYBIND11绑定函数 711 | 712 | 写好了反向传播函数之后,不要忘记绑定函数,这样我们才能在最后的python中调用对应的函数。 713 | 714 | ```C++ 715 | torch::Tensor trilinear_interpolation_fw( 716 | const torch::Tensor feats, 717 | const torch::Tensor points 718 | ){ 719 | CHECK_INPUT(feats); 720 | CHECK_INPUT(points); 721 | 722 | return trilinear_fw_cu(feats, points); 723 | } 724 | 725 | 726 | torch::Tensor trilinear_interpolation_bw( 727 | const torch::Tensor dL_dfeat_interp, 728 | const torch::Tensor feats, 729 | const torch::Tensor points 730 | ){ 731 | CHECK_INPUT(dL_dfeat_interp); 732 | CHECK_INPUT(feats); 733 | CHECK_INPUT(points); 734 | 735 | return trilinear_bw_cu(dL_dfeat_interp, feats, points); 736 | } 737 | 738 | 739 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 740 | m.def("trilinear_interpolation_fw", &trilinear_interpolation_fw); 741 | m.def("trilinear_interpolation_bw", &trilinear_interpolation_bw); 742 | } 743 | ``` 744 | 745 | ### torch.autograd.Function封装 746 | 747 | 为了使用pytorch的`autograd`我们还差最后一步,就是使用`torch.autograd.Function`进行封装,不然的话不能进行反向传播,会出现一些奇奇怪怪的bug。 748 | 749 | 在下面的代码中,首先,我们需要定义`forward`和`backward`函数,**记得我们都需要定义`@staticmethod`装饰器,这个是一定要的。**接下来我们就可以开始完善`forward`和`backward`函数。在两个函数中,实际上我们就是调用C++扩展写好的函数,这里面唯一一个需要注意的就是`ctx`,实际上这里是`context`的缩写,这里就是表示有什么数据需要进行保存在反向传播中使用到,因为在`backward`我们还要传入对应的`feats`和`points`,所以在这里这两个参数都需要`save_for_backward`。 750 | 751 | 最后的`backward`就更简单了,传入的参数与`forward`返回的参数进行对应,接着我们从`ctx`取出需要用到的参数,从`ctx.saved_tensors`中取出,后续只需要调用对应的C++函数即可,在这里面我们返回了两个参数,分别是`dL_dfeats, None`,这一部分是因为实际上是因为,我们有两个参数,分别`feats, points`,而我们并没有对`points`进行计算微分,所以这里就返回None。 752 | 753 | ```python 754 | class Trilinear_interpolation_cuda(torch.autograd.Function): 755 | @staticmethod 756 | def forward(ctx, feats, points): 757 | feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points) 758 | 759 | ctx.save_for_backward(feats, points) 760 | 761 | return feat_interp 762 | 763 | @staticmethod 764 | def backward(ctx, dL_dfeat_interp): 765 | feats, points = ctx.saved_tensors 766 | 767 | dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points) 768 | 769 | return dL_dfeats, None 770 | ``` 771 | 772 | ### backward验证与比较 773 | 774 | 和上述一样,经过`python setup.py install`以后(每次修改后都要重新运行`setup.py`),我们就可以进行运行了,在这里面为了验证结果的正确性和与pytorch本身的反向传播进行比较,比较两者的结果和时间效率,`test.py`的主函数如下: 775 | 776 | ```python 777 | class Trilinear_interpolation_cuda(torch.autograd.Function): 778 | @staticmethod 779 | def forward(ctx, feats, points): 780 | feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points) 781 | 782 | ctx.save_for_backward(feats, points) 783 | 784 | return feat_interp 785 | 786 | @staticmethod 787 | def backward(ctx, dL_dfeat_interp): 788 | feats, points = ctx.saved_tensors 789 | 790 | dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points) 791 | 792 | return dL_dfeats, None 793 | 794 | 795 | if __name__ == '__main__': 796 | N = 65536; F = 256 797 | rand = torch.rand(N, 8, F, device='cuda') 798 | feats = rand.clone().requires_grad_() 799 | feats2 = rand.clone().requires_grad_() 800 | points = torch.rand(N, 3, device='cuda')*2-1 801 | 802 | t = time.time() 803 | # 调用CUDA计算 804 | out_cuda = Trilinear_interpolation_cuda.apply(feats2, points) 805 | torch.cuda.synchronize() 806 | print(' cuda fw time', time.time()-t, 's') 807 | 808 | t = time.time() 809 | out_py = trilinear_interpolation_py(feats, points) 810 | torch.cuda.synchronize() 811 | print('pytorch fw time', time.time()-t, 's') 812 | 813 | print('fw all close', torch.allclose(out_py, out_cuda)) 814 | 815 | t = time.time() 816 | # CUDA反向传播 817 | loss2 = out_cuda.sum() 818 | loss2.backward() 819 | torch.cuda.synchronize() 820 | print(' cuda bw time', time.time()-t, 's') 821 | 822 | t = time.time() 823 | loss = out_py.sum() 824 | loss.backward() 825 | torch.cuda.synchronize() 826 | print('pytorch bw time', time.time()-t, 's') 827 | 828 | print('bw all close', torch.allclose(feats.grad, feats2.grad)) 829 | ``` 830 | 831 | 经过运行和计算后,我们会得到以下结果,我们可以看到,CUDA和Pytorch前向传播相差不大,但是对于反向传播的效率可以看得出来,结果大概差了10倍所有,CUDA的反向传播还是有一个较为明显的效率提升的。 832 | 833 | ```bash 834 | cuda fw time 0.0033109188079833984 s 835 | pytorch fw time 0.004142045974731445 s 836 | fw all close True 837 | cuda bw time 0.004648447036743164 s 838 | pytorch bw time 0.04614758491516113 s 839 | bw all close True 840 | ``` 841 | 842 | 843 | 844 | ## 参考 845 | 846 | 视频资料: [https://www.youtube.com/watch?v=l_Rpk6CRJYI&list=PLDV2CyUo4q-LKuiNltBqCKdO9GH4SS_ec&ab_channel=AI%E8%91%B5](https://www.youtube.com/watch?v=l_Rpk6CRJYI&list=PLDV2CyUo4q-LKuiNltBqCKdO9GH4SS_ec&ab_channel=AI葵) 847 | 848 | Github:[https://github.com/kwea123/pytorch-cppcuda-tutorial](https://github.com/kwea123/pytorch-cppcuda-tutorial) 849 | 850 | Pytorch官方资料:[https://pytorch.org/cppdocs/](https://pytorch.org/cppdocs/),[https://pytorch.org/tutorials/advanced/cpp_extension.html](https://pytorch.org/tutorials/advanced/cpp_extension.html) 851 | 852 | CUDA doc:[https://nyu-cds.github.io/python-gpu/02-cuda/](https://nyu-cds.github.io/python-gpu/02-cuda/) -------------------------------------------------------------------------------- /include/utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 5 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 6 | 7 | 8 | torch::Tensor trilinear_fw_cu( 9 | const torch::Tensor feats, 10 | const torch::Tensor points 11 | ); 12 | 13 | 14 | torch::Tensor trilinear_bw_cu( 15 | const torch::Tensor dL_dfeat_interp, 16 | const torch::Tensor feats, 17 | const torch::Tensor points 18 | ); -------------------------------------------------------------------------------- /interpolation.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | 4 | torch::Tensor trilinear_interpolation_fw( 5 | const torch::Tensor feats, 6 | const torch::Tensor points 7 | ){ 8 | CHECK_INPUT(feats); 9 | CHECK_INPUT(points); 10 | 11 | return trilinear_fw_cu(feats, points); 12 | } 13 | 14 | 15 | torch::Tensor trilinear_interpolation_bw( 16 | const torch::Tensor dL_dfeat_interp, 17 | const torch::Tensor feats, 18 | const torch::Tensor points 19 | ){ 20 | CHECK_INPUT(dL_dfeat_interp); 21 | CHECK_INPUT(feats); 22 | CHECK_INPUT(points); 23 | 24 | return trilinear_bw_cu(dL_dfeat_interp, feats, points); 25 | } 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 29 | m.def("trilinear_interpolation_fw", &trilinear_interpolation_fw); 30 | m.def("trilinear_interpolation_bw", &trilinear_interpolation_bw); 31 | } 32 | -------------------------------------------------------------------------------- /interpolation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | template 5 | __global__ void trilinear_fw_kernel( 6 | const torch::PackedTensorAccessor feats, 7 | const torch::PackedTensorAccessor points, 8 | torch::PackedTensorAccessor feat_interp 9 | ){ 10 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int f = blockIdx.y * blockDim.y + threadIdx.y; 12 | 13 | if (n>=feats.size(0) || f>=feats.size(2)) return; 14 | 15 | // point -1~1 16 | const scalar_t u = (points[n][0]+1)/2; 17 | const scalar_t v = (points[n][1]+1)/2; 18 | const scalar_t w = (points[n][2]+1)/2; 19 | 20 | const scalar_t a = (1-v)*(1-w); 21 | const scalar_t b = (1-v)*w; 22 | const scalar_t c = v*(1-w); 23 | const scalar_t d = 1-a-b-c; 24 | feat_interp[n][f] = (1-u)*(a*feats[n][0][f] + 25 | b*feats[n][1][f] + 26 | c*feats[n][2][f] + 27 | d*feats[n][3][f]) + 28 | u*(a*feats[n][4][f] + 29 | b*feats[n][5][f] + 30 | c*feats[n][6][f] + 31 | d*feats[n][7][f]); 32 | } 33 | 34 | 35 | torch::Tensor trilinear_fw_cu( 36 | const torch::Tensor feats, 37 | const torch::Tensor points 38 | ){ 39 | const int N = feats.size(0), F = feats.size(2); 40 | 41 | torch::Tensor feat_interp = torch::empty({N, F}, feats.options()); 42 | 43 | const dim3 threads(16, 16); 44 | const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y); 45 | 46 | AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_fw_cu", 47 | ([&] { 48 | trilinear_fw_kernel<<>>( 49 | feats.packed_accessor(), 50 | points.packed_accessor(), 51 | feat_interp.packed_accessor() 52 | ); 53 | })); 54 | 55 | return feat_interp; 56 | } 57 | 58 | 59 | template 60 | __global__ void trilinear_bw_kernel( 61 | const torch::PackedTensorAccessor dL_dfeat_interp, 62 | const torch::PackedTensorAccessor feats, 63 | const torch::PackedTensorAccessor points, 64 | torch::PackedTensorAccessor dL_dfeats 65 | ){ 66 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 67 | const int f = blockIdx.y * blockDim.y + threadIdx.y; 68 | 69 | if (n>=feats.size(0) || f>=feats.size(2)) return; 70 | 71 | // point -1~1 72 | const scalar_t u = (points[n][0]+1)/2; 73 | const scalar_t v = (points[n][1]+1)/2; 74 | const scalar_t w = (points[n][2]+1)/2; 75 | 76 | const scalar_t a = (1-v)*(1-w); 77 | const scalar_t b = (1-v)*w; 78 | const scalar_t c = v*(1-w); 79 | const scalar_t d = 1-a-b-c; 80 | 81 | dL_dfeats[n][0][f] = (1-u)*a*dL_dfeat_interp[n][f]; 82 | dL_dfeats[n][1][f] = (1-u)*b*dL_dfeat_interp[n][f]; 83 | dL_dfeats[n][2][f] = (1-u)*c*dL_dfeat_interp[n][f]; 84 | dL_dfeats[n][3][f] = (1-u)*d*dL_dfeat_interp[n][f]; 85 | dL_dfeats[n][4][f] = u*a*dL_dfeat_interp[n][f]; 86 | dL_dfeats[n][5][f] = u*b*dL_dfeat_interp[n][f]; 87 | dL_dfeats[n][6][f] = u*c*dL_dfeat_interp[n][f]; 88 | dL_dfeats[n][7][f] = u*d*dL_dfeat_interp[n][f]; 89 | } 90 | 91 | 92 | torch::Tensor trilinear_bw_cu( 93 | const torch::Tensor dL_dfeat_interp, 94 | const torch::Tensor feats, 95 | const torch::Tensor points 96 | ){ 97 | const int N = feats.size(0), F = feats.size(2); 98 | 99 | torch::Tensor dL_dfeats = torch::empty({N, 8, F}, feats.options()); 100 | 101 | const dim3 threads(16, 16); 102 | const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y); 103 | 104 | AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_bw_cu", 105 | ([&] { 106 | trilinear_bw_kernel<<>>( 107 | dL_dfeat_interp.packed_accessor(), 108 | feats.packed_accessor(), 109 | points.packed_accessor(), 110 | dL_dfeats.packed_accessor() 111 | ); 112 | })); 113 | 114 | return dL_dfeats; 115 | } 116 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | 10 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 11 | 12 | 13 | setup( 14 | name='cppcuda_tutorial', 15 | version='1.0', 16 | author='kwea123', 17 | author_email='kwea123@gmail.com', 18 | description='cppcuda_tutorial', 19 | long_description='cppcuda_tutorial', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='cppcuda_tutorial', 23 | sources=sources, 24 | include_dirs=include_dirs, 25 | extra_compile_args={'cxx': ['-O2'], 26 | 'nvcc': ['-O2']} 27 | ) 28 | ], 29 | cmdclass={ 30 | 'build_ext': BuildExtension 31 | } 32 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cppcuda_tutorial 3 | import time 4 | 5 | 6 | def trilinear_interpolation_py(feats, points): 7 | """ 8 | Inputs: 9 | feats: (N, 8, F) 10 | points: (N, 3) local coordinates in [-1, 1] 11 | 12 | Outputs: 13 | feats_interp: (N, F) 14 | """ 15 | u = (points[:, 0:1]+1)/2 16 | v = (points[:, 1:2]+1)/2 17 | w = (points[:, 2:3]+1)/2 18 | a = (1-v)*(1-w) 19 | b = (1-v)*w 20 | c = v*(1-w) 21 | d = 1-a-b-c 22 | 23 | feats_interp = (1-u)*(a*feats[:, 0] + 24 | b*feats[:, 1] + 25 | c*feats[:, 2] + 26 | d*feats[:, 3]) + \ 27 | u*(a*feats[:, 4] + 28 | b*feats[:, 5] + 29 | c*feats[:, 6] + 30 | d*feats[:, 7]) 31 | 32 | return feats_interp 33 | 34 | 35 | class Trilinear_interpolation_cuda(torch.autograd.Function): 36 | @staticmethod 37 | def forward(ctx, feats, points): 38 | feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points) 39 | 40 | ctx.save_for_backward(feats, points) 41 | 42 | return feat_interp 43 | 44 | @staticmethod 45 | def backward(ctx, dL_dfeat_interp): 46 | feats, points = ctx.saved_tensors 47 | 48 | dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points) 49 | 50 | return dL_dfeats, None 51 | 52 | 53 | if __name__ == '__main__': 54 | N = 65536; F = 256 55 | rand = torch.rand(N, 8, F, device='cuda') 56 | feats = rand.clone().requires_grad_() 57 | feats2 = rand.clone().requires_grad_() 58 | points = torch.rand(N, 3, device='cuda')*2-1 59 | 60 | t = time.time() 61 | out_cuda = Trilinear_interpolation_cuda.apply(feats2, points) 62 | torch.cuda.synchronize() 63 | print(' cuda fw time', time.time()-t, 's') 64 | 65 | t = time.time() 66 | out_py = trilinear_interpolation_py(feats, points) 67 | torch.cuda.synchronize() 68 | print('pytorch fw time', time.time()-t, 's') 69 | 70 | print('fw all close', torch.allclose(out_py, out_cuda)) 71 | 72 | t = time.time() 73 | loss2 = out_cuda.sum() 74 | loss2.backward() 75 | torch.cuda.synchronize() 76 | print(' cuda bw time', time.time()-t, 's') 77 | 78 | t = time.time() 79 | loss = out_py.sum() 80 | loss.backward() 81 | torch.cuda.synchronize() 82 | print('pytorch bw time', time.time()-t, 's') 83 | 84 | print('bw all close', torch.allclose(feats.grad, feats2.grad)) -------------------------------------------------------------------------------- /test_rji.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | import torch 4 | import time 5 | from torch.utils.cpp_extension import load 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | 10 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 11 | cppcuda_tutorial = load(name="cppcuda_tutorial", 12 | sources=sources, 13 | extra_include_paths=include_dirs,) 14 | 15 | def trilinear_interpolation_py(feats, points): 16 | """ 17 | Inputs: 18 | feats: (N, 8, F) 19 | points: (N, 3) local coordinates in [-1, 1] 20 | 21 | Outputs: 22 | feats_interp: (N, F) 23 | """ 24 | u = (points[:, 0:1]+1)/2 25 | v = (points[:, 1:2]+1)/2 26 | w = (points[:, 2:3]+1)/2 27 | a = (1-v)*(1-w) 28 | b = (1-v)*w 29 | c = v*(1-w) 30 | d = 1-a-b-c 31 | 32 | feats_interp = (1-u)*(a*feats[:, 0] + 33 | b*feats[:, 1] + 34 | c*feats[:, 2] + 35 | d*feats[:, 3]) + \ 36 | u*(a*feats[:, 4] + 37 | b*feats[:, 5] + 38 | c*feats[:, 6] + 39 | d*feats[:, 7]) 40 | 41 | return feats_interp 42 | 43 | 44 | class Trilinear_interpolation_cuda(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, feats, points): 47 | feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points) 48 | 49 | ctx.save_for_backward(feats, points) 50 | 51 | return feat_interp 52 | 53 | @staticmethod 54 | def backward(ctx, dL_dfeat_interp): 55 | feats, points = ctx.saved_tensors 56 | 57 | dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points) 58 | 59 | return dL_dfeats, None 60 | 61 | 62 | if __name__ == '__main__': 63 | N = 65536; F = 256 64 | rand = torch.rand(N, 8, F, device='cuda') 65 | feats = rand.clone().requires_grad_() 66 | feats2 = rand.clone().requires_grad_() 67 | points = torch.rand(N, 3, device='cuda')*2-1 68 | 69 | t = time.time() 70 | out_cuda = Trilinear_interpolation_cuda.apply(feats2, points) 71 | torch.cuda.synchronize() 72 | print(' cuda fw time', time.time()-t, 's') 73 | 74 | t = time.time() 75 | out_py = trilinear_interpolation_py(feats, points) 76 | torch.cuda.synchronize() 77 | print('pytorch fw time', time.time()-t, 's') 78 | 79 | print('fw all close', torch.allclose(out_py, out_cuda)) 80 | 81 | t = time.time() 82 | loss2 = out_cuda.sum() 83 | loss2.backward() 84 | torch.cuda.synchronize() 85 | print(' cuda bw time', time.time()-t, 's') 86 | 87 | t = time.time() 88 | loss = out_py.sum() 89 | loss.backward() 90 | torch.cuda.synchronize() 91 | print('pytorch bw time', time.time()-t, 's') 92 | 93 | print('bw all close', torch.allclose(feats.grad, feats2.grad)) --------------------------------------------------------------------------------