├── .gitignore ├── LICENSE ├── README.md ├── benchmark.py ├── cpp ├── __init__.py ├── dense.py ├── linear.cpp └── setup.py ├── cuda ├── __init__.py ├── dense.py ├── linear.cpp ├── setup.py ├── sigmoid_cuda.cpp └── sigmoid_cuda_kernel.cu ├── distributed.py ├── grad_check.py ├── parallel ├── __init__.py └── train.py ├── ptx ├── compile.sh ├── recompile.sh └── sigmoid_cuda_kernal.cu └── python ├── __init__.py └── dense.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python 2 | *.pyc 3 | *.pyd 4 | 5 | # IPython notebook checkpoints 6 | .ipynb_checkpoints 7 | 8 | # Editor temporaries 9 | *.swn 10 | *.swo 11 | *.swp 12 | *~ 13 | 14 | # Sublime Text settings 15 | *.sublime-workspace 16 | *.sublime-project 17 | 18 | # Eclipse Project settings 19 | *.*project 20 | .settings 21 | 22 | # QtCreator files 23 | *.user 24 | 25 | # PyCharm files 26 | .idea 27 | 28 | # Visual Studio Code files 29 | .vscode 30 | .vs 31 | 32 | # OSX dir files 33 | .DS_Store 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018-present, Zhi Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parallel Optimization in PyTorch 2 | 3 | In the project, we first write python code, and then gradually use C++ and CUDA to optimize key operations. Next, we implemented distributed training using the map-allreduce algorithm. I hope this project will help your Pytorch, ATen, CUDA and PTX learning. 4 | 5 | ## How to run 6 | 7 | ### Python Extensions 8 | 9 | Check the grad. 10 | 11 | ``` 12 | # ./ 13 | python grad_check.py py 14 | ``` 15 | 16 | ### C++ Extensions 17 | 18 | Pybind11 is used for Python and C++ interactions. Install these packages: 19 | 20 | ``` 21 | conda install pytest pybind11 22 | ``` 23 | 24 | Enter the C++ folder and compile the code. 25 | 26 | ``` 27 | # ./cpp 28 | python setup.py install 29 | ``` 30 | 31 | Check the grad. 32 | 33 | ``` 34 | # ./ 35 | python grad_check.py cpp 36 | ``` 37 | 38 | ### CUDA Extensions 39 | 40 | Enter the CUDA folder and compile the code. 41 | 42 | ``` 43 | # ./cuda 44 | python setup.py install 45 | ``` 46 | 47 | Check the grad. 48 | 49 | ``` 50 | # ./ 51 | python grad_check.py cuda 52 | ``` 53 | 54 | ### PTX Example 55 | 56 | Enter the PTX folder and compile the code. 57 | 58 | ``` 59 | # ./ptx 60 | sh compile.sh 61 | ``` 62 | 63 | After changing the ` sigmoid_cuda_kernal.ptx` file, recompile your code. 64 | 65 | ``` 66 | # ./ptx 67 | sh recompile.sh 68 | ``` 69 | 70 | Test your result. 71 | 72 | ``` 73 | ./sigmoid_cuda_kernal 74 | ``` 75 | 76 | ### Benchmark 77 | 78 | Select mode to run the benchmark file. 79 | 80 | ``` 81 | python benchmark.py -m py 82 | python benchmark.py -m cpp 83 | python benchmark.py -m cuda 84 | ``` 85 | 86 | ### Ring-AllReduce 87 | 88 | ``` 89 | python distributed.py -m py 90 | python distributed.py -m cpp 91 | python distributed.py -m cuda 92 | ``` 93 | 94 | ## How to write 95 | 96 | After Reading the example of the pytorch official website, I feel that it is really a little difficult for novices to learn CUDA. So I wrote a simple Demo for the students who just started. What we want to achieve is a Dense layer in Tensorflow. If it is not for teaching, you can use Linear + activation functions directly. But this time, we will start with Python and gradually use CPP and CUDA to optimize key operations. 97 | 98 | There are two steps to implementing a Python extension: 99 | 100 | * Implement a Function that completes the definition of forward and backward operations 101 | * Implement a Module that completes the parameters' initialization according to the hyperparameter, and then calls the Function to calculate 102 | 103 | For the operations provided in Pytorch, we don't need a Function and Module will help us to automatically backward. But in order to make us understand better, let's write a Function is defined as follows: 104 | 105 | ``` 106 | class DenseFunction(Function): 107 | @staticmethod 108 | def forward(ctx, input, weight, bias=None): 109 | output = input.mm(weight.t()) 110 | if bias is not None: 111 | output += bias.unsqueeze(0).expand_as(output) 112 | output = torch.sigmoid(output) 113 | ctx.save_for_backward(input, weight, bias, output) 114 | return output 115 | 116 | @staticmethod 117 | def backward(ctx, grad_output): 118 | input, weight, bias, output = ctx.saved_tensors 119 | grad_sigmoid = (1.0 - output) * output 120 | grad_output = grad_sigmoid * grad_output 121 | grad_input = grad_weight = grad_bias = None 122 | if ctx.needs_input_grad[0]: 123 | grad_input = grad_output.mm(weight) 124 | if ctx.needs_input_grad[1]: 125 | grad_weight = grad_output.t().mm(input) 126 | if bias is not None and ctx.needs_input_grad[2]: 127 | grad_bias = grad_output.sum(0).squeeze(0) 128 | return grad_input, grad_weight, grad_bias 129 | ``` 130 | 131 | We need to remember that the number of parameters for forward is the number of outputs for backward. And the number of outputs for forward is the number of parameters for backward. Don't loss them. 132 | After completing the Function definition, the calculations are clear. All the Module has to do is initialize the training parameters based on the hyperparameters. 133 | 134 | ``` 135 | class Dense(Module): 136 | def __init__(self, input_features, output_features, bias=True): 137 | super(Dense, self).__init__() 138 | self.input_features = input_features 139 | self.output_features = output_features 140 | self.weight = Parameter(torch.Tensor(output_features, input_features)) 141 | if bias: 142 | self.bias = Parameter(torch.Tensor(output_features)) 143 | else: 144 | self.register_parameter('bias', None) 145 | self.weight.data.uniform_(-0.1, 0.1) 146 | if bias is not None: 147 | self.bias.data.uniform_(-0.1, 0.1) 148 | 149 | def forward(self, input): 150 | return DenseFunction.apply(input, self.weight, self.bias) 151 | ``` 152 | 153 | Based on the Python version of the custom layer, we extracted the linear part and accelerated it with CPP. 154 | 155 | ``` 156 | import linear_cpp 157 | 158 | class DenseFunction(Function): 159 | @staticmethod 160 | def forward(ctx, input, weight, bias=None): 161 | output = linear_cpp.forward(input, weight, bias) 162 | output = torch.sigmoid(output) 163 | ctx.save_for_backward(input, weight, bias, output) 164 | return output 165 | 166 | @staticmethod 167 | def backward(ctx, grad_output): 168 | input, weight, bias, output = ctx.saved_variables 169 | grad_sigmoid = (1.0 - output) * output 170 | grad_output = grad_sigmoid * grad_output 171 | grad_input, grad_weight, grad_bias = linear_cpp.backward(grad_output, input, weight, bias) 172 | return grad_input, grad_weight, grad_bias 173 | ``` 174 | 175 | linear_cpp is a CPP library that pybind11 compiled and introduced into. We pass the linear operation to CPP by calling forward and backward function of linear_cpp. You will find that the activation part is still achieved by Python (we will use CUDA later). 176 | 177 | The code for CPP is not difficult, we can make it directly with `matmul` function and `add` function provided by ATen. Of course, Pytorch's source code is different, because there are more efficient APIs. 178 | 179 | ``` 180 | at::Tensor linear_forward(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias={}) { 181 | auto output = at::matmul(input, weight.t()); 182 | if (bias.defined()) { 183 | output.add_(bias); 184 | } 185 | return output; 186 | } 187 | 188 | std::vector linear_backword(const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) { 189 | auto grad_input = at::matmul(grad_output, weight); 190 | auto grad_weight = at::matmul(grad_output.t(), input); 191 | auto grad_bias = bias.defined() ? grad_output.sum(0, /*keepdim=*/false) : at::Tensor{}; 192 | return { 193 | grad_input, 194 | grad_weight, 195 | grad_bias 196 | }; 197 | } 198 | 199 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 200 | m.def("forward", &linear_forward, "linear forward"); 201 | m.def("backward", &linear_backword, "linear backward"); 202 | } 203 | ``` 204 | 205 | It should be noted that the CPP cannot save the input and output during forward and keep them for backward. CPP can only accept input and produce output as a pure function. The dirty work of saving variables is done with Python. Parameters is passed to CPP by Python when CPP is called in backward. 206 | 207 | We left the activation function to Python in the previous section. It's time to hand it over to CUDA. To call CUDA we need to use CPP. We can calls CUDA functions to get the result of CUDA calculation with CPP . 208 | 209 | ``` 210 | at::Tensor sigmoid_forward( 211 | at::Tensor input) { 212 | CHECK_INPUT(input); 213 | return sigmoid_cuda_forward(input); 214 | } 215 | 216 | at::Tensor sigmoid_backward( 217 | at::Tensor grad_output, 218 | at::Tensor output) { 219 | CHECK_INPUT(grad_output); 220 | CHECK_INPUT(output); 221 | return sigmoid_cuda_backward( 222 | grad_output, 223 | output); 224 | } 225 | ``` 226 | 227 | The data type of CPP is slightly different from the data type of CUDA. We use the `AT_DISPATCH_FLOATING_TYPES` function to help us pass the CPP parameters to CUDA, and then pass the CUDA results to CPP without manual data type conversion. 228 | 229 | ``` 230 | at::Tensor sigmoid_cuda_forward( 231 | at::Tensor input) { 232 | auto output = at::zeros_like(input); 233 | const dim3 blocks(input.size(0), input.size(1)); 234 | const int threads = 1; 235 | 236 | AT_DISPATCH_FLOATING_TYPES(input.type(), "sigmoid_forward_cuda", ([&] { 237 | sigmoid_cuda_forward_kernel<<>>( 238 | input.data(), 239 | output.data()); 240 | })); 241 | 242 | return output; 243 | } 244 | ``` 245 | 246 | It is worth noting that the operation needs to be called by global and executed by device. 247 | 248 | ``` 249 | ... 250 | template 251 | __device__ __forceinline__ scalar_t sigmoid(scalar_t z) { 252 | return 1.0 / (1.0 + exp(-z)); 253 | } 254 | 255 | template 256 | __global__ void sigmoid_cuda_forward_kernel( 257 | const scalar_t* __restrict__ input, 258 | scalar_t* __restrict__ output) { 259 | const int index = blockIdx.x * blockDim.x + blockIdx.y; 260 | output[index] = sigmoid(input[index]); 261 | } 262 | ... 263 | ``` 264 | 265 | The same is true for backward delivery. In this way, we can easily complete the CUDA extension. 266 | 267 | Good luck for you. 268 | 269 | ## License 270 | 271 | [MIT](http://opensource.org/licenses/MIT) 272 | 273 | Copyright (c) 2018-present, Zhi Zhang 274 | 275 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import time 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import Module, Parameter 9 | from torch.nn import functional as F 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-m', '--mode', choices=['py', 'cpp', 'cuda']) 13 | parser.add_argument('-e', '--epoch', type=int, default=100) 14 | parser.add_argument('-s', '--size', type=int, default=100) 15 | options = parser.parse_args() 16 | 17 | if options.mode == 'py': 18 | from python.dense import Dense 19 | elif options.mode == 'cpp': 20 | from cpp.dense import Dense 21 | elif options.mode == 'cuda': 22 | from cuda.dense import Dense 23 | 24 | inputs = torch.randn(options.size, 256) 25 | labels = torch.rand(options.size).mul(10).long() 26 | 27 | class Model(Module): 28 | 29 | def __init__(self): 30 | super(Model, self).__init__() 31 | self.dense1 = Dense(256, 64) 32 | self.dense2 = Dense(64, 16) 33 | self.dense3 = Dense(16, 10) 34 | 35 | def forward(self, x): 36 | x = self.dense1(x) 37 | x = self.dense2(x) 38 | x = self.dense3(x) 39 | return F.log_softmax(x, dim=1) 40 | 41 | model = Model() 42 | 43 | inputs = inputs.cuda() 44 | labels = labels.cuda() 45 | model = model.cuda() 46 | 47 | # dataparallel = DataParallel(model, device_ids=[0, 1, 2, 3]) 48 | 49 | criterion = nn.CrossEntropyLoss() 50 | # optimizer = torch.optim.SGD(dataparallel.module.parameters(), lr=1e-4) 51 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) 52 | 53 | forward_time = 0 54 | backward_time = 0 55 | for _ in range(options.epoch): 56 | optimizer.zero_grad() 57 | 58 | start = time.time() 59 | outputs = model(inputs) 60 | loss = criterion(outputs, labels) 61 | elapsed = time.time() - start 62 | forward_time += elapsed 63 | 64 | start = time.time() 65 | loss.backward() 66 | optimizer.step() 67 | elapsed = time.time() - start 68 | backward_time += elapsed 69 | 70 | print('Forward: {0:.3f} | Backward {1:.3f}'.format(forward_time, backward_time)) -------------------------------------------------------------------------------- /cpp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tczhangzhi/pytorch-parallel/8d8baf80dd48234386051d0bab616de5b55f8f5c/cpp/__init__.py -------------------------------------------------------------------------------- /cpp/dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Parameter 3 | from torch.autograd import Function 4 | 5 | import linear_cpp 6 | 7 | class DenseFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input, weight, bias=None): 11 | output = linear_cpp.forward(input, weight, bias) 12 | output = torch.sigmoid(output) 13 | ctx.save_for_backward(input, weight, bias, output) 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | input, weight, bias, output = ctx.saved_variables 19 | grad_sigmoid = (1.0 - output) * output 20 | grad_output = grad_sigmoid * grad_output 21 | grad_input, grad_weight, grad_bias = linear_cpp.backward(grad_output, input, weight, bias) 22 | return grad_input, grad_weight, grad_bias 23 | 24 | class Dense(Module): 25 | 26 | def __init__(self, input_features, output_features, bias=True): 27 | super(Dense, self).__init__() 28 | self.input_features = input_features 29 | self.output_features = output_features 30 | self.weight = Parameter(torch.Tensor(output_features, input_features)) 31 | if bias: 32 | self.bias = Parameter(torch.Tensor(output_features)) 33 | else: 34 | self.register_parameter('bias', None) 35 | self.weight.data.uniform_(-0.1, 0.1) 36 | if bias is not None: 37 | self.bias.data.uniform_(-0.1, 0.1) 38 | 39 | def forward(self, input): 40 | return DenseFunction.apply(input, self.weight, self.bias) -------------------------------------------------------------------------------- /cpp/linear.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor linear_forward(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias={}) { 5 | auto output = at::matmul(input, weight.t()); 6 | if (bias.defined()) { 7 | output.add_(bias); 8 | } 9 | return output; 10 | } 11 | 12 | std::vector linear_backword(const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) { 13 | auto grad_input = at::matmul(grad_output, weight); 14 | auto grad_weight = at::matmul(grad_output.t(), input); 15 | auto grad_bias = bias.defined() ? grad_output.sum(0, /*keepdim=*/false) : at::Tensor{}; 16 | return { 17 | grad_input, 18 | grad_weight, 19 | grad_bias 20 | }; 21 | } 22 | 23 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 24 | m.def("forward", &linear_forward, "linear forward"); 25 | m.def("backward", &linear_backword, "linear backward"); 26 | } -------------------------------------------------------------------------------- /cpp/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | setup( 5 | name='linear_cpp', 6 | ext_modules=[ 7 | CppExtension('linear_cpp', ['linear.cpp']), 8 | ], 9 | cmdclass={ 10 | 'build_ext': BuildExtension 11 | }) 12 | -------------------------------------------------------------------------------- /cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tczhangzhi/pytorch-parallel/8d8baf80dd48234386051d0bab616de5b55f8f5c/cuda/__init__.py -------------------------------------------------------------------------------- /cuda/dense.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Parameter 2 | from torch.autograd import Function 3 | 4 | import torch 5 | import linear_cpp 6 | import sigmoid_cuda 7 | 8 | class DenseFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, input, weight, bias=None): 12 | output = linear_cpp.forward(input, weight, bias) 13 | output = sigmoid_cuda.forward(output) 14 | ctx.save_for_backward(input, weight, bias, output) 15 | return output 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | input, weight, bias, output = ctx.saved_variables 20 | grad_sigmoid = sigmoid_cuda.backward(grad_output, output) 21 | grad_output = grad_sigmoid * grad_output 22 | grad_input, grad_weight, grad_bias = linear_cpp.backward(grad_output, input, weight, bias) 23 | return grad_input, grad_weight, grad_bias 24 | 25 | class Dense(Module): 26 | 27 | def __init__(self, input_features, output_features, bias=True): 28 | super(Dense, self).__init__() 29 | self.input_features = input_features 30 | self.output_features = output_features 31 | self.weight = Parameter(torch.Tensor(output_features, input_features)) 32 | if bias: 33 | self.bias = Parameter(torch.Tensor(output_features)) 34 | else: 35 | self.register_parameter('bias', None) 36 | self.weight.data.uniform_(-0.1, 0.1) 37 | if bias is not None: 38 | self.bias.data.uniform_(-0.1, 0.1) 39 | 40 | def forward(self, input): 41 | return DenseFunction.apply(input, self.weight, self.bias) -------------------------------------------------------------------------------- /cuda/linear.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor linear_forward(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias={}) { 5 | auto output = at::matmul(input, weight.t()); 6 | if (bias.defined()) { 7 | output.add_(bias); 8 | } 9 | return output; 10 | } 11 | 12 | std::vector linear_backword(const at::Tensor& grad_output, const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) { 13 | auto grad_input = at::matmul(grad_output, weight); 14 | auto grad_weight = at::matmul(grad_output.t(), input); 15 | auto grad_bias = bias.defined() ? grad_output.sum(0, /*keepdim=*/false) : at::Tensor{}; 16 | return { 17 | grad_input, 18 | grad_weight, 19 | grad_bias 20 | }; 21 | } 22 | 23 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 24 | m.def("forward", &linear_forward, "linear forward"); 25 | m.def("backward", &linear_backword, "linear backward"); 26 | } -------------------------------------------------------------------------------- /cuda/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 3 | 4 | setup( 5 | name='sigmoid_cuda_linear_cpp', 6 | ext_modules=[ 7 | CUDAExtension('sigmoid_cuda', [ 8 | 'sigmoid_cuda.cpp', 9 | 'sigmoid_cuda_kernel.cu', 10 | ]), 11 | CppExtension('linear_cpp', ['linear.cpp']) 12 | ], 13 | cmdclass={ 14 | 'build_ext': BuildExtension 15 | }) 16 | -------------------------------------------------------------------------------- /cuda/sigmoid_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor sigmoid_cuda_forward( 5 | at::Tensor input); 6 | 7 | at::Tensor sigmoid_cuda_backward( 8 | at::Tensor grad_output, 9 | at::Tensor output); 10 | 11 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | at::Tensor sigmoid_forward( 16 | at::Tensor input) { 17 | CHECK_INPUT(input); 18 | return sigmoid_cuda_forward(input); 19 | } 20 | 21 | at::Tensor sigmoid_backward( 22 | at::Tensor grad_output, 23 | at::Tensor output) { 24 | CHECK_INPUT(grad_output); 25 | CHECK_INPUT(output); 26 | return sigmoid_cuda_backward( 27 | grad_output, 28 | output); 29 | } 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("forward", &sigmoid_forward, "sigmoid forward (CUDA)"); 33 | m.def("backward", &sigmoid_backward, "sigmoid backward (CUDA)"); 34 | } 35 | -------------------------------------------------------------------------------- /cuda/sigmoid_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace { 9 | template 10 | __device__ __forceinline__ scalar_t sigmoid(scalar_t z) { 11 | return 1.0 / (1.0 + exp(-z)); 12 | } 13 | 14 | template 15 | __device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { 16 | return (1.0 - z) * z; 17 | } 18 | 19 | template 20 | __global__ void sigmoid_cuda_forward_kernel( 21 | const scalar_t* __restrict__ input, 22 | scalar_t* __restrict__ output) { 23 | const int index = blockIdx.x * blockDim.x + blockIdx.y; 24 | output[index] = sigmoid(input[index]); 25 | } 26 | 27 | template 28 | __global__ void sigmoid_cuda_backward_kernel( 29 | const scalar_t* __restrict__ grad_output, 30 | const scalar_t* __restrict__ output, 31 | scalar_t* __restrict__ new_grad_output) { 32 | const int index = blockIdx.x * blockDim.x + blockIdx.y; 33 | new_grad_output[index] = d_sigmoid(output[index] * grad_output[index]); 34 | } 35 | } // namespace 36 | 37 | at::Tensor sigmoid_cuda_forward( 38 | at::Tensor input) { 39 | auto output = at::zeros_like(input); 40 | const dim3 blocks(input.size(0), input.size(1)); 41 | const int threads = 1; 42 | 43 | AT_DISPATCH_FLOATING_TYPES(input.type(), "sigmoid_forward_cuda", ([&] { 44 | sigmoid_cuda_forward_kernel<<>>( 45 | input.data(), 46 | output.data()); 47 | })); 48 | 49 | return output; 50 | } 51 | 52 | at::Tensor sigmoid_cuda_backward( 53 | at::Tensor grad_output, 54 | at::Tensor output) { 55 | auto new_grad_output = at::zeros_like(grad_output); 56 | const dim3 blocks(grad_output.size(0), grad_output.size(1)); 57 | const int threads = 1; 58 | 59 | AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "sigmoid_backward_cuda", ([&] { 60 | sigmoid_cuda_backward_kernel<<>>( 61 | grad_output.data(), 62 | output.data(), 63 | new_grad_output.data()); 64 | })); 65 | 66 | return new_grad_output; 67 | } 68 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import time 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import Module, Parameter 9 | from torch.nn import functional as F 10 | from parallel.train import RingAllReduce 11 | from torch.utils.data import TensorDataset 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-m', '--mode', choices=['py', 'cpp', 'cuda']) 15 | parser.add_argument('-e', '--epoch', type=int, default=100) 16 | parser.add_argument('-s', '--size', type=int, default=100) 17 | options = parser.parse_args() 18 | 19 | if options.mode == 'py': 20 | from python.dense import Dense 21 | elif options.mode == 'cpp': 22 | from cpp.dense import Dense 23 | elif options.mode == 'cuda': 24 | from cuda.dense import Dense 25 | 26 | inputs = torch.randn(options.size, 256) 27 | labels = torch.rand(options.size).mul(10).long() 28 | 29 | dataset = TensorDataset(inputs, labels) 30 | 31 | class Model(Module): 32 | 33 | def __init__(self): 34 | super(Model, self).__init__() 35 | self.dense1 = Dense(256, 64) 36 | self.dense2 = Dense(64, 16) 37 | self.dense3 = Dense(16, 10) 38 | 39 | def forward(self, x): 40 | x = self.dense1(x) 41 | x = self.dense2(x) 42 | x = self.dense3(x) 43 | return F.log_softmax(x, dim=1) 44 | 45 | model = Model() 46 | 47 | criterion = nn.CrossEntropyLoss() 48 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) 49 | 50 | handler = RingAllReduce(model, criterion, optimizer, dataset, epoch=options.epoch) 51 | handler.train() -------------------------------------------------------------------------------- /grad_check.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import torch 6 | 7 | from torch.autograd import Variable, gradcheck 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('example', choices=['py', 'cpp', 'cuda']) 11 | parser.add_argument('-b', '--batch-size', type=int, default=3) 12 | parser.add_argument('-f', '--feature-size', type=int, default=17) 13 | parser.add_argument('-o', '--output-size', type=int, default=3) 14 | parser.add_argument('-c', '--cuda', action='store_true') 15 | options = parser.parse_args() 16 | 17 | if options.example == 'py': 18 | from python.dense import DenseFunction 19 | elif options.example == 'cpp': 20 | from cpp.dense import DenseFunction 21 | else: 22 | from cuda.dense import DenseFunction 23 | options.cuda = True 24 | 25 | X = torch.randn(options.batch_size, options.feature_size) 26 | W = torch.randn(options.output_size, options.feature_size) 27 | b = torch.randn(options.output_size) 28 | 29 | variables = [X, W, b] 30 | 31 | for i, var in enumerate(variables): 32 | if options.cuda: 33 | var = var.cuda() 34 | variables[i] = Variable(var.double(), requires_grad=True) 35 | 36 | if gradcheck(DenseFunction.apply, variables, eps=1e-6, atol=1e-4): 37 | print('Ok') 38 | -------------------------------------------------------------------------------- /parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tczhangzhi/pytorch-parallel/8d8baf80dd48234386051d0bab616de5b55f8f5c/parallel/__init__.py -------------------------------------------------------------------------------- /parallel/train.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from math import ceil 7 | from random import Random 8 | from torch.multiprocessing import Process 9 | from torch.autograd import Variable 10 | from torchvision import datasets, transforms 11 | 12 | class Partition(object): 13 | """ Dataset-like object, but only access a subset of it. """ 14 | 15 | def __init__(self, data, index): 16 | self.data = data 17 | self.index = index 18 | 19 | def __len__(self): 20 | return len(self.index) 21 | 22 | def __getitem__(self, index): 23 | data_idx = self.index[index] 24 | return self.data[data_idx] 25 | 26 | 27 | class DataPartitioner(object): 28 | """ Partitions a dataset into different chuncks. """ 29 | 30 | def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): 31 | self.data = data 32 | self.partitions = [] 33 | rng = Random() 34 | rng.seed(seed) 35 | data_len = len(data) 36 | indexes = [x for x in range(0, data_len)] 37 | rng.shuffle(indexes) 38 | 39 | for frac in sizes: 40 | part_len = int(frac * data_len) 41 | self.partitions.append(indexes[0:part_len]) 42 | indexes = indexes[part_len:] 43 | 44 | def use(self, partition): 45 | return Partition(self.data, self.partitions[partition]) 46 | 47 | class RingAllReduce(object): 48 | 49 | def __init__(self, model, criterion, optimizer, dataset, epoch=100, addr='127.0.0.1', port='29500', backend='gloo'): 50 | self.model = model 51 | self.criterion = criterion 52 | self.optimizer = optimizer 53 | self.dataset = dataset 54 | self.epoch = epoch 55 | self.addr = addr 56 | self.port = port 57 | self.backend = backend 58 | 59 | def partition_dataset(self): 60 | size = dist.get_world_size() 61 | bsz = 128 // size 62 | partition_sizes = [1.0 / size for _ in range(size)] 63 | partition = DataPartitioner(self.dataset, partition_sizes) 64 | partition = partition.use(dist.get_rank()) 65 | train_set = torch.utils.data.DataLoader( 66 | partition, batch_size=bsz, shuffle=True) 67 | return train_set, bsz 68 | 69 | def average_gradients(self): 70 | """ Gradient averaging. """ 71 | size = float(dist.get_world_size()) 72 | for param in self.model.parameters(): 73 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0) 74 | param.grad.data /= size 75 | 76 | def run(self, rank, size): 77 | """ Distributed Synchronous SGD Example """ 78 | torch.manual_seed(1234) 79 | train_set, bsz = self.partition_dataset() 80 | optimizer = self.optimizer 81 | criterion = self.criterion 82 | 83 | num_batches = ceil(len(train_set.dataset) / float(bsz)) 84 | for epoch in range(self.epoch): 85 | epoch_loss = 0.0 86 | for data, target in train_set: 87 | data, target = Variable(data), Variable(target) 88 | optimizer.zero_grad() 89 | output = self.model(data) 90 | loss = criterion(output, target) 91 | epoch_loss += loss.data.item() 92 | loss.backward() 93 | self.average_gradients() 94 | optimizer.step() 95 | print('Rank ', 96 | dist.get_rank(), ', epoch ', epoch, ': ', 97 | epoch_loss / num_batches) 98 | 99 | def init_processes(self, rank, size, fn): 100 | """ Initialize the distributed environment. """ 101 | os.environ['MASTER_ADDR'] = self.addr 102 | os.environ['MASTER_PORT'] = self.port 103 | dist.init_process_group(self.backend, rank=rank, world_size=size) 104 | fn(rank, size) 105 | 106 | def train(self, size=2): 107 | processes = [] 108 | for rank in range(size): 109 | p = Process(target=self.init_processes, args=(rank, size, self.run)) 110 | p.start() 111 | processes.append(p) 112 | 113 | for p in processes: 114 | p.join() -------------------------------------------------------------------------------- /ptx/compile.sh: -------------------------------------------------------------------------------- 1 | nvcc -keep -o sigmoid_cuda_kernal sigmoid_cuda_kernal.cu -------------------------------------------------------------------------------- /ptx/recompile.sh: -------------------------------------------------------------------------------- 1 | # from here: nvcc -dryrun -o sigmoid_cuda_kernal sigmoid_cuda_kernal.cu --keep 2>dryrun.out 2 | ptxas -arch=sm_30 -m64 "sigmoid_cuda_kernal.ptx" -o "sigmoid_cuda_kernal.sm_30.cubin" 3 | fatbinary --create="sigmoid_cuda_kernal.fatbin" -64 "--image=profile=sm_30,file=sigmoid_cuda_kernal.sm_30.cubin" "--image=profile=compute_30,file=sigmoid_cuda_kernal.ptx" --embedded-fatbin="sigmoid_cuda_kernal.fatbin.c" --cuda 4 | gcc -E -x c++ -D__CUDACC__ -D__NVCC__ "-I/usr/local/cuda/bin/..//include" -D"__CUDACC_VER_BUILD__=176" -D"__CUDACC_VER_MINOR__=0" -D"__CUDACC_VER_MAJOR__=9" -include "cuda_runtime.h" -m64 "sigmoid_cuda_kernal.cu" > "sigmoid_cuda_kernal.cpp4.ii" 5 | cudafe++ --gnu_version=50500 --allow_managed --m64 --parse_templates --gen_c_file_name "sigmoid_cuda_kernal.cudafe1.cpp" --stub_file_name "sigmoid_cuda_kernal.cudafe1.stub.c" --module_id_file_name "sigmoid_cuda_kernal.module_id" "sigmoid_cuda_kernal.cpp4.ii" 6 | gcc -D__CUDA_ARCH__=300 -c -x c++ -DCUDA_DOUBLE_MATH_FUNCTIONS "-I/usr/local/cuda/bin/..//include" -m64 -o "sigmoid_cuda_kernal.o" "sigmoid_cuda_kernal.cudafe1.cpp" 7 | nvlink --arch=sm_30 --register-link-binaries="sigmoid_cuda_kernal_dlink.reg.c" -m64 "-L/usr/local/cuda/bin/..//lib64/stubs" "-L/usr/local/cuda/bin/..//lib64" -cpu-arch=X86_64 "sigmoid_cuda_kernal.o" -o "sigmoid_cuda_kernal_dlink.sm_30.cubin" 8 | fatbinary --create="sigmoid_cuda_kernal_dlink.fatbin" -64 -link "--image=profile=sm_30,file=sigmoid_cuda_kernal_dlink.sm_30.cubin" --embedded-fatbin="sigmoid_cuda_kernal_dlink.fatbin.c" 9 | gcc -c -x c++ -DFATBINFILE="\"sigmoid_cuda_kernal_dlink.fatbin.c\"" -DREGISTERLINKBINARYFILE="\"sigmoid_cuda_kernal_dlink.reg.c\"" -I. "-I/usr/local/cuda/bin/..//include" -D"__CUDACC_VER_BUILD__=176" -D"__CUDACC_VER_MINOR__=0" -D"__CUDACC_VER_MAJOR__=9" -m64 -o "sigmoid_cuda_kernal_dlink.o" "/usr/local/cuda/bin/crt/link.stub" 10 | g++ -m64 -o "sigmoid_cuda_kernal" -Wl,--start-group "sigmoid_cuda_kernal_dlink.o" "sigmoid_cuda_kernal.o" "-L/usr/local/cuda/bin/..//lib64/stubs" "-L/usr/local/cuda/bin/..//lib64" -lcudadevrt -lcudart_static -lrt -lpthread -ldl -Wl,--end-group -------------------------------------------------------------------------------- /ptx/sigmoid_cuda_kernal.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | __global__ void d_sigmoid(float *data){ 4 | *data = (1.0 - *data) * *data; 5 | } 6 | 7 | int main(){ 8 | float *d_data, h_data = 0; 9 | cudaMalloc((void **)&d_data, sizeof(float)); 10 | cudaMemcpy(d_data, &h_data, sizeof(float), cudaMemcpyHostToDevice); 11 | d_sigmoid<<<1,1>>>(d_data); 12 | cudaMemcpy(&h_data, d_data, sizeof(float), cudaMemcpyDeviceToHost); 13 | printf("data = %d\n", h_data); 14 | return 0; 15 | } -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tczhangzhi/pytorch-parallel/8d8baf80dd48234386051d0bab616de5b55f8f5c/python/__init__.py -------------------------------------------------------------------------------- /python/dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Parameter 3 | from torch.autograd import Function 4 | 5 | class DenseFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input, weight, bias=None): 9 | output = input.mm(weight.t()) 10 | if bias is not None: 11 | output += bias.unsqueeze(0).expand_as(output) 12 | output = torch.sigmoid(output) 13 | ctx.save_for_backward(input, weight, bias, output) 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | input, weight, bias, output = ctx.saved_tensors 19 | grad_sigmoid = (1.0 - output) * output 20 | grad_output = grad_sigmoid * grad_output 21 | grad_input = grad_weight = grad_bias = None 22 | if ctx.needs_input_grad[0]: 23 | grad_input = grad_output.mm(weight) 24 | if ctx.needs_input_grad[1]: 25 | grad_weight = grad_output.t().mm(input) 26 | if bias is not None and ctx.needs_input_grad[2]: 27 | grad_bias = grad_output.sum(0).squeeze(0) 28 | return grad_input, grad_weight, grad_bias 29 | 30 | class Dense(Module): 31 | 32 | def __init__(self, input_features, output_features, bias=True): 33 | super(Dense, self).__init__() 34 | self.input_features = input_features 35 | self.output_features = output_features 36 | self.weight = Parameter(torch.Tensor(output_features, input_features)) 37 | if bias: 38 | self.bias = Parameter(torch.Tensor(output_features)) 39 | else: 40 | self.register_parameter('bias', None) 41 | self.weight.data.uniform_(-0.1, 0.1) 42 | if bias is not None: 43 | self.bias.data.uniform_(-0.1, 0.1) 44 | 45 | def forward(self, input): 46 | return DenseFunction.apply(input, self.weight, self.bias) --------------------------------------------------------------------------------