├── .gitignore ├── CMakeLists.txt ├── README.md ├── lltm ├── __init__.py ├── lltm.cpp ├── lltm.py └── lltm_cuda.cu ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # skbuild related 138 | _skbuild/ 139 | .vscode/ 140 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.8 FATAL_ERROR) 2 | project(LLTM LANGUAGES C CXX CUDA VERSION 1.0) 3 | # if CUDA is optional: enable_language(CUDA) 4 | 5 | find_package(Torch REQUIRED) 6 | find_package(PythonExtensions REQUIRED) 7 | set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}) 8 | 9 | # tweak torch dependencies, see https://github.com/pytorch/pytorch/issues/33928 10 | get_target_property(TORCH_INTERFACE_LIB torch INTERFACE_LINK_LIBRARIES) 11 | string(REPLACE "/usr/local/cuda" ${CUDA_TOOLKIT_ROOT_DIR} TORCH_INTERFACE_LIB "${TORCH_INTERFACE_LIB}") 12 | set_target_properties(torch PROPERTIES INTERFACE_LINK_LIBRARIES ${TORCH_INTERFACE_LIB}) 13 | 14 | # add library 15 | add_library(lltm_ext MODULE 16 | lltm/lltm.cpp 17 | lltm/lltm_cuda.cu 18 | ) 19 | python_extension_module(lltm_ext) 20 | target_link_libraries(lltm_ext ${TORCH_LIBRARIES}) 21 | target_include_directories(lltm_ext PRIVATE ${TORCH_INCLUDE_DIRS}) 22 | set_property(TARGET lltm_ext PROPERTY CXX_STANDARD 14) 23 | 24 | install(TARGETS lltm_ext DESTINATION lltm) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch CMake Example 2 | 3 | This repository is an example for creating CMake-based pytorch CUDA extension. It's modified from [pytorch extension example](https://github.com/pytorch/extension-cpp) and [scikit-build example](https://github.com/scikit-build/scikit-build-sample-projects). 4 | 5 | To build this repository, install essential requirements and then execute `python setup.py build`. If you want to use CUDA in custom location (for example your library is installed from `conda install cudatoolkit-dev -c conda-forge`), you can give hint to CMake by defining CMake definition `CMAKE_CUDA_COMPILER`. 6 | -------------------------------------------------------------------------------- /lltm/__init__.py: -------------------------------------------------------------------------------- 1 | import torch # ensure pytorch library is loaded 2 | from .lltm import LLTM 3 | -------------------------------------------------------------------------------- /lltm/lltm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // s'(z) = (1 - s(z)) * s(z) 6 | torch::Tensor d_sigmoid(torch::Tensor z) { 7 | auto s = torch::sigmoid(z); 8 | return (1 - s) * s; 9 | } 10 | 11 | // tanh'(z) = 1 - tanh^2(z) 12 | torch::Tensor d_tanh(torch::Tensor z) { 13 | return 1 - z.tanh().pow(2); 14 | } 15 | 16 | // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} 17 | torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { 18 | auto e = z.exp(); 19 | auto mask = (alpha * (e - 1)) < 0; 20 | return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); 21 | } 22 | 23 | std::vector lltm_forward( 24 | torch::Tensor input, 25 | torch::Tensor weights, 26 | torch::Tensor bias, 27 | torch::Tensor old_h, 28 | torch::Tensor old_cell) { 29 | auto X = torch::cat({old_h, input}, /*dim=*/1); 30 | 31 | auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); 32 | auto gates = gate_weights.chunk(3, /*dim=*/1); 33 | 34 | auto input_gate = torch::sigmoid(gates[0]); 35 | auto output_gate = torch::sigmoid(gates[1]); 36 | auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); 37 | 38 | auto new_cell = old_cell + candidate_cell * input_gate; 39 | auto new_h = torch::tanh(new_cell) * output_gate; 40 | 41 | return {new_h, 42 | new_cell, 43 | input_gate, 44 | output_gate, 45 | candidate_cell, 46 | X, 47 | gate_weights}; 48 | } 49 | 50 | std::vector lltm_backward( 51 | torch::Tensor grad_h, 52 | torch::Tensor grad_cell, 53 | torch::Tensor new_cell, 54 | torch::Tensor input_gate, 55 | torch::Tensor output_gate, 56 | torch::Tensor candidate_cell, 57 | torch::Tensor X, 58 | torch::Tensor gate_weights, 59 | torch::Tensor weights) { 60 | auto d_output_gate = torch::tanh(new_cell) * grad_h; 61 | auto d_tanh_new_cell = output_gate * grad_h; 62 | auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; 63 | 64 | auto d_old_cell = d_new_cell; 65 | auto d_candidate_cell = input_gate * d_new_cell; 66 | auto d_input_gate = candidate_cell * d_new_cell; 67 | 68 | auto gates = gate_weights.chunk(3, /*dim=*/1); 69 | d_input_gate *= d_sigmoid(gates[0]); 70 | d_output_gate *= d_sigmoid(gates[1]); 71 | d_candidate_cell *= d_elu(gates[2]); 72 | 73 | auto d_gates = 74 | torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); 75 | 76 | auto d_weights = d_gates.t().mm(X); 77 | auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); 78 | 79 | auto d_X = d_gates.mm(weights); 80 | const auto state_size = grad_h.size(1); 81 | auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); 82 | auto d_input = d_X.slice(/*dim=*/1, state_size); 83 | 84 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; 85 | } 86 | 87 | // CUDA forward declarations 88 | 89 | std::vector lltm_cuda_forward( 90 | torch::Tensor input, 91 | torch::Tensor weights, 92 | torch::Tensor bias, 93 | torch::Tensor old_h, 94 | torch::Tensor old_cell); 95 | 96 | std::vector lltm_cuda_backward( 97 | torch::Tensor grad_h, 98 | torch::Tensor grad_cell, 99 | torch::Tensor new_cell, 100 | torch::Tensor input_gate, 101 | torch::Tensor output_gate, 102 | torch::Tensor candidate_cell, 103 | torch::Tensor X, 104 | torch::Tensor gate_weights, 105 | torch::Tensor weights); 106 | 107 | PYBIND11_MODULE(lltm_ext, m) { 108 | m.def("forward_cuda", &lltm_cuda_forward, "LLTM forward (CUDA)"); 109 | m.def("backward_cuda", &lltm_cuda_backward, "LLTM backward (CUDA)"); 110 | m.def("forward", &lltm_forward, "LLTM forward"); 111 | m.def("backward", &lltm_backward, "LLTM backward"); 112 | } 113 | -------------------------------------------------------------------------------- /lltm/lltm.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | 6 | from lltm import lltm_ext 7 | 8 | torch.manual_seed(42) 9 | 10 | 11 | class LLTMFunction(Function): 12 | @staticmethod 13 | def forward(ctx, input, weights, bias, old_h, old_cell): 14 | outputs = lltm_ext.forward(input, weights, bias, old_h, old_cell) 15 | new_h, new_cell = outputs[:2] 16 | variables = outputs[1:] + [weights] 17 | ctx.save_for_backward(*variables) 18 | 19 | return new_h, new_cell 20 | 21 | @staticmethod 22 | def backward(ctx, grad_h, grad_cell): 23 | outputs = lltm_ext.backward( 24 | grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) 25 | d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs 26 | return d_input, d_weights, d_bias, d_old_h, d_old_cell 27 | 28 | 29 | class LLTM(nn.Module): 30 | def __init__(self, input_features, state_size): 31 | super(LLTM, self).__init__() 32 | self.input_features = input_features 33 | self.state_size = state_size 34 | self.weights = nn.Parameter( 35 | torch.Tensor(3 * state_size, input_features + state_size)) 36 | self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) 37 | self.reset_parameters() 38 | 39 | def reset_parameters(self): 40 | stdv = 1.0 / math.sqrt(self.state_size) 41 | for weight in self.parameters(): 42 | weight.data.uniform_(-stdv, +stdv) 43 | 44 | def forward(self, input, state): 45 | return LLTMFunction.apply(input, self.weights, self.bias, *state) 46 | -------------------------------------------------------------------------------- /lltm/lltm_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace { 9 | template 10 | __device__ __forceinline__ scalar_t sigmoid(scalar_t z) { 11 | return 1.0 / (1.0 + exp(-z)); 12 | } 13 | 14 | template 15 | __device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { 16 | const auto s = sigmoid(z); 17 | return (1.0 - s) * s; 18 | } 19 | 20 | template 21 | __device__ __forceinline__ scalar_t d_tanh(scalar_t z) { 22 | const auto t = tanh(z); 23 | return 1 - (t * t); 24 | } 25 | 26 | template 27 | __device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { 28 | return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); 29 | } 30 | 31 | template 32 | __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { 33 | const auto e = exp(z); 34 | const auto d_relu = z < 0.0 ? 0.0 : 1.0; 35 | return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); 36 | } 37 | 38 | template 39 | __global__ void lltm_cuda_forward_kernel( 40 | const torch::PackedTensorAccessor gates, 41 | const torch::PackedTensorAccessor old_cell, 42 | torch::PackedTensorAccessor new_h, 43 | torch::PackedTensorAccessor new_cell, 44 | torch::PackedTensorAccessor input_gate, 45 | torch::PackedTensorAccessor output_gate, 46 | torch::PackedTensorAccessor candidate_cell) { 47 | //batch index 48 | const int n = blockIdx.y; 49 | // column index 50 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 51 | if (c < gates.size(2)){ 52 | input_gate[n][c] = sigmoid(gates[n][0][c]); 53 | output_gate[n][c] = sigmoid(gates[n][1][c]); 54 | candidate_cell[n][c] = elu(gates[n][2][c]); 55 | new_cell[n][c] = 56 | old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; 57 | new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; 58 | } 59 | } 60 | 61 | template 62 | __global__ void lltm_cuda_backward_kernel( 63 | torch::PackedTensorAccessor d_old_cell, 64 | torch::PackedTensorAccessor d_gates, 65 | const torch::PackedTensorAccessor grad_h, 66 | const torch::PackedTensorAccessor grad_cell, 67 | const torch::PackedTensorAccessor new_cell, 68 | const torch::PackedTensorAccessor input_gate, 69 | const torch::PackedTensorAccessor output_gate, 70 | const torch::PackedTensorAccessor candidate_cell, 71 | const torch::PackedTensorAccessor gate_weights) { 72 | //batch index 73 | const int n = blockIdx.y; 74 | // column index 75 | const int c = blockIdx.x * blockDim.x + threadIdx.x; 76 | if (c < d_gates.size(2)){ 77 | const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; 78 | const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; 79 | const auto d_new_cell = 80 | d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; 81 | 82 | 83 | d_old_cell[n][c] = d_new_cell; 84 | const auto d_candidate_cell = input_gate[n][c] * d_new_cell; 85 | const auto d_input_gate = candidate_cell[n][c] * d_new_cell; 86 | 87 | d_gates[n][0][c] = 88 | d_input_gate * d_sigmoid(gate_weights[n][0][c]); 89 | d_gates[n][1][c] = 90 | d_output_gate * d_sigmoid(gate_weights[n][1][c]); 91 | d_gates[n][2][c] = 92 | d_candidate_cell * d_elu(gate_weights[n][2][c]); 93 | } 94 | } 95 | } // namespace 96 | 97 | std::vector lltm_cuda_forward( 98 | torch::Tensor input, 99 | torch::Tensor weights, 100 | torch::Tensor bias, 101 | torch::Tensor old_h, 102 | torch::Tensor old_cell) { 103 | auto X = torch::cat({old_h, input}, /*dim=*/1); 104 | auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); 105 | 106 | const auto batch_size = old_cell.size(0); 107 | const auto state_size = old_cell.size(1); 108 | 109 | auto gates = gate_weights.reshape({batch_size, 3, state_size}); 110 | auto new_h = torch::zeros_like(old_cell); 111 | auto new_cell = torch::zeros_like(old_cell); 112 | auto input_gate = torch::zeros_like(old_cell); 113 | auto output_gate = torch::zeros_like(old_cell); 114 | auto candidate_cell = torch::zeros_like(old_cell); 115 | 116 | const int threads = 1024; 117 | const dim3 blocks((state_size + threads - 1) / threads, batch_size); 118 | 119 | AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { 120 | lltm_cuda_forward_kernel<<>>( 121 | gates.packed_accessor(), 122 | old_cell.packed_accessor(), 123 | new_h.packed_accessor(), 124 | new_cell.packed_accessor(), 125 | input_gate.packed_accessor(), 126 | output_gate.packed_accessor(), 127 | candidate_cell.packed_accessor()); 128 | })); 129 | 130 | return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; 131 | } 132 | 133 | std::vector lltm_cuda_backward( 134 | torch::Tensor grad_h, 135 | torch::Tensor grad_cell, 136 | torch::Tensor new_cell, 137 | torch::Tensor input_gate, 138 | torch::Tensor output_gate, 139 | torch::Tensor candidate_cell, 140 | torch::Tensor X, 141 | torch::Tensor gates, 142 | torch::Tensor weights) { 143 | auto d_old_cell = torch::zeros_like(new_cell); 144 | auto d_gates = torch::zeros_like(gates); 145 | 146 | const auto batch_size = new_cell.size(0); 147 | const auto state_size = new_cell.size(1); 148 | 149 | const int threads = 1024; 150 | const dim3 blocks((state_size + threads - 1) / threads, batch_size); 151 | 152 | AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { 153 | lltm_cuda_backward_kernel<<>>( 154 | d_old_cell.packed_accessor(), 155 | d_gates.packed_accessor(), 156 | grad_h.packed_accessor(), 157 | grad_cell.packed_accessor(), 158 | new_cell.packed_accessor(), 159 | input_gate.packed_accessor(), 160 | output_gate.packed_accessor(), 161 | candidate_cell.packed_accessor(), 162 | gates.packed_accessor()); 163 | })); 164 | 165 | auto d_gate_weights = d_gates.flatten(1, 2); 166 | auto d_weights = d_gate_weights.t().mm(X); 167 | auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); 168 | 169 | auto d_X = d_gate_weights.mm(weights); 170 | auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); 171 | auto d_input = d_X.slice(/*dim=*/1, state_size); 172 | 173 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates}; 174 | } 175 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | try: 4 | from skbuild import setup 5 | except ImportError: 6 | print('scikit-build is required to build from source.', file=sys.stderr) 7 | print('Please run:', file=sys.stderr) 8 | print('', file=sys.stderr) 9 | print(' python -m pip install scikit-build') 10 | sys.exit(1) 11 | 12 | import torch 13 | torch_root = os.path.dirname(torch.__file__) 14 | 15 | setup( 16 | name="lltm-extension", 17 | version="0.0.1", 18 | description="a minimal example package for pytorch extension (with pybind11 and scikit-build)", 19 | license="MIT", 20 | packages=['lltm'], 21 | cmake_args=[f'-DCMAKE_PREFIX_PATH={torch_root}'] # to specify CUDA location: -DCMAKE_CUDA_COMPILER=... 22 | ) --------------------------------------------------------------------------------