├── .gitignore ├── LICENSE ├── README.md ├── csrc ├── swish.cpp ├── swish.h └── swish_kernel.cu ├── external └── CUDAApplyUtils.cuh ├── extra ├── Comparison.ipynb └── package.py ├── setup.py ├── src └── swish_torch │ └── __init__.py └── test └── test_swish.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build.ninja 3 | .gdb_history 4 | crap 5 | 6 | # Prerequisites 7 | *.d 8 | 9 | # Compiled Object files 10 | *.slo 11 | *.lo 12 | *.o 13 | *.obj 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Compiled Dynamic libraries 20 | *.so 21 | *.dylib 22 | *.dll 23 | 24 | # Fortran module files 25 | *.mod 26 | *.smod 27 | 28 | # Compiled Static libraries 29 | *.lai 30 | *.la 31 | *.a 32 | *.lib 33 | 34 | # Executables 35 | *.exe 36 | *.out 37 | *.app 38 | 39 | # Byte-compiled / optimized / DLL files 40 | __pycache__/ 41 | *.py[cod] 42 | *$py.class 43 | 44 | # C extensions 45 | *.so 46 | 47 | # Distribution / packaging 48 | .Python 49 | build/ 50 | develop-eggs/ 51 | dist/ 52 | downloads/ 53 | eggs/ 54 | .eggs/ 55 | lib/ 56 | lib64/ 57 | parts/ 58 | sdist/ 59 | var/ 60 | wheels/ 61 | pip-wheel-metadata/ 62 | share/python-wheels/ 63 | *.egg-info/ 64 | .installed.cfg 65 | *.egg 66 | MANIFEST 67 | 68 | # PyInstaller 69 | # Usually these files are written by a python script from a template 70 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 71 | *.manifest 72 | *.spec 73 | 74 | # Installer logs 75 | pip-log.txt 76 | pip-delete-this-directory.txt 77 | 78 | # Unit test / coverage reports 79 | htmlcov/ 80 | .tox/ 81 | .nox/ 82 | .coverage 83 | .coverage.* 84 | .cache 85 | nosetests.xml 86 | coverage.xml 87 | *.cover 88 | .hypothesis/ 89 | .pytest_cache/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 thomasbrandon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Swish Activation - PyTorch CUDA Implementation 2 | 3 | This is a PyTorch CUDA implementation of the Swish activation function (https://arxiv.org/abs/1710.05941). 4 | 5 | ## Installation 6 | It is currently distributed as a source only PyTorch extension. So you need a properly set up toolchain and CUDA compilers to install. 7 | 1) _Toolchain_ - In conda the `gxx_linux-64` package provides an appropriate toolchain. However there can still be compatbility issues with this depending on system. You can also try with the system toolchain. 8 | 2) _CUDA Toolkit_ - The [nVidia CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) is required in addition to drivers to provide needed headers and tools. Get the appropriate version for your Linux distro from nVidia or check for distro specific instructions otherwise. 9 | 10 | _It is important your CUDA Toolkit matches the version PyTorch is built for or errors can occur. Currently PyTorch builds for v10.0 and v9.2._ 11 | 12 | -------------------------------------------------------------------------------- /csrc/swish.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | using namespace pybind11::literals; 4 | 5 | // Forward declaration of kernels 6 | void swish_forward_cuda(torch::Tensor &output, const torch::Tensor &input); 7 | void swish_backward_cuda(torch::Tensor &grad_inp, const torch::Tensor &input, const torch::Tensor &grad_out); 8 | 9 | torch::Tensor 10 | swish_forward(const torch::Tensor &input, const at::optional out) { 11 | auto input_arg = torch::TensorArg(input, "input", 0); 12 | if (out) { 13 | auto out_arg = torch::TensorArg(*out, "out", 1); 14 | torch::checkSameType("swish_forward", input_arg, out_arg); 15 | torch::checkSameSize("swish_forward", input_arg, out_arg); 16 | } 17 | auto o = out.value_or(torch::empty_like(input)); 18 | switch (input.device().type()) { 19 | case c10::kCUDA: 20 | swish_forward_cuda(o, input); 21 | break; 22 | default: 23 | TORCH_CHECK(false, "Unsupported device type, should be CUDA but got ", input.device().type()); 24 | } 25 | return o; 26 | } 27 | 28 | torch::Tensor 29 | swish_backward(const torch::Tensor &input, const torch::Tensor &grad_out) { 30 | auto input_arg = torch::TensorArg(input, "input", 0); 31 | auto grad_out_arg = torch::TensorArg(grad_out, "grad_out", 1); 32 | torch::checkSameType("swish_backward", input_arg, grad_out_arg); 33 | 34 | auto grad_inp = torch::empty_like(input); 35 | switch (input.device().type()) { 36 | case c10::kCUDA: 37 | swish_backward_cuda(grad_inp, input, grad_out); 38 | break; 39 | default: 40 | TORCH_CHECK(false, "Unsupported device type, should be CUDA but got ", input.device().type()); 41 | } 42 | return grad_inp; 43 | } 44 | 45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 46 | m.def("swish_forward", &swish_forward, "Swish activation forward", "input"_a, "out"_a = nullptr); 47 | m.def("swish_backward", &swish_backward, "Swish activation backward", "input"_a, "grad_out"_a); 48 | } 49 | -------------------------------------------------------------------------------- /csrc/swish.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #ifdef __CUDACC__ 4 | #include 5 | #include 6 | #define GLOBAL_INLINE __forceinline__ __host__ __device__ 7 | #else 8 | #include 9 | #define GLOBAL_INLINE __inline__ 10 | #endif 11 | 12 | // TODO: Try and convert these to lambda functions 13 | template 14 | GLOBAL_INLINE 15 | void swish_fwd_func(scalar_t &out, const scalar_t &inp) { 16 | out = inp / (scalar_t(1.0) + exp(-inp)); 17 | }; 18 | 19 | template 20 | GLOBAL_INLINE 21 | void swish_bwd_func(scalar_t &grad_inp, const scalar_t &inp, const scalar_t &grad_out) { 22 | const scalar_t sig = scalar_t(1.0) / (scalar_t(1.0) + exp(-inp)); 23 | const scalar_t grad = sig * (scalar_t(1.0) + inp * (scalar_t(1.0) - sig)); 24 | grad_inp = grad_out * grad; 25 | }; 26 | 27 | // Specialisations for Half to calculate as float 28 | // Increases precision and also lacking certain instrinsics for Half 29 | template <> 30 | GLOBAL_INLINE 31 | void swish_fwd_func(c10::Half &out, const c10::Half &inp) { 32 | float res; 33 | swish_fwd_func(res, (float)inp); 34 | out = res; 35 | }; 36 | 37 | template <> 38 | GLOBAL_INLINE 39 | void swish_bwd_func(c10::Half &grad_inp, const c10::Half &inp, const c10::Half &grad_out) { 40 | float res; 41 | swish_bwd_func(res, (float)inp, (float)grad_out); 42 | grad_inp = res; 43 | }; -------------------------------------------------------------------------------- /csrc/swish_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "CUDAApplyUtils.cuh" 4 | 5 | // TORCH_CHECK replaces AT_CHECK in PyTorch 1,2, support 1.1 as well. 6 | #ifndef TORCH_CHECK 7 | #define TORCH_CHECK AT_CHECK 8 | #endif 9 | 10 | #ifndef __CUDACC_EXTENDED_LAMBDA__ 11 | #error "please compile with --expt-extended-lambda" 12 | #endif 13 | 14 | namespace kernel { 15 | #include "swish.h" 16 | 17 | using at::cuda::CUDA_tensor_apply2; 18 | using at::cuda::CUDA_tensor_apply3; 19 | using at::cuda::TensorArgType; 20 | 21 | template 22 | void 23 | swish_forward( 24 | torch::Tensor &output, 25 | const torch::Tensor &input 26 | ) { 27 | CUDA_tensor_apply2( 28 | output, input, 29 | [=] __host__ __device__ (scalar_t &out, const scalar_t &inp) { 30 | swish_fwd_func(out, inp); 31 | }, 32 | TensorArgType::ReadWrite, TensorArgType::ReadOnly 33 | ); 34 | } 35 | 36 | template 37 | void 38 | swish_backward( 39 | torch::Tensor &grad_inp, 40 | const torch::Tensor &input, 41 | const torch::Tensor &grad_out 42 | ) { 43 | CUDA_tensor_apply3( 44 | grad_inp, input, grad_out, 45 | [=] __host__ __device__ (scalar_t &grad_inp, const scalar_t &inp, const scalar_t &grad_out) { 46 | swish_bwd_func(grad_inp, inp, grad_out); 47 | }, 48 | TensorArgType::ReadWrite, TensorArgType::ReadOnly, TensorArgType::ReadOnly 49 | ); 50 | } 51 | 52 | } // namespace kernel 53 | 54 | void 55 | swish_forward_cuda( 56 | torch::Tensor &output, const torch::Tensor &input 57 | ) { 58 | auto in_arg = torch::TensorArg(input, "input", 0), 59 | out_arg = torch::TensorArg(output, "output", 1); 60 | torch::checkAllDefined("swish_forward_cuda", {in_arg, out_arg}); 61 | torch::checkAllSameGPU("swish_forward_cuda", {in_arg, out_arg}); 62 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "swish_forward_cuda", [&] { 63 | kernel::swish_forward(output, input); 64 | }); 65 | } 66 | 67 | void 68 | swish_backward_cuda( 69 | torch::Tensor &grad_inp, const torch::Tensor &input, const torch::Tensor &grad_out 70 | ) { 71 | auto gi_arg = torch::TensorArg(grad_inp, "grad_inp", 0), 72 | in_arg = torch::TensorArg(input, "input", 1), 73 | go_arg = torch::TensorArg(grad_out, "grad_out", 2); 74 | torch::checkAllDefined("swish_backward_cuda", {gi_arg, in_arg, go_arg}); 75 | torch::checkAllSameGPU("swish_backward_cuda", {gi_arg, in_arg, go_arg}); 76 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_inp.scalar_type(), "swish_backward_cuda", [&] { 77 | kernel::swish_backward(grad_inp, input, grad_out); 78 | }); 79 | } 80 | 81 | -------------------------------------------------------------------------------- /external/CUDAApplyUtils.cuh: -------------------------------------------------------------------------------- 1 | 2 | // Taken from https://github.com/pytorch/pytorch/blob/ff6cda0da6d8fceadbe0cf31ef73c78d3c9e9bcc/aten/src/ATen/cuda/CUDAApplyUtils.cuh 3 | // Changes marked with TB: 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | // 16 | // This file contains pointwise operation functions and kernels that 17 | // work on both contiguous and non-contiguous tensor arguments of 18 | // arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without 19 | // copying or temporary storage. 20 | // 21 | 22 | /* 23 | NOTE [ CUDA_tensor_applyN helpers ] 24 | 25 | The following CUDA_tensor_applyN (where N currently can be 1, 2, 3, or 4) 26 | functions apply a pointwise operator to N tensor(s). 27 | 28 | The calling convention is 29 | 30 | 1. The template arguments should be, sequentially, 31 | - First N typename args specify the scalar types of each of the N tensors. 32 | - (Optional) `int step` arg specifies the number of elements processed 33 | together at the same time. 34 | Default is 1. 35 | - A usually omitted (i.e., inferred) typename arg specifies the type of the 36 | function/functor applied on `N * step` values in each iteration of each 37 | CUDA thread. 38 | 2. The arguments should be, sequentially, 39 | - N tensors 40 | - op: a function/functor that processes `N * step` values at the same time. 41 | - If `step == 1`, it must have signature 42 | `void(*)(scalar1_t&, scalar2_t&, ..., scalarN_t&)`, where 43 | `scalar*_t`s are the first N typename template args, and the inputs 44 | are the `N` values from the `N` tensors retrieved at a common index. 45 | - Otherwise, it must must have signature 46 | void(*)(int n, scalar1_t&, scalar1_t&, ..., scalar1_t&, // repeat `step` times 47 | scalar2_t&, scalar2_t&, ..., scalar2_t&, // repeat `step` times 48 | ..., 49 | scalarN_t&, scalarN_t&, ..., scalarN_t&) // repeat `step` times 50 | Different from `step == 1` case, it processes `N * step` values taken 51 | from `step` common indices. Moreover, the first input `n` represents the 52 | number of valid indices (it will always have `0 < n <= step`). It will 53 | almost always be `step`, but at the boundary we may not have full `step` 54 | elements and `n` can be a lesser value. 55 | 56 | E.g., if `step == 4` and `N == 2`, `op` could be 57 | 58 | [](int n, scalar1_t &u1, scalar1_t &u2, scalar1_t &u3, scalar1_t &u4, 59 | scalar2_t &v1, scalar2_t &v2, scalar2_t &v3, scalar2_t &v4) { 60 | // Only process u1, ..., un and v1, ..., vn. 61 | // So if `n == 3`, `u4` and `v4` need not to be considered. 62 | } 63 | 64 | In both cases, the references can actually be const, but at least one of 65 | them should be non-const in order to write the output. 66 | - (Optional, but recommended) N TensorArgType args that specify for each 67 | tensor whether `op` reads AND writes ] (i.e., TensorArgType::ReadWrite), 68 | or only reads (i.e., TensorArgType::ReadOnly). 69 | Default is TensorArgType::ReadWrite for first Tensor, and 70 | TensorArgType::ReadOnly for the rest. 71 | 72 | E.g., 73 | 74 | to compute a = b^2 for a and b of same dtype, we can call 75 | 76 | CUDA_tensor_apply2( 77 | a, b, 78 | [] __device__ (scalar &a_val, const scalar &b_val) { a_val = b_val * b_val; } 79 | ); 80 | 81 | to work on 2 values at the same time, we can call 82 | 83 | CUDA_tensor_apply2( 84 | a, b, 85 | [] __device__ (int n, scalar1 &a_val1, scalar1 &a_val2, 86 | const scalar2 &b_val1, const scalar2 &b_val2) { 87 | // call special vectorized op here, or just do elementwise and enjoy unrolling... 88 | // if n == 1, only process a_val1 and b_val1 89 | } 90 | ); 91 | */ 92 | 93 | namespace at { 94 | namespace cuda { 95 | 96 | // TODO: combine with TensorArg? So far that's been for debugging, and this is functional... 97 | enum class TensorArgType { ReadWrite, ReadOnly }; 98 | 99 | namespace { 100 | 101 | // Rearrange dimensions for pointwise operations so that strides are in 102 | // decreasing order as much as possible, so that kernels have better memory 103 | // access patterns. 104 | // 105 | // For example, consider a binary operation on two "transposed" 2-dim tensors: 106 | // sizes: 256 512 107 | // aInfo->strides: 1 256 108 | // bInfo->strides: 1 256 109 | // 110 | // Given this, each concurrent memory access inside kernelPointwiseApply2() is 111 | // exactly 256 elements apart, resulting in poor performance. 112 | // 113 | // This function exchanges dimensions so that memory access is contiguous: 114 | // sizes: 512 256 115 | // aInfo->strides: 256 1 116 | // bInfo->strides: 256 1 117 | // 118 | // (Actually, it becomes even better because now collapseDims() can turn each 119 | // input into one contiguous array.) 120 | // 121 | // In general, given M (<=4) TensorInfo's with N dimensions, we can view each 122 | // strides[i] (0 <= i < N) as an M-tuple. Given each pair i < j, we exchange 123 | // strides[i] and [j] if 124 | // (1) strides[i][k] < strides[j][k] for some k (0 <= k < M) 125 | // (exchanging them will benefit input #k), and 126 | // (2) strides[i][k] <= strieds[j][k] for all k 127 | // (exchanging them will not make any input worse). 128 | template 130 | inline void rearrangeDims(detail::TensorInfo* aInfo, 131 | detail::TensorInfo* bInfo = nullptr, 132 | detail::TensorInfo* cInfo = nullptr, 133 | detail::TensorInfo* dInfo = nullptr) { 134 | int numInfos = 1; 135 | int dims = aInfo->dims; 136 | IndexType *sizes[4] = { aInfo->sizes, }; 137 | IndexType *strides[4] = { aInfo->strides, }; 138 | 139 | if (bInfo != nullptr) { 140 | ++numInfos; 141 | if (bInfo->dims != dims) return; 142 | sizes[1] = bInfo->sizes; 143 | strides[1] = bInfo->strides; 144 | } 145 | 146 | if (cInfo != nullptr) { 147 | ++numInfos; 148 | if (cInfo->dims != dims) return; 149 | sizes[2] = cInfo->sizes; 150 | strides[2] = cInfo->strides; 151 | } 152 | 153 | if (dInfo != nullptr) { 154 | ++numInfos; 155 | if (dInfo->dims != dims) return; 156 | sizes[3] = dInfo->sizes; 157 | strides[3] = dInfo->strides; 158 | } 159 | 160 | // Bail out if sizes do not match: we are using "deprecated pointwise 161 | // behavior" among tensors of different shapes but same number of elements. 162 | for (int i = 1; i < numInfos; ++i) { 163 | for (int j = 0; j < dims; ++j) { 164 | if (sizes[i][j] != sizes[0][j]) return; 165 | } 166 | } 167 | 168 | for (int i = 0; i < dims - 1; ++i) { 169 | // No need to consider dimensions of size 1. 170 | if (sizes[0][i] == 1) continue; 171 | 172 | for (int j = i + 1; j < dims; ++j) { 173 | if (sizes[0][j] == 1) continue; 174 | 175 | // Compare the relative sizes of strides between dim #i and dim #j. 176 | bool hasIncreasingStrides = false; 177 | bool hasDecreasingStrides = false; 178 | 179 | for (int k = 0; k < numInfos; k++) { 180 | IndexType stride_i = strides[k][i]; 181 | IndexType stride_j = strides[k][j]; 182 | if (stride_i < stride_j) { 183 | hasIncreasingStrides = true; 184 | } else if (stride_i > stride_j) { 185 | hasDecreasingStrides = true; 186 | } 187 | } 188 | 189 | if (hasIncreasingStrides && !hasDecreasingStrides) { 190 | for (int k = 0; k < numInfos; k++) { 191 | IndexType size = sizes[k][i]; 192 | sizes[k][i] = sizes[k][j]; 193 | sizes[k][j] = size; 194 | 195 | IndexType stride = strides[k][i]; 196 | strides[k][i] = strides[k][j]; 197 | strides[k][j] = stride; 198 | } 199 | } 200 | } 201 | } 202 | } 203 | 204 | // Threads per block for our apply kernel 205 | // FIXME: use occupancy calculator instead 206 | constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512; 207 | constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4; 208 | 209 | // The `remaining_steps` argument is used to support Op that operates on 210 | // multiple elements at the same time. Generally, the strategy of ApplyOpN is to 211 | // 1. Initialize `remaining_steps = step`, where `step` is the template arg of 212 | // CUDA_tensor_applyN helpers. The input arg `n` to `apply()` represents the 213 | // number of elements in bound for this call. It will almost always equal to 214 | // `step` except at boundaries. 215 | // 2. If `remaining_steps > 0` convert the current linearIndex to offset (if in 216 | // bound), and recursively call `ApplyOpN` with `remaining_steps - 1`. 217 | // 3. At `remaining_steps = 0`, 218 | // if `step = 1`, call `op(tensor1_val, tensor2_val, ...)`; 219 | // if `step > 1`, call `op(n, tensor1_val1, tensor1_val2, ..., tesor1_valstep, 220 | // tensor2_val1, tensor2_val2, ..., tesor2_valstep, 221 | // ... 222 | // tensorN_val1, tensorN_val2, ..., tesorN_valstep);` 223 | // 224 | // See NOTE [ CUDA_tensor_applyN helpers ] above for how Op may look like. 225 | 226 | template 232 | struct ApplyOp1 { 233 | __device__ __forceinline__ 234 | static void apply(detail::TensorInfo &a, const Op &op, int n, 235 | IndexType linearIndex, Offsets... aOffsets) { 236 | // Convert `linearIndex` into an offset of `a` 237 | const IndexType aOffset = sizeof...(Offsets) < n ? 238 | detail::IndexToOffset::get(linearIndex, a) : 0; 239 | 240 | ApplyOp1::apply( 241 | a, op, n, linearIndex + 1, aOffsets..., aOffset 242 | ); 243 | } 244 | }; 245 | 246 | // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). 247 | // We don't need to pass in how many elements need to processed in this case. 248 | template 253 | struct ApplyOp1 { 254 | __device__ __forceinline__ 255 | static void apply(detail::TensorInfo &a, const Op &op, 256 | int n, IndexType linearIndex, Offset offset) { 257 | op(a.data[offset]); 258 | } 259 | }; 260 | 261 | template 266 | struct ApplyOp1 { 267 | __device__ __forceinline__ 268 | static void apply(detail::TensorInfo &a, const Op &op, int n, 269 | IndexType linearIndex, Offsets... offsets) { 270 | op(n, a.data[offsets]...); 271 | } 272 | }; 273 | 274 | template 279 | #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ 280 | C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) 281 | #endif 282 | __global__ void kernelPointwiseApply1(detail::TensorInfo a, 283 | IndexType totalElements, const Op op) { 284 | for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; 285 | linearIndex < totalElements; 286 | linearIndex += gridDim.x * blockDim.x * step) { 287 | ApplyOp1::apply( 288 | a, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex); 289 | } 290 | } 291 | 292 | 293 | template 301 | struct ApplyOp2 { 302 | __device__ __forceinline__ 303 | static void apply(detail::TensorInfo &a, 304 | detail::TensorInfo &b, 305 | const Op &op, int n, IndexType linearIndex, 306 | Offsets... aOffsets, Offsets... bOffsets) { 307 | // Convert `linearIndex` into an offset of `a` 308 | const IndexType aOffset = sizeof...(Offsets) < n ? 309 | detail::IndexToOffset::get(linearIndex, a) : 0; 310 | 311 | // Convert `linearIndex` into an offset of `b` 312 | const IndexType bOffset = sizeof...(Offsets) < n ? 313 | detail::IndexToOffset::get(linearIndex, b) : 0; 314 | 315 | ApplyOp2::apply( 316 | a, b, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset 317 | ); 318 | } 319 | }; 320 | 321 | // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). 322 | // We don't need to pass in how many elements need to processed in this case. 323 | template 330 | struct ApplyOp2 { 331 | __device__ __forceinline__ 332 | static void apply(detail::TensorInfo &a, 333 | detail::TensorInfo &b, 334 | const Op &op, int n, IndexType linearIndex, 335 | Offset aOffset, Offset bOffset) { 336 | op(a.data[aOffset], b.data[bOffset]); 337 | } 338 | }; 339 | 340 | template 347 | struct ApplyOp2 { 348 | __device__ __forceinline__ 349 | static void apply(detail::TensorInfo &a, 350 | detail::TensorInfo &b, 351 | const Op &op, int n, IndexType linearIndex, 352 | Offsets... aOffsets, Offsets... bOffsets) { 353 | op(n, a.data[aOffsets]..., b.data[bOffsets]...); 354 | } 355 | }; 356 | 357 | template 363 | #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ 364 | C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) 365 | #endif 366 | __global__ void 367 | kernelPointwiseApply2(detail::TensorInfo a, 368 | detail::TensorInfo b, 369 | IndexType totalElements, 370 | const Op op) { 371 | for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; 372 | linearIndex < totalElements; 373 | linearIndex += gridDim.x * blockDim.x * step) { 374 | ApplyOp2::apply( 375 | a, b, op, ::min(step, static_cast(totalElements - linearIndex)), 376 | linearIndex); 377 | } 378 | } 379 | 380 | 381 | template 391 | struct ApplyOp3 { 392 | __device__ __forceinline__ 393 | static void apply(detail::TensorInfo &a, 394 | detail::TensorInfo &b, 395 | detail::TensorInfo &c, 396 | const Op &op, int n, IndexType linearIndex, 397 | Offsets... aOffsets, Offsets... bOffsets, 398 | Offsets... cOffsets) { 399 | // Convert `linearIndex` into an offset of `a` 400 | const IndexType aOffset = sizeof...(Offsets) < n ? 401 | detail::IndexToOffset::get(linearIndex, a) : 0; 402 | 403 | // Convert `linearIndex` into an offset of `b` 404 | const IndexType bOffset = sizeof...(Offsets) < n ? 405 | detail::IndexToOffset::get(linearIndex, b) : 0; 406 | 407 | // Convert `linearIndex` into an offset of `c` 408 | const IndexType cOffset = sizeof...(Offsets) < n ? 409 | detail::IndexToOffset::get(linearIndex, c) : 0; 410 | 411 | ApplyOp3::apply( 413 | a, b, c, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset, 414 | cOffsets..., cOffset 415 | ); 416 | } 417 | }; 418 | 419 | // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). 420 | // We don't need to pass in how many elements need to processed in this case. 421 | template 430 | struct ApplyOp3 { 432 | __device__ __forceinline__ 433 | static void apply(detail::TensorInfo &a, 434 | detail::TensorInfo &b, 435 | detail::TensorInfo &c, 436 | const Op &op, int n, IndexType linearIndex, 437 | Offset aOffset, Offset bOffset, Offset cOffset) { 438 | op(a.data[aOffset], b.data[bOffset], c.data[cOffset]); 439 | } 440 | }; 441 | 442 | template 451 | struct ApplyOp3 { 453 | __device__ __forceinline__ 454 | static void apply(detail::TensorInfo &a, 455 | detail::TensorInfo &b, 456 | detail::TensorInfo &c, 457 | const Op &op, int n, IndexType linearIndex, 458 | Offsets... aOffsets, Offsets... bOffsets, 459 | Offsets... cOffsets) { 460 | op(n, a.data[aOffsets]..., b.data[bOffsets]..., c.data[cOffsets]...); 461 | } 462 | }; 463 | 464 | 465 | template 472 | #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ 473 | C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) 474 | #endif 475 | __global__ void 476 | kernelPointwiseApply3(detail::TensorInfo a, 477 | detail::TensorInfo b, 478 | detail::TensorInfo c, 479 | IndexType totalElements, 480 | const Op op) { 481 | for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; 482 | linearIndex < totalElements; 483 | linearIndex += gridDim.x * blockDim.x * step) { 484 | ApplyOp3::apply( 485 | a, b, c, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex); 486 | } 487 | } 488 | 489 | 490 | template 502 | struct ApplyOp4 { 503 | __device__ __forceinline__ 504 | static void apply(detail::TensorInfo &a, 505 | detail::TensorInfo &b, 506 | detail::TensorInfo &c, 507 | detail::TensorInfo &d, 508 | const Op &op, int n, IndexType linearIndex, 509 | Offsets... aOffsets, Offsets... bOffsets, 510 | Offsets... cOffsets, Offsets... dOffsets) { 511 | // Convert `linearIndex` into an offset of `a` 512 | const IndexType aOffset = sizeof...(Offsets) < n ? 513 | detail::IndexToOffset::get(linearIndex, a) : 0; 514 | 515 | // Convert `linearIndex` into an offset of `b` 516 | const IndexType bOffset = sizeof...(Offsets) < n ? 517 | detail::IndexToOffset::get(linearIndex, b) : 0; 518 | 519 | // Convert `linearIndex` into an offset of `c` 520 | const IndexType cOffset = sizeof...(Offsets) < n ? 521 | detail::IndexToOffset::get(linearIndex, c) : 0; 522 | 523 | // Convert `linearIndex` into an offset of `d` 524 | const IndexType dOffset = sizeof...(Offsets) < n ? 525 | detail::IndexToOffset::get(linearIndex, d) : 0; 526 | 527 | ApplyOp4::apply( 529 | a, b, c, d, op, n, linearIndex + 1, aOffsets..., aOffset, bOffsets..., bOffset, 530 | cOffsets..., cOffset, dOffsets..., dOffset 531 | ); 532 | } 533 | }; 534 | 535 | // Specialize `step=1` case (i.e., `remaining_steps=0` and `len(Offsets)=1`). 536 | // We don't need to pass in how many elements need to processed in this case. 537 | template 548 | struct ApplyOp4 { 550 | __device__ __forceinline__ 551 | static void apply(detail::TensorInfo &a, 552 | detail::TensorInfo &b, 553 | detail::TensorInfo &c, 554 | detail::TensorInfo &d, 555 | const Op &op, int n, IndexType linearIndex, 556 | Offset aOffset, Offset bOffset, 557 | Offset cOffset, Offset dOffset) { 558 | op(a.data[aOffset], b.data[bOffset], c.data[cOffset], d.data[dOffset]); 559 | } 560 | }; 561 | 562 | template 573 | struct ApplyOp4 { 575 | __device__ __forceinline__ 576 | static void apply(detail::TensorInfo &a, 577 | detail::TensorInfo &b, 578 | detail::TensorInfo &c, 579 | detail::TensorInfo &d, 580 | const Op &op, int n, IndexType linearIndex, 581 | Offsets... aOffsets, Offsets... bOffsets, 582 | Offsets... cOffsets, Offsets... dOffsets) { 583 | op(n, a.data[aOffsets]..., b.data[bOffsets]..., c.data[cOffsets]..., d.data[dOffsets]...); 584 | } 585 | }; 586 | 587 | template 595 | #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ 596 | C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) 597 | #endif 598 | __global__ void 599 | kernelPointwiseApply4(detail::TensorInfo a, 600 | detail::TensorInfo b, 601 | detail::TensorInfo c, 602 | detail::TensorInfo d, 603 | IndexType totalElements, 604 | const Op op) { 605 | for (IndexType linearIndex = (blockIdx.x * blockDim.x + threadIdx.x) * step; 606 | linearIndex < totalElements; 607 | linearIndex += gridDim.x * blockDim.x * step) { 608 | ApplyOp4::apply( 610 | a, b, c, d, op, ::min(step, static_cast(totalElements - linearIndex)), linearIndex); 611 | } 612 | } 613 | 614 | } // namespace 615 | 616 | /** 617 | Computes ceil(a / b) 618 | */ 619 | template 620 | __host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) { 621 | return (a + b - 1) / b; 622 | } 623 | 624 | template 625 | inline bool getApplyGrid(uint64_t totalElements, dim3& grid, int64_t curDevice) { 626 | if (curDevice == -1) return false; 627 | uint64_t numel_per_thread = static_cast(AT_APPLY_THREADS_PER_BLOCK) * static_cast(step); 628 | uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread); 629 | uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0]; 630 | if (numBlocks > maxGridX) 631 | numBlocks = maxGridX; 632 | grid = dim3(numBlocks); 633 | return true; 634 | } 635 | 636 | inline dim3 getApplyBlock() { 637 | return dim3(AT_APPLY_THREADS_PER_BLOCK); 638 | } 639 | 640 | 641 | template 642 | inline bool CUDA_tensor_apply1(at::Tensor a, 643 | const Op op, 644 | TensorArgType aType = TensorArgType::ReadWrite) { 645 | checkBackend("CUDA_tensor_apply1", {a}, Backend::CUDA); 646 | auto dim = a.dim(); 647 | 648 | /* 649 | Since this is a unary op, we can easily first check for expanded dimensions 650 | (with stride 0), and remove them, to avoid calling .contiguous() in such 651 | case when detail::maybeOverlappingIndices(a) returns true. 652 | */ 653 | std::vector collapsed_shape; 654 | std::vector collapsed_strides; 655 | collapsed_shape.reserve(dim); 656 | collapsed_strides.reserve(dim); 657 | for (int64_t i = 0; i < dim; i++) { 658 | if (a.stride(i) != 0) { 659 | collapsed_shape.push_back(a.size(i)); 660 | collapsed_strides.push_back(a.stride(i)); 661 | } 662 | } 663 | if (collapsed_shape.size() != dim) { 664 | a = a.as_strided(collapsed_shape, collapsed_strides); 665 | } 666 | 667 | int64_t totalElements = a.numel(); 668 | 669 | if (dim > MAX_TENSORINFO_DIMS) { 670 | return false; 671 | } 672 | 673 | if (totalElements == 0) { 674 | // Empty tensor; do nothing 675 | return true; 676 | } 677 | const dim3 block = getApplyBlock(); 678 | 679 | dim3 grid; 680 | int64_t curDevice = current_device(); 681 | if (curDevice == -1) return false; 682 | if (!getApplyGrid(totalElements, grid, curDevice)) { 683 | return false; 684 | } 685 | 686 | /* 687 | Expands readable/writable tensors whose indices may be "overlapped." 688 | This ensures that each element of the tensor is operated on once and only 689 | once. 690 | */ 691 | Tensor oldA; 692 | 693 | if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { 694 | // Must perform in contiguous space 695 | oldA = a; 696 | a = a.contiguous(); 697 | } 698 | 699 | // It is possible that the tensor dimensions are able to be collapsed, 700 | // and thus we can reduce the actual code complexity of the copy by 701 | // exploiting this knowledge statically, since the div/mod is the 702 | // most expensive part of the operation, more so than memory accesses. 703 | // For instance, when copying a non-contiguous to a contiguous tensor 704 | // (or vice versa), the contiguous tensor can be collapsed to one 705 | // dimension, and the loop to translate the linear index to the array 706 | // index can be similarly collapsed. That is what this unrolling is for. 707 | 708 | #define HANDLE_CASE(TYPE, A) \ 709 | kernelPointwiseApply1 \ 712 | <<>>( \ 713 | aInfo, static_cast(totalElements), op); 714 | 715 | #define HANDLE_A_CASE(TYPE, A) { \ 716 | switch (A) { \ 717 | case 1: \ 718 | HANDLE_CASE(TYPE, 1); \ 719 | break; \ 720 | case 2: \ 721 | HANDLE_CASE(TYPE, 2); \ 722 | break; \ 723 | default: \ 724 | HANDLE_CASE(TYPE, -1); \ 725 | break; \ 726 | } \ 727 | } 728 | 729 | if (detail::canUse32BitIndexMath(a)) { 730 | detail::TensorInfo aInfo = 731 | detail::getTensorInfo(a); 732 | 733 | rearrangeDims(&aInfo); 734 | aInfo.collapseDims(); 735 | 736 | HANDLE_A_CASE(unsigned int, aInfo.dims); 737 | } else { 738 | detail::TensorInfo aInfo = 739 | detail::getTensorInfo(a); 740 | 741 | rearrangeDims(&aInfo); 742 | aInfo.collapseDims(); 743 | 744 | /* 745 | Only instantiates the all 1D special case and the fallback all nD case for 746 | large (64-bit indexed) tensors to reduce compilation time. 747 | */ 748 | if (aInfo.dims == 1) { 749 | HANDLE_CASE(uint64_t, 1); 750 | } else { 751 | HANDLE_CASE(uint64_t, -1); 752 | } 753 | } 754 | #undef HANDLE_CASE 755 | #undef HANDLE_A_CASE 756 | 757 | if (oldA.defined()) { 758 | // TB: No need for _th_copy_ignoring_overlaps_ 759 | oldA.copy_(a); 760 | } 761 | 762 | return true; 763 | } 764 | 765 | /* Provides default step = 1 to CUDA_tensor_apply1. */ 766 | template 767 | inline bool CUDA_tensor_apply1(at::Tensor a, 768 | const Op op, 769 | TensorArgType aType = TensorArgType::ReadWrite) { 770 | return CUDA_tensor_apply1(a, op, aType); 771 | } 772 | 773 | 774 | template 775 | inline bool CUDA_tensor_apply2(at::Tensor a, 776 | at::Tensor b, 777 | const Op op, 778 | TensorArgType aType = TensorArgType::ReadWrite, 779 | TensorArgType bType = TensorArgType::ReadOnly) { 780 | checkBackend("CUDA_tensor_apply2", {a, b}, Backend::CUDA); 781 | int64_t totalElements = a.numel(); 782 | 783 | if (totalElements != b.numel()) { 784 | return false; 785 | } 786 | 787 | if (a.dim() > MAX_TENSORINFO_DIMS || 788 | b.dim() > MAX_TENSORINFO_DIMS) { 789 | return false; 790 | } 791 | 792 | if (a.numel() == 0) { 793 | // Empty tensor; do nothing 794 | return true; 795 | } 796 | const dim3 block = getApplyBlock(); 797 | 798 | dim3 grid; 799 | int64_t curDevice = current_device(); 800 | if (curDevice == -1) return false; 801 | if (!getApplyGrid(totalElements, grid, curDevice)) { 802 | return false; 803 | } 804 | 805 | /* 806 | Expands readable/writable tensors whose indices may be "overlapped." 807 | This ensures that each element of the tensor is operated on once and only 808 | once. 809 | */ 810 | Tensor oldA; 811 | Tensor oldB; 812 | 813 | if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { 814 | // Must perform in contiguous space 815 | oldA = a; 816 | a = a.contiguous(); 817 | } 818 | if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { 819 | // Must perform in contiguous space 820 | oldB = b; 821 | b = b.contiguous(); 822 | } 823 | 824 | // It is possible that the tensor dimensions are able to be collapsed, 825 | // and thus we can reduce the actual code complexity of the copy by 826 | // exploiting this knowledge statically, since the div/mod is the 827 | // most expensive part of the operation, more so than memory accesses. 828 | // For instance, when copying a non-contiguous to a contiguous tensor 829 | // (or vice versa), the contiguous tensor can be collapsed to one 830 | // dimension, and the loop to translate the linear index to the array 831 | // index can be similarly collapsed. That is what this unrolling is for. 832 | 833 | #define HANDLE_CASE(TYPE, A, B) \ 834 | kernelPointwiseApply2 \ 838 | <<>>( \ 839 | aInfo, bInfo, static_cast(totalElements), op); 840 | 841 | #define HANDLE_B_CASE(TYPE, A, B) { \ 842 | switch (B) { \ 843 | case 1: \ 844 | HANDLE_CASE(TYPE, A, 1); \ 845 | break; \ 846 | case 2: \ 847 | HANDLE_CASE(TYPE, A, 2); \ 848 | break; \ 849 | default: \ 850 | HANDLE_CASE(TYPE, A, -1); \ 851 | break; \ 852 | } \ 853 | } 854 | 855 | #define HANDLE_A_CASE(TYPE, A, B) { \ 856 | switch (A) { \ 857 | case 1: \ 858 | HANDLE_B_CASE(TYPE, 1, B); \ 859 | break; \ 860 | case 2: \ 861 | HANDLE_B_CASE(TYPE, 2, B); \ 862 | break; \ 863 | default: \ 864 | HANDLE_B_CASE(TYPE, -1, B); \ 865 | break; \ 866 | } \ 867 | } 868 | 869 | if (detail::canUse32BitIndexMath(a) && 870 | detail::canUse32BitIndexMath(b)) { 871 | detail::TensorInfo aInfo = 872 | detail::getTensorInfo(a); 873 | 874 | detail::TensorInfo bInfo = 875 | detail::getTensorInfo(b); 876 | rearrangeDims(&aInfo, &bInfo); 877 | aInfo.collapseDims(); 878 | bInfo.collapseDims(); 879 | 880 | HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); 881 | } else { 882 | detail::TensorInfo aInfo = 883 | detail::getTensorInfo(a); 884 | 885 | detail::TensorInfo bInfo = 886 | detail::getTensorInfo(b); 887 | rearrangeDims(&aInfo, &bInfo); 888 | aInfo.collapseDims(); 889 | bInfo.collapseDims(); 890 | 891 | /* 892 | Only instantiates the all 1D special case and the fallback all nD case for 893 | large (64-bit indexed) tensors to reduce compilation time. 894 | */ 895 | if (aInfo.dims == 1 && bInfo.dims == 1) { 896 | HANDLE_CASE(uint64_t, 1, 1); 897 | } else { 898 | HANDLE_CASE(uint64_t, -1, -1); 899 | } 900 | } 901 | #undef HANDLE_CASE 902 | #undef HANDLE_B_CASE 903 | #undef HANDLE_A_CASE 904 | 905 | if (oldA.defined()) { 906 | // TB: No need for _th_copy_ignoring_overlaps_ 907 | oldA.copy_(a); 908 | } 909 | 910 | if (oldB.defined()) { 911 | // TB: No need for _th_copy_ignoring_overlaps_ 912 | oldB.copy_(b); 913 | } 914 | 915 | return true; 916 | } 917 | 918 | /* Provides default step = 1 to CUDA_tensor_apply2. */ 919 | template 920 | inline bool CUDA_tensor_apply2(at::Tensor a, 921 | at::Tensor b, 922 | const Op op, 923 | TensorArgType aType = TensorArgType::ReadWrite, 924 | TensorArgType bType = TensorArgType::ReadOnly) { 925 | return CUDA_tensor_apply2(a, b, op, aType, bType); 926 | } 927 | 928 | 929 | template 930 | inline bool CUDA_tensor_apply3(at::Tensor a, 931 | at::Tensor b, 932 | at::Tensor c, 933 | const Op op, 934 | TensorArgType aType = TensorArgType::ReadWrite, 935 | TensorArgType bType = TensorArgType::ReadOnly, 936 | TensorArgType cType = TensorArgType::ReadOnly) { 937 | checkBackend("CUDA_tensor_apply3", {a, b, c}, Backend::CUDA); 938 | int64_t totalElements = a.numel(); 939 | 940 | if (totalElements != b.numel() || 941 | totalElements != c.numel()) { 942 | return false; 943 | } 944 | 945 | if (a.dim() > MAX_TENSORINFO_DIMS || 946 | b.dim() > MAX_TENSORINFO_DIMS || 947 | c.dim() > MAX_TENSORINFO_DIMS) { 948 | return false; 949 | } 950 | 951 | if (a.numel() == 0) { 952 | // Empty tensor; do nothing 953 | return true; 954 | } 955 | 956 | const dim3 block = getApplyBlock(); 957 | 958 | dim3 grid; 959 | int64_t curDevice = current_device(); 960 | if (curDevice == -1) return false; 961 | if (!getApplyGrid(totalElements, grid, curDevice)) { 962 | return false; 963 | } 964 | 965 | /* 966 | Expands readable/writable tensors whose indices may be "overlapped." 967 | This ensures that each element of the tensor is operated on once and only 968 | once. 969 | */ 970 | Tensor oldA; 971 | Tensor oldB; 972 | Tensor oldC; 973 | 974 | if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { 975 | // Must perform in contiguous space 976 | oldA = a; 977 | a = a.contiguous(); 978 | } 979 | if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { 980 | // Must perform in contiguous space 981 | oldB = b; 982 | b = b.contiguous(); 983 | } 984 | if (cType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(c)) { 985 | // Must perform in contiguous space 986 | oldC = c; 987 | c = c.contiguous(); 988 | } 989 | 990 | #define HANDLE_CASE(TYPE, A, B, C) \ 991 | kernelPointwiseApply3 \ 996 | <<>>( \ 997 | aInfo, bInfo, cInfo, static_cast(totalElements), op); 998 | 999 | #define HANDLE_C_CASE(TYPE, A, B, C) { \ 1000 | switch (C) { \ 1001 | case 1: \ 1002 | HANDLE_CASE(TYPE, A, B, 1); \ 1003 | break; \ 1004 | case 2: \ 1005 | HANDLE_CASE(TYPE, A, B, 2); \ 1006 | break; \ 1007 | default: \ 1008 | HANDLE_CASE(TYPE, A, B, -1); \ 1009 | break; \ 1010 | } \ 1011 | } 1012 | 1013 | #define HANDLE_B_CASE(TYPE, A, B, C) { \ 1014 | switch (B) { \ 1015 | case 1: \ 1016 | HANDLE_C_CASE(TYPE, A, 1, C); \ 1017 | break; \ 1018 | case 2: \ 1019 | HANDLE_C_CASE(TYPE, A, 2, C); \ 1020 | break; \ 1021 | default: \ 1022 | HANDLE_C_CASE(TYPE, A, -1, C); \ 1023 | break; \ 1024 | } \ 1025 | } 1026 | 1027 | #define HANDLE_A_CASE(TYPE, A, B, C) { \ 1028 | switch (A) { \ 1029 | case 1: \ 1030 | HANDLE_B_CASE(TYPE, 1, B, C); \ 1031 | break; \ 1032 | case 2: \ 1033 | HANDLE_B_CASE(TYPE, 2, B, C); \ 1034 | break; \ 1035 | default: \ 1036 | HANDLE_B_CASE(TYPE, -1, B, C); \ 1037 | break; \ 1038 | } \ 1039 | } 1040 | 1041 | if (detail::canUse32BitIndexMath(a) && 1042 | detail::canUse32BitIndexMath(b) && 1043 | detail::canUse32BitIndexMath(c)) { 1044 | detail::TensorInfo aInfo = 1045 | detail::getTensorInfo(a); 1046 | 1047 | detail::TensorInfo bInfo = 1048 | detail::getTensorInfo(b); 1049 | 1050 | detail::TensorInfo cInfo = 1051 | detail::getTensorInfo(c); 1052 | 1053 | rearrangeDims(&aInfo, &bInfo, &cInfo); 1054 | aInfo.collapseDims(); 1055 | bInfo.collapseDims(); 1056 | cInfo.collapseDims(); 1057 | 1058 | HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims); 1059 | } else { 1060 | detail::TensorInfo aInfo = 1061 | detail::getTensorInfo(a); 1062 | 1063 | detail::TensorInfo bInfo = 1064 | detail::getTensorInfo(b); 1065 | 1066 | detail::TensorInfo cInfo = 1067 | detail::getTensorInfo(c); 1068 | 1069 | rearrangeDims(&aInfo, &bInfo, &cInfo); 1070 | aInfo.collapseDims(); 1071 | bInfo.collapseDims(); 1072 | cInfo.collapseDims(); 1073 | 1074 | /* 1075 | Only instantiates the all 1D special case and the fallback all nD case for 1076 | large (64-bit indexed) tensors to reduce compilation time. 1077 | */ 1078 | if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1) { 1079 | HANDLE_CASE(uint64_t, 1, 1, 1); 1080 | } else { 1081 | HANDLE_CASE(uint64_t, -1, -1, -1); 1082 | } 1083 | } 1084 | #undef HANDLE_CASE 1085 | #undef HANDLE_C_CASE 1086 | #undef HANDLE_B_CASE 1087 | #undef HANDLE_A_CASE 1088 | 1089 | if (oldA.defined()) { 1090 | // TB: No need for _th_copy_ignoring_overlaps_ 1091 | oldA.copy_(a); 1092 | } 1093 | 1094 | if (oldB.defined()) { 1095 | // TB: No need for _th_copy_ignoring_overlaps_ 1096 | oldB.copy_(b); 1097 | } 1098 | 1099 | if (oldC.defined()) { 1100 | // TB: No need for _th_copy_ignoring_overlaps_ 1101 | oldC.copy_(c); 1102 | } 1103 | 1104 | return true; 1105 | } 1106 | 1107 | /* Provides default step = 1 to CUDA_tensor_apply3. */ 1108 | template 1109 | inline bool CUDA_tensor_apply3(at::Tensor a, 1110 | at::Tensor b, 1111 | at::Tensor c, 1112 | const Op op, 1113 | TensorArgType aType = TensorArgType::ReadWrite, 1114 | TensorArgType bType = TensorArgType::ReadOnly, 1115 | TensorArgType cType = TensorArgType::ReadOnly) { 1116 | return CUDA_tensor_apply3( 1117 | a, b, c, op, aType, bType, cType); 1118 | } 1119 | 1120 | 1121 | template 1123 | inline bool CUDA_tensor_apply4(at::Tensor a, 1124 | at::Tensor b, 1125 | at::Tensor c, 1126 | at::Tensor d, 1127 | const Op op, 1128 | TensorArgType aType = TensorArgType::ReadWrite, 1129 | TensorArgType bType = TensorArgType::ReadOnly, 1130 | TensorArgType cType = TensorArgType::ReadOnly, 1131 | TensorArgType dType = TensorArgType::ReadOnly) { 1132 | checkBackend("CUDA_tensor_apply4", {a, b, c, d}, Backend::CUDA); 1133 | int64_t totalElements = a.numel(); 1134 | 1135 | if (totalElements != b.numel() || 1136 | totalElements != c.numel() || 1137 | totalElements != d.numel()) { 1138 | return false; 1139 | } 1140 | 1141 | if (a.dim() > MAX_TENSORINFO_DIMS || 1142 | b.dim() > MAX_TENSORINFO_DIMS || 1143 | c.dim() > MAX_TENSORINFO_DIMS || 1144 | d.dim() > MAX_TENSORINFO_DIMS) { 1145 | return false; 1146 | } 1147 | 1148 | if (a.numel() == 0) { 1149 | // Empty tensor; do nothing 1150 | return true; 1151 | } 1152 | 1153 | const dim3 block = getApplyBlock(); 1154 | 1155 | dim3 grid; 1156 | int64_t curDevice = current_device(); 1157 | if (curDevice == -1) return false; 1158 | if (!getApplyGrid(totalElements, grid, curDevice)) { 1159 | return false; 1160 | } 1161 | 1162 | /* 1163 | Expands readable/writable tensors whose indices may be "overlapped." 1164 | This ensures that each element of the tensor is operated on once and only 1165 | once. 1166 | */ 1167 | Tensor oldA; 1168 | Tensor oldB; 1169 | Tensor oldC; 1170 | Tensor oldD; 1171 | 1172 | if (aType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(a)) { 1173 | // Must perform in contiguous space 1174 | oldA = a; 1175 | a = a.contiguous(); 1176 | } 1177 | if (bType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(b)) { 1178 | // Must perform in contiguous space 1179 | oldB = b; 1180 | b = b.contiguous(); 1181 | } 1182 | if (cType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(c)) { 1183 | // Must perform in contiguous space 1184 | oldC = c; 1185 | c = c.contiguous(); 1186 | } 1187 | if (dType == TensorArgType::ReadWrite && detail::maybeOverlappingIndices(c)) { 1188 | // Must perform in contiguous space 1189 | oldD = d; 1190 | d = d.contiguous(); 1191 | } 1192 | 1193 | #define HANDLE_CASE(TYPE, A, B, C, D) \ 1194 | kernelPointwiseApply4 \ 1200 | <<>>( \ 1201 | aInfo, bInfo, cInfo, dInfo, static_cast(totalElements), op); 1202 | 1203 | #define HANDLE_D_CASE(TYPE, A, B, C, D) { \ 1204 | switch (D) { \ 1205 | case 1: \ 1206 | HANDLE_CASE(TYPE, A, B, C, 1); \ 1207 | break; \ 1208 | case 2: \ 1209 | HANDLE_CASE(TYPE, A, B, C, 2); \ 1210 | break; \ 1211 | default: \ 1212 | HANDLE_CASE(TYPE, A, B, C, -1); \ 1213 | break; \ 1214 | } \ 1215 | } 1216 | 1217 | #define HANDLE_C_CASE(TYPE, A, B, C, D) { \ 1218 | switch (C) { \ 1219 | case 1: \ 1220 | HANDLE_D_CASE(TYPE, A, B, 1, D); \ 1221 | break; \ 1222 | case 2: \ 1223 | HANDLE_D_CASE(TYPE, A, B, 2, D); \ 1224 | break; \ 1225 | default: \ 1226 | HANDLE_D_CASE(TYPE, A, B, -1, D); \ 1227 | break; \ 1228 | } \ 1229 | } 1230 | 1231 | #define HANDLE_B_CASE(TYPE, A, B, C, D) { \ 1232 | switch (B) { \ 1233 | case 1: \ 1234 | HANDLE_C_CASE(TYPE, A, 1, C, D); \ 1235 | break; \ 1236 | case 2: \ 1237 | HANDLE_C_CASE(TYPE, A, 2, C, D); \ 1238 | break; \ 1239 | default: \ 1240 | HANDLE_C_CASE(TYPE, A, -1, C, D); \ 1241 | break; \ 1242 | } \ 1243 | } 1244 | 1245 | #define HANDLE_A_CASE(TYPE, A, B, C, D) { \ 1246 | switch (A) { \ 1247 | case 1: \ 1248 | HANDLE_B_CASE(TYPE, 1, B, C, D); \ 1249 | break; \ 1250 | case 2: \ 1251 | HANDLE_B_CASE(TYPE, 2, B, C, D); \ 1252 | break; \ 1253 | default: \ 1254 | HANDLE_B_CASE(TYPE, -1, B, C, D); \ 1255 | break; \ 1256 | } \ 1257 | } 1258 | 1259 | if (detail::canUse32BitIndexMath(a) && 1260 | detail::canUse32BitIndexMath(b) && 1261 | detail::canUse32BitIndexMath(c) && 1262 | detail::canUse32BitIndexMath(d)) { 1263 | detail::TensorInfo aInfo = 1264 | detail::getTensorInfo(a); 1265 | 1266 | detail::TensorInfo bInfo = 1267 | detail::getTensorInfo(b); 1268 | 1269 | detail::TensorInfo cInfo = 1270 | detail::getTensorInfo(c); 1271 | 1272 | detail::TensorInfo dInfo = 1273 | detail::getTensorInfo(d); 1274 | 1275 | rearrangeDims(&aInfo, &bInfo, &cInfo, &dInfo); 1276 | aInfo.collapseDims(); 1277 | bInfo.collapseDims(); 1278 | cInfo.collapseDims(); 1279 | dInfo.collapseDims(); 1280 | 1281 | HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims, dInfo.dims); 1282 | } else { 1283 | detail::TensorInfo aInfo = 1284 | detail::getTensorInfo(a); 1285 | 1286 | detail::TensorInfo bInfo = 1287 | detail::getTensorInfo(b); 1288 | 1289 | detail::TensorInfo cInfo = 1290 | detail::getTensorInfo(c); 1291 | 1292 | detail::TensorInfo dInfo = 1293 | detail::getTensorInfo(d); 1294 | 1295 | rearrangeDims(&aInfo, &bInfo, &cInfo, &dInfo); 1296 | aInfo.collapseDims(); 1297 | bInfo.collapseDims(); 1298 | cInfo.collapseDims(); 1299 | dInfo.collapseDims(); 1300 | 1301 | /* 1302 | Only instantiates the all 1D special case and the fallback all nD case for 1303 | large (64-bit indexed) tensors to reduce compilation time. 1304 | */ 1305 | if (aInfo.dims == 1 && bInfo.dims == 1 && cInfo.dims == 1 && dInfo.dims == 1) { 1306 | HANDLE_CASE(uint64_t, 1, 1, 1, 1); 1307 | } else { 1308 | HANDLE_CASE(uint64_t, -1, -1, -1, -1); 1309 | } 1310 | } 1311 | #undef HANDLE_CASE 1312 | #undef HANDLE_D_CASE 1313 | #undef HANDLE_C_CASE 1314 | #undef HANDLE_B_CASE 1315 | #undef HANDLE_A_CASE 1316 | 1317 | if (oldA.defined()) { 1318 | // TB: No need for _th_copy_ignoring_overlaps_ 1319 | oldA.copy_(a); 1320 | } 1321 | 1322 | if (oldB.defined()) { 1323 | // TB: No need for _th_copy_ignoring_overlaps_ 1324 | oldB.copy_(b); 1325 | } 1326 | 1327 | if (oldC.defined()) { 1328 | // TB: No need for _th_copy_ignoring_overlaps_ 1329 | oldC.copy_(c); 1330 | } 1331 | 1332 | if (oldD.defined()) { 1333 | // TB: No need for _th_copy_ignoring_overlaps_ 1334 | oldD.copy_(d); 1335 | } 1336 | 1337 | return true; 1338 | } 1339 | 1340 | /* Provides default step = 1 to CUDA_tensor_apply4. */ 1341 | template 1343 | inline bool CUDA_tensor_apply4(at::Tensor a, 1344 | at::Tensor b, 1345 | at::Tensor c, 1346 | at::Tensor d, 1347 | const Op op, 1348 | TensorArgType aType = TensorArgType::ReadWrite, 1349 | TensorArgType bType = TensorArgType::ReadOnly, 1350 | TensorArgType cType = TensorArgType::ReadOnly, 1351 | TensorArgType dType = TensorArgType::ReadOnly) { 1352 | return CUDA_tensor_apply4( 1353 | a, b, c, d, op, aType, bType, cType); 1354 | } 1355 | 1356 | } // cuda 1357 | } // at 1358 | -------------------------------------------------------------------------------- /extra/Comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "Collapsed": "false" 7 | }, 8 | "source": [ 9 | "# Swish Implementation Comparison" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "Collapsed": "false" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# Minimal fork of https://github.com/rwightman/gen-efficientnet-pytorch\n", 21 | "# Adds setup and lets you set the activation function\n", 22 | "# Note changes on setup branch\n", 23 | "# !pip install git+https://github.com/thomasbrandon/gen-efficientnet-pytorch@setup" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "metadata": { 30 | "Collapsed": "false" 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "from fastai.vision import *\n", 35 | "from gen_efficientnet.gen_efficientnet import efficientnet_b0, model_urls\n", 36 | "import swish_torch" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": { 43 | "Collapsed": "false" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "SIZE = 256 # Resize crop to 256x256\n", 48 | "BS = 48 # Could probably be a little higher for CUDA/Function but will use same for all\n", 49 | "LR=1e-3" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "Collapsed": "false" 56 | }, 57 | "source": [ 58 | "## Setup" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": { 65 | "Collapsed": "false" 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "PATH = untar_data(URLs.IMAGEWOOF_320)\n", 70 | "data = (ImageList\n", 71 | " .from_folder(PATH)\n", 72 | " .split_by_folder(valid='val')\n", 73 | " .label_from_folder()\n", 74 | " .transform(([flip_lr(p=0.5)], []), size=SIZE)\n", 75 | " .databunch(bs=BS, num_workers=6)\n", 76 | " .presize(SIZE, scale=(0.35,1))\n", 77 | " .normalize(imagenet_stats))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": { 84 | "Collapsed": "false" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "class PeakMemMetric(LearnerCallback):\n", 89 | " \"Callback that measures used and peak GPU memory.\"\n", 90 | " _order=-20 # Needs to run before the recorder\n", 91 | "\n", 92 | " def __init__(self, learn:Learner, device=None):\n", 93 | " super().__init__(learn)\n", 94 | " assert torch.cuda.is_available(), \"pytorch CUDA is required\"\n", 95 | " self._dev = ifnone(device, torch.cuda.current_device())\n", 96 | "\n", 97 | " def on_train_begin(self, **kwargs):\n", 98 | " self.learn.recorder.add_metric_names(['cache MB', 'alloc MB'])\n", 99 | "\n", 100 | " def on_epoch_begin(self, **kwargs):\n", 101 | " torch.cuda.reset_max_memory_cached(self._dev)\n", 102 | " torch.cuda.reset_max_memory_allocated(self._dev)\n", 103 | " \n", 104 | " def on_epoch_end(self, last_metrics, **kwargs):\n", 105 | " b2mb = lambda num: int(num/2**20)\n", 106 | " cache = torch.cuda.max_memory_cached(self._dev)\n", 107 | " alloc = torch.cuda.max_memory_allocated(self._dev)\n", 108 | " return add_metrics(last_metrics, [b2mb(cache), b2mb(alloc)])" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 6, 114 | "metadata": { 115 | "Collapsed": "false" 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "def load_pretrained(mdl):\n", 120 | " # Load pretrained data, except for differently size linear layers\n", 121 | " state_dict = torch.utils.model_zoo.load_url(model_urls['efficientnet_b0'])\n", 122 | " for attr in ['weight','bias']: state_dict[f'classifier.{attr}'] = getattr(mdl.classifier, attr)\n", 123 | " mdl.load_state_dict(state_dict)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 7, 129 | "metadata": { 130 | "Collapsed": "false" 131 | }, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/plain": [ 136 | "ImageDataBunch;\n", 137 | "\n", 138 | "Train: LabelList (12454 items)\n", 139 | "x: ImageList\n", 140 | "Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)\n", 141 | "y: CategoryList\n", 142 | "n02111889,n02111889,n02111889,n02111889,n02111889\n", 143 | "Path: /home/user/.fastai/data/imagewoof-320;\n", 144 | "\n", 145 | "Valid: LabelList (500 items)\n", 146 | "x: ImageList\n", 147 | "Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256)\n", 148 | "y: CategoryList\n", 149 | "n02111889,n02111889,n02111889,n02111889,n02111889\n", 150 | "Path: /home/user/.fastai/data/imagewoof-320;\n", 151 | "\n", 152 | "Test: None" 153 | ] 154 | }, 155 | "execution_count": 7, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "# https://github.com/fastai/imagenette\n", 162 | "# Subset of 10 dog breeds from Imagenet, 320px shortest side\n", 163 | "data" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "Collapsed": "false" 170 | }, 171 | "source": [ 172 | "## Original Implementation" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 8, 178 | "metadata": { 179 | "Collapsed": "false" 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "mdl = efficientnet_b0(num_classes=data.c)\n", 184 | "load_pretrained(mdl)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "metadata": { 191 | "Collapsed": "false" 192 | }, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "\u001b[0;31mSignature:\u001b[0m \u001b[0mmdl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 198 | "\u001b[0;31mDocstring:\u001b[0m \n", 199 | "\u001b[0;31mSource:\u001b[0m \n", 200 | "\u001b[0;32mdef\u001b[0m \u001b[0mswish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", 201 | "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", 202 | "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", 203 | "\u001b[0;34m\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", 204 | "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 205 | "\u001b[0;31mFile:\u001b[0m ~/.conda/envs/fastai/lib/python3.7/site-packages/gen_efficientnet/efficientnet_builder.py\n", 206 | "\u001b[0;31mType:\u001b[0m function\n" 207 | ] 208 | }, 209 | "metadata": {}, 210 | "output_type": "display_data" 211 | } 212 | ], 213 | "source": [ 214 | "mdl.act_fn??" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 10, 220 | "metadata": { 221 | "Collapsed": "false" 222 | }, 223 | "outputs": [ 224 | { 225 | "data": { 226 | "text/html": [ 227 | "\n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | "
epochtrain_lossvalid_lossaccuracycache MBalloc MBtime
00.4009870.3706520.8900007204689001:12
10.4396660.3857240.8900007106687901:11
20.2985810.2746520.9100007106687901:12
30.1365970.2313830.9180007106687901:11
40.0759610.2117510.9320007106687901:11
" 287 | ], 288 | "text/plain": [ 289 | "" 290 | ] 291 | }, 292 | "metadata": {}, 293 | "output_type": "display_data" 294 | } 295 | ], 296 | "source": [ 297 | "lrn = Learner(data, mdl, callback_fns=[PeakMemMetric], metrics=[accuracy])\n", 298 | "lrn.fit_one_cycle(5, LR)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": { 305 | "Collapsed": "false" 306 | }, 307 | "outputs": [], 308 | "source": [ 309 | "lrn.destroy()\n", 310 | "del lrn, mdl" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "Collapsed": "false" 317 | }, 318 | "source": [ 319 | "## Autograd Function Implementation" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 8, 325 | "metadata": { 326 | "Collapsed": "false" 327 | }, 328 | "outputs": [ 329 | { 330 | "data": { 331 | "text/html": [ 332 | "\n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | "
epochtrain_lossvalid_lossaccuracycache MBalloc MBtime
00.4500810.5934700.8820006432542101:14
10.4369540.3684580.8800006432542101:13
20.2621580.3686610.8900006432542101:14
30.1427930.2466730.9280006432542101:14
40.0753770.2405330.9240006432542101:14
" 392 | ], 393 | "text/plain": [ 394 | "" 395 | ] 396 | }, 397 | "metadata": {}, 398 | "output_type": "display_data" 399 | } 400 | ], 401 | "source": [ 402 | "class SwishFunction(torch.autograd.Function):\n", 403 | " @staticmethod\n", 404 | " def forward(ctx, i):\n", 405 | " result = i * torch.sigmoid(i)\n", 406 | " ctx.save_for_backward(i)\n", 407 | " return result\n", 408 | "\n", 409 | " @staticmethod\n", 410 | " def backward(ctx, grad_output):\n", 411 | " i, = ctx.saved_tensors\n", 412 | " if not ctx.needs_input_grad[0]: return (None,)\n", 413 | " sigmoid_i = torch.sigmoid(i)\n", 414 | " return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n", 415 | " \n", 416 | "# Activation function for gen_efficientnet has an inplace keyword\n", 417 | "# Can't be inplace so just ignore\n", 418 | "def swish_function(x, inplace=False): return SwishFunction.apply(x)\n", 419 | "\n", 420 | "mdl = efficientnet_b0(num_classes=data.c, act_fn=swish_function)\n", 421 | "load_pretrained(mdl)\n", 422 | "lrn = Learner(data, mdl, callback_fns=[PeakMemMetric], metrics=[accuracy])\n", 423 | "lrn.fit_one_cycle(5, LR)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": { 430 | "Collapsed": "false" 431 | }, 432 | "outputs": [], 433 | "source": [ 434 | "lrn.destroy()\n", 435 | "del lrn, mdl" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "Collapsed": "false" 442 | }, 443 | "source": [ 444 | "## CUDA Implementation" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 8, 450 | "metadata": { 451 | "Collapsed": "false" 452 | }, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/html": [ 457 | "\n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | "
epochtrain_lossvalid_lossaccuracycache MBalloc MBtime
00.4447610.3947720.8740005934540001:02
10.4415380.4345010.8660005934540001:01
20.2933200.2760600.9060005934540001:02
30.1494190.2453420.9180005934540001:02
40.0616240.2584650.9180005934540001:02
" 517 | ], 518 | "text/plain": [ 519 | "" 520 | ] 521 | }, 522 | "metadata": {}, 523 | "output_type": "display_data" 524 | } 525 | ], 526 | "source": [ 527 | "# Activation function for gen_efficientnet has an inplace keyword\n", 528 | "# Can't be inplace so just ignore\n", 529 | "def swish_cuda_fn(x, inplace=False): return swish_torch.swish(x)\n", 530 | "\n", 531 | "mdl = efficientnet_b0(num_classes=data.c, act_fn=swish_cuda_fn)\n", 532 | "load_pretrained(mdl)\n", 533 | "lrn = Learner(data, mdl, callback_fns=[PeakMemMetric], metrics=[accuracy])\n", 534 | "lrn.fit_one_cycle(5, LR)" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": { 541 | "Collapsed": "false" 542 | }, 543 | "outputs": [], 544 | "source": [ 545 | "lrn.destroy()\n", 546 | "del lrn, mdl" 547 | ] 548 | }, 549 | { 550 | "cell_type": "markdown", 551 | "metadata": { 552 | "Collapsed": "false" 553 | }, 554 | "source": [ 555 | "# Results\n", 556 | "```\n", 557 | "\t\t train_loss valid_loss accuracy cache MB alloc MB time\n", 558 | "Original 0.075961 0.211751 0.932000 7106 6879 01:11\n", 559 | "Autograd 0.075377 0.240533 0.924000 6432 5421 01:14\n", 560 | "CUDA 0.061624 0.258465 0.918000 5934 5400 01:02\n", 561 | "```\n", 562 | "\n", 563 | "So the CUDA version is (slightly) faster than the original with the memory usage of the Autoigrad version." 564 | ] 565 | } 566 | ], 567 | "metadata": { 568 | "kernelspec": { 569 | "display_name": "Python [conda env:.conda-fastai]", 570 | "language": "python", 571 | "name": "conda-env-.conda-fastai-py" 572 | }, 573 | "language_info": { 574 | "codemirror_mode": { 575 | "name": "ipython", 576 | "version": 3 577 | }, 578 | "file_extension": ".py", 579 | "mimetype": "text/x-python", 580 | "name": "python", 581 | "nbconvert_exporter": "python", 582 | "pygments_lexer": "ipython3", 583 | "version": "3.7.4" 584 | } 585 | }, 586 | "nbformat": 4, 587 | "nbformat_minor": 4 588 | } 589 | -------------------------------------------------------------------------------- /extra/package.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This script will package the extension into a single file for inline JIT loading. 3 | 4 | from sys import exit 5 | from pathlib import Path 6 | from argparse import ArgumentParser 7 | from zlib import compress 8 | from base64 import b64encode 9 | from itertools import chain 10 | import re 11 | 12 | parser = ArgumentParser(description="Package the extension into a single file for inline JIT loading.") 13 | parser.add_argument('-p', '--path', default=None, help="Location of source files, defaults based on this scripts location.") 14 | parser.add_argument('-o', '--output', default = "swish_inline.py", help="File to write output to, default: ./__init__.py") 15 | parser.add_argument('-s', '--stdout', action="store_true", help="Weite output to stdout instead of file") 16 | args = parser.parse_args() 17 | 18 | path = Path(__file__).absolute().parent.parent if args.path is None else Path(args.path).absolute() 19 | if not path.exists(): exit("Input path doesn't exist.") 20 | if not (path/'csrc').exists(): exit(f"Path doesn't appear to contain extension sources. Couldn't find {(path/'csrc').absolute()}.") 21 | cpp_files = list((path/'csrc').glob('*.cpp')) 22 | cu_files = list((path/'csrc').glob('*.cu')) 23 | incs = {p.name: p.read_text() for p in chain(*[(path/d).glob(p) for d in ['csrc','external'] for p in ['*.h','*.cuh']])} 24 | 25 | def proc_src(f): 26 | src = f.read_text() 27 | res,pos = f"\n// From {f}\n\n",0 28 | for m in re.finditer(r'#include "([^"]+)"\n', src): 29 | res += src[pos:m.start()] 30 | inc = m.group(1) 31 | if inc not in incs: exit(f"Couldn't find included file '{inc}' included in {f}'") 32 | res += f"\n// Include: {f}\n" + incs[inc] + f"\n// End Include: {f}\n\n" 33 | pos = m.end() 34 | res += src[pos:] 35 | return res 36 | 37 | cpp_srcs = [proc_src(f) for f in cpp_files] 38 | cu_srcs = [proc_src(f) for f in cu_files] 39 | 40 | m = re.search(r"""version=['"]([^'"]+)['"]""", (path/'setup.py').read_text()) 41 | if not m: exit("Unable to find version in setup.py") 42 | ver = m.group(1) 43 | 44 | src = f""" 45 | ALL = ['Swish','SwishFunction','swish','__version__'] 46 | from base64 import b64decode 47 | from zlib import decompress 48 | import torch 49 | from torch.utils.cpp_extension import load_inline 50 | 51 | __version__='{ver}' 52 | 53 | def load_module(): 54 | print("Compiling script_torch module...") 55 | cpp_comp = [{','.join((f"'{b64encode(compress(s.encode(),9)).decode()}'" for s in cpp_srcs))}] 56 | cu_comp = [{','.join((f"'{b64encode(compress(s.encode(),9)).decode()}'" for s in cu_srcs ))}] 57 | cpp_srcs = [decompress(b64decode(src)).decode() for src in cpp_comp] 58 | cu_srcs = [decompress(b64decode(src)).decode() for src in cu_comp] 59 | 60 | swish_mod = load_inline("swish_torch_inline", cpp_sources=cpp_srcs, cuda_sources=cu_srcs, extra_cuda_cflags=['--expt-extended-lambda']) 61 | return swish_mod 62 | 63 | if not torch.cuda.is_available(): 64 | print("CUDA not available but is required for swish_torch") 65 | swish_mod = None 66 | else: 67 | swish_mod = load_module() 68 | 69 | class SwishFunction(torch.autograd.Function): 70 | @staticmethod 71 | def forward(ctx, inp): 72 | ctx.save_for_backward(inp) 73 | return swish_mod.swish_forward(inp) 74 | 75 | @staticmethod 76 | def backward(ctx, grad_out): 77 | inp, = ctx.saved_tensors 78 | if not ctx.needs_input_grad[0]: return (None,) 79 | return swish_mod.swish_backward(inp, grad_out) 80 | 81 | class Swish(torch.nn.Module): 82 | '''Swish Activation Function - Inline PyTorch CUDA Version''' 83 | def forward(self, inp): return SwishFunction.apply(inp) 84 | 85 | swish = SwishFunction.apply 86 | 87 | if swish_mod is not None: 88 | print(f"Successfully loaded swish-torch inline version {{__version__}}") 89 | 90 | """ 91 | 92 | src = src.replace("$CPP_SRCS$", 93 | ','.join((f"'{b64encode(compress(s.encode(),9)).decode()}'" for s in cpp_srcs))) 94 | src = src.replace("$CU_SRCS$", 95 | ','.join((f"'{b64encode(compress(s.encode(),9)).decode()}'" for s in cu_srcs))) 96 | 97 | if args.stdout: 98 | print(src) 99 | else: 100 | with Path(args.output).open('w') as out: 101 | out.write(src) 102 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | EXT_SRCS = [ 5 | 'csrc/swish.cpp', 6 | 'csrc/swish_kernel.cu', 7 | ] 8 | 9 | setup( 10 | name='swish_torch', 11 | version='0.0.1', 12 | packages=find_packages('src'), 13 | package_dir={'': 'src'}, 14 | include_package_data=True, 15 | zip_safe=False, 16 | install_requires=['torch>=1.2'], 17 | ext_modules=[ 18 | CUDAExtension( 19 | 'swish_torch._C', 20 | EXT_SRCS, 21 | extra_compile_args={ 22 | 'cxx': [], 23 | 'nvcc': ['--expt-extended-lambda'] 24 | }, 25 | include_dirs=['external'] 26 | ) 27 | ], 28 | cmdclass={ 29 | 'build_ext': BuildExtension 30 | }) 31 | -------------------------------------------------------------------------------- /src/swish_torch/__init__.py: -------------------------------------------------------------------------------- 1 | ALL = ['Swish','SwishFunction','swish'] 2 | 3 | import torch # Must import torch before C extension 4 | from ._C import swish_forward, swish_backward 5 | 6 | class SwishFunction(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, inp): 9 | ctx.save_for_backward(inp) 10 | return swish_forward(inp) 11 | 12 | @staticmethod 13 | def backward(ctx, grad_out): 14 | inp, = ctx.saved_tensors 15 | if not ctx.needs_input_grad[0]: return (None,) 16 | return swish_backward(inp, grad_out) 17 | 18 | class Swish(torch.nn.Module): 19 | """Swish Activation Function - PyTorch CUDA Version""" 20 | def forward(self, inp): return SwishFunction.apply(inp) 21 | 22 | swish = SwishFunction.apply -------------------------------------------------------------------------------- /test/test_swish.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.testing import assert_allclose 5 | 6 | swish_forward_pt = lambda x: x.mul(torch.sigmoid(x)) 7 | 8 | class SwishPT(torch.nn.Module): 9 | def forward(self, x): return swish_forward_pt(x) 10 | 11 | def get_input_params(): 12 | assert torch.cuda.is_available() and torch.cuda.device_count() > 0 13 | devs = ['cuda:0'] # TODO: Allow other devices 14 | dev_types = [(dtype,device) 15 | for dtype in [torch.float16,torch.float32,torch.float64] 16 | for device in devs 17 | # Basic ops not supported on CPU/Half, could test by converting but skip for now 18 | if not (dtype==torch.float16 and torch.device(device).type == 'cpu')] 19 | inputs = [(ndim,dtype,device) 20 | for (dtype,device) in dev_types 21 | for ndim in [1,2,3,4,8]] 22 | return inputs 23 | 24 | @pytest.fixture(params=get_input_params()) 25 | def test_input(request): 26 | ndim,dtype,device = request.param 27 | sz = (2,) * (ndim-1) + (10,) 28 | if device == 'cpu' and dtype == torch.float16: 29 | t = torch.randn(*sz).half() # No randn for half on CPU 30 | else: t = torch.randn(*sz, device=device, dtype=dtype) 31 | return t + torch.randint(-1000, 1000, sz, device=device, dtype=dtype) 32 | 33 | def test_forward(test_input): 34 | from swish_torch import swish_forward 35 | res = swish_forward(test_input) 36 | exp = swish_forward_pt(test_input) 37 | assert_allclose(res, exp) 38 | 39 | def get_grads(inp): 40 | y = swish_forward_pt(inp) 41 | l = y.mean() 42 | grad_out, = torch.autograd.grad(l, y, retain_graph=True) 43 | exp, = torch.autograd.grad(y, inp, grad_out, retain_graph=True) 44 | return grad_out, exp 45 | 46 | def test_backward(test_input): 47 | from swish_torch import swish_backward 48 | x = test_input.requires_grad_() 49 | grad_out,exp = get_grads(test_input) 50 | res = swish_backward(test_input.detach(), grad_out) 51 | assert_allclose(res, exp) 52 | 53 | def test_function(test_input): 54 | from swish_torch import SwishFunction 55 | x1,x2 = (test_input.clone().requires_grad_() for i in range(2)) 56 | 57 | y1 = swish_forward_pt(x1) 58 | l1 = y1.mean() 59 | exp, = torch.autograd.grad(l1, x1) 60 | 61 | y2 = SwishFunction.apply(x2) 62 | l2 = y2.mean() 63 | res, = torch.autograd.grad(l2, x2) 64 | assert_allclose(res, exp) 65 | 66 | def test_module(test_input): 67 | from swish_torch import Swish 68 | x1,x2 = (test_input.clone().requires_grad_() for i in range(2)) 69 | 70 | m1 = SwishPT() 71 | y1 = m1(x1) 72 | l1 = y1.mean() 73 | exp, = torch.autograd.grad(l1, x1) 74 | 75 | m2 = Swish() 76 | y2 = m2(x2) 77 | l2 = y2.mean() 78 | res, = torch.autograd.grad(l2, x2) 79 | assert_allclose(res, exp) 80 | 81 | def test_gradient(): 82 | from swish_torch import SwishFunction 83 | inp = torch.randn(10, 10, dtype=torch.float64, requires_grad=True, device='cuda:0') 84 | assert torch.autograd.gradcheck(SwishFunction.apply, inp) 85 | 86 | def test_gradgrad(): 87 | from swish_torch import SwishFunction 88 | inp = torch.randn(10, 10, dtype=torch.float64, requires_grad=True, device='cuda:0') 89 | assert torch.autograd.gradgradcheck(SwishFunction.apply, inp) 90 | 91 | def test_overlapping(): 92 | '''Test handling of overlapping output tensors''' 93 | from swish_torch import swish_forward 94 | t = torch.randn(2, 10, device='cuda:0') 95 | t_o = t.as_strided((3,10), (5,1)) # OVerlapping 96 | t_c = t_o.contiguous() # Contiguous 97 | o_o = swish_forward(t_o, torch.empty_like(t_o)) 98 | o_c = swish_forward(t_c, torch.empty_like(t_c)) 99 | assert torch.equal(o_o, o_c) 100 | --------------------------------------------------------------------------------