├── README.md ├── conv_cuda ├── conv.py ├── conv_cuda.cpp ├── jit.py └── setup.py ├── requirements.txt └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # Implemented convolution based on CUDA extensions in PyTorch 2 | 3 | A convolution implementation based on cuda extension for PyTorch. The source code reference to the PyTorch's inefficient implementation [here](https://github.com/pytorch/pytorch/blob/master/aten/src/THCUNN/generic/SpatialConvolutionMM.cu). See [here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial. 4 | 5 | - Build CUDA extensions by going into the `conv_cuda/` folder and executing `python setup.py install`, 6 | - JIT-compile CUDA extensions by going into the `conv_cuda/` folder and calling `python jit.py`, which will JIT-compile the extension and load it, 7 | - Check the result of the convolution by running `python test.py` 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /conv_cuda/conv.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | import conv_cuda 5 | 6 | 7 | class ConvFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input, weights, bias, params): 10 | 11 | dW, dH, padW, padH, is_bias = int(params[0]), int(params[1]), int(params[2]), int(params[3]), int(params[4]) 12 | kW, kH = weights.shape[2], weights.shape[3] 13 | 14 | outputs = conv_cuda.forward(input, weights, bias, kW, kH, dW, dH, padW, padH, is_bias)[0] 15 | 16 | variables = [input, weights, bias, params] 17 | ctx.save_for_backward(*variables) 18 | 19 | return outputs 20 | 21 | @staticmethod 22 | def backward(ctx, gradOutput): 23 | _ = torch.autograd.Variable(torch.zeros(5)) 24 | 25 | input, weights, bias, params = ctx.saved_tensors 26 | 27 | dW, dH, padW, padH, is_bias = int(params[0]), int(params[1]), int(params[2]), int(params[3]), int(params[4]) 28 | kW, kH = weights.shape[2], weights.shape[3] 29 | 30 | gradInput, gradWeight, gradBias = conv_cuda.backward(input, gradOutput, weights, 31 | kW, kH, dW, dH, padW, padH, is_bias) 32 | return gradInput, gradWeight, gradBias, _ 33 | 34 | 35 | class Conv(nn.Module): 36 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1, is_bias=True): 37 | super(Conv, self).__init__() 38 | 39 | self.in_channels = in_channels 40 | self.out_channels = out_channels 41 | self.kernel_size = kernel_size 42 | self.padding = padding 43 | self.stride = stride 44 | self.dilation = dilation 45 | self.is_bias = is_bias 46 | 47 | self.params = torch.autograd.Variable(torch.Tensor([stride, stride, padding, padding, is_bias])).cuda() 48 | 49 | self.weight = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size).cuda()) 50 | self.bias = nn.Parameter(torch.empty(out_channels).cuda()) 51 | 52 | self._initialize_weights() 53 | 54 | def _initialize_weights(self): 55 | nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') 56 | if self.is_bias: 57 | nn.init.constant_(self.bias, 0) 58 | 59 | def forward(self, input): 60 | return ConvFunction.apply(input, self.weight, self.bias, self.params) 61 | -------------------------------------------------------------------------------- /conv_cuda/conv_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | // C++ interface 8 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 9 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | std::vector conv_forward(torch::Tensor input, 14 | torch::Tensor weights, 15 | torch::Tensor bias, 16 | int64_t kW, int64_t kH, 17 | int64_t dW, int64_t dH, 18 | int64_t padW, int64_t padH, bool is_bias) { 19 | 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(weights); 22 | CHECK_INPUT(bias); 23 | 24 | // std::cout< backward_gradParameters(torch::Tensor input, 104 | torch::Tensor gradOutput, 105 | torch::Tensor weights, 106 | int64_t kW, int64_t kH, 107 | int64_t dW, int64_t dH, 108 | int64_t padW, int64_t padH, 109 | bool is_bias){ 110 | 111 | int64_t batch_size = input.size(0); 112 | int64_t nInputPlane = input.size(1); 113 | int64_t inputHeight = input.size(2); 114 | int64_t inputWidth = input.size(3); 115 | 116 | int64_t nOutputPlane = gradOutput.size(1); 117 | int64_t outputHeight = gradOutput.size(2); 118 | int64_t outputWidth = gradOutput.size(3); 119 | 120 | torch::Tensor gradWeights = torch::zeros(torch::IntArrayRef({weights.size(0), weights.size(1), 121 | weights.size(2), weights.size(3)})).cuda(); 122 | torch::Tensor gradBias = torch::zeros(torch::IntArrayRef({nOutputPlane})).cuda(); 123 | torch::Tensor ones = torch::ones(torch::IntArrayRef({outputHeight*outputWidth, 1})).cuda(); 124 | 125 | torch::Tensor columns = torch::zeros(torch::IntArrayRef({nInputPlane*kW*kH, outputHeight*outputWidth})).cuda(); 126 | 127 | for(int elt = 0; elt < batch_size; elt++){ 128 | torch::Tensor gradOutput_n = gradOutput[elt]; 129 | gradOutput_n = gradOutput_n.reshape(torch::IntArrayRef({nOutputPlane, outputHeight*outputWidth})).cuda(); 130 | 131 | // columns.dim: (inplanes * kW * kH) * (outHeight * outWidth) 132 | columns = torch::im2col(input[elt].clone(), /*kernel_size=*/torch::IntArrayRef({kW, kH}), 133 | /*dilation=*/torch::IntArrayRef({1, 1}), 134 | /*padding=*/torch::IntArrayRef({padW, padH}), 135 | /*stride=*/torch::IntArrayRef({dW, dH})).t().cuda(); 136 | gradWeights.add_(gradOutput_n.mm(columns).reshape(torch::IntArrayRef({nOutputPlane, nInputPlane, kW, kH})).cuda(), 1); 137 | 138 | if(is_bias){ 139 | gradBias.add_(gradOutput_n.mm(ones).reshape(torch::IntArrayRef({nOutputPlane})), 1); 140 | } 141 | 142 | } 143 | return {gradWeights, gradBias}; 144 | } 145 | 146 | 147 | std::vector conv_backward(torch::Tensor input, 148 | torch::Tensor gradOutput, 149 | torch::Tensor weights, 150 | int64_t kW, int64_t kH, 151 | int64_t dW, int64_t dH, 152 | int64_t padW, int64_t padH, 153 | bool is_bias) { 154 | 155 | CHECK_INPUT(gradOutput); 156 | CHECK_INPUT(weights); 157 | CHECK_INPUT(input); 158 | 159 | torch::Tensor gradInput = backward_gradInput(input, gradOutput, weights, kW, kH, dW, dH, padW, padH); 160 | std::vector gradParas = backward_gradParameters(input, gradOutput, weights, kW, kH, dW, dH, padW, padH, is_bias); 161 | 162 | torch::Tensor gradWeights = gradParas[0]; 163 | torch::Tensor gradBias = gradParas[1]; 164 | 165 | return {gradInput, gradWeights, gradBias}; 166 | 167 | } 168 | 169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 170 | m.def("forward", &conv_forward, "conv forward (CUDA)"); 171 | m.def("backward", &conv_backward, "conv backward (CUDA)"); 172 | } 173 | -------------------------------------------------------------------------------- /conv_cuda/jit.py: -------------------------------------------------------------------------------- 1 | from torch.utils.cpp_extension import load 2 | conv_cuda = load('conv_cuda', sources=['conv_cuda.cpp'], verbose=True) 3 | help(conv_cuda) 4 | 5 | -------------------------------------------------------------------------------- /conv_cuda/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='conv_cuda', 6 | ext_modules=[ 7 | CUDAExtension('conv_cuda', 8 | sources=['conv_cuda.cpp']), 9 | ], 10 | cmdclass={ 11 | 'build_ext': BuildExtension 12 | }) 13 | # , 'conv_cuda_kernel.cu' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from conv_cuda.conv import Conv 3 | 4 | test_data = torch.ones([1, 1, 3, 3]) 5 | 6 | conv = Conv(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1) 7 | conv_torch = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1) 8 | 9 | print('conv based on PyTorch extension: \n'conv(test_data)) 10 | print('conv based on PyTorch: \n', conv_torch(test_data)) --------------------------------------------------------------------------------