├── .gitignore ├── .vscode └── c_cpp_properties.json ├── README.md ├── include └── utils.h ├── interpolation.cpp ├── interpolation_kernel.cu ├── setup.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/**", 7 | "/home/ubuntu/anaconda3/envs/cppcuda/include/python3.8", 8 | "/home/ubuntu/anaconda3/envs/cppcuda/lib/python3.8/site-packages/torch/include", 9 | "/home/ubuntu/anaconda3/envs/cppcuda/lib/python3.8/site-packages/torch/include/torch/csrc/api/include" 10 | ], 11 | "defines": [], 12 | "compilerPath": "/usr/bin/clang", 13 | "cStandard": "c17", 14 | "cppStandard": "c++14", 15 | "intelliSenseMode": "linux-clang-x64" 16 | } 17 | ], 18 | "version": 4 19 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-cppcuda-tutorial 2 | tutorial for writing custom pytorch cpp+cuda kernel, applied on volume rendering (NeRF) 3 | 4 | tutorial playlist: https://www.youtube.com/watch?v=l_Rpk6CRJYI&list=PLDV2CyUo4q-LKuiNltBqCKdO9GH4SS_ec&ab_channel=AI%E8%91%B5 5 | 6 | 7 | CUDA explanation in video 2: https://nyu-cds.github.io/python-gpu/02-cuda/ 8 | 9 | C++ API in video 3: https://pytorch.org/cppdocs/ 10 | 11 | kernel launching: https://pytorch.org/tutorials/advanced/cpp_extension.html -------------------------------------------------------------------------------- /include/utils.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 5 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 6 | 7 | 8 | torch::Tensor trilinear_fw_cu( 9 | const torch::Tensor feats, 10 | const torch::Tensor points 11 | ); 12 | 13 | 14 | torch::Tensor trilinear_bw_cu( 15 | const torch::Tensor dL_dfeat_interp, 16 | const torch::Tensor feats, 17 | const torch::Tensor points 18 | ); -------------------------------------------------------------------------------- /interpolation.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | 3 | 4 | torch::Tensor trilinear_interpolation_fw( 5 | const torch::Tensor feats, 6 | const torch::Tensor points 7 | ){ 8 | CHECK_INPUT(feats); 9 | CHECK_INPUT(points); 10 | 11 | return trilinear_fw_cu(feats, points); 12 | } 13 | 14 | 15 | torch::Tensor trilinear_interpolation_bw( 16 | const torch::Tensor dL_dfeat_interp, 17 | const torch::Tensor feats, 18 | const torch::Tensor points 19 | ){ 20 | CHECK_INPUT(dL_dfeat_interp); 21 | CHECK_INPUT(feats); 22 | CHECK_INPUT(points); 23 | 24 | return trilinear_bw_cu(dL_dfeat_interp, feats, points); 25 | } 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 29 | m.def("trilinear_interpolation_fw", &trilinear_interpolation_fw); 30 | m.def("trilinear_interpolation_bw", &trilinear_interpolation_bw); 31 | } 32 | -------------------------------------------------------------------------------- /interpolation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | template 5 | __global__ void trilinear_fw_kernel( 6 | const torch::PackedTensorAccessor feats, 7 | const torch::PackedTensorAccessor points, 8 | torch::PackedTensorAccessor feat_interp 9 | ){ 10 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int f = blockIdx.y * blockDim.y + threadIdx.y; 12 | 13 | if (n>=feats.size(0) || f>=feats.size(2)) return; 14 | 15 | // point -1~1 16 | const scalar_t u = (points[n][0]+1)/2; 17 | const scalar_t v = (points[n][1]+1)/2; 18 | const scalar_t w = (points[n][2]+1)/2; 19 | 20 | const scalar_t a = (1-v)*(1-w); 21 | const scalar_t b = (1-v)*w; 22 | const scalar_t c = v*(1-w); 23 | const scalar_t d = 1-a-b-c; 24 | feat_interp[n][f] = (1-u)*(a*feats[n][0][f] + 25 | b*feats[n][1][f] + 26 | c*feats[n][2][f] + 27 | d*feats[n][3][f]) + 28 | u*(a*feats[n][4][f] + 29 | b*feats[n][5][f] + 30 | c*feats[n][6][f] + 31 | d*feats[n][7][f]); 32 | } 33 | 34 | 35 | torch::Tensor trilinear_fw_cu( 36 | const torch::Tensor feats, 37 | const torch::Tensor points 38 | ){ 39 | const int N = feats.size(0), F = feats.size(2); 40 | 41 | torch::Tensor feat_interp = torch::empty({N, F}, feats.options()); 42 | 43 | const dim3 threads(16, 16); 44 | const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y); 45 | 46 | AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_fw_cu", 47 | ([&] { 48 | trilinear_fw_kernel<<>>( 49 | feats.packed_accessor(), 50 | points.packed_accessor(), 51 | feat_interp.packed_accessor() 52 | ); 53 | })); 54 | 55 | return feat_interp; 56 | } 57 | 58 | 59 | template 60 | __global__ void trilinear_bw_kernel( 61 | const torch::PackedTensorAccessor dL_dfeat_interp, 62 | const torch::PackedTensorAccessor feats, 63 | const torch::PackedTensorAccessor points, 64 | torch::PackedTensorAccessor dL_dfeats 65 | ){ 66 | const int n = blockIdx.x * blockDim.x + threadIdx.x; 67 | const int f = blockIdx.y * blockDim.y + threadIdx.y; 68 | 69 | if (n>=feats.size(0) || f>=feats.size(2)) return; 70 | 71 | // point -1~1 72 | const scalar_t u = (points[n][0]+1)/2; 73 | const scalar_t v = (points[n][1]+1)/2; 74 | const scalar_t w = (points[n][2]+1)/2; 75 | 76 | const scalar_t a = (1-v)*(1-w); 77 | const scalar_t b = (1-v)*w; 78 | const scalar_t c = v*(1-w); 79 | const scalar_t d = 1-a-b-c; 80 | 81 | dL_dfeats[n][0][f] = (1-u)*a*dL_dfeat_interp[n][f]; 82 | dL_dfeats[n][1][f] = (1-u)*b*dL_dfeat_interp[n][f]; 83 | dL_dfeats[n][2][f] = (1-u)*c*dL_dfeat_interp[n][f]; 84 | dL_dfeats[n][3][f] = (1-u)*d*dL_dfeat_interp[n][f]; 85 | dL_dfeats[n][4][f] = u*a*dL_dfeat_interp[n][f]; 86 | dL_dfeats[n][5][f] = u*b*dL_dfeat_interp[n][f]; 87 | dL_dfeats[n][6][f] = u*c*dL_dfeat_interp[n][f]; 88 | dL_dfeats[n][7][f] = u*d*dL_dfeat_interp[n][f]; 89 | } 90 | 91 | 92 | torch::Tensor trilinear_bw_cu( 93 | const torch::Tensor dL_dfeat_interp, 94 | const torch::Tensor feats, 95 | const torch::Tensor points 96 | ){ 97 | const int N = feats.size(0), F = feats.size(2); 98 | 99 | torch::Tensor dL_dfeats = torch::empty({N, 8, F}, feats.options()); 100 | 101 | const dim3 threads(16, 16); 102 | const dim3 blocks((N+threads.x-1)/threads.x, (F+threads.y-1)/threads.y); 103 | 104 | AT_DISPATCH_FLOATING_TYPES(feats.type(), "trilinear_bw_cu", 105 | ([&] { 106 | trilinear_bw_kernel<<>>( 107 | dL_dfeat_interp.packed_accessor(), 108 | feats.packed_accessor(), 109 | points.packed_accessor(), 110 | dL_dfeats.packed_accessor() 111 | ); 112 | })); 113 | 114 | return dL_dfeats; 115 | } 116 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 5 | 6 | 7 | ROOT_DIR = osp.dirname(osp.abspath(__file__)) 8 | include_dirs = [osp.join(ROOT_DIR, "include")] 9 | 10 | sources = glob.glob('*.cpp')+glob.glob('*.cu') 11 | 12 | 13 | setup( 14 | name='cppcuda_tutorial', 15 | version='1.0', 16 | author='kwea123', 17 | author_email='kwea123@gmail.com', 18 | description='cppcuda_tutorial', 19 | long_description='cppcuda_tutorial', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='cppcuda_tutorial', 23 | sources=sources, 24 | include_dirs=include_dirs, 25 | extra_compile_args={'cxx': ['-O2'], 26 | 'nvcc': ['-O2']} 27 | ) 28 | ], 29 | cmdclass={ 30 | 'build_ext': BuildExtension 31 | } 32 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cppcuda_tutorial 3 | import time 4 | 5 | 6 | def trilinear_interpolation_py(feats, points): 7 | """ 8 | Inputs: 9 | feats: (N, 8, F) 10 | points: (N, 3) local coordinates in [-1, 1] 11 | 12 | Outputs: 13 | feats_interp: (N, F) 14 | """ 15 | u = (points[:, 0:1]+1)/2 16 | v = (points[:, 1:2]+1)/2 17 | w = (points[:, 2:3]+1)/2 18 | a = (1-v)*(1-w) 19 | b = (1-v)*w 20 | c = v*(1-w) 21 | d = 1-a-b-c 22 | 23 | feats_interp = (1-u)*(a*feats[:, 0] + 24 | b*feats[:, 1] + 25 | c*feats[:, 2] + 26 | d*feats[:, 3]) + \ 27 | u*(a*feats[:, 4] + 28 | b*feats[:, 5] + 29 | c*feats[:, 6] + 30 | d*feats[:, 7]) 31 | 32 | return feats_interp 33 | 34 | 35 | class Trilinear_interpolation_cuda(torch.autograd.Function): 36 | @staticmethod 37 | def forward(ctx, feats, points): 38 | feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points) 39 | 40 | ctx.save_for_backward(feats, points) 41 | 42 | return feat_interp 43 | 44 | @staticmethod 45 | def backward(ctx, dL_dfeat_interp): 46 | feats, points = ctx.saved_tensors 47 | 48 | dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points) 49 | 50 | return dL_dfeats, None 51 | 52 | 53 | if __name__ == '__main__': 54 | N = 65536; F = 256 55 | rand = torch.rand(N, 8, F, device='cuda') 56 | feats = rand.clone().requires_grad_() 57 | feats2 = rand.clone().requires_grad_() 58 | points = torch.rand(N, 3, device='cuda')*2-1 59 | 60 | t = time.time() 61 | out_cuda = Trilinear_interpolation_cuda.apply(feats2, points) 62 | torch.cuda.synchronize() 63 | print(' cuda fw time', time.time()-t, 's') 64 | 65 | t = time.time() 66 | out_py = trilinear_interpolation_py(feats, points) 67 | torch.cuda.synchronize() 68 | print('pytorch fw time', time.time()-t, 's') 69 | 70 | print('fw all close', torch.allclose(out_py, out_cuda)) 71 | 72 | t = time.time() 73 | loss2 = out_cuda.sum() 74 | loss2.backward() 75 | torch.cuda.synchronize() 76 | print(' cuda bw time', time.time()-t, 's') 77 | 78 | t = time.time() 79 | loss = out_py.sum() 80 | loss.backward() 81 | torch.cuda.synchronize() 82 | print('pytorch bw time', time.time()-t, 's') 83 | 84 | print('bw all close', torch.allclose(feats.grad, feats2.grad)) --------------------------------------------------------------------------------