├── .gitignore ├── README.md ├── benchmark.py ├── cpp ├── __init__.py ├── jit.py ├── lltm.cpp ├── lltm.py └── setup.py ├── cuda ├── __init__.py ├── jit.py ├── lltm.py ├── lltm_cuda.cpp ├── lltm_cuda_kernel.cu └── setup.py ├── grad_check.py └── python ├── __init__.py ├── lltm.py └── lltm_baseline.py /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | torch.egg-info/ 4 | */**/__pycache__ 5 | torch/version.py 6 | torch/csrc/generic/TensorMethods.cpp 7 | torch/lib/*.so* 8 | torch/lib/*.a* 9 | torch/lib/*.dll* 10 | torch/lib/*.lib 11 | torch/lib/*.dylib* 12 | torch/lib/*.h 13 | torch/lib/build 14 | torch/lib/tmp_install 15 | torch/lib/include 16 | torch/lib/torch_shm_manager 17 | torch/csrc/jit/generated/* 18 | torch/csrc/autograd/generated/* 19 | torch/csrc/cudnn/cuDNN.cpp 20 | torch/csrc/nn/THNN.cwrap 21 | torch/csrc/nn/THNN.cpp 22 | torch/csrc/nn/THCUNN.cwrap 23 | torch/csrc/nn/THCUNN.cpp 24 | torch/csrc/nn/THNN_generic.cwrap 25 | torch/csrc/nn/THNN_generic.cpp 26 | torch/csrc/nn/THNN_generic.h 27 | torch/csrc/generated 28 | docs/src/**/* 29 | test/data/legacy_modules.t7 30 | test/data/gpu_tensors.pt 31 | test/htmlcov 32 | test/.coverage 33 | */*.pyc 34 | */**/*.pyc 35 | */**/**/*.pyc 36 | */**/**/**/*.pyc 37 | */**/**/**/**/*.pyc 38 | */*.so* 39 | */**/*.so* 40 | */**/*.dylib* 41 | test/data/legacy_serialized.pt 42 | test/data/linear.pt 43 | .mypy_cache 44 | 45 | # IPython notebook checkpoints 46 | .ipynb_checkpoints 47 | 48 | # Editor temporaries 49 | *.swn 50 | *.swo 51 | *.swp 52 | *~ 53 | 54 | # macOS dir files 55 | .DS_Store 56 | 57 | # Ninja files 58 | .ninja_deps 59 | .ninja_log 60 | compile_commands.json 61 | *.egg-info/ 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-cpp-extension 2 | 3 | An example of writing a C++ extension for PyTorch. 4 | 5 | NOTE: Moved to https://github.com/pytorch/extension-cpp 6 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import math 6 | import time 7 | 8 | import torch 9 | 10 | TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('example', choices=['py', 'cpp', 'cuda']) 14 | parser.add_argument('-b', '--batch-size', type=int, default=16) 15 | parser.add_argument('-f', '--features', type=int, default=32) 16 | parser.add_argument('-s', '--state-size', type=int, default=128) 17 | parser.add_argument('-r', '--runs', type=int, default=100) 18 | parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') 19 | parser.add_argument('-c', '--cuda', action='store_true') 20 | options = parser.parse_args() 21 | 22 | if options.example == 'py': 23 | from python.lltm import LLTM 24 | elif options.example == 'cpp': 25 | from cpp.lltm import LLTM 26 | else: 27 | from cuda.lltm import LLTM 28 | options.cuda = True 29 | 30 | X = torch.randn(options.batch_size, options.features) 31 | h = torch.randn(options.batch_size, options.state_size) 32 | C = torch.randn(options.batch_size, options.state_size) 33 | rnn = LLTM(options.features, options.state_size) 34 | 35 | if options.cuda: 36 | X = X.cuda() 37 | h = h.cuda() 38 | C = C.cuda() 39 | rnn.cuda() 40 | 41 | # Force CUDA initialization 42 | new_h, new_C = rnn(X, (h, C)) 43 | (new_h.sum() + new_C.sum()).backward() 44 | 45 | forward_min = math.inf 46 | forward_time = 0 47 | backward_min = math.inf 48 | backward_time = 0 49 | for _ in range(options.runs): 50 | rnn.zero_grad() 51 | 52 | start = time.time() 53 | new_h, new_C = rnn(X, (h, C)) 54 | elapsed = time.time() - start 55 | forward_min = min(forward_min, elapsed) 56 | forward_time += elapsed 57 | 58 | start = time.time() 59 | (new_h.sum() + new_C.sum()).backward() 60 | elapsed = time.time() - start 61 | backward_min = min(backward_min, elapsed) 62 | backward_time += elapsed 63 | 64 | scale = TIME_SCALES[options.scale] 65 | forward_min *= scale 66 | backward_min *= scale 67 | forward_average = forward_time / options.runs * scale 68 | backward_average = backward_time / options.runs * scale 69 | 70 | print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format( 71 | forward_min, forward_average, backward_min, backward_average, 72 | options.scale)) 73 | -------------------------------------------------------------------------------- /cpp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goldsborough/pytorch-cpp-extension/3d88f24657154231de5ca8b3996efc067323f7c4/cpp/__init__.py -------------------------------------------------------------------------------- /cpp/jit.py: -------------------------------------------------------------------------------- 1 | from torch.utils.cpp_extension import load 2 | lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"], verbose=True) 3 | help(lltm_cpp) 4 | -------------------------------------------------------------------------------- /cpp/lltm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // s'(z) = (1 - s(z)) * s(z) 6 | at::Tensor d_sigmoid(at::Tensor z) { 7 | auto s = at::sigmoid(z); 8 | return (1 - s) * s; 9 | } 10 | 11 | // tanh'(z) = 1 - tanh^2(z) 12 | at::Tensor d_tanh(at::Tensor z) { 13 | return 1 - z.tanh().pow(2); 14 | } 15 | 16 | // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} 17 | at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) { 18 | auto e = z.exp(); 19 | auto mask = (alpha * (e - 1)) < 0; 20 | return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); 21 | } 22 | 23 | std::vector lltm_forward( 24 | at::Tensor input, 25 | at::Tensor weights, 26 | at::Tensor bias, 27 | at::Tensor old_h, 28 | at::Tensor old_cell) { 29 | auto X = at::cat({old_h, input}, /*dim=*/1); 30 | 31 | auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1)); 32 | auto gates = gate_weights.chunk(3, /*dim=*/1); 33 | 34 | auto input_gate = at::sigmoid(gates[0]); 35 | auto output_gate = at::sigmoid(gates[1]); 36 | auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0); 37 | 38 | auto new_cell = old_cell + candidate_cell * input_gate; 39 | auto new_h = at::tanh(new_cell) * output_gate; 40 | 41 | return {new_h, 42 | new_cell, 43 | input_gate, 44 | output_gate, 45 | candidate_cell, 46 | X, 47 | gate_weights}; 48 | } 49 | 50 | std::vector lltm_backward( 51 | at::Tensor grad_h, 52 | at::Tensor grad_cell, 53 | at::Tensor new_cell, 54 | at::Tensor input_gate, 55 | at::Tensor output_gate, 56 | at::Tensor candidate_cell, 57 | at::Tensor X, 58 | at::Tensor gate_weights, 59 | at::Tensor weights) { 60 | auto d_output_gate = at::tanh(new_cell) * grad_h; 61 | auto d_tanh_new_cell = output_gate * grad_h; 62 | auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; 63 | 64 | auto d_old_cell = d_new_cell; 65 | auto d_candidate_cell = input_gate * d_new_cell; 66 | auto d_input_gate = candidate_cell * d_new_cell; 67 | 68 | auto gates = gate_weights.chunk(3, /*dim=*/1); 69 | d_input_gate *= d_sigmoid(gates[0]); 70 | d_output_gate *= d_sigmoid(gates[1]); 71 | d_candidate_cell *= d_elu(gates[2]); 72 | 73 | auto d_gates = 74 | at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); 75 | 76 | auto d_weights = d_gates.t().mm(X); 77 | auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); 78 | 79 | auto d_X = d_gates.mm(weights); 80 | const auto state_size = grad_h.size(1); 81 | auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); 82 | auto d_input = d_X.slice(/*dim=*/1, state_size); 83 | 84 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; 85 | } 86 | 87 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 88 | m.def("forward", &lltm_forward, "LLTM forward"); 89 | m.def("backward", &lltm_backward, "LLTM backward"); 90 | } 91 | -------------------------------------------------------------------------------- /cpp/lltm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | 6 | import lltm_cpp 7 | 8 | torch.manual_seed(42) 9 | 10 | 11 | class LLTMFunction(Function): 12 | @staticmethod 13 | def forward(ctx, input, weights, bias, old_h, old_cell): 14 | outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell) 15 | new_h, new_cell = outputs[:2] 16 | variables = outputs[1:] + [weights] 17 | ctx.save_for_backward(*variables) 18 | 19 | return new_h, new_cell 20 | 21 | @staticmethod 22 | def backward(ctx, grad_h, grad_cell): 23 | d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward( 24 | grad_h, grad_cell, *ctx.saved_variables) 25 | return d_input, d_weights, d_bias, d_old_h, d_old_cell 26 | 27 | 28 | class LLTM(nn.Module): 29 | def __init__(self, input_features, state_size): 30 | super(LLTM, self).__init__() 31 | self.input_features = input_features 32 | self.state_size = state_size 33 | self.weights = nn.Parameter( 34 | torch.Tensor(3 * state_size, input_features + state_size)) 35 | self.bias = nn.Parameter(torch.Tensor(3 * state_size)) 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | stdv = 1.0 / math.sqrt(self.state_size) 40 | for weight in self.parameters(): 41 | weight.data.uniform_(-stdv, +stdv) 42 | 43 | def forward(self, input, state): 44 | return LLTMFunction.apply(input, self.weights, self.bias, *state) 45 | -------------------------------------------------------------------------------- /cpp/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | setup( 5 | name='lltm_cpp', 6 | ext_modules=[ 7 | CppExtension('lltm_cpp', ['lltm.cpp']), 8 | ], 9 | cmdclass={ 10 | 'build_ext': BuildExtension 11 | }) 12 | -------------------------------------------------------------------------------- /cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goldsborough/pytorch-cpp-extension/3d88f24657154231de5ca8b3996efc067323f7c4/cuda/__init__.py -------------------------------------------------------------------------------- /cuda/jit.py: -------------------------------------------------------------------------------- 1 | from torch.utils.cpp_extension import load 2 | lltm_cuda = load( 3 | 'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True) 4 | help(lltm_cuda) 5 | -------------------------------------------------------------------------------- /cuda/lltm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | 6 | import lltm_cuda 7 | 8 | torch.manual_seed(42) 9 | 10 | 11 | class LLTMFunction(Function): 12 | @staticmethod 13 | def forward(ctx, input, weights, bias, old_h, old_cell): 14 | outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell) 15 | new_h, new_cell = outputs[:2] 16 | variables = outputs[1:] + [weights] 17 | ctx.save_for_backward(*variables) 18 | 19 | return new_h, new_cell 20 | 21 | @staticmethod 22 | def backward(ctx, grad_h, grad_cell): 23 | d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cuda.backward( 24 | grad_h, grad_cell, *ctx.saved_variables) 25 | return d_input, d_weights, d_bias, d_old_h, d_old_cell 26 | 27 | 28 | class LLTM(nn.Module): 29 | def __init__(self, input_features, state_size): 30 | super(LLTM, self).__init__() 31 | self.input_features = input_features 32 | self.state_size = state_size 33 | self.weights = nn.Parameter( 34 | torch.Tensor(3 * state_size, input_features + state_size)) 35 | self.bias = nn.Parameter(torch.Tensor(3 * state_size)) 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | stdv = 1.0 / math.sqrt(self.state_size) 40 | for weight in self.parameters(): 41 | weight.data.uniform_(-stdv, +stdv) 42 | 43 | def forward(self, input, state): 44 | return LLTMFunction.apply(input, self.weights, self.bias, *state) 45 | -------------------------------------------------------------------------------- /cuda/lltm_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector lltm_cuda_forward( 8 | at::Tensor input, 9 | at::Tensor weights, 10 | at::Tensor bias, 11 | at::Tensor old_h, 12 | at::Tensor old_cell); 13 | 14 | std::vector lltm_cuda_backward( 15 | at::Tensor grad_h, 16 | at::Tensor grad_cell, 17 | at::Tensor new_cell, 18 | at::Tensor input_gate, 19 | at::Tensor output_gate, 20 | at::Tensor candidate_cell, 21 | at::Tensor X, 22 | at::Tensor gate_weights, 23 | at::Tensor weights); 24 | 25 | // C++ interface 26 | 27 | #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") 28 | 29 | std::vector lltm_forward( 30 | at::Tensor input, 31 | at::Tensor weights, 32 | at::Tensor bias, 33 | at::Tensor old_h, 34 | at::Tensor old_cell) { 35 | CHECK_CUDA(input); 36 | CHECK_CUDA(weights); 37 | CHECK_CUDA(bias); 38 | CHECK_CUDA(old_h); 39 | CHECK_CUDA(old_cell); 40 | 41 | return lltm_cuda_forward(input, weights, bias, old_h, old_cell); 42 | } 43 | 44 | std::vector lltm_backward( 45 | at::Tensor grad_h, 46 | at::Tensor grad_cell, 47 | at::Tensor new_cell, 48 | at::Tensor input_gate, 49 | at::Tensor output_gate, 50 | at::Tensor candidate_cell, 51 | at::Tensor X, 52 | at::Tensor gate_weights, 53 | at::Tensor weights) { 54 | CHECK_CUDA(grad_h); 55 | CHECK_CUDA(grad_cell); 56 | CHECK_CUDA(input_gate); 57 | CHECK_CUDA(output_gate); 58 | CHECK_CUDA(candidate_cell); 59 | CHECK_CUDA(X); 60 | CHECK_CUDA(gate_weights); 61 | CHECK_CUDA(weights); 62 | 63 | return lltm_cuda_backward( 64 | grad_h, 65 | grad_cell, 66 | new_cell, 67 | input_gate, 68 | output_gate, 69 | candidate_cell, 70 | X, 71 | gate_weights, 72 | weights); 73 | } 74 | 75 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 76 | m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); 77 | m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); 78 | } 79 | -------------------------------------------------------------------------------- /cuda/lltm_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace { 9 | template 10 | __device__ __forceinline__ scalar_t sigmoid(scalar_t z) { 11 | return 1.0 / (1.0 + exp(-z)); 12 | } 13 | 14 | template 15 | __device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { 16 | const auto s = sigmoid(z); 17 | return (1.0 - s) * s; 18 | } 19 | 20 | template 21 | __device__ __forceinline__ scalar_t d_tanh(scalar_t z) { 22 | const auto t = tanh(z); 23 | return 1 - (t * t); 24 | } 25 | 26 | template 27 | __device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { 28 | return fmax(0.0, z) + fmin(0.0, alpha * (exp(z) - 1.0)); 29 | } 30 | 31 | template 32 | __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { 33 | const auto e = exp(z); 34 | const auto d_relu = z < 0.0 ? 0.0 : 1.0; 35 | return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); 36 | } 37 | 38 | template 39 | __global__ void lltm_cuda_forward_kernel( 40 | const scalar_t* __restrict__ gates, 41 | const scalar_t* __restrict__ old_cell, 42 | scalar_t* __restrict__ new_h, 43 | scalar_t* __restrict__ new_cell, 44 | scalar_t* __restrict__ input_gate, 45 | scalar_t* __restrict__ output_gate, 46 | scalar_t* __restrict__ candidate_cell, 47 | size_t state_size) { 48 | const auto column = blockIdx.x * blockDim.x + threadIdx.x; 49 | const auto index = blockIdx.y * state_size + column; 50 | if (column < state_size) { 51 | input_gate[index] = sigmoid(gates[index]); 52 | output_gate[index] = sigmoid(gates[state_size + index]); 53 | candidate_cell[index] = elu(gates[2 * state_size + index]); 54 | new_cell[index] = 55 | old_cell[index] + candidate_cell[index] * input_gate[index]; 56 | new_h[index] = tanh(new_cell[index]) * output_gate[index]; 57 | } 58 | } 59 | 60 | template 61 | __global__ void lltm_cuda_backward_kernel( 62 | scalar_t* __restrict__ d_old_cell, 63 | scalar_t* __restrict__ d_gates, 64 | const scalar_t* __restrict__ grad_h, 65 | const scalar_t* __restrict__ grad_cell, 66 | const scalar_t* __restrict__ new_cell, 67 | const scalar_t* __restrict__ input_gate, 68 | const scalar_t* __restrict__ output_gate, 69 | const scalar_t* __restrict__ candidate_cell, 70 | const scalar_t* __restrict__ gate_weights, 71 | size_t state_size) { 72 | const int column = blockIdx.x * blockDim.x + threadIdx.x; 73 | const int index = blockIdx.y * state_size + column; 74 | if (column < state_size) { 75 | const auto d_output_gate = tanh(new_cell[index]) * grad_h[index]; 76 | const auto d_tanh_new_cell = output_gate[index] * grad_h[index]; 77 | const auto d_new_cell = 78 | d_tanh(new_cell[index]) * d_tanh_new_cell + grad_cell[index]; 79 | 80 | d_old_cell[index] = d_new_cell; 81 | const auto d_candidate_cell = input_gate[index] * d_new_cell; 82 | const auto d_input_gate = candidate_cell[index] * d_new_cell; 83 | 84 | const auto input_gate_index = index; 85 | const auto output_gate_index = state_size + index; 86 | const auto candidate_cell_index = 2 * state_size + index; 87 | 88 | d_gates[input_gate_index] = 89 | d_input_gate * d_sigmoid(gate_weights[input_gate_index]); 90 | d_gates[output_gate_index] = 91 | d_output_gate * d_sigmoid(gate_weights[output_gate_index]); 92 | d_gates[candidate_cell_index] = 93 | d_candidate_cell * d_elu(gate_weights[candidate_cell_index]); 94 | } 95 | } 96 | } // namespace 97 | 98 | std::vector lltm_cuda_forward( 99 | at::Tensor input, 100 | at::Tensor weights, 101 | at::Tensor bias, 102 | at::Tensor old_h, 103 | at::Tensor old_cell) { 104 | auto X = at::cat({old_h, input}, /*dim=*/1); 105 | auto gates = at::addmm(bias, X, weights.transpose(0, 1)); 106 | 107 | const auto batch_size = old_cell.size(0); 108 | const auto state_size = old_cell.size(1); 109 | 110 | auto new_h = at::zeros_like(old_cell); 111 | auto new_cell = at::zeros_like(old_cell); 112 | auto input_gate = at::zeros_like(old_cell); 113 | auto output_gate = at::zeros_like(old_cell); 114 | auto candidate_cell = at::zeros_like(old_cell); 115 | 116 | const int threads = 1024; 117 | const dim3 blocks(batch_size, (state_size + threads - 1) / threads); 118 | 119 | AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { 120 | lltm_cuda_forward_kernel<<>>( 121 | gates.data(), 122 | old_cell.data(), 123 | new_h.data(), 124 | new_cell.data(), 125 | input_gate.data(), 126 | output_gate.data(), 127 | candidate_cell.data(), 128 | state_size); 129 | })); 130 | 131 | return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; 132 | } 133 | 134 | std::vector lltm_cuda_backward( 135 | at::Tensor grad_h, 136 | at::Tensor grad_cell, 137 | at::Tensor new_cell, 138 | at::Tensor input_gate, 139 | at::Tensor output_gate, 140 | at::Tensor candidate_cell, 141 | at::Tensor X, 142 | at::Tensor gate_weights, 143 | at::Tensor weights) { 144 | auto d_old_cell = at::zeros_like(new_cell); 145 | auto d_gates = at::zeros_like(gate_weights); 146 | 147 | const auto batch_size = new_cell.size(0); 148 | const auto state_size = new_cell.size(1); 149 | 150 | const int threads = 1024; 151 | const dim3 blocks(batch_size, (state_size + threads - 1) / threads); 152 | 153 | AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { 154 | lltm_cuda_backward_kernel<<>>( 155 | d_old_cell.data(), 156 | d_gates.data(), 157 | grad_h.data(), 158 | grad_cell.data(), 159 | new_cell.data(), 160 | input_gate.data(), 161 | output_gate.data(), 162 | candidate_cell.data(), 163 | gate_weights.data(), 164 | state_size); 165 | })); 166 | 167 | auto d_weights = d_gates.t().mm(X); 168 | auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); 169 | 170 | auto d_X = d_gates.mm(weights); 171 | auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); 172 | auto d_input = d_X.slice(/*dim=*/1, state_size); 173 | 174 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; 175 | } 176 | -------------------------------------------------------------------------------- /cuda/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='lltm_cuda', 6 | ext_modules=[ 7 | CUDAExtension('lltm_cuda', [ 8 | 'lltm_cuda.cpp', 9 | 'lltm_cuda_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /grad_check.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import torch 6 | 7 | from torch.autograd import Variable, gradcheck 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('example', choices=['py', 'cpp', 'cuda']) 11 | parser.add_argument('-b', '--batch-size', type=int, default=3) 12 | parser.add_argument('-f', '--features', type=int, default=17) 13 | parser.add_argument('-s', '--state-size', type=int, default=5) 14 | parser.add_argument('-c', '--cuda', action='store_true') 15 | options = parser.parse_args() 16 | 17 | if options.example == 'py': 18 | from python.lltm_baseline import LLTMFunction 19 | elif options.example == 'cpp': 20 | from cpp.lltm import LLTMFunction 21 | else: 22 | from cuda.lltm import LLTMFunction 23 | options.cuda = True 24 | 25 | X = torch.randn(options.batch_size, options.features) 26 | h = torch.randn(options.batch_size, options.state_size) 27 | C = torch.randn(options.batch_size, options.state_size) 28 | W = torch.randn(3 * options.state_size, options.features + options.state_size) 29 | b = torch.randn(1, 3 * options.state_size) 30 | 31 | variables = [X, W, b, h, C] 32 | 33 | for i, var in enumerate(variables): 34 | if options.cuda: 35 | var = var.cuda() 36 | variables[i] = Variable(var.double(), requires_grad=True) 37 | 38 | print(LLTMFunction.apply(*variables)) 39 | 40 | if gradcheck(LLTMFunction.apply, variables): 41 | print('Ok') 42 | -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goldsborough/pytorch-cpp-extension/3d88f24657154231de5ca8b3996efc067323f7c4/python/__init__.py -------------------------------------------------------------------------------- /python/lltm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | torch.manual_seed(42) 6 | 7 | 8 | class LLTM(torch.nn.Module): 9 | def __init__(self, input_features, state_size): 10 | super(LLTM, self).__init__() 11 | self.input_features = input_features 12 | self.state_size = state_size 13 | # 3 * state_size for input gate, output gate and candidate cell gate. 14 | # input_features + state_size because we will multiply with [input, h]. 15 | self.weights = torch.nn.Parameter( 16 | torch.Tensor(3 * state_size, input_features + state_size)) 17 | self.bias = torch.nn.Parameter(torch.Tensor(3 * state_size)) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | stdv = 1.0 / math.sqrt(self.state_size) 22 | for weight in self.parameters(): 23 | weight.data.uniform_(-stdv, +stdv) 24 | 25 | def forward(self, input, state): 26 | old_h, old_cell = state 27 | X = torch.cat([old_h, input], dim=1) 28 | 29 | # Compute the input, output and candidate cell gates with one MM. 30 | gate_weights = F.linear(X, self.weights, self.bias) 31 | # Split the combined gate weight matrix into its components. 32 | gates = gate_weights.chunk(3, dim=1) 33 | 34 | input_gate = F.sigmoid(gates[0]) 35 | output_gate = F.sigmoid(gates[1]) 36 | # Here we use an ELU instead of the usual tanh. 37 | candidate_cell = F.elu(gates[2]) 38 | 39 | # Compute the new cell state. 40 | new_cell = old_cell + candidate_cell * input_gate 41 | # Compute the new hidden state and output. 42 | new_h = F.tanh(new_cell) * output_gate 43 | 44 | return new_h, new_cell 45 | -------------------------------------------------------------------------------- /python/lltm_baseline.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch import nn 4 | from torch.autograd import Function 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | torch.manual_seed(42) 9 | 10 | 11 | def d_sigmoid(z): 12 | s = F.sigmoid(z) 13 | return (1 - s) * s 14 | 15 | 16 | def d_tanh(z): 17 | t = F.tanh(z) 18 | return 1 - (t * t) 19 | 20 | 21 | def d_elu(z, alpha=1.0): 22 | e = z.exp() 23 | mask = (alpha * (e - 1)) < 0 24 | return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e) 25 | 26 | 27 | class LLTMFunction(Function): 28 | @staticmethod 29 | def forward(ctx, input, weights, bias, old_h, old_cell): 30 | X = torch.cat([old_h, input], dim=1) 31 | 32 | gate_weights = F.linear(X, weights, bias) 33 | gates = gate_weights.chunk(3, dim=1) 34 | 35 | input_gate = F.sigmoid(gates[0]) 36 | output_gate = F.sigmoid(gates[1]) 37 | candidate_cell = F.elu(gates[2]) 38 | 39 | new_cell = old_cell + candidate_cell * input_gate 40 | new_h = F.tanh(new_cell) * output_gate 41 | 42 | ctx.save_for_backward(X, weights, input_gate, output_gate, old_cell, 43 | new_cell, candidate_cell, gate_weights) 44 | 45 | return new_h, new_cell 46 | 47 | @staticmethod 48 | def backward(ctx, grad_h, grad_cell): 49 | X, weights, input_gate, output_gate, old_cell = ctx.saved_variables[:5] 50 | new_cell, candidate_cell, gate_weights = ctx.saved_variables[5:] 51 | 52 | d_input = d_weights = d_bias = d_old_h = d_old_cell = None 53 | 54 | d_output_gate = F.tanh(new_cell) * grad_h 55 | d_tanh_new_cell = output_gate * grad_h 56 | d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell 57 | 58 | d_old_cell = d_new_cell 59 | d_candidate_cell = input_gate * d_new_cell 60 | d_input_gate = candidate_cell * d_new_cell 61 | 62 | gates = gate_weights.chunk(3, dim=1) 63 | d_input_gate *= d_sigmoid(gates[0]) 64 | d_output_gate *= d_sigmoid(gates[1]) 65 | d_candidate_cell *= d_elu(gates[2]) 66 | 67 | d_gates = torch.cat( 68 | [d_input_gate, d_output_gate, d_candidate_cell], dim=1) 69 | 70 | if ctx.needs_input_grad[1]: 71 | d_weights = d_gates.t().mm(X) 72 | if ctx.needs_input_grad[2]: 73 | d_bias = d_gates.sum(dim=0, keepdim=True) 74 | if ctx.needs_input_grad[3] or ctx.needs_input_grad[4]: 75 | d_X = d_gates.mm(weights) 76 | state_size = grad_h.shape[1] 77 | d_old_h, d_input = d_X[:, :state_size], d_X[:, state_size:] 78 | 79 | return d_input, d_weights, d_bias, d_old_h, d_old_cell 80 | 81 | 82 | class LLTM(nn.Module): 83 | def __init__(self, input_features, state_size): 84 | super(LLTM, self).__init__() 85 | self.input_features = input_features 86 | self.state_size = state_size 87 | self.weights = nn.Parameter( 88 | torch.Tensor(3 * state_size, input_features + state_size)) 89 | self.bias = nn.Parameter(torch.Tensor(3 * state_size)) 90 | self.reset_parameters() 91 | 92 | def reset_parameters(self): 93 | stdv = 1.0 / math.sqrt(self.state_size) 94 | for weight in self.parameters(): 95 | weight.data.uniform_(-stdv, +stdv) 96 | 97 | def forward(self, input, state): 98 | return LLTMFunction.apply(input, self.weights, self.bias, *state) 99 | --------------------------------------------------------------------------------