├── .gitignore ├── COPYING.txt ├── LICENSE.txt ├── Makefile ├── README.md ├── build_ffi.py ├── include ├── shiftnet_cuda.h └── shiftnet_cuda_kernels.h ├── module_test.py ├── nn.py ├── requirements.txt ├── src ├── shiftnet_cuda.c └── shiftnet_cuda_kernels.cu └── test_shiftnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build 3 | shiftnet_cuda 4 | __init__.py 5 | 6 | -------------------------------------------------------------------------------- /COPYING.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 the shiftnet_cuda_v2 authors 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Unix commands. 2 | PYTHON := python 3 | NVCC := /usr/local/cuda-8.0/bin/nvcc 4 | NVCC_COMPILE := $(NVCC) -arch=sm_30 \ 5 | -gencode=arch=compute_30,code=sm_30 \ 6 | -gencode=arch=compute_50,code=sm_50 \ 7 | -gencode=arch=compute_52,code=sm_52 \ 8 | -gencode=arch=compute_60,code=sm_60 \ 9 | -gencode=arch=compute_61,code=sm_61 \ 10 | -gencode=arch=compute_61,code=compute_61 -c -o 11 | RM_RF := rm -rf 12 | 13 | # Library compilation rules. 14 | NVCC_FLAGS := -x cu -Xcompiler -fPIC -shared 15 | 16 | # File structure. 17 | BUILD_DIR := build 18 | INCLUDE_DIRS := include 19 | TORCH_FFI_BUILD := build_ffi.py 20 | MATHUTIL_KERNEL := $(BUILD_DIR)/shiftnet_cuda_kernels.so 21 | TORCH_FFI_TARGET := $(BUILD_DIR)/shiftnet_cuda/_shiftnet_cuda.so 22 | 23 | INCLUDE_FLAGS := $(foreach d, $(INCLUDE_DIRS), -I$d) 24 | 25 | all: $(TORCH_FFI_TARGET) 26 | 27 | $(TORCH_FFI_TARGET): $(MATHUTIL_KERNEL) $(TORCH_FFI_BUILD) 28 | $(PYTHON) $(TORCH_FFI_BUILD) 29 | 30 | $(BUILD_DIR)/%.so: src/%.cu 31 | @ mkdir -p $(BUILD_DIR) 32 | # Separate cpp shared library that will be loaded to the extern C ffi 33 | $(NVCC_COMPILE) $@ $? $(NVCC_FLAGS) $(INCLUDE_FLAGS) 34 | 35 | clean: 36 | $(RM_RF) $(BUILD_DIR) shiftnet_cuda/__init__.py shiftnet_cuda/_shiftnet_cuda.so *.pyc 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shift Operation CUDA Implementation 2 | 3 | created by Peter Jin 4 | 5 | Tradeoffs and further analysis can be found in the paper. If you find this work useful for your research, please consider citing: 6 | 7 | @inproceedings{shift, 8 | Author = {Bichen Wu and Alvin Wan and Xiangyu Yue and Peter Jin and Sicheng Zhao and Noah Golmant and Amir Gholaminejad and Joseph Gonzalez and Kurt Keutzer}, 9 | Title = {Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions}, 10 | Journal = {arXiv:1711.08141}, 11 | Year = {2017} 12 | } 13 | 14 | Uses of Shift: 15 | * [ShiftResNet](http://github.com/alvinwan/shiftresnet-cifar) for CIFAR-10, CIFAR-100 classification 16 | 17 | # Installation 18 | 19 | > If you have included this `shift` repository as a submodule in a separate repository, feel free to skip down to step 5. 20 | 21 | 1. If you have not already, setup a virtual environment with Python3, and activate it. 22 | 23 | ``` 24 | virtualenv shift --python=python3 25 | source shift/bin/activate 26 | ``` 27 | 28 | Your prompt should now be prefaced with `(shift)`, as in 29 | 30 | ``` 31 | (shift) [user@server:~]$ 32 | ``` 33 | 34 | 2. Install `pytorch` and `torchvision`. Access [pytorch.org](http://pytorch.org), scroll down to the "Getting Started" section, and select the appropriate OS, package manager, Python, and CUDA build. For example, selecting Linux, pip, Python3.5, and CUDA 8 gives the following, as of the time of this writing 35 | 36 | ``` 37 | pip3 install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl 38 | pip3 install torchvision 39 | ``` 40 | 41 | 3. Clone this repository. 42 | 43 | ``` 44 | git clone git@github.com:peterhj/shiftnet_cuda_v2.git 45 | ``` 46 | 47 | 4. `cd` into the root of this repository. 48 | 49 | ``` 50 | cd shiftnet_cuda_v2 51 | ``` 52 | 53 | 5. Install the Python requirements for this package. 54 | 55 | ``` 56 | pip3 install -r requirements.txt 57 | ``` 58 | 59 | 6. Compile the Shift Layer implementation in C. 60 | 61 | ``` 62 | make 63 | ``` 64 | 65 | > **Getting `invalid_device_function`?** Update the architecture code in [`models/shiftnet_cuda_v2/Makefile`](https://github.com/alvinwan/shiftresnet-cifar/blob/master/models/shiftnet_cuda_v2/Makefile#L4), currently configured for a Titan X. e.g., A Tesla K80 is `sm-30`. 66 | 67 | Your custom CUDA layer is now installed. 68 | 69 | # Test 70 | 71 | To check that the build completed successfully, run the test script 72 | 73 | ``` 74 | python test_shiftnet.py 75 | ``` 76 | 77 | After ~3s, the script should output a number of different tensors, where the last tensor has non-zero values only in the first column. 78 | 79 | ``` 80 | Columns 13 to 17 81 | 89 0 0 0 0 82 | 107 0 0 0 0 83 | 125 0 0 0 0 84 | 143 0 0 0 0 85 | 161 0 0 0 0 86 | 179 0 0 0 0 87 | 197 0 0 0 0 88 | 215 0 0 0 0 89 | 233 0 0 0 0 90 | 251 0 0 0 0 91 | 269 0 0 0 0 92 | 287 0 0 0 0 93 | 305 0 0 0 0 94 | 323 0 0 0 0 95 | 0 0 0 0 0 96 | 0 0 0 0 0 97 | 0 0 0 0 0 98 | 0 0 0 0 0 99 | [torch.FloatTensor of size 18x18] 100 | ``` 101 | -------------------------------------------------------------------------------- /build_ffi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.ffi import create_extension 3 | import os 4 | 5 | abs_path = os.path.dirname(os.path.realpath(__file__)) 6 | extra_objects = [os.path.join(abs_path, 'build/shiftnet_cuda_kernels.so')] 7 | 8 | ffi = create_extension( 9 | 'shiftnet_cuda', 10 | headers=['include/shiftnet_cuda.h'], 11 | sources=['src/shiftnet_cuda.c'], 12 | define_macros=[('WITH_CUDA', None)], 13 | relative_to=__file__, 14 | with_cuda=True, 15 | extra_objects=extra_objects, 16 | include_dirs=[os.path.join(abs_path, 'include')] 17 | ) 18 | 19 | if __name__ == '__main__': 20 | assert torch.cuda.is_available(), 'Please install CUDA for GPU support.' 21 | ffi.build() 22 | -------------------------------------------------------------------------------- /include/shiftnet_cuda.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 the shiftnet_cuda_v2 authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | int moduloshift3x3_nchw(THCudaTensor *src_tensor, THCudaTensor *dst_tensor); 18 | int moduloshift3x3bwd_nchw(THCudaTensor *src_tensor, THCudaTensor *dst_tensor); 19 | int moduloshiftgeneric_nchw(THCudaTensor *src_tensor, THCudaTensor *dst_tensor, int kernel_size, int dilate_factor, int direction); 20 | -------------------------------------------------------------------------------- /include/shiftnet_cuda_kernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 the shiftnet_cuda_v2 authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include 18 | 19 | void shiftnet_cuda_moduloshift3x3_nchw_float32( 20 | float *src, 21 | float *dst, 22 | int batch_sz, 23 | int channels, 24 | int height, 25 | int width, 26 | cudaStream_t stream); 27 | 28 | void shiftnet_cuda_moduloshift3x3bwd_nchw_float32( 29 | float *src, 30 | float *dst, 31 | int batch_sz, 32 | int channels, 33 | int height, 34 | int width, 35 | cudaStream_t stream); 36 | 37 | void shiftnet_cuda_moduloshiftgeneric_nchw_float32( 38 | float *src, 39 | float *dst, 40 | int batch_sz, 41 | int channels, 42 | int height, 43 | int width, 44 | int kernel_size, 45 | int dilate_factor, 46 | int direction, 47 | cudaStream_t stream); 48 | -------------------------------------------------------------------------------- /module_test.py: -------------------------------------------------------------------------------- 1 | #import math 2 | #import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as model_zoo 8 | import shiftnet_cuda 9 | 10 | from torch.autograd import Variable 11 | from torch.autograd import Function 12 | import time 13 | 14 | class ShiftFn(Function): 15 | @staticmethod 16 | def forward(ctx, src): 17 | dst = torch.cuda.FloatTensor(src.size()) 18 | ret = shiftnet_cuda.moduloshift3x3_nchw(src, dst) 19 | assert ret == 1 20 | return dst 21 | 22 | @staticmethod 23 | def backward(ctx, grad_dst): 24 | grad_src = torch.cuda.FloatTensor(grad_dst.data.size()) 25 | ret = shiftnet_cuda.moduloshift3x3bwd_nchw(grad_dst.data, grad_src) 26 | assert ret == 1 27 | return Variable(grad_src, requires_grad=grad_dst.requires_grad) 28 | 29 | class Shift3x3(nn.Module): 30 | def __init__(self): 31 | super(Shift3x3, self).__init__() 32 | 33 | def forward(self, src): 34 | print("DEBUG: fwd:", type(src)) 35 | return ShiftFn.apply(src) 36 | 37 | class Shift3x3_cuda(nn.Module): 38 | def __init__(self): 39 | super(Shift3x3_cuda, self).__init__() 40 | 41 | def forward(self, src): 42 | return ShiftFn.apply(src) 43 | 44 | if __name__ == "__main__": 45 | #import sys 46 | #sys.path.append("./") 47 | 48 | #import shift_module 49 | import numpy as np 50 | from torch.autograd import Variable 51 | 52 | pattern = np.arange(25).reshape(5,5) 53 | src_buf = np.zeros((16, 10, 5, 5)).astype(np.float32) 54 | for bnr in range(16): 55 | for ch in range(10): 56 | src_buf[bnr,ch,:,:] = pattern 57 | src = Variable(torch.from_numpy(src_buf).cuda(), requires_grad=True, volatile=False) 58 | print("DEBUG: src:", src.requires_grad) 59 | print(src.data.cpu().numpy()[0,:,:,:]) 60 | 61 | shift = Shift3x3() 62 | 63 | out = shift.forward(src) 64 | print("DEBUG: out:", out.requires_grad) 65 | print(out.data.cpu().numpy()[0,:,:,:]) 66 | 67 | sink = Variable(torch.ones(out.size()).cuda()) 68 | out.backward(sink) 69 | print("DEBUG: src grad:") 70 | print(src.grad.data.cpu().numpy()[0,:,:,:]) 71 | -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2017 the shiftnet_cuda_v2 authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import shiftnet_cuda 18 | import torch 19 | from torch.autograd import Function, Variable 20 | 21 | class ShiftFn(Function): 22 | @staticmethod 23 | def forward(ctx, src): 24 | dst = torch.cuda.FloatTensor(src.size()) 25 | ret = shiftnet_cuda.moduloshift3x3_nchw(src, dst) 26 | assert ret == 1 27 | return dst 28 | 29 | @staticmethod 30 | def backward(ctx, grad_dst): 31 | grad_src = torch.cuda.FloatTensor(grad_dst.data.size()) 32 | ret = shiftnet_cuda.moduloshift3x3bwd_nchw(grad_dst.data, grad_src) 33 | assert ret == 1 34 | return Variable(grad_src, requires_grad=grad_dst.requires_grad) 35 | 36 | class GenericShiftFn(Function): 37 | @staticmethod 38 | def forward(ctx, src, kernel_size, dilate_factor): 39 | ctx.kernel_size = kernel_size 40 | ctx.dilate_factor = dilate_factor 41 | dst = torch.cuda.FloatTensor(src.size()) 42 | ret = shiftnet_cuda.moduloshiftgeneric_nchw(src, dst, kernel_size, dilate_factor, 1) 43 | assert ret == 1, "GenericShiftFn: forward: invalid args, your kernel or dilation are probably too large" 44 | return dst 45 | 46 | @staticmethod 47 | def backward(ctx, grad_dst): 48 | grad_src = torch.cuda.FloatTensor(grad_dst.data.size()) 49 | ret = shiftnet_cuda.moduloshiftgeneric_nchw(grad_dst.data, grad_src, ctx.kernel_size, ctx.dilate_factor, -1) 50 | assert ret == 1, "GenericShiftFn: backward: invalid args, your kernel or dilation are probably too large" 51 | return Variable(grad_src, requires_grad=grad_dst.requires_grad), None, None 52 | 53 | class Shift3x3_cuda(torch.nn.Module): 54 | def __init__(self): 55 | super(Shift3x3_cuda, self).__init__() 56 | 57 | def forward(self, x): 58 | return ShiftFn.apply(x) 59 | 60 | class GenericShift_cuda(torch.nn.Module): 61 | def __init__(self, kernel_size, dilate_factor=1): 62 | super(GenericShift_cuda, self).__init__() 63 | self._kernel_size = kernel_size 64 | self._dilate_factor = dilate_factor 65 | 66 | def forward(self, x): 67 | return GenericShiftFn.apply(x, self._kernel_size, self._dilate_factor) 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cffi==1.11.2 2 | -------------------------------------------------------------------------------- /src/shiftnet_cuda.c: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 the shiftnet_cuda_v2 authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include "shiftnet_cuda_kernels.h" 18 | 19 | #include 20 | 21 | #include 22 | 23 | extern THCState *state; 24 | 25 | int moduloshift3x3_nchw(THCudaTensor *src_tensor, THCudaTensor *dst_tensor) { 26 | cudaStream_t stream = THCState_getCurrentStream(state); 27 | 28 | // TODO: support for generic storage types. 29 | float *src_dptr = THCudaTensor_data(state, src_tensor); 30 | float *dst_dptr = THCudaTensor_data(state, dst_tensor); 31 | 32 | // Check tensors are 4D. 33 | int src_ndim = THCudaTensor_nDimension(state, src_tensor); 34 | int dst_ndim = THCudaTensor_nDimension(state, dst_tensor); 35 | if (4 != src_ndim) { 36 | return 0; 37 | } 38 | if (4 != dst_ndim) { 39 | return 0; 40 | } 41 | 42 | // Check tensor sizes match. 43 | long src_batch_sz = THCudaTensor_size(state, src_tensor, 0); 44 | long src_channels = THCudaTensor_size(state, src_tensor, 1); 45 | long src_height = THCudaTensor_size(state, src_tensor, 2); 46 | long src_width = THCudaTensor_size(state, src_tensor, 3); 47 | long dst_batch_sz = THCudaTensor_size(state, dst_tensor, 0); 48 | long dst_channels = THCudaTensor_size(state, dst_tensor, 1); 49 | long dst_height = THCudaTensor_size(state, dst_tensor, 2); 50 | long dst_width = THCudaTensor_size(state, dst_tensor, 3); 51 | if (src_batch_sz != dst_batch_sz) { 52 | return 0; 53 | } 54 | if (src_channels != dst_channels) { 55 | return 0; 56 | } 57 | if (src_height != dst_height) { 58 | return 0; 59 | } 60 | if (src_width != dst_width) { 61 | return 0; 62 | } 63 | 64 | // Check tensor strides are packed. 65 | long src_batch_sz_stride = THCudaTensor_stride(state, src_tensor, 0); 66 | long src_channels_stride = THCudaTensor_stride(state, src_tensor, 1); 67 | long src_height_stride = THCudaTensor_stride(state, src_tensor, 2); 68 | long src_width_stride = THCudaTensor_stride(state, src_tensor, 3); 69 | long dst_batch_sz_stride = THCudaTensor_stride(state, dst_tensor, 0); 70 | long dst_channels_stride = THCudaTensor_stride(state, dst_tensor, 1); 71 | long dst_height_stride = THCudaTensor_stride(state, dst_tensor, 2); 72 | long dst_width_stride = THCudaTensor_stride(state, dst_tensor, 3); 73 | long packed_stride_d0 = 1; 74 | long packed_stride_d1 = packed_stride_d0 * src_width; 75 | long packed_stride_d2 = packed_stride_d1 * src_height; 76 | long packed_stride_d3 = packed_stride_d2 * src_channels; 77 | if (packed_stride_d0 != src_width_stride || packed_stride_d0 != dst_width_stride) { 78 | return 0; 79 | } 80 | if (packed_stride_d1 != src_height_stride || packed_stride_d1 != dst_height_stride) { 81 | return 0; 82 | } 83 | if (packed_stride_d2 != src_channels_stride || packed_stride_d2 != dst_channels_stride) { 84 | return 0; 85 | } 86 | if (packed_stride_d3 != src_batch_sz_stride || packed_stride_d3 != dst_batch_sz_stride) { 87 | return 0; 88 | } 89 | 90 | //printf("DEBUG: moduloshift3x3_nchw: passed size checks\n"); 91 | shiftnet_cuda_moduloshift3x3_nchw_float32( 92 | src_dptr, 93 | dst_dptr, 94 | src_batch_sz, 95 | src_channels, 96 | src_height, 97 | src_width, 98 | stream); 99 | 100 | return 1; 101 | } 102 | 103 | int moduloshift3x3bwd_nchw(THCudaTensor *src_tensor, THCudaTensor *dst_tensor) { 104 | cudaStream_t stream = THCState_getCurrentStream(state); 105 | 106 | // TODO: support for generic storage types. 107 | float *src_dptr = THCudaTensor_data(state, src_tensor); 108 | float *dst_dptr = THCudaTensor_data(state, dst_tensor); 109 | 110 | // Check tensors are 4D. 111 | int src_ndim = THCudaTensor_nDimension(state, src_tensor); 112 | int dst_ndim = THCudaTensor_nDimension(state, dst_tensor); 113 | if (4 != src_ndim) { 114 | return 0; 115 | } 116 | if (4 != dst_ndim) { 117 | return 0; 118 | } 119 | 120 | // Check tensor sizes match. 121 | long src_batch_sz = THCudaTensor_size(state, src_tensor, 0); 122 | long src_channels = THCudaTensor_size(state, src_tensor, 1); 123 | long src_height = THCudaTensor_size(state, src_tensor, 2); 124 | long src_width = THCudaTensor_size(state, src_tensor, 3); 125 | long dst_batch_sz = THCudaTensor_size(state, dst_tensor, 0); 126 | long dst_channels = THCudaTensor_size(state, dst_tensor, 1); 127 | long dst_height = THCudaTensor_size(state, dst_tensor, 2); 128 | long dst_width = THCudaTensor_size(state, dst_tensor, 3); 129 | if (src_batch_sz != dst_batch_sz) { 130 | return 0; 131 | } 132 | if (src_channels != dst_channels) { 133 | return 0; 134 | } 135 | if (src_height != dst_height) { 136 | return 0; 137 | } 138 | if (src_width != dst_width) { 139 | return 0; 140 | } 141 | 142 | // Check tensor strides are packed. 143 | long src_batch_sz_stride = THCudaTensor_stride(state, src_tensor, 0); 144 | long src_channels_stride = THCudaTensor_stride(state, src_tensor, 1); 145 | long src_height_stride = THCudaTensor_stride(state, src_tensor, 2); 146 | long src_width_stride = THCudaTensor_stride(state, src_tensor, 3); 147 | long dst_batch_sz_stride = THCudaTensor_stride(state, dst_tensor, 0); 148 | long dst_channels_stride = THCudaTensor_stride(state, dst_tensor, 1); 149 | long dst_height_stride = THCudaTensor_stride(state, dst_tensor, 2); 150 | long dst_width_stride = THCudaTensor_stride(state, dst_tensor, 3); 151 | long packed_stride_d0 = 1; 152 | long packed_stride_d1 = packed_stride_d0 * src_width; 153 | long packed_stride_d2 = packed_stride_d1 * src_height; 154 | long packed_stride_d3 = packed_stride_d2 * src_channels; 155 | if (packed_stride_d0 != src_width_stride || packed_stride_d0 != dst_width_stride) { 156 | return 0; 157 | } 158 | if (packed_stride_d1 != src_height_stride || packed_stride_d1 != dst_height_stride) { 159 | return 0; 160 | } 161 | if (packed_stride_d2 != src_channels_stride || packed_stride_d2 != dst_channels_stride) { 162 | return 0; 163 | } 164 | if (packed_stride_d3 != src_batch_sz_stride || packed_stride_d3 != dst_batch_sz_stride) { 165 | return 0; 166 | } 167 | 168 | //printf("DEBUG: moduloshift3x3_nchw: passed size checks\n"); 169 | shiftnet_cuda_moduloshift3x3bwd_nchw_float32( 170 | src_dptr, 171 | dst_dptr, 172 | src_batch_sz, 173 | src_channels, 174 | src_height, 175 | src_width, 176 | stream); 177 | 178 | return 1; 179 | } 180 | 181 | int moduloshiftgeneric_nchw(THCudaTensor *src_tensor, THCudaTensor *dst_tensor, int kernel_size, int dilate_factor, int direction) { 182 | cudaStream_t stream = THCState_getCurrentStream(state); 183 | 184 | // TODO: support for generic storage types. 185 | float *src_dptr = THCudaTensor_data(state, src_tensor); 186 | float *dst_dptr = THCudaTensor_data(state, dst_tensor); 187 | 188 | // Check tensors are 4D. 189 | int src_ndim = THCudaTensor_nDimension(state, src_tensor); 190 | int dst_ndim = THCudaTensor_nDimension(state, dst_tensor); 191 | if (4 != src_ndim) { 192 | return 0; 193 | } 194 | if (4 != dst_ndim) { 195 | return 0; 196 | } 197 | 198 | // Check tensor sizes match. 199 | long src_batch_sz = THCudaTensor_size(state, src_tensor, 0); 200 | long src_channels = THCudaTensor_size(state, src_tensor, 1); 201 | long src_height = THCudaTensor_size(state, src_tensor, 2); 202 | long src_width = THCudaTensor_size(state, src_tensor, 3); 203 | long dst_batch_sz = THCudaTensor_size(state, dst_tensor, 0); 204 | long dst_channels = THCudaTensor_size(state, dst_tensor, 1); 205 | long dst_height = THCudaTensor_size(state, dst_tensor, 2); 206 | long dst_width = THCudaTensor_size(state, dst_tensor, 3); 207 | if (src_batch_sz != dst_batch_sz) { 208 | return 0; 209 | } 210 | if (src_channels != dst_channels) { 211 | return 0; 212 | } 213 | if (src_height != dst_height) { 214 | return 0; 215 | } 216 | if (src_width != dst_width) { 217 | return 0; 218 | } 219 | 220 | // Check tensor strides are packed. 221 | long src_batch_sz_stride = THCudaTensor_stride(state, src_tensor, 0); 222 | long src_channels_stride = THCudaTensor_stride(state, src_tensor, 1); 223 | long src_height_stride = THCudaTensor_stride(state, src_tensor, 2); 224 | long src_width_stride = THCudaTensor_stride(state, src_tensor, 3); 225 | long dst_batch_sz_stride = THCudaTensor_stride(state, dst_tensor, 0); 226 | long dst_channels_stride = THCudaTensor_stride(state, dst_tensor, 1); 227 | long dst_height_stride = THCudaTensor_stride(state, dst_tensor, 2); 228 | long dst_width_stride = THCudaTensor_stride(state, dst_tensor, 3); 229 | long packed_stride_d0 = 1; 230 | long packed_stride_d1 = packed_stride_d0 * src_width; 231 | long packed_stride_d2 = packed_stride_d1 * src_height; 232 | long packed_stride_d3 = packed_stride_d2 * src_channels; 233 | if (packed_stride_d0 != src_width_stride || packed_stride_d0 != dst_width_stride) { 234 | return 0; 235 | } 236 | if (packed_stride_d1 != src_height_stride || packed_stride_d1 != dst_height_stride) { 237 | return 0; 238 | } 239 | if (packed_stride_d2 != src_channels_stride || packed_stride_d2 != dst_channels_stride) { 240 | return 0; 241 | } 242 | if (packed_stride_d3 != src_batch_sz_stride || packed_stride_d3 != dst_batch_sz_stride) { 243 | return 0; 244 | } 245 | 246 | if (kernel_size <= 0) { 247 | return 0; 248 | } 249 | if (dilate_factor <= 0) { 250 | return 0; 251 | } 252 | if (direction != 1 && direction != -1) { 253 | return 0; 254 | } 255 | 256 | int dilated_half_kernel_size = dilate_factor * (kernel_size / 2); 257 | int tile_out_size = 16 - 2 * dilated_half_kernel_size; 258 | if (tile_out_size <= 0) { 259 | return 0; 260 | } 261 | 262 | //printf("DEBUG: moduloshiftgeneric_nchw: passed size checks: %d %d %d\n", 263 | // kernel_size, dilate_factor, direction); 264 | shiftnet_cuda_moduloshiftgeneric_nchw_float32( 265 | src_dptr, 266 | dst_dptr, 267 | src_batch_sz, 268 | src_channels, 269 | src_height, 270 | src_width, 271 | kernel_size, 272 | dilate_factor, 273 | direction, 274 | stream); 275 | 276 | return 1; 277 | } 278 | -------------------------------------------------------------------------------- /src/shiftnet_cuda_kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2017 the shiftnet_cuda_v2 authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | #include 18 | 19 | //#include 20 | 21 | #define MAX_BLOCKS 128 22 | 23 | //using std::min; 24 | 25 | __global__ void shiftnet_cuda_moduloshift3x3_nchw_float32_kernel_tilein16x16_tileout14x14( 26 | float *src, 27 | float *dst, 28 | int num_h_tiles, 29 | int num_w_tiles, 30 | int batch_sz, 31 | int channels, 32 | int height, 33 | int width) 34 | { 35 | __shared__ float cache[256]; 36 | const int num_blocks = batch_sz * channels * num_h_tiles * num_w_tiles; 37 | const int num_threads = blockDim.x * num_blocks; 38 | const int rd_chans = (channels / 9) * 9; 39 | for (int idx = threadIdx.x + blockDim.x * blockIdx.x; 40 | idx < num_threads; idx += blockDim.x * gridDim.x) 41 | { 42 | const int w_tile_idx = (idx / 256) % num_w_tiles; 43 | const int h_tile_idx = ((idx / 256) / num_w_tiles) % num_h_tiles; 44 | const int tile_ch = (((idx / 256) / num_w_tiles) / num_h_tiles) % channels; 45 | const int tile_batch_idx = ((((idx / 256) / num_w_tiles) / num_h_tiles) / channels) % batch_sz; 46 | const int w_shift = ((tile_ch % 3) - 1) * (tile_ch < rd_chans); 47 | const int h_shift = (((tile_ch / 3) % 3) - 1) * (tile_ch < rd_chans); 48 | const int w_tile_off = threadIdx.x % 16; 49 | const int h_tile_off = threadIdx.x / 16; 50 | const int w_idx = w_tile_off - 1 + 14 * w_tile_idx; 51 | const int h_idx = h_tile_off - 1 + 14 * h_tile_idx; 52 | const int buf_idx = w_idx + width * (h_idx + height * (tile_ch + channels * tile_batch_idx)); 53 | if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) { 54 | cache[threadIdx.x] = src[buf_idx]; 55 | } else { 56 | cache[threadIdx.x] = 0.0f; 57 | } 58 | __syncthreads(); 59 | if (w_tile_off >= 1 && w_tile_off < 15 && h_tile_off >= 1 && h_tile_off < 15) { 60 | if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) { 61 | const int cache_idx = (w_tile_off + w_shift) + 16 * (h_tile_off + h_shift); 62 | dst[buf_idx] = cache[cache_idx]; 63 | } 64 | } 65 | __syncthreads(); 66 | } 67 | } 68 | 69 | extern "C" void shiftnet_cuda_moduloshift3x3_nchw_float32( 70 | float *src, 71 | float *dst, 72 | int batch_sz, 73 | int channels, 74 | int height, 75 | int width, 76 | cudaStream_t stream) 77 | { 78 | int num_h_tiles = (height + 14 - 1) / 14; 79 | int num_w_tiles = (width + 14 - 1) / 14; 80 | int num_blocks = min(MAX_BLOCKS, batch_sz * channels * num_h_tiles * num_w_tiles); 81 | shiftnet_cuda_moduloshift3x3_nchw_float32_kernel_tilein16x16_tileout14x14<<>>( 82 | src, dst, num_h_tiles, num_w_tiles, batch_sz, channels, height, width); 83 | } 84 | 85 | __global__ void shiftnet_cuda_moduloshift3x3bwd_nchw_float32_kernel_tilein16x16_tileout14x14( 86 | float *src, 87 | float *dst, 88 | int num_h_tiles, 89 | int num_w_tiles, 90 | int batch_sz, 91 | int channels, 92 | int height, 93 | int width) 94 | { 95 | __shared__ float cache[256]; 96 | const int num_blocks = batch_sz * channels * num_h_tiles * num_w_tiles; 97 | const int num_threads = blockDim.x * num_blocks; 98 | const int rd_chans = (channels / 9) * 9; 99 | for (int idx = threadIdx.x + blockDim.x * blockIdx.x; 100 | idx < num_threads; idx += blockDim.x * gridDim.x) 101 | { 102 | const int w_tile_idx = (idx / 256) % num_w_tiles; 103 | const int h_tile_idx = ((idx / 256) / num_w_tiles) % num_h_tiles; 104 | const int tile_ch = (((idx / 256) / num_w_tiles) / num_h_tiles) % channels; 105 | const int tile_batch_idx = ((((idx / 256) / num_w_tiles) / num_h_tiles) / channels) % batch_sz; 106 | const int w_shift = (1 - (tile_ch % 3)) * (tile_ch < rd_chans); 107 | const int h_shift = (1 - ((tile_ch / 3) % 3)) * (tile_ch < rd_chans); 108 | const int w_tile_off = threadIdx.x % 16; 109 | const int h_tile_off = threadIdx.x / 16; 110 | const int w_idx = w_tile_off - 1 + 14 * w_tile_idx; 111 | const int h_idx = h_tile_off - 1 + 14 * h_tile_idx; 112 | const int buf_idx = w_idx + width * (h_idx + height * (tile_ch + channels * tile_batch_idx)); 113 | if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) { 114 | cache[threadIdx.x] = src[buf_idx]; 115 | } else { 116 | cache[threadIdx.x] = 0.0f; 117 | } 118 | __syncthreads(); 119 | if (w_tile_off >= 1 && w_tile_off < 15 && h_tile_off >= 1 && h_tile_off < 15) { 120 | if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) { 121 | const int cache_idx = (w_tile_off + w_shift) + 16 * (h_tile_off + h_shift); 122 | dst[buf_idx] = cache[cache_idx]; 123 | } 124 | } 125 | __syncthreads(); 126 | } 127 | } 128 | 129 | extern "C" void shiftnet_cuda_moduloshift3x3bwd_nchw_float32( 130 | float *src, 131 | float *dst, 132 | int batch_sz, 133 | int channels, 134 | int height, 135 | int width, 136 | cudaStream_t stream) 137 | { 138 | int num_h_tiles = (height + 14 - 1) / 14; 139 | int num_w_tiles = (width + 14 - 1) / 14; 140 | int num_blocks = min(MAX_BLOCKS, batch_sz * channels * num_h_tiles * num_w_tiles); 141 | shiftnet_cuda_moduloshift3x3bwd_nchw_float32_kernel_tilein16x16_tileout14x14<<>>( 142 | src, dst, num_h_tiles, num_w_tiles, batch_sz, channels, height, width); 143 | } 144 | 145 | __global__ void shiftnet_cuda_moduloshiftgeneric_nchw_float32_kernel_tilein16x16( 146 | float *src, 147 | float *dst, 148 | int num_h_tiles, 149 | int num_w_tiles, 150 | int batch_sz, 151 | int channels, 152 | int height, 153 | int width, 154 | int kernel_size, 155 | int dilate_factor, 156 | int direction) 157 | { 158 | __shared__ float cache[256]; 159 | const int num_blocks = batch_sz * channels * num_h_tiles * num_w_tiles; 160 | const int num_threads = blockDim.x * num_blocks; 161 | const int rd_chans = (channels / (kernel_size * kernel_size)) * (kernel_size * kernel_size); 162 | const int half_kernel_size = kernel_size / 2; 163 | const int dilated_half_kernel_size = dilate_factor * half_kernel_size; 164 | for (int idx = threadIdx.x + blockDim.x * blockIdx.x; 165 | idx < num_threads; idx += blockDim.x * gridDim.x) 166 | { 167 | const int w_tile_idx = (idx / 256) % num_w_tiles; 168 | const int h_tile_idx = ((idx / 256) / num_w_tiles) % num_h_tiles; 169 | const int tile_ch = (((idx / 256) / num_w_tiles) / num_h_tiles) % channels; 170 | const int tile_batch_idx = ((((idx / 256) / num_w_tiles) / num_h_tiles) / channels) % batch_sz; 171 | const int w_shift = direction * dilate_factor * (tile_ch < rd_chans) * ((tile_ch % kernel_size) - half_kernel_size); 172 | const int h_shift = direction * dilate_factor * (tile_ch < rd_chans) * (((tile_ch / kernel_size) % kernel_size) - half_kernel_size); 173 | const int w_tile_off = threadIdx.x % 16; 174 | const int h_tile_off = threadIdx.x / 16; 175 | const int w_idx = w_tile_off - dilated_half_kernel_size + (16 - 2 * dilated_half_kernel_size) * w_tile_idx; 176 | const int h_idx = h_tile_off - dilated_half_kernel_size + (16 - 2 * dilated_half_kernel_size) * h_tile_idx; 177 | const int buf_idx = w_idx + width * (h_idx + height * (tile_ch + channels * tile_batch_idx)); 178 | if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) { 179 | cache[threadIdx.x] = src[buf_idx]; 180 | } else { 181 | cache[threadIdx.x] = 0.0f; 182 | } 183 | __syncthreads(); 184 | if (w_tile_off >= dilated_half_kernel_size && 185 | w_tile_off < (16 - dilated_half_kernel_size) && 186 | h_tile_off >= dilated_half_kernel_size && 187 | h_tile_off < (16 - dilated_half_kernel_size)) 188 | { 189 | if (w_idx >= 0 && w_idx < width && h_idx >= 0 && h_idx < height) { 190 | const int cache_idx = (w_tile_off + w_shift) + 16 * (h_tile_off + h_shift); 191 | dst[buf_idx] = cache[cache_idx]; 192 | } 193 | } 194 | __syncthreads(); 195 | } 196 | } 197 | 198 | extern "C" void shiftnet_cuda_moduloshiftgeneric_nchw_float32( 199 | float *src, 200 | float *dst, 201 | int batch_sz, 202 | int channels, 203 | int height, 204 | int width, 205 | int kernel_size, 206 | int dilate_factor, 207 | int direction, 208 | cudaStream_t stream) 209 | { 210 | int dilated_half_kernel_size = dilate_factor * (kernel_size / 2); 211 | int tile_out_size = 16 - 2 * dilated_half_kernel_size; 212 | int num_h_tiles = (height + tile_out_size - 1) / tile_out_size; 213 | int num_w_tiles = (width + tile_out_size - 1) / tile_out_size; 214 | int num_blocks = min(MAX_BLOCKS, batch_sz * channels * num_h_tiles * num_w_tiles); 215 | shiftnet_cuda_moduloshiftgeneric_nchw_float32_kernel_tilein16x16<<>>( 216 | src, dst, num_h_tiles, num_w_tiles, batch_sz, channels, height, width, kernel_size, dilate_factor, direction); 217 | } 218 | -------------------------------------------------------------------------------- /test_shiftnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./") 3 | 4 | import shiftnet_cuda 5 | 6 | import numpy as np 7 | import torch 8 | import torch.cuda 9 | 10 | def main(): 11 | pattern = np.arange(18 * 18).reshape(18, 18) 12 | src_buf = np.zeros((32, 64, 18, 18)).astype(np.float32) 13 | for bnr in range(32): 14 | for ch in range(64): 15 | src_buf[bnr,ch,:,:] = pattern 16 | 17 | x_hin = torch.zeros(32, 64, 18, 18).type(torch.FloatTensor) 18 | #x_hin[:,:,1:4,1:4] = 1.0 19 | x_hin.copy_(torch.from_numpy(src_buf)) 20 | 21 | y_hin = torch.zeros(32, 64, 18, 18).type(torch.FloatTensor) 22 | 23 | x = x_hin.cuda() 24 | y = y_hin.cuda() 25 | 26 | #ret = shiftnet_cuda.moduloshift3x3_nchw(x, y) 27 | ret = shiftnet_cuda.moduloshiftgeneric_nchw(x, y, 7, 2, -1) 28 | assert ret == 1 29 | 30 | x_hout = x.cpu() 31 | y_hout = y.cpu() 32 | 33 | print(x_hout[0,0,:18,:18]) 34 | for ch in range(9): 35 | print(y_hout[0,ch,:18,:18]) 36 | 37 | if __name__ == "__main__": 38 | main() 39 | --------------------------------------------------------------------------------