├── .gitignore ├── README.md ├── setup.cfg ├── setup.py ├── src ├── masked_softmax_cuda.cpp └── masked_softmax_cuda_kernel.cu ├── tcop ├── __init__.py └── masked_softmax.py └── tests └── test_masked_softmax.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | dist 3 | .eggs 4 | *.egg-info 5 | .pytest_cache 6 | build 7 | *.so 8 | 9 | *.swp 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tcop-pytorch 2 | 3 | This is my collection of CUDA custom operators for PyTorch. It currently contains only `MaskedSoftmax`. 4 | 5 | ## Requirements 6 | 7 | I tested this operator in this environment: 8 | 9 | - PyTorch 1.0 10 | - CUDA 9.2 11 | 12 | ## How to Install 13 | 14 | ``` 15 | $ python setup.py install 16 | ``` 17 | 18 | ## How to Test 19 | ``` 20 | $ python setup.py test 21 | ``` 22 | 23 | ## How to Use 24 | 25 | ### MaksedSoftmax 26 | 27 | ```MaskedSoftmax.apply(input, mask, scale)``` 28 | 29 | You can find details about this operator in [my blog post](https://tunz.kr/post/5). 30 | 31 | ```python 32 | import torch 33 | from tcop.masked_softmax import MaskedSoftmax 34 | 35 | x = torch.tensor([[0.3, 0.2, 0.1], 36 | [0.3, 0.4, 0.5]]).cuda() 37 | x = x.view(1, 1, 2, 3) # [batch_size, head_size, q, k] 38 | 39 | mask = torch.tensor([2, 1], dtype=torch.int32).cuda() 40 | mask = mask.view(1, 2) # [batch_size, q] 41 | 42 | scale = 0.5 43 | MaskedSoftmax.apply(x, mask, scale) 44 | # tensor([[[[0.5125, 0.4875, 0.0000], 45 | # [1.0000, 0.0000, 0.0000]]]], device='cuda:0') 46 | ``` 47 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose 6 | python_files = tests/*.py 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='tcop-pytorch', 6 | packages=['tcop'], 7 | ext_modules=[ 8 | CUDAExtension('masked_softmax_cuda', [ 9 | 'src/masked_softmax_cuda.cpp', 10 | 'src/masked_softmax_cuda_kernel.cu', 11 | ]) 12 | ], 13 | cmdclass={ 14 | 'build_ext': BuildExtension 15 | }, 16 | setup_requires=["pytest-runner"], 17 | tests_require=["pytest"]) 18 | -------------------------------------------------------------------------------- /src/masked_softmax_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector masked_softmax_cuda_forward( 8 | at::Tensor input, 9 | at::Tensor mask, 10 | at::Tensor scale); 11 | 12 | std::vector masked_softmax_cuda_backward( 13 | at::Tensor grad, 14 | at::Tensor output, 15 | at::Tensor mask, 16 | at::Tensor scale); 17 | 18 | // C++ interface 19 | 20 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 21 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 22 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 23 | 24 | std::vector masked_softmax_forward( 25 | at::Tensor input, 26 | at::Tensor mask, 27 | at::Tensor scale) { 28 | CHECK_INPUT(input); 29 | CHECK_INPUT(mask); 30 | 31 | return masked_softmax_cuda_forward(input, mask, scale); 32 | } 33 | 34 | std::vector masked_softmax_backward( 35 | at::Tensor grad, 36 | at::Tensor output, 37 | at::Tensor mask, 38 | at::Tensor scale) { 39 | CHECK_INPUT(grad); 40 | CHECK_INPUT(output); 41 | CHECK_INPUT(mask); 42 | 43 | return masked_softmax_cuda_backward( 44 | grad, 45 | output, 46 | mask, 47 | scale); 48 | } 49 | 50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 51 | m.def("forward", &masked_softmax_forward, "LLTM forward (CUDA)"); 52 | m.def("backward", &masked_softmax_backward, "LLTM backward (CUDA)"); 53 | } 54 | -------------------------------------------------------------------------------- /src/masked_softmax_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | using namespace at; 11 | 12 | namespace { 13 | template 14 | __global__ void __launch_bounds__(32) masked_softmax_cuda_forward_kernel( 15 | const scalar_t* __restrict__ input, 16 | const int* __restrict__ mask, 17 | scalar_t* __restrict__ output, 18 | unsigned int hidden_size, 19 | unsigned int m0, 20 | unsigned int m1, 21 | scalar_t scale) { 22 | 23 | const int tid = threadIdx.x; 24 | const unsigned int ibase = blockIdx.x * gridDim.y * gridDim.z * hidden_size + 25 | blockIdx.y * gridDim.z * hidden_size + 26 | blockIdx.z * hidden_size; 27 | 28 | const unsigned int mask_offset = blockIdx.x * (m0 > 1 ? m1 : 0) + 29 | blockIdx.z * (m1 > 1 ? 1 : 0); 30 | unsigned int mask_size = min(static_cast(mask[mask_offset]), 31 | hidden_size); 32 | unsigned shfl_mask = __ballot_sync(0xffffffff, threadIdx.x < mask_size); 33 | 34 | scalar_t max_x = -FLT_MAX; 35 | for (unsigned int i = tid; i < mask_size; i+=blockDim.x) { 36 | max_x = fmaxf(max_x, input[ibase + i] * scale); 37 | } 38 | for (unsigned int i = 16; i > 0; i >>= 1) { 39 | max_x = max(max_x, __shfl_xor_sync(shfl_mask, max_x, i)); 40 | } 41 | 42 | scalar_t exp_sum = 0; 43 | for (unsigned int i = tid; i < mask_size; i+=blockDim.x) { 44 | exp_sum += std::exp(input[ibase + i] * scale - max_x); 45 | } 46 | for (unsigned int i = 16; i > 0; i >>= 1) { 47 | exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i); 48 | } 49 | 50 | for (unsigned int i = tid; i < mask_size; i+=blockDim.x) { 51 | output[ibase + i] = std::exp(input[ibase + i] * scale - max_x) / exp_sum; 52 | } 53 | } 54 | 55 | // d_input = output * (grad_output - output * sum(grad_output)) * scale 56 | template 57 | __global__ void __launch_bounds__(32) masked_softmax_cuda_backward_kernel( 58 | scalar_t* __restrict__ d_input, 59 | const scalar_t* __restrict__ grad_output, 60 | const scalar_t* __restrict__ output, 61 | const int* __restrict__ mask, 62 | unsigned int hidden_size, 63 | unsigned int m0, 64 | unsigned int m1, 65 | scalar_t scale) { 66 | const int tid = threadIdx.x; 67 | const unsigned int ibase = blockIdx.x * gridDim.y * gridDim.z * hidden_size + 68 | blockIdx.y * gridDim.z * hidden_size + 69 | blockIdx.z * hidden_size; 70 | 71 | const unsigned int mask_offset = blockIdx.x * (m0 > 1 ? m1 : 0) + 72 | blockIdx.z * (m1 > 1 ? 1 : 0); 73 | unsigned int mask_size = min(static_cast(mask[mask_offset]), 74 | hidden_size); 75 | unsigned shfl_mask = __ballot_sync(0xffffffff, threadIdx.x < mask_size); 76 | 77 | scalar_t grad_sum = 0; 78 | for (unsigned int i = tid; i < mask_size; i+=blockDim.x) { 79 | scalar_t o = output[ibase + i]; 80 | grad_sum += grad_output[ibase + i] * o; 81 | } 82 | for (unsigned int i = 16; i > 0; i >>= 1) { 83 | grad_sum += __shfl_xor_sync(shfl_mask, grad_sum, i); 84 | } 85 | 86 | for (unsigned int i = tid; i < mask_size; i+=blockDim.x) { 87 | scalar_t o = output[ibase + i]; 88 | d_input[ibase + i] = o * (grad_output[ibase + i] - grad_sum) * scale; 89 | } 90 | } 91 | } // namespace 92 | 93 | std::vector masked_softmax_cuda_forward( 94 | at::Tensor input, 95 | at::Tensor mask, 96 | at::Tensor scale) { 97 | AT_CHECK(input.dim() == 4, "input has an incorrect shape"); 98 | AT_CHECK(mask.dim() == 2, "mask has an incorrect shape"); 99 | AT_CHECK(mask.size(0) == 1 || mask.size(0) == input.size(0), 100 | "mask dim #0 has an incorrect shape"); 101 | AT_CHECK(mask.size(1) == 1 || mask.size(1) == input.size(2), 102 | "mask dim #2 has an incorrect shape"); 103 | 104 | auto output = at::zeros_like(input); 105 | 106 | const int threads = 32; 107 | const dim3 blocks(input.size(0), input.size(1), input.size(2)); 108 | 109 | AT_DISPATCH_FLOATING_TYPES(input.type(), "masked_softmax_forward_cuda", ([&] { 110 | masked_softmax_cuda_forward_kernel<<>>( 111 | input.data(), 112 | mask.data(), 113 | output.data(), 114 | input.size(3), 115 | mask.size(0), 116 | mask.size(1), 117 | scale.item()); 118 | })); 119 | 120 | return {output}; 121 | } 122 | 123 | std::vector masked_softmax_cuda_backward( 124 | at::Tensor grad_output, 125 | at::Tensor output, 126 | at::Tensor mask, 127 | at::Tensor scale) { 128 | auto d_input = at::zeros_like(output); 129 | 130 | const int threads = 32; 131 | const dim3 blocks(output.size(0), output.size(1), output.size(2)); 132 | 133 | AT_DISPATCH_FLOATING_TYPES(output.type(), "masked_softmax_forward_cuda", ([&] { 134 | masked_softmax_cuda_backward_kernel<<>>( 135 | d_input.data(), 136 | grad_output.data(), 137 | output.data(), 138 | mask.data(), 139 | output.size(3), 140 | mask.size(0), 141 | mask.size(1), 142 | scale.item()); 143 | })); 144 | 145 | return {d_input}; 146 | } 147 | -------------------------------------------------------------------------------- /tcop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tunz/tcop-pytorch/fe5dada36964085850d5a50405498c193fb5c426/tcop/__init__.py -------------------------------------------------------------------------------- /tcop/masked_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import masked_softmax_cuda 3 | 4 | # pylint: disable=arguments-differ 5 | 6 | 7 | class MaskedSoftmax(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, inputs, mask, scale): 10 | scale = torch.tensor(scale) 11 | 12 | output = masked_softmax_cuda.forward( 13 | inputs.contiguous(), mask.contiguous(), scale)[0] 14 | ctx.save_for_backward(output, mask, scale) 15 | return output 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | d_input = masked_softmax_cuda.backward( 20 | grad_output.contiguous(), *ctx.saved_tensors)[0] 21 | return d_input, None, None 22 | -------------------------------------------------------------------------------- /tests/test_masked_softmax.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from tcop.masked_softmax import MaskedSoftmax 6 | 7 | 8 | class TestMaskedSoftmax(unittest.TestCase): 9 | 10 | def test_forward(self): 11 | inputs = torch.tensor([[0.1, 0.2, 0.3], 12 | [0.2, 0.3, 0.4], 13 | [-0.3, -0.4, -0.5]]) 14 | mask = torch.tensor([[0, 0, 0], 15 | [0, 0, 0], 16 | [0, 0, 0]], dtype=torch.float) 17 | 18 | inputs = inputs.view(1, 1, 3, 3).cuda() 19 | mask = mask.view(1, 3, 3).cuda() 20 | 21 | with torch.no_grad(): 22 | expected = torch.nn.functional.softmax(inputs, dim=3) 23 | 24 | with torch.no_grad(): 25 | mask = mask.size(2) - mask.sum(dim=2, dtype=torch.int32) 26 | output = MaskedSoftmax.apply(inputs, mask, 1.0) 27 | # torch.set_printoptions(precision=10) 28 | # print(output) 29 | # print(expected) 30 | self.assertTrue((output == expected).all()) 31 | 32 | def _forward_test(self, inputs, mask, k, scale, debug=False): 33 | with torch.no_grad(): 34 | x = inputs * scale 35 | if k is not None: 36 | x = torch.matmul(x, k) 37 | t_mask = mask.unsqueeze(1).byte() 38 | x = x + torch.zeros_like(x).masked_fill_(t_mask, -1e9) 39 | expected = torch.nn.functional.softmax(x, dim=3) 40 | 41 | with torch.no_grad(): 42 | mask = mask.size(2) - mask.sum(dim=2, dtype=torch.int32) 43 | x = torch.matmul(inputs, k) if k is not None else inputs 44 | output = MaskedSoftmax.apply(x, mask, scale) 45 | if debug: 46 | torch.set_printoptions(precision=10) 47 | print("output", output) 48 | print("expected", expected) 49 | print("diff", torch.abs(output - expected)) 50 | # We do not use equality here because MaskedSoftmax doesn't multiply 51 | # scale at first, so it has precesion differences. 52 | self.assertTrue((torch.abs(output - expected) < 1e-7).all()) 53 | # self.assertTrue((output == expected).all()) 54 | 55 | def test_forward_mask(self): 56 | inputs = torch.tensor([[0.1, 0.2, 0.3], 57 | [-0.2, -0.3, -0.4], 58 | [0.5, 0.4, 0.3]]) 59 | mask = torch.tensor([[0, 1, 1], 60 | [0, 0, 1], 61 | [0, 0, 0]], dtype=torch.float) 62 | k = torch.tensor([[0.5, 2, 4], 63 | [-0.2, -1, -0.4], 64 | [0.5, 0.3, 0.3]]) 65 | 66 | inputs = inputs.view(1, 1, 3, 3).cuda() 67 | mask = mask.view(1, 3, 3).cuda() 68 | k = k.view(1, 1, 3, 3).cuda() 69 | scale = 0.12 70 | 71 | self._forward_test(inputs, mask, k, scale) 72 | 73 | def test_forward_long(self): 74 | inputs = torch.tensor([1000] + list(range(69)), dtype=torch.float) 75 | mask = torch.tensor([0] * 70, dtype=torch.float) 76 | 77 | inputs = inputs.view(1, 1, 1, 70).cuda() 78 | mask = mask.view(1, 1, 70).cuda() 79 | scale = 0.1 80 | 81 | self._forward_test(inputs, mask, None, scale) 82 | 83 | def test_forward_last_all_mask(self): 84 | inputs = torch.tensor(list(range(64)), dtype=torch.float) 85 | mask = torch.tensor([0] * 32 + [1] * 32, dtype=torch.float) 86 | 87 | inputs = inputs.view(1, 1, 1, 64).cuda() 88 | mask = mask.view(1, 1, 64).cuda() 89 | scale = 0.1 90 | 91 | self._forward_test(inputs, mask, None, scale) 92 | 93 | def test_forward_multi_batch(self): 94 | inputs = torch.tensor([list(range(4)) * 4] * 12, dtype=torch.float) 95 | mask = torch.tensor([[0, 0, 1, 1]] * 4, dtype=torch.float) 96 | 97 | inputs = inputs.view(4, 3, 4, 4).cuda() 98 | mask = mask.view(4, 1, 4).cuda() 99 | scale = 0.1 100 | 101 | self._forward_test(inputs, mask, None, scale, debug=True) 102 | 103 | def test_forward_mini_seq(self): 104 | inputs = torch.tensor([list(range(64))] * 3, dtype=torch.float) 105 | mask = torch.tensor([[0] * 32 + [1] * 32] * 3, dtype=torch.float) 106 | 107 | inputs = inputs.view(1, 1, 3, 64).cuda() 108 | mask = mask.view(1, 3, 64).cuda() 109 | scale = 0.1 110 | 111 | self._forward_test(inputs, mask, None, scale) 112 | 113 | def test_forward_mini_batch_seq(self): 114 | inputs = torch.tensor([[list(range(64))] * 3] * 4, dtype=torch.float) 115 | mask = torch.tensor([[[0] * 32 + [1] * 32] * 3] * 4, dtype=torch.float) 116 | 117 | inputs = inputs.view(4, 1, 3, 64).cuda() 118 | mask = mask.view(4, 3, 64).cuda() 119 | scale = 0.1 120 | 121 | self._forward_test(inputs, mask, None, scale) 122 | 123 | def test_backward(self): 124 | inputs1 = torch.tensor([[0.3, 0.2, 0.1], 125 | [0.2, 0.3, 0.4], 126 | [0.3, 0.4, 0.5]], requires_grad=True) 127 | inputs2 = torch.tensor([[0.3, 0.2, 0.1], 128 | [0.2, 0.3, 0.4], 129 | [0.3, 0.4, 0.5]], requires_grad=True) 130 | mask = torch.tensor([[0, 1, 1], 131 | [0, 0, 1], 132 | [0, 0, 0]], dtype=torch.float) 133 | k = torch.tensor([[0.5, 2, 4], 134 | [-0.2, -1, -0.4], 135 | [0.5, 0.3, 0.3]], requires_grad=True) 136 | 137 | inputs1_cuda = inputs1.view(1, 1, 3, 3).cuda() 138 | inputs2_cuda = inputs2.view(1, 1, 3, 3).cuda() 139 | mask = mask.view(1, 3, 3).cuda() 140 | k = k.view(1, 1, 3, 3).cuda() 141 | scale = 0.1 142 | 143 | x = inputs1_cuda * scale 144 | x = torch.matmul(x, k) 145 | x = x + torch.zeros_like(x).masked_fill_(mask.byte(), -1e9) 146 | expected = torch.nn.functional.softmax(x, dim=3) 147 | loss = torch.mean(expected) 148 | loss.backward() 149 | 150 | x = torch.matmul(inputs2_cuda, k) 151 | mask = mask.size(2) - mask.sum(dim=2, dtype=torch.int32) 152 | output = MaskedSoftmax.apply(x, mask, scale) 153 | loss = torch.mean(output) 154 | loss.backward() 155 | 156 | # torch.set_printoptions(precision=10) 157 | # print(output) 158 | # print(expected) 159 | # print(inputs1.grad) 160 | # print(inputs2.grad) 161 | self.assertTrue((output == expected).all()) 162 | self.assertTrue((torch.abs(inputs1.grad - inputs2.grad) < 1e-8).all()) 163 | 164 | 165 | if __name__ == '__main__': 166 | unittest.main() 167 | --------------------------------------------------------------------------------